From f13f4f7555b82b6da4f3680923e0bc004554b7c3 Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Mon, 29 Dec 2025 16:30:48 +0800 Subject: [PATCH] feat(flow/react): add short-term memory example with MessageFuture and tool calls Change-Id: Iff0920fa888b74ee40770ea4e59126bbc01949c5 --- adk/intro/http-sse-service/go.mod | 2 +- adk/intro/http-sse-service/go.sum | 1 + flow/agent/react/memory_example/main.go | 133 +++++++++++++++++++++--- 3 files changed, 120 insertions(+), 16 deletions(-) diff --git a/adk/intro/http-sse-service/go.mod b/adk/intro/http-sse-service/go.mod index 200475a..8711a18 100644 --- a/adk/intro/http-sse-service/go.mod +++ b/adk/intro/http-sse-service/go.mod @@ -5,7 +5,7 @@ go 1.24.9 replace github.com/cloudwego/eino-examples => ../../.. require ( - github.com/cloudwego/eino v0.7.14-0.20251225084034-ff42791c540f + github.com/cloudwego/eino v0.7.14 github.com/cloudwego/eino-examples v0.0.0-00010101000000-000000000000 github.com/cloudwego/hertz v0.10.3 github.com/hertz-contrib/sse v0.1.0 diff --git a/adk/intro/http-sse-service/go.sum b/adk/intro/http-sse-service/go.sum index 96ac8fb..e5b7a20 100644 --- a/adk/intro/http-sse-service/go.sum +++ b/adk/intro/http-sse-service/go.sum @@ -108,6 +108,7 @@ github.com/cloudwego/eino v0.7.8 h1:3a2j1UKZZuQ3SzqDToOI5g6lrlJ7xZEtMlNQkTgIvaI= github.com/cloudwego/eino v0.7.8/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= github.com/cloudwego/eino v0.7.11/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= github.com/cloudwego/eino v0.7.14-0.20251225084034-ff42791c540f/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= +github.com/cloudwego/eino v0.7.14/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= github.com/cloudwego/eino-ext/components/model/ark v0.1.45 h1:LWvSHJVlvS1S/IxN9XUKNw/MI0I7YPePt3LMNxyCrZ0= github.com/cloudwego/eino-ext/components/model/ark v0.1.45/go.mod h1:e8P5dGVI/JMQ1FYNgmu5EFRWA8fivBc6NwNJ9g8FBK8= github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU= diff --git a/flow/agent/react/memory_example/main.go b/flow/agent/react/memory_example/main.go index 90b0607..8225ec2 100644 --- a/flow/agent/react/memory_example/main.go +++ b/flow/agent/react/memory_example/main.go @@ -18,14 +18,21 @@ package main import ( "context" + "errors" "fmt" + "io" "os" + "sync" "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/flow/agent/react" "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino-examples/components/tool/middlewares/errorremover" "github.com/cloudwego/eino-examples/flow/agent/react/memory_example/memory" + "github.com/cloudwego/eino-examples/flow/agent/react/tools" ) func main() { @@ -41,11 +48,17 @@ func main() { panic(err) } - // System prompt is injected at runtime and not persisted. - sys := "You are a concise assistant. Maintain context across turns." + sys := "你是一个简洁的助手。请在多轮对话中保持上下文。当用户询问餐厅或菜品时,请使用工具查询。" + + restaurantTool := tools.GetRestaurantTool() + dishTool := tools.GetDishTool() agent, err := react.NewAgent(ctx, &react.AgentConfig{ Model: model, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{restaurantTool, dishTool}, + ToolCallMiddlewares: []compose.ToolMiddleware{errorremover.Middleware()}, + }, MessageModifier: func(_ context.Context, input []*schema.Message) []*schema.Message { return append([]*schema.Message{schema.SystemMessage(sys)}, input...) }, @@ -54,30 +67,106 @@ func main() { panic(err) } - // Choose your store: InMemoryStore (default) or RedisStore (see README). store := memory.NewInMemoryStore() sessionID := "session:demo" verifyGobRoundTrip() run := func(turn string) { - // 1) restore prior messages, 2) append new input, 3) call agent, 4) persist with output + fmt.Println("\n========== Turn Start ==========") + fmt.Printf("[User Input] %s\n", turn) + prev, _ := store.Read(ctx, sessionID) + fmt.Printf("[Restored %d messages]\n", len(prev)) + for i, m := range prev { + if len(m.ToolCalls) > 0 { + for _, tc := range m.ToolCalls { + fmt.Printf(" [%d] role=%s tool_call=%s args=%s\n", i, m.Role, tc.Function.Name, truncateRunes(tc.Function.Arguments, 60)) + } + } else if m.Role == schema.Tool { + fmt.Printf(" [%d] role=%s tool=%s result=%s\n", i, m.Role, m.ToolName, truncateRunes(m.Content, 60)) + } else { + fmt.Printf(" [%d] role=%s content=%s\n", i, m.Role, truncateRunes(m.Content, 60)) + } + } + eff := append(prev, schema.UserMessage(turn)) - msg, err := agent.Generate(ctx, eff) - if err != nil { - panic(err) + + msgFutureOpt, msgFuture := react.WithMessageFuture() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sr, err := agent.Stream(ctx, eff, msgFutureOpt) + if err != nil { + panic(err) + } + sr.Close() + }() + + produced := make([]*schema.Message, 0, 4) + iter := msgFuture.GetMessageStreams() + idx := 0 + for { + sr, ok, e := iter.Next() + if e != nil { + panic(e) + } + if !ok { + break + } + var chunks []*schema.Message + for { + m, er := sr.Recv() + if errors.Is(er, io.EOF) { + break + } + if er != nil { + panic(er) + } + chunks = append(chunks, m) + } + full, er := schema.ConcatMessages(chunks) + if er == nil && full != nil { + printMessage(idx, full) + produced = append(produced, full) + } + idx++ } - fmt.Printf("history_before=%d after=%d\n", len(prev), len(eff)+1) - fmt.Println(msg.Content) - _ = store.Write(ctx, sessionID, append(eff, msg)) - hits, _ := store.Query(ctx, sessionID, "AI", 3) - fmt.Printf("query_hits=%d\n", len(hits)) + wg.Wait() + + fmt.Printf("[Produced %d messages this turn]\n", len(produced)) + _ = store.Write(ctx, sessionID, append(eff, produced...)) + + hits, _ := store.Query(ctx, sessionID, "restaurant", 3) + fmt.Printf("[Query 'restaurant' hits=%d]\n", len(hits)) + for i, h := range hits { + fmt.Printf(" hit[%d] role=%s content=%s\n", i, h.Role, truncate(h.Content, 60)) + } + fmt.Println("========== Turn End ==========") } - run("Hello, summarize AI briefly.") - run("Add two more details.") + run("帮我找北京排名前2的餐厅。") + run("第一家餐厅有什么菜?") +} + +func printMessage(idx int, m *schema.Message) { + switch m.Role { + case schema.Assistant: + if len(m.ToolCalls) > 0 { + for _, tc := range m.ToolCalls { + fmt.Printf("[Stream %d] role=%s tool_call=%s args=%s\n", idx, m.Role, tc.Function.Name, truncate(tc.Function.Arguments, 60)) + } + } else { + fmt.Printf("[Stream %d] role=%s content=%s\n", idx, m.Role, truncate(m.Content, 80)) + } + case schema.Tool: + fmt.Printf("[Stream %d] role=%s tool=%s result=%s\n", idx, m.Role, m.ToolName, truncate(m.Content, 80)) + default: + fmt.Printf("[Stream %d] role=%s content=%s\n", idx, m.Role, truncate(m.Content, 80)) + } } func verifyGobRoundTrip() { @@ -85,7 +174,6 @@ func verifyGobRoundTrip() { schema.UserMessage("a"), schema.AssistantMessage("b", nil), } - // Round-trip serialize/deserialize to validate gob setup. b, err := memory.EncodeMessages(msgs) if err != nil { panic(err) @@ -96,3 +184,18 @@ func verifyGobRoundTrip() { } fmt.Printf("gob_round_trip=%d\n", len(out)) } + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func truncateRunes(s string, n int) string { + runes := []rune(s) + if len(runes) <= n { + return s + } + return string(runes[:n]) + "..." +}