feat: graph tool

Change-Id: Ic523afc41e82af04c578968cc2b535c4b76ce828
drew/english
shentong.martin 5 months ago committed by shentongmartin
parent f13f4f7555
commit 25b11b1a40

@ -0,0 +1,170 @@
package tool
import (
"context"
"fmt"
"reflect"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/components/tool/utils"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
type InvokableGraphTool[I, O any] struct {
graph compose.Graph[I, O]
compileOptions []compose.GraphCompileOption
tInfo *schema.ToolInfo
}
func NewInvokableGraphTool[I, O any](graph compose.Graph[I, O],
name, desc string,
opts ...compose.GraphCompileOption,
) (*InvokableGraphTool[I, O], error) {
tInfo, err := utils.GoStruct2ToolInfo[I](name, desc)
if err != nil {
return nil, err
}
return &InvokableGraphTool[I, O]{
graph: graph,
compileOptions: opts,
tInfo: tInfo,
}, nil
}
type graphToolOptions struct {
composeOpts []compose.Option
}
func WithGraphToolOption(opts ...compose.Option) tool.Option {
return tool.WrapImplSpecificOptFn(func(opt *graphToolOptions) {
opt.composeOpts = opts
})
}
func (g *InvokableGraphTool[I, O]) InvokableRun(ctx context.Context, input string,
opts ...tool.Option) (output string, err error) {
var (
checkpointStore *graphToolStore
inputParams I
originOutput O
runnable compose.Runnable[I, O]
)
compileOptions := make([]compose.GraphCompileOption, len(g.compileOptions)+1)
copy(compileOptions, g.compileOptions)
compileOptions[len(g.compileOptions)] = compose.WithCheckPointStore(checkpointStore)
callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts
callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID))
wasInterrupted, hasState, state := compose.GetInterruptState[[]byte](ctx)
if !wasInterrupted {
checkpointStore = newEmptyStore()
if runnable, err = g.graph.Compile(ctx, compileOptions...); err != nil {
return "", err
}
inputParams = NewInstance[I]()
if err = sonic.UnmarshalString(input, &inputParams); err != nil {
return "", err
}
} else {
if !hasState {
return "", fmt.Errorf("graph tool interrupt has happened, but cannot find interrupt state")
}
checkpointStore = newResumeStore(state)
if runnable, err = g.graph.Compile(ctx, compileOptions...); err != nil {
return "", err
}
}
originOutput, err = runnable.Invoke(ctx, inputParams, callOpts...)
if err != nil {
_, ok := compose.ExtractInterruptInfo(err)
if !ok {
return "", err
}
data, existed, err := checkpointStore.Get(ctx, graphToolCheckPointID)
if err != nil {
return "", err
}
if !existed {
return "", fmt.Errorf("interrupt has happened, but checkpoint not exist in store")
}
return "", compose.CompositeInterrupt(ctx, "graph tool interrupt", data,
err)
}
return sonic.MarshalString(originOutput)
}
func (g *InvokableGraphTool[I, O]) Info(_ context.Context) (*schema.ToolInfo, error) {
return g.tInfo, nil
}
const graphToolCheckPointID = "graph_tool_checkpoint_id"
func newEmptyStore() *graphToolStore {
return &graphToolStore{}
}
func newResumeStore(data []byte) *graphToolStore {
return &graphToolStore{
Data: data,
Valid: true,
}
}
type graphToolStore struct {
Data []byte
Valid bool
}
func (m *graphToolStore) Get(_ context.Context, _ string) ([]byte, bool, error) {
if m.Valid {
return m.Data, true, nil
}
return nil, false, nil
}
func (m *graphToolStore) Set(_ context.Context, _ string, checkPoint []byte) error {
m.Data = checkPoint
m.Valid = true
return nil
}
func NewInstance[T any]() T {
typ := TypeOf[T]()
switch typ.Kind() {
case reflect.Map:
return reflect.MakeMap(typ).Interface().(T)
case reflect.Slice, reflect.Array:
return reflect.MakeSlice(typ, 0, 0).Interface().(T)
case reflect.Ptr:
typ = typ.Elem()
origin := reflect.New(typ)
inst := origin
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
inst = inst.Elem()
inst.Set(reflect.New(typ))
}
return origin.Interface().(T)
default:
var t T
return t
}
}
func TypeOf[T any]() reflect.Type {
return reflect.TypeOf((*T)(nil)).Elem()
}

@ -23,8 +23,12 @@ import (
"log"
"os"
"strings"
"time"
clc "github.com/cloudwego/eino-ext/callbacks/cozeloop"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/callbacks"
"github.com/coze-dev/cozeloop-go"
"github.com/cloudwego/eino-examples/adk/common/prints"
"github.com/cloudwego/eino-examples/adk/common/store"
@ -33,8 +37,29 @@ import (
func main() {
ctx := context.Background()
cozeloopApiToken := os.Getenv("COZELOOP_API_TOKEN")
cozeloopWorkspaceID := os.Getenv("COZELOOP_WORKSPACE_ID") // use cozeloop trace, from https://loop.coze.cn/open/docs/cozeloop/go-sdk#4a8c980e
var handlers []callbacks.Handler
if cozeloopApiToken != "" && cozeloopWorkspaceID != "" {
client, err := cozeloop.NewClient(
cozeloop.WithAPIToken(cozeloopApiToken),
cozeloop.WithWorkspaceID(cozeloopWorkspaceID),
)
if err != nil {
panic(err)
}
defer func() {
time.Sleep(5 * time.Second)
client.Close(ctx)
}()
handlers = append(handlers, clc.NewLoopHandler(client))
}
callbacks.AppendGlobalHandlers(handlers...)
a := NewTicketBookingAgent()
runner := adk.NewRunner(ctx, adk.RunnerConfig{
runner := adk.NewRunner(context.Background(), adk.RunnerConfig{
EnableStreaming: true, // you can disable streaming here
Agent: a,
@ -43,7 +68,7 @@ func main() {
// In the real world, you can use a distributed store like Redis to persist the checkpoints.
CheckPointStore: store.NewInMemoryStore(),
})
iter := runner.Query(ctx, "book a ticket for Martin, to Beijing, on 2025-12-01, the phone number is 1234567. directly call tool.", adk.WithCheckPointID("1"))
iter := runner.Query(context.Background(), "book a ticket for Martin, to Beijing, on 2025-12-01, the phone number is 1234567. directly call tool.", adk.WithCheckPointID("1"))
var lastEvent *adk.AgentEvent
for {
@ -98,7 +123,7 @@ func main() {
// In the real world, the original `Runner.Run/Query` and the subsequent `Runner.ResumeWithParams`
// can happen in different processes or machines, as long as you use the same `CheckPointID`,
// and you provided a distributed `CheckPointStore` when creating the `Runner` instance.
iter, err := runner.ResumeWithParams(ctx, "1", &adk.ResumeParams{
iter, err := runner.ResumeWithParams(context.Background(), "1", &adk.ResumeParams{
Targets: map[string]any{
interruptID: apResult,
},

Loading…
Cancel
Save