diff --git a/adk/common/tool/approval_wrapper.go b/adk/common/tool/approval_wrapper.go index 626714b..08d50a0 100644 --- a/adk/common/tool/approval_wrapper.go +++ b/adk/common/tool/approval_wrapper.go @@ -21,14 +21,12 @@ import ( "fmt" "github.com/cloudwego/eino/components/tool" - "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) type ApprovalInfo struct { ToolName string ArgumentsInJSON string - ToolCallID string } type ApprovalResult struct { @@ -62,16 +60,15 @@ func (i InvokableApprovableTool) InvokableRun(ctx context.Context, argumentsInJS return "", err } - wasInterrupted, _, storedArguments := compose.GetInterruptState[string](ctx) - if !wasInterrupted { // initial invocation, interrupt and wait for approval - return "", compose.StatefulInterrupt(ctx, &ApprovalInfo{ + wasInterrupted, _, storedArguments := tool.GetInterruptState[string](ctx) + if !wasInterrupted { + return "", tool.StatefulInterrupt(ctx, &ApprovalInfo{ ToolName: toolInfo.Name, ArgumentsInJSON: argumentsInJSON, - ToolCallID: compose.GetToolCallID(ctx), }, argumentsInJSON) } - isResumeTarget, hasData, data := compose.GetResumeContext[*ApprovalResult](ctx) + isResumeTarget, hasData, data := tool.GetResumeContext[*ApprovalResult](ctx) if isResumeTarget && hasData { if data.Approved { return i.InvokableTool.InvokableRun(ctx, storedArguments, opts...) @@ -84,12 +81,11 @@ func (i InvokableApprovableTool) InvokableRun(ctx context.Context, argumentsInJS return fmt.Sprintf("tool '%s' disapproved", toolInfo.Name), nil } - isResumeTarget, _, _ = compose.GetResumeContext[any](ctx) + isResumeTarget, _, _ = tool.GetResumeContext[any](ctx) if !isResumeTarget { - return "", compose.StatefulInterrupt(ctx, &ApprovalInfo{ + return "", tool.StatefulInterrupt(ctx, &ApprovalInfo{ ToolName: toolInfo.Name, ArgumentsInJSON: storedArguments, - ToolCallID: compose.GetToolCallID(ctx), }, storedArguments) } diff --git a/adk/common/tool/graphtool/graph_tool.go b/adk/common/tool/graphtool/graph_tool.go index 707761a..e780abb 100644 --- a/adk/common/tool/graphtool/graph_tool.go +++ b/adk/common/tool/graphtool/graph_tool.go @@ -86,7 +86,7 @@ func (g *InvokableGraphTool[I, O]) InvokableRun(ctx context.Context, input strin callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID)) - wasInterrupted, hasState, state := compose.GetInterruptState[*graphToolInterruptState](ctx) + wasInterrupted, hasState, state := tool.GetInterruptState[*graphToolInterruptState](ctx) if wasInterrupted && hasState { input = state.ToolInput @@ -130,7 +130,7 @@ func (g *InvokableGraphTool[I, O]) InvokableRun(ctx context.Context, input strin return "", fmt.Errorf("interrupt has happened, but checkpoint not exist in store") } - return "", compose.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ + return "", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ Data: data, ToolInput: input, }, interruptErr) @@ -181,7 +181,7 @@ func (g *StreamableGraphTool[I, O]) StreamableRun(ctx context.Context, input str callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID)) - wasInterrupted, hasState, state := compose.GetInterruptState[*graphToolInterruptState](ctx) + wasInterrupted, hasState, state := tool.GetInterruptState[*graphToolInterruptState](ctx) if wasInterrupted && hasState { input = state.ToolInput @@ -233,7 +233,7 @@ func (g *StreamableGraphTool[I, O]) StreamableRun(ctx context.Context, input str return } - sw.Send("", compose.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ + sw.Send("", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ Data: data, ToolInput: input, }, interruptErr)) @@ -264,7 +264,7 @@ func (g *StreamableGraphTool[I, O]) StreamableRun(ctx context.Context, input str return } - sw.Send("", compose.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ + sw.Send("", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ Data: data, ToolInput: input, }, interruptErr))