From 25b11b1a4019aee17524793e7cb4f1825d626e4f Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Fri, 5 Dec 2025 14:01:50 +0800 Subject: [PATCH] feat: graph tool Change-Id: Ic523afc41e82af04c578968cc2b535c4b76ce828 --- adk/common/tool/graph_tool.go | 170 +++++++++++++++++++++++ adk/human-in-the-loop/1_approval/main.go | 31 ++++- 2 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 adk/common/tool/graph_tool.go diff --git a/adk/common/tool/graph_tool.go b/adk/common/tool/graph_tool.go new file mode 100644 index 0000000..bd79690 --- /dev/null +++ b/adk/common/tool/graph_tool.go @@ -0,0 +1,170 @@ +package tool + +import ( + "context" + "fmt" + "reflect" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/components/tool/utils" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type InvokableGraphTool[I, O any] struct { + graph compose.Graph[I, O] + compileOptions []compose.GraphCompileOption + tInfo *schema.ToolInfo +} + +func NewInvokableGraphTool[I, O any](graph compose.Graph[I, O], + name, desc string, + opts ...compose.GraphCompileOption, +) (*InvokableGraphTool[I, O], error) { + tInfo, err := utils.GoStruct2ToolInfo[I](name, desc) + if err != nil { + return nil, err + } + + return &InvokableGraphTool[I, O]{ + graph: graph, + compileOptions: opts, + tInfo: tInfo, + }, nil +} + +type graphToolOptions struct { + composeOpts []compose.Option +} + +func WithGraphToolOption(opts ...compose.Option) tool.Option { + return tool.WrapImplSpecificOptFn(func(opt *graphToolOptions) { + opt.composeOpts = opts + }) +} + +func (g *InvokableGraphTool[I, O]) InvokableRun(ctx context.Context, input string, + opts ...tool.Option) (output string, err error) { + var ( + checkpointStore *graphToolStore + inputParams I + originOutput O + runnable compose.Runnable[I, O] + ) + + compileOptions := make([]compose.GraphCompileOption, len(g.compileOptions)+1) + copy(compileOptions, g.compileOptions) + compileOptions[len(g.compileOptions)] = compose.WithCheckPointStore(checkpointStore) + + callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts + callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID)) + + wasInterrupted, hasState, state := compose.GetInterruptState[[]byte](ctx) + if !wasInterrupted { + checkpointStore = newEmptyStore() + + if runnable, err = g.graph.Compile(ctx, compileOptions...); err != nil { + return "", err + } + + inputParams = NewInstance[I]() + if err = sonic.UnmarshalString(input, &inputParams); err != nil { + return "", err + } + } else { + if !hasState { + return "", fmt.Errorf("graph tool interrupt has happened, but cannot find interrupt state") + } + + checkpointStore = newResumeStore(state) + if runnable, err = g.graph.Compile(ctx, compileOptions...); err != nil { + return "", err + } + } + + originOutput, err = runnable.Invoke(ctx, inputParams, callOpts...) + if err != nil { + _, ok := compose.ExtractInterruptInfo(err) + if !ok { + return "", err + } + data, existed, err := checkpointStore.Get(ctx, graphToolCheckPointID) + if err != nil { + return "", err + } + if !existed { + return "", fmt.Errorf("interrupt has happened, but checkpoint not exist in store") + } + + return "", compose.CompositeInterrupt(ctx, "graph tool interrupt", data, + err) + } + + return sonic.MarshalString(originOutput) +} + +func (g *InvokableGraphTool[I, O]) Info(_ context.Context) (*schema.ToolInfo, error) { + return g.tInfo, nil +} + +const graphToolCheckPointID = "graph_tool_checkpoint_id" + +func newEmptyStore() *graphToolStore { + return &graphToolStore{} +} + +func newResumeStore(data []byte) *graphToolStore { + return &graphToolStore{ + Data: data, + Valid: true, + } +} + +type graphToolStore struct { + Data []byte + Valid bool +} + +func (m *graphToolStore) Get(_ context.Context, _ string) ([]byte, bool, error) { + if m.Valid { + return m.Data, true, nil + } + return nil, false, nil +} + +func (m *graphToolStore) Set(_ context.Context, _ string, checkPoint []byte) error { + m.Data = checkPoint + m.Valid = true + return nil +} + +func NewInstance[T any]() T { + typ := TypeOf[T]() + + switch typ.Kind() { + case reflect.Map: + return reflect.MakeMap(typ).Interface().(T) + case reflect.Slice, reflect.Array: + return reflect.MakeSlice(typ, 0, 0).Interface().(T) + case reflect.Ptr: + typ = typ.Elem() + origin := reflect.New(typ) + inst := origin + + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + inst = inst.Elem() + inst.Set(reflect.New(typ)) + } + + return origin.Interface().(T) + default: + var t T + return t + } +} + +func TypeOf[T any]() reflect.Type { + return reflect.TypeOf((*T)(nil)).Elem() +} diff --git a/adk/human-in-the-loop/1_approval/main.go b/adk/human-in-the-loop/1_approval/main.go index 44997be..1668920 100644 --- a/adk/human-in-the-loop/1_approval/main.go +++ b/adk/human-in-the-loop/1_approval/main.go @@ -23,8 +23,12 @@ import ( "log" "os" "strings" + "time" + clc "github.com/cloudwego/eino-ext/callbacks/cozeloop" "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/callbacks" + "github.com/coze-dev/cozeloop-go" "github.com/cloudwego/eino-examples/adk/common/prints" "github.com/cloudwego/eino-examples/adk/common/store" @@ -33,8 +37,29 @@ import ( func main() { ctx := context.Background() + + cozeloopApiToken := os.Getenv("COZELOOP_API_TOKEN") + cozeloopWorkspaceID := os.Getenv("COZELOOP_WORKSPACE_ID") // use cozeloop trace, from https://loop.coze.cn/open/docs/cozeloop/go-sdk#4a8c980e + + var handlers []callbacks.Handler + if cozeloopApiToken != "" && cozeloopWorkspaceID != "" { + client, err := cozeloop.NewClient( + cozeloop.WithAPIToken(cozeloopApiToken), + cozeloop.WithWorkspaceID(cozeloopWorkspaceID), + ) + if err != nil { + panic(err) + } + defer func() { + time.Sleep(5 * time.Second) + client.Close(ctx) + }() + handlers = append(handlers, clc.NewLoopHandler(client)) + } + callbacks.AppendGlobalHandlers(handlers...) + a := NewTicketBookingAgent() - runner := adk.NewRunner(ctx, adk.RunnerConfig{ + runner := adk.NewRunner(context.Background(), adk.RunnerConfig{ EnableStreaming: true, // you can disable streaming here Agent: a, @@ -43,7 +68,7 @@ func main() { // In the real world, you can use a distributed store like Redis to persist the checkpoints. CheckPointStore: store.NewInMemoryStore(), }) - iter := runner.Query(ctx, "book a ticket for Martin, to Beijing, on 2025-12-01, the phone number is 1234567. directly call tool.", adk.WithCheckPointID("1")) + iter := runner.Query(context.Background(), "book a ticket for Martin, to Beijing, on 2025-12-01, the phone number is 1234567. directly call tool.", adk.WithCheckPointID("1")) var lastEvent *adk.AgentEvent for { @@ -98,7 +123,7 @@ func main() { // In the real world, the original `Runner.Run/Query` and the subsequent `Runner.ResumeWithParams` // can happen in different processes or machines, as long as you use the same `CheckPointID`, // and you provided a distributed `CheckPointStore` when creating the `Runner` instance. - iter, err := runner.ResumeWithParams(ctx, "1", &adk.ResumeParams{ + iter, err := runner.ResumeWithParams(context.Background(), "1", &adk.ResumeParams{ Targets: map[string]any{ interruptID: apResult, },