/* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package main import ( "bufio" "bytes" "context" "encoding/json" "errors" "flag" "fmt" "io" "log" "os" "path/filepath" "strings" "time" "github.com/google/uuid" localbk "github.com/cloudwego/eino-ext/adk/backend/local" clc "github.com/cloudwego/eino-ext/callbacks/cozeloop" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/prebuilt/deep" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" "github.com/coze-dev/cozeloop-go" examplemodel "github.com/cloudwego/eino-examples/adk/common/model" "github.com/cloudwego/eino-examples/quickstart/chatwitheino/mem" ) func main() { var sessionID string var instruction string flag.StringVar(&sessionID, "session", "", "session ID (creates new if empty)") flag.StringVar(&instruction, "instruction", "", "custom instruction (empty for default)") flag.Parse() ctx := context.Background() // Setup CozeLoop tracing (optional) // Set COZELOOP_API_TOKEN and COZELOOP_WORKSPACE_ID to enable cozeloopApiToken := os.Getenv("COZELOOP_API_TOKEN") cozeloopWorkspaceID := os.Getenv("COZELOOP_WORKSPACE_ID") if cozeloopApiToken != "" && cozeloopWorkspaceID != "" { client, err := cozeloop.NewClient( cozeloop.WithAPIToken(cozeloopApiToken), cozeloop.WithWorkspaceID(cozeloopWorkspaceID), ) if err != nil { log.Fatalf("cozeloop.NewClient failed: %v", err) } defer func() { time.Sleep(5 * time.Second) client.Close(ctx) }() callbacks.AppendGlobalHandlers(clc.NewLoopHandler(client)) log.Println("CozeLoop tracing enabled") } else { log.Println("CozeLoop tracing disabled (set COZELOOP_API_TOKEN and COZELOOP_WORKSPACE_ID to enable)") } cm := examplemodel.NewChatModel() projectRoot := os.Getenv("PROJECT_ROOT") if projectRoot == "" { if cwd, err := os.Getwd(); err == nil { projectRoot = cwd } } if abs, err := filepath.Abs(projectRoot); err == nil { projectRoot = abs } defaultInstruction := fmt.Sprintf(`You are a helpful assistant that helps users learn the Eino framework. IMPORTANT: When using filesystem tools (ls, read_file, glob, grep, etc.), you MUST use absolute paths. The project root directory is: %s - When the user asks to list files in "current directory", use path: %s - When the user asks to read a file with a relative path, convert it to absolute path by prepending %s - Example: if user says "read main.go", you should call read_file with file_path: "%s/main.go" Always use absolute paths when calling filesystem tools.`, projectRoot, projectRoot, projectRoot, projectRoot) agentInstruction := defaultInstruction if instruction != "" { agentInstruction = instruction } backend, err := localbk.NewBackend(ctx, &localbk.Config{}) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } agent, err := deep.New(ctx, &deep.Config{ Name: "Ch06CallbackAgent", Description: "ChatWithDoc agent with CozeLoop tracing.", ChatModel: cm, Instruction: agentInstruction, Backend: backend, StreamingShell: backend, MaxIteration: 50, Handlers: []adk.ChatModelAgentMiddleware{ &safeToolMiddleware{}, }, ModelRetryConfig: &adk.ModelRetryConfig{ MaxRetries: 5, IsRetryAble: func(_ context.Context, err error) bool { return strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "Too Many Requests") || strings.Contains(err.Error(), "qpm limit") }, }, }) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: agent, EnableStreaming: true, }) sessionDir := os.Getenv("SESSION_DIR") if sessionDir == "" { sessionDir = "./data/sessions" } store, err := mem.NewStore(sessionDir) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } if sessionID == "" { sessionID = uuid.New().String() fmt.Printf("Created new session: %s\n", sessionID) } else { fmt.Printf("Resuming session: %s\n", sessionID) } session, err := store.GetOrCreate(sessionID) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } fmt.Printf("Session title: %s\n", session.Title()) fmt.Printf("Project root: %s\n", projectRoot) fmt.Println("Enter your message (empty line to exit):") scanner := bufio.NewScanner(os.Stdin) for { _, _ = fmt.Fprint(os.Stdout, "you> ") if !scanner.Scan() { break } line := strings.TrimSpace(scanner.Text()) if line == "" { break } userMsg := schema.UserMessage(line) if err := session.Append(userMsg); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } history := session.GetMessages() events := runner.Run(ctx, history) content, err := printAndCollectAssistantFromEvents(events) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } assistantMsg := schema.AssistantMessage(content, nil) if err := session.Append(assistantMsg); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } if err := scanner.Err(); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } fmt.Printf("\nSession saved: %s\n", sessionID) fmt.Printf("Resume with: go run ./cmd/ch06 --session %s\n", sessionID) } type safeToolMiddleware struct { *adk.BaseChatModelAgentMiddleware } func (m *safeToolMiddleware) WrapInvokableToolCall( _ context.Context, endpoint adk.InvokableToolCallEndpoint, _ *adk.ToolContext, ) (adk.InvokableToolCallEndpoint, error) { return func(ctx context.Context, args string, opts ...tool.Option) (string, error) { result, err := endpoint(ctx, args, opts...) if err != nil { if _, ok := compose.IsInterruptRerunError(err); ok { return "", err } return fmt.Sprintf("[tool error] %v", err), nil } return result, nil }, nil } func (m *safeToolMiddleware) WrapStreamableToolCall( _ context.Context, endpoint adk.StreamableToolCallEndpoint, _ *adk.ToolContext, ) (adk.StreamableToolCallEndpoint, error) { return func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { sr, err := endpoint(ctx, args, opts...) if err != nil { if _, ok := compose.IsInterruptRerunError(err); ok { return nil, err } return singleChunkReader(fmt.Sprintf("[tool error] %v", err)), nil } return safeWrapReader(sr), nil }, nil } func singleChunkReader(msg string) *schema.StreamReader[string] { r, w := schema.Pipe[string](1) _ = w.Send(msg, nil) w.Close() return r } func safeWrapReader(sr *schema.StreamReader[string]) *schema.StreamReader[string] { r, w := schema.Pipe[string](64) go func() { defer w.Close() for { chunk, err := sr.Recv() if errors.Is(err, io.EOF) { return } if err != nil { _ = w.Send(fmt.Sprintf("\n[tool error] %v", err), nil) return } _ = w.Send(chunk, nil) } }() return r } func printAndCollectAssistantFromEvents(events *adk.AsyncIterator[*adk.AgentEvent]) (string, error) { var sb strings.Builder for { event, ok := events.Next() if !ok { break } if event.Err != nil { return "", event.Err } if event.Output != nil && event.Output.MessageOutput != nil { mv := event.Output.MessageOutput if mv.Role == schema.Tool { content := drainToolResult(mv) fmt.Printf("[tool result] %s\n", truncate(content, 200)) continue } if mv.Role != schema.Assistant && mv.Role != "" { continue } if mv.IsStreaming { mv.MessageStream.SetAutomaticClose() var accumulatedToolCalls []schema.ToolCall for { frame, err := mv.MessageStream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { return "", err } if frame != nil { if frame.Content != "" { sb.WriteString(frame.Content) _, _ = fmt.Fprint(os.Stdout, frame.Content) } if len(frame.ToolCalls) > 0 { accumulatedToolCalls = append(accumulatedToolCalls, frame.ToolCalls...) } } } for _, tc := range accumulatedToolCalls { if tc.Function.Name != "" && tc.Function.Arguments != "" { fmt.Printf("\n[tool call] %s(%s)\n", tc.Function.Name, tc.Function.Arguments) } } _, _ = fmt.Fprintln(os.Stdout) continue } if mv.Message != nil { sb.WriteString(mv.Message.Content) _, _ = fmt.Fprintln(os.Stdout, mv.Message.Content) for _, tc := range mv.Message.ToolCalls { fmt.Printf("[tool call] %s(%s)\n", tc.Function.Name, tc.Function.Arguments) } } } } return sb.String(), nil } func drainToolResult(mo *adk.MessageVariant) string { if mo.IsStreaming && mo.MessageStream != nil { var sb strings.Builder for { chunk, err := mo.MessageStream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { break } if chunk != nil && chunk.Content != "" { sb.WriteString(chunk.Content) } } return sb.String() } if mo.Message != nil { return mo.Message.Content } return "" } func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } var result bytes.Buffer if err := json.Compact(&result, []byte(s)); err == nil { s = result.String() } if len(s) <= maxLen { return s } return s[:maxLen] + "..." }