refactor: use tool package interrupt APIs instead of compose package

- approval_wrapper.go: use tool.GetInterruptState, tool.GetResumeContext, tool.StatefulInterrupt
- graphtool/graph_tool.go: use tool.GetInterruptState, tool.CompositeInterrupt
- Keep compose.ExtractInterruptInfo for checking graph execution errors
- Remove ToolCallID from ApprovalInfo (compose.GetToolCallID no longer used)

Change-Id: I7c9a7f73a0e0036ab384478d1ea6770ea0031c4b
drew/english
shentong.martin 4 months ago committed by shentongmartin
parent f2f7dbb918
commit f7ed18dd7e

@ -21,14 +21,12 @@ import (
"fmt" "fmt"
"github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
) )
type ApprovalInfo struct { type ApprovalInfo struct {
ToolName string ToolName string
ArgumentsInJSON string ArgumentsInJSON string
ToolCallID string
} }
type ApprovalResult struct { type ApprovalResult struct {
@ -62,16 +60,15 @@ func (i InvokableApprovableTool) InvokableRun(ctx context.Context, argumentsInJS
return "", err return "", err
} }
wasInterrupted, _, storedArguments := compose.GetInterruptState[string](ctx) wasInterrupted, _, storedArguments := tool.GetInterruptState[string](ctx)
if !wasInterrupted { // initial invocation, interrupt and wait for approval if !wasInterrupted {
return "", compose.StatefulInterrupt(ctx, &ApprovalInfo{ return "", tool.StatefulInterrupt(ctx, &ApprovalInfo{
ToolName: toolInfo.Name, ToolName: toolInfo.Name,
ArgumentsInJSON: argumentsInJSON, ArgumentsInJSON: argumentsInJSON,
ToolCallID: compose.GetToolCallID(ctx),
}, argumentsInJSON) }, argumentsInJSON)
} }
isResumeTarget, hasData, data := compose.GetResumeContext[*ApprovalResult](ctx) isResumeTarget, hasData, data := tool.GetResumeContext[*ApprovalResult](ctx)
if isResumeTarget && hasData { if isResumeTarget && hasData {
if data.Approved { if data.Approved {
return i.InvokableTool.InvokableRun(ctx, storedArguments, opts...) return i.InvokableTool.InvokableRun(ctx, storedArguments, opts...)
@ -84,12 +81,11 @@ func (i InvokableApprovableTool) InvokableRun(ctx context.Context, argumentsInJS
return fmt.Sprintf("tool '%s' disapproved", toolInfo.Name), nil return fmt.Sprintf("tool '%s' disapproved", toolInfo.Name), nil
} }
isResumeTarget, _, _ = compose.GetResumeContext[any](ctx) isResumeTarget, _, _ = tool.GetResumeContext[any](ctx)
if !isResumeTarget { if !isResumeTarget {
return "", compose.StatefulInterrupt(ctx, &ApprovalInfo{ return "", tool.StatefulInterrupt(ctx, &ApprovalInfo{
ToolName: toolInfo.Name, ToolName: toolInfo.Name,
ArgumentsInJSON: storedArguments, ArgumentsInJSON: storedArguments,
ToolCallID: compose.GetToolCallID(ctx),
}, storedArguments) }, storedArguments)
} }

@ -86,7 +86,7 @@ func (g *InvokableGraphTool[I, O]) InvokableRun(ctx context.Context, input strin
callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts
callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID)) callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID))
wasInterrupted, hasState, state := compose.GetInterruptState[*graphToolInterruptState](ctx) wasInterrupted, hasState, state := tool.GetInterruptState[*graphToolInterruptState](ctx)
if wasInterrupted && hasState { if wasInterrupted && hasState {
input = state.ToolInput input = state.ToolInput
@ -130,7 +130,7 @@ func (g *InvokableGraphTool[I, O]) InvokableRun(ctx context.Context, input strin
return "", fmt.Errorf("interrupt has happened, but checkpoint not exist in store") return "", fmt.Errorf("interrupt has happened, but checkpoint not exist in store")
} }
return "", compose.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ return "", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{
Data: data, Data: data,
ToolInput: input, ToolInput: input,
}, interruptErr) }, interruptErr)
@ -181,7 +181,7 @@ func (g *StreamableGraphTool[I, O]) StreamableRun(ctx context.Context, input str
callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts
callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID)) callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID))
wasInterrupted, hasState, state := compose.GetInterruptState[*graphToolInterruptState](ctx) wasInterrupted, hasState, state := tool.GetInterruptState[*graphToolInterruptState](ctx)
if wasInterrupted && hasState { if wasInterrupted && hasState {
input = state.ToolInput input = state.ToolInput
@ -233,7 +233,7 @@ func (g *StreamableGraphTool[I, O]) StreamableRun(ctx context.Context, input str
return return
} }
sw.Send("", compose.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ sw.Send("", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{
Data: data, Data: data,
ToolInput: input, ToolInput: input,
}, interruptErr)) }, interruptErr))
@ -264,7 +264,7 @@ func (g *StreamableGraphTool[I, O]) StreamableRun(ctx context.Context, input str
return return
} }
sw.Send("", compose.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{ sw.Send("", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{
Data: data, Data: data,
ToolInput: input, ToolInput: input,
}, interruptErr)) }, interruptErr))

Loading…
Cancel
Save