feat(flow/react): add short-term memory example with MessageFuture and tool calls

Change-Id: Iff0920fa888b74ee40770ea4e59126bbc01949c5
drew/english
shentong.martin 5 months ago committed by shentongmartin
parent 60eb90bce2
commit f13f4f7555

@ -5,7 +5,7 @@ go 1.24.9
replace github.com/cloudwego/eino-examples => ../../.. replace github.com/cloudwego/eino-examples => ../../..
require ( require (
github.com/cloudwego/eino v0.7.14-0.20251225084034-ff42791c540f github.com/cloudwego/eino v0.7.14
github.com/cloudwego/eino-examples v0.0.0-00010101000000-000000000000 github.com/cloudwego/eino-examples v0.0.0-00010101000000-000000000000
github.com/cloudwego/hertz v0.10.3 github.com/cloudwego/hertz v0.10.3
github.com/hertz-contrib/sse v0.1.0 github.com/hertz-contrib/sse v0.1.0

@ -108,6 +108,7 @@ github.com/cloudwego/eino v0.7.8 h1:3a2j1UKZZuQ3SzqDToOI5g6lrlJ7xZEtMlNQkTgIvaI=
github.com/cloudwego/eino v0.7.8/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= github.com/cloudwego/eino v0.7.8/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ=
github.com/cloudwego/eino v0.7.11/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= github.com/cloudwego/eino v0.7.11/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ=
github.com/cloudwego/eino v0.7.14-0.20251225084034-ff42791c540f/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= github.com/cloudwego/eino v0.7.14-0.20251225084034-ff42791c540f/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ=
github.com/cloudwego/eino v0.7.14/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ=
github.com/cloudwego/eino-ext/components/model/ark v0.1.45 h1:LWvSHJVlvS1S/IxN9XUKNw/MI0I7YPePt3LMNxyCrZ0= github.com/cloudwego/eino-ext/components/model/ark v0.1.45 h1:LWvSHJVlvS1S/IxN9XUKNw/MI0I7YPePt3LMNxyCrZ0=
github.com/cloudwego/eino-ext/components/model/ark v0.1.45/go.mod h1:e8P5dGVI/JMQ1FYNgmu5EFRWA8fivBc6NwNJ9g8FBK8= github.com/cloudwego/eino-ext/components/model/ark v0.1.45/go.mod h1:e8P5dGVI/JMQ1FYNgmu5EFRWA8fivBc6NwNJ9g8FBK8=
github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU= github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU=

@ -18,14 +18,21 @@ package main
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"os" "os"
"sync"
"github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent/react" "github.com/cloudwego/eino/flow/agent/react"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino-examples/components/tool/middlewares/errorremover"
"github.com/cloudwego/eino-examples/flow/agent/react/memory_example/memory" "github.com/cloudwego/eino-examples/flow/agent/react/memory_example/memory"
"github.com/cloudwego/eino-examples/flow/agent/react/tools"
) )
func main() { func main() {
@ -41,11 +48,17 @@ func main() {
panic(err) panic(err)
} }
// System prompt is injected at runtime and not persisted. sys := "你是一个简洁的助手。请在多轮对话中保持上下文。当用户询问餐厅或菜品时,请使用工具查询。"
sys := "You are a concise assistant. Maintain context across turns."
restaurantTool := tools.GetRestaurantTool()
dishTool := tools.GetDishTool()
agent, err := react.NewAgent(ctx, &react.AgentConfig{ agent, err := react.NewAgent(ctx, &react.AgentConfig{
Model: model, Model: model,
ToolsConfig: compose.ToolsNodeConfig{
Tools: []tool.BaseTool{restaurantTool, dishTool},
ToolCallMiddlewares: []compose.ToolMiddleware{errorremover.Middleware()},
},
MessageModifier: func(_ context.Context, input []*schema.Message) []*schema.Message { MessageModifier: func(_ context.Context, input []*schema.Message) []*schema.Message {
return append([]*schema.Message{schema.SystemMessage(sys)}, input...) return append([]*schema.Message{schema.SystemMessage(sys)}, input...)
}, },
@ -54,30 +67,106 @@ func main() {
panic(err) panic(err)
} }
// Choose your store: InMemoryStore (default) or RedisStore (see README).
store := memory.NewInMemoryStore() store := memory.NewInMemoryStore()
sessionID := "session:demo" sessionID := "session:demo"
verifyGobRoundTrip() verifyGobRoundTrip()
run := func(turn string) { run := func(turn string) {
// 1) restore prior messages, 2) append new input, 3) call agent, 4) persist with output fmt.Println("\n========== Turn Start ==========")
fmt.Printf("[User Input] %s\n", turn)
prev, _ := store.Read(ctx, sessionID) prev, _ := store.Read(ctx, sessionID)
fmt.Printf("[Restored %d messages]\n", len(prev))
for i, m := range prev {
if len(m.ToolCalls) > 0 {
for _, tc := range m.ToolCalls {
fmt.Printf(" [%d] role=%s tool_call=%s args=%s\n", i, m.Role, tc.Function.Name, truncateRunes(tc.Function.Arguments, 60))
}
} else if m.Role == schema.Tool {
fmt.Printf(" [%d] role=%s tool=%s result=%s\n", i, m.Role, m.ToolName, truncateRunes(m.Content, 60))
} else {
fmt.Printf(" [%d] role=%s content=%s\n", i, m.Role, truncateRunes(m.Content, 60))
}
}
eff := append(prev, schema.UserMessage(turn)) eff := append(prev, schema.UserMessage(turn))
msg, err := agent.Generate(ctx, eff)
if err != nil { msgFutureOpt, msgFuture := react.WithMessageFuture()
panic(err)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
sr, err := agent.Stream(ctx, eff, msgFutureOpt)
if err != nil {
panic(err)
}
sr.Close()
}()
produced := make([]*schema.Message, 0, 4)
iter := msgFuture.GetMessageStreams()
idx := 0
for {
sr, ok, e := iter.Next()
if e != nil {
panic(e)
}
if !ok {
break
}
var chunks []*schema.Message
for {
m, er := sr.Recv()
if errors.Is(er, io.EOF) {
break
}
if er != nil {
panic(er)
}
chunks = append(chunks, m)
}
full, er := schema.ConcatMessages(chunks)
if er == nil && full != nil {
printMessage(idx, full)
produced = append(produced, full)
}
idx++
} }
fmt.Printf("history_before=%d after=%d\n", len(prev), len(eff)+1)
fmt.Println(msg.Content)
_ = store.Write(ctx, sessionID, append(eff, msg))
hits, _ := store.Query(ctx, sessionID, "AI", 3) wg.Wait()
fmt.Printf("query_hits=%d\n", len(hits))
fmt.Printf("[Produced %d messages this turn]\n", len(produced))
_ = store.Write(ctx, sessionID, append(eff, produced...))
hits, _ := store.Query(ctx, sessionID, "restaurant", 3)
fmt.Printf("[Query 'restaurant' hits=%d]\n", len(hits))
for i, h := range hits {
fmt.Printf(" hit[%d] role=%s content=%s\n", i, h.Role, truncate(h.Content, 60))
}
fmt.Println("========== Turn End ==========")
} }
run("Hello, summarize AI briefly.") run("帮我找北京排名前2的餐厅。")
run("Add two more details.") run("第一家餐厅有什么菜?")
}
func printMessage(idx int, m *schema.Message) {
switch m.Role {
case schema.Assistant:
if len(m.ToolCalls) > 0 {
for _, tc := range m.ToolCalls {
fmt.Printf("[Stream %d] role=%s tool_call=%s args=%s\n", idx, m.Role, tc.Function.Name, truncate(tc.Function.Arguments, 60))
}
} else {
fmt.Printf("[Stream %d] role=%s content=%s\n", idx, m.Role, truncate(m.Content, 80))
}
case schema.Tool:
fmt.Printf("[Stream %d] role=%s tool=%s result=%s\n", idx, m.Role, m.ToolName, truncate(m.Content, 80))
default:
fmt.Printf("[Stream %d] role=%s content=%s\n", idx, m.Role, truncate(m.Content, 80))
}
} }
func verifyGobRoundTrip() { func verifyGobRoundTrip() {
@ -85,7 +174,6 @@ func verifyGobRoundTrip() {
schema.UserMessage("a"), schema.UserMessage("a"),
schema.AssistantMessage("b", nil), schema.AssistantMessage("b", nil),
} }
// Round-trip serialize/deserialize to validate gob setup.
b, err := memory.EncodeMessages(msgs) b, err := memory.EncodeMessages(msgs)
if err != nil { if err != nil {
panic(err) panic(err)
@ -96,3 +184,18 @@ func verifyGobRoundTrip() {
} }
fmt.Printf("gob_round_trip=%d\n", len(out)) fmt.Printf("gob_round_trip=%d\n", len(out))
} }
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
func truncateRunes(s string, n int) string {
runes := []rune(s)
if len(runes) <= n {
return s
}
return string(runes[:n]) + "..."
}

Loading…
Cancel
Save