You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
283 lines
8.7 KiB
Go
283 lines
8.7 KiB
Go
/*
|
|
* Copyright 2025 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/eino/adk"
|
|
"github.com/cloudwego/eino/components/tool"
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
"github.com/cloudwego/eino-examples/adk/common/model"
|
|
"github.com/cloudwego/eino-examples/adk/common/prints"
|
|
"github.com/cloudwego/eino-examples/adk/common/store"
|
|
tool2 "github.com/cloudwego/eino-examples/adk/common/tool"
|
|
"github.com/cloudwego/eino-examples/adk/common/tool/graphtool"
|
|
)
|
|
|
|
type TransferInput struct {
|
|
FromAccount string `json:"from_account" jsonschema_description:"Source account ID"`
|
|
ToAccount string `json:"to_account" jsonschema_description:"Destination account ID"`
|
|
Amount float64 `json:"amount" jsonschema_description:"Amount to transfer"`
|
|
}
|
|
|
|
type TransferOutput struct {
|
|
TransactionID string `json:"transaction_id"`
|
|
Status string `json:"status"`
|
|
Message string `json:"message"`
|
|
FromBalance float64 `json:"from_balance"`
|
|
ToBalance float64 `json:"to_balance"`
|
|
}
|
|
|
|
type InternalApprovalInfo struct {
|
|
Step string
|
|
Message string
|
|
}
|
|
|
|
func (ai *InternalApprovalInfo) String() string {
|
|
return fmt.Sprintf("\n[INTERNAL WORKFLOW APPROVAL]\nStep: %s\nMessage: %s\nApprove? (Y/N):", ai.Step, ai.Message)
|
|
}
|
|
|
|
type InternalApprovalResult struct {
|
|
Approved bool
|
|
Comment string
|
|
}
|
|
|
|
func init() {
|
|
schema.Register[*InternalApprovalInfo]()
|
|
schema.Register[*InternalApprovalResult]()
|
|
schema.Register[*validationResult]()
|
|
}
|
|
|
|
type validationResult struct {
|
|
Valid bool
|
|
FromAccount string
|
|
ToAccount string
|
|
Amount float64
|
|
}
|
|
|
|
func NewTransferToolWithInternalInterrupt(ctx context.Context) (tool.InvokableTool, error) {
|
|
workflow := compose.NewWorkflow[*TransferInput, *TransferOutput]()
|
|
|
|
workflow.AddLambdaNode("validate", compose.InvokableLambda(func(ctx context.Context, input *TransferInput) (*validationResult, error) {
|
|
fmt.Println(" [Workflow] Validating transfer...")
|
|
if input.Amount <= 0 {
|
|
return &validationResult{Valid: false}, nil
|
|
}
|
|
return &validationResult{
|
|
Valid: true,
|
|
FromAccount: input.FromAccount,
|
|
ToAccount: input.ToAccount,
|
|
Amount: input.Amount,
|
|
}, nil
|
|
})).AddInput(compose.START)
|
|
|
|
workflow.AddLambdaNode("risk_check_and_execute", compose.InvokableLambda(func(ctx context.Context, validation *validationResult) (*TransferOutput, error) {
|
|
wasInterrupted, _, storedValidation := compose.GetInterruptState[*validationResult](ctx)
|
|
|
|
if wasInterrupted {
|
|
fmt.Println(" [Workflow] Resuming from interrupt...")
|
|
isTarget, hasData, data := compose.GetResumeContext[*InternalApprovalResult](ctx)
|
|
|
|
if !isTarget {
|
|
fmt.Println(" [Workflow] Not resume target, re-interrupting...")
|
|
return nil, compose.StatefulInterrupt(ctx, &InternalApprovalInfo{
|
|
Step: "risk_check",
|
|
Message: fmt.Sprintf("High-value transfer of $%.2f requires risk team approval", storedValidation.Amount),
|
|
}, storedValidation)
|
|
}
|
|
|
|
if !hasData {
|
|
return nil, fmt.Errorf("resumed without approval data")
|
|
}
|
|
|
|
if !data.Approved {
|
|
return &TransferOutput{
|
|
Status: "rejected",
|
|
Message: fmt.Sprintf("Transfer rejected by risk team: %s", data.Comment),
|
|
}, nil
|
|
}
|
|
|
|
fmt.Printf(" [Workflow] Risk team approved with comment: %s\n", data.Comment)
|
|
fmt.Println(" [Workflow] Executing transfer...")
|
|
return &TransferOutput{
|
|
TransactionID: "TXN-12345",
|
|
Status: "completed",
|
|
Message: fmt.Sprintf("Transfer of $%.2f completed (risk approved)", storedValidation.Amount),
|
|
FromBalance: 10000 - storedValidation.Amount,
|
|
ToBalance: 5000 + storedValidation.Amount,
|
|
}, nil
|
|
}
|
|
|
|
if !validation.Valid {
|
|
return &TransferOutput{
|
|
Status: "rejected",
|
|
Message: "Invalid transfer: validation failed",
|
|
}, nil
|
|
}
|
|
|
|
fmt.Println(" [Workflow] Performing risk check...")
|
|
if validation.Amount > 1000 {
|
|
fmt.Println(" [Workflow] High-value transfer detected, triggering INTERNAL interrupt...")
|
|
return nil, compose.StatefulInterrupt(ctx, &InternalApprovalInfo{
|
|
Step: "risk_check",
|
|
Message: fmt.Sprintf("High-value transfer of $%.2f requires risk team approval", validation.Amount),
|
|
}, validation)
|
|
}
|
|
|
|
fmt.Println(" [Workflow] Low-value transfer, executing directly...")
|
|
return &TransferOutput{
|
|
TransactionID: "TXN-12345",
|
|
Status: "completed",
|
|
Message: fmt.Sprintf("Transfer of $%.2f completed", validation.Amount),
|
|
FromBalance: 10000 - validation.Amount,
|
|
ToBalance: 5000 + validation.Amount,
|
|
}, nil
|
|
})).AddInput("validate")
|
|
|
|
workflow.End().AddInput("risk_check_and_execute")
|
|
|
|
return graphtool.NewInvokableGraphTool[*TransferInput, *TransferOutput](
|
|
workflow,
|
|
"transfer_funds",
|
|
"Transfer funds between accounts. High-value transfers (>$1000) require internal risk approval.",
|
|
)
|
|
}
|
|
|
|
func main() {
|
|
ctx := context.Background()
|
|
|
|
innerTool, err := NewTransferToolWithInternalInterrupt(ctx)
|
|
if err != nil {
|
|
log.Fatalf("failed to create transfer tool: %v", err)
|
|
}
|
|
|
|
transferTool := tool2.InvokableApprovableTool{InvokableTool: innerTool}
|
|
|
|
agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
|
Name: "TransferAssistant",
|
|
Description: "An assistant that can transfer funds between accounts",
|
|
Instruction: `You are a helpful banking assistant.
|
|
When the user wants to transfer funds, IMMEDIATELY use the transfer_funds tool without asking for confirmation.
|
|
All transfers require initial approval. High-value transfers (>$1000) also require internal risk team approval.`,
|
|
Model: model.NewChatModel(),
|
|
ToolsConfig: adk.ToolsConfig{
|
|
ToolsNodeConfig: compose.ToolsNodeConfig{
|
|
Tools: []tool.BaseTool{transferTool},
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("failed to create agent: %v", err)
|
|
}
|
|
|
|
checkpointStore := store.NewInMemoryStore()
|
|
runner := adk.NewRunner(ctx, adk.RunnerConfig{
|
|
EnableStreaming: true,
|
|
Agent: agent,
|
|
CheckPointStore: checkpointStore,
|
|
})
|
|
|
|
query := "Transfer $1500 from account A001 to account B002"
|
|
|
|
fmt.Println("=== Nested Interrupt Test ===")
|
|
fmt.Println()
|
|
fmt.Println("This example tests:")
|
|
fmt.Println("1. InvokableApprovableTool wraps InvokableGraphTool")
|
|
fmt.Println("2. The inner workflow has its own interrupt (risk check)")
|
|
fmt.Println("3. Both interrupts should work independently")
|
|
fmt.Println()
|
|
fmt.Printf("User Query: %s\n\n", query)
|
|
|
|
checkpointID := "nested-interrupt-test"
|
|
iter := runner.Query(ctx, query, adk.WithCheckPointID(checkpointID))
|
|
|
|
interruptCount := 0
|
|
for {
|
|
var lastEvent *adk.AgentEvent
|
|
for {
|
|
event, ok := iter.Next()
|
|
if !ok {
|
|
break
|
|
}
|
|
if event.Err != nil {
|
|
log.Fatalf("error: %v", event.Err)
|
|
}
|
|
prints.Event(event)
|
|
lastEvent = event
|
|
}
|
|
|
|
if lastEvent == nil {
|
|
break
|
|
}
|
|
|
|
if lastEvent.Action != nil && lastEvent.Action.Interrupted != nil {
|
|
interruptCount++
|
|
fmt.Printf("\n--- Interrupt #%d detected ---\n", interruptCount)
|
|
|
|
interruptID := lastEvent.Action.Interrupted.InterruptContexts[0].ID
|
|
fmt.Printf("Interrupt ID: %s\n, Address: %v\n, Info: %v\n", interruptID,
|
|
lastEvent.Action.Interrupted.InterruptContexts[0].Address,
|
|
lastEvent.Action.Interrupted.InterruptContexts[0].Info)
|
|
|
|
scanner := bufio.NewScanner(os.Stdin)
|
|
fmt.Print("\nYour decision (Y/N): ")
|
|
scanner.Scan()
|
|
input := strings.TrimSpace(scanner.Text())
|
|
|
|
var resumeData any
|
|
if strings.ToUpper(input) == "Y" {
|
|
if interruptCount == 1 {
|
|
resumeData = &tool2.ApprovalResult{Approved: true}
|
|
} else {
|
|
resumeData = &InternalApprovalResult{Approved: true, Comment: "Risk approved by manager"}
|
|
}
|
|
} else {
|
|
if interruptCount == 1 {
|
|
reason := "User rejected"
|
|
resumeData = &tool2.ApprovalResult{Approved: false, DisapproveReason: &reason}
|
|
} else {
|
|
resumeData = &InternalApprovalResult{Approved: false, Comment: "Risk team rejected"}
|
|
}
|
|
}
|
|
|
|
fmt.Printf("\n--- Resuming (interrupt #%d) ---\n\n", interruptCount)
|
|
|
|
iter, err = runner.ResumeWithParams(ctx, checkpointID, &adk.ResumeParams{
|
|
Targets: map[string]any{
|
|
interruptID: resumeData,
|
|
},
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("failed to resume: %v", err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
fmt.Printf("\n=== Test Complete (Total interrupts: %d) ===\n", interruptCount)
|
|
}
|