feat: add interrupt

drew/english
Megumin 1 year ago
parent df73bb95e2
commit 7fc033fd1f

@ -0,0 +1,12 @@
This example assumes a ticket booking scenario to demonstrate Eino interrupt and checkpoint practices: The agent receives the user's input such as name and destination, then calls the booking tool to assist the user in booking tickets. Before the tool executes, users can confirm if the information is correct. If it is incorrect, they can modify or complete the information, and then resume the agent's operation.
<img alt="topology of agent" src="topology.png" title="topology of agent" width="300" height="300"/>
The trace of a single call is as follows:
```
will call tool: BookTicket, arguments: {"location":"Beijing","passenger_name":"Megumin","passenger_phone_number":""}
Are the arguments as expected? (y/n): n
Please enter the modified arguments: {"location":"Beijing","passenger_name":"Megumin","passenger_phone_number":"1234567890"}
Updated arguments to: {"location":"Beijing","passenger_name":"Megumin","passenger_phone_number":"1234567890"}
final result: Your ticket to Beijing has been successfully booked, Megumin! Safe travels!
```

@ -0,0 +1,235 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"bufio"
"context"
"fmt"
"log"
"os"
"strings"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/components/tool/utils"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
func main() {
compose.RegisterSerializableType[myState]("state")
ctx := context.Background()
runner, err := composeGraph[map[string]any, *schema.Message](
ctx,
newChatTemplate(ctx),
newChatModel(ctx),
newToolsNode(ctx),
newCheckPointStore(ctx),
)
if err != nil {
log.Fatal(err)
}
var history []*schema.Message
for {
result, err := runner.Invoke(ctx, map[string]any{"name": "Megumin", "location": "Beijing"}, compose.WithCheckPointID("1"), compose.WithStateModifier(func(ctx context.Context, path compose.NodePath, state any) error {
state.(*myState).history = history
return nil
}))
if err == nil {
fmt.Printf("final result: %s", result.Content)
break
}
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
log.Fatal(err)
}
history = info.State.(*myState).history
for i, tc := range history[len(history)-1].ToolCalls {
fmt.Printf("will call tool: %s, arguments: %s\n", tc.Function.Name, tc.Function.Arguments)
fmt.Print("Are the arguments as expected? (y/n): ")
var response string
fmt.Scanln(&response)
if strings.ToLower(response) == "n" {
fmt.Print("Please enter the modified arguments: ")
scanner := bufio.NewScanner(os.Stdin)
var newArguments string
if scanner.Scan() {
newArguments = scanner.Text()
}
// Update the tool call arguments
history[len(history)-1].ToolCalls[i].Function.Arguments = newArguments
fmt.Printf("Updated arguments to: %s\n", newArguments)
}
}
}
}
func newChatTemplate(_ context.Context) prompt.ChatTemplate {
return prompt.FromMessages(schema.FString,
schema.SystemMessage("You are a helpful assistant. If the user asks about the booking, call the \"BookTicket\" tool to book ticket."),
schema.UserMessage("I'm {name}. Help me book a ticket to {location}"),
)
}
func newChatModel(ctx context.Context) model.ChatModel {
cm, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
APIKey: os.Getenv("OPENAI_API_KEY"),
Model: os.Getenv("OPENAI_MODEL"),
BaseURL: os.Getenv("OPENAI_BASE_URL"),
})
if err != nil {
log.Fatal(err)
}
tools := getTools()
var toolsInfo []*schema.ToolInfo
for _, t := range tools {
info, err := t.Info(ctx)
if err != nil {
log.Fatal(err)
}
toolsInfo = append(toolsInfo, info)
}
err = cm.BindTools(toolsInfo)
if err != nil {
log.Fatal(err)
}
return cm
}
type bookInput struct {
Location string `json:"location"`
PassengerName string `json:"passenger_name"`
PassengerPhoneNumber string `json:"passenger_phone_number"`
}
func newToolsNode(ctx context.Context) *compose.ToolsNode {
tools := getTools()
tn, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{Tools: tools})
if err != nil {
log.Fatal(err)
}
return tn
}
func newCheckPointStore(ctx context.Context) compose.CheckPointStore {
return &myStore{buf: make(map[string][]byte)}
}
type myState struct {
history []*schema.Message
}
func composeGraph[I, O any](ctx context.Context, tpl prompt.ChatTemplate, cm model.ChatModel, tn *compose.ToolsNode, store compose.CheckPointStore) (compose.Runnable[I, O], error) {
g := compose.NewGraph[I, O](compose.WithGenLocalState(func(ctx context.Context) *myState {
return &myState{}
}))
err := g.AddChatTemplateNode(
"ChatTemplate",
tpl,
)
if err != nil {
return nil, err
}
err = g.AddChatModelNode(
"ChatModel",
cm,
compose.WithStatePreHandler(func(ctx context.Context, in []*schema.Message, state *myState) ([]*schema.Message, error) {
state.history = append(state.history, in...)
return state.history, nil
}),
compose.WithStatePostHandler(func(ctx context.Context, out *schema.Message, state *myState) (*schema.Message, error) {
state.history = append(state.history, out)
return out, nil
}),
)
if err != nil {
return nil, err
}
err = g.AddToolsNode("ToolsNode", tn, compose.WithStatePreHandler(func(ctx context.Context, in *schema.Message, state *myState) (*schema.Message, error) {
return state.history[len(state.history)-1], nil
}))
if err != nil {
return nil, err
}
err = g.AddEdge(compose.START, "ChatTemplate")
if err != nil {
return nil, err
}
err = g.AddEdge("ChatTemplate", "ChatModel")
if err != nil {
return nil, err
}
err = g.AddEdge("ToolsNode", "ChatModel")
if err != nil {
return nil, err
}
err = g.AddBranch("ChatModel", compose.NewGraphBranch(func(ctx context.Context, in *schema.Message) (endNode string, err error) {
if len(in.ToolCalls) > 0 {
return "ToolsNode", nil
}
return compose.END, nil
}, map[string]bool{"ToolsNode": true, compose.END: true}))
if err != nil {
return nil, err
}
return g.Compile(
ctx,
compose.WithCheckPointStore(store),
compose.WithInterruptBeforeNodes([]string{"ToolsNode"}),
)
}
func getTools() []tool.BaseTool {
getWeather, err := utils.InferTool("BookTicket", "this tool can book ticket of the specific location", func(ctx context.Context, input bookInput) (output string, err error) {
return "success", nil
})
if err != nil {
log.Fatal(err)
}
return []tool.BaseTool{
getWeather,
}
}
type myStore struct {
buf map[string][]byte
}
func (m *myStore) Get(ctx context.Context, checkPointID string) ([]byte, bool, error) {
data, ok := m.buf[checkPointID]
return data, ok, nil
}
func (m *myStore) Set(ctx context.Context, checkPointID string, checkPoint []byte) error {
m.buf[checkPointID] = checkPoint
return nil
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

Loading…
Cancel
Save