/* * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package main import ( "bufio" "bytes" "context" "encoding/json" "errors" "flag" "fmt" "io" "log" "os" "path/filepath" "strings" "time" "github.com/google/uuid" localbk "github.com/cloudwego/eino-ext/adk/backend/local" clc "github.com/cloudwego/eino-ext/callbacks/cozeloop" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/middlewares/skill" "github.com/cloudwego/eino/adk/prebuilt/deep" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" "github.com/coze-dev/cozeloop-go" examplemodel "github.com/cloudwego/eino-examples/adk/common/model" adkstore "github.com/cloudwego/eino-examples/adk/common/store" commontool "github.com/cloudwego/eino-examples/adk/common/tool" "github.com/cloudwego/eino-examples/quickstart/chatwitheino/mem" "github.com/cloudwego/eino-examples/quickstart/chatwitheino/rag" ) func main() { var sessionID string var instruction string flag.StringVar(&sessionID, "session", "", "session ID (creates new if empty)") flag.StringVar(&instruction, "instruction", "", "custom instruction (empty for default)") flag.Parse() ctx := context.Background() cozeloopApiToken := os.Getenv("COZELOOP_API_TOKEN") cozeloopWorkspaceID := os.Getenv("COZELOOP_WORKSPACE_ID") if cozeloopApiToken != "" && cozeloopWorkspaceID != "" { client, err := cozeloop.NewClient( cozeloop.WithAPIToken(cozeloopApiToken), cozeloop.WithWorkspaceID(cozeloopWorkspaceID), ) if err != nil { log.Fatalf("cozeloop.NewClient failed: %v", err) } defer func() { time.Sleep(5 * time.Second) client.Close(ctx) }() callbacks.AppendGlobalHandlers(clc.NewLoopHandler(client)) log.Println("CozeLoop tracing enabled") } else { log.Println("CozeLoop tracing disabled (set COZELOOP_API_TOKEN and COZELOOP_WORKSPACE_ID to enable)") } cm := examplemodel.NewChatModel() projectRoot := os.Getenv("PROJECT_ROOT") if projectRoot == "" { if cwd, err := os.Getwd(); err == nil { projectRoot = cwd } } if abs, err := filepath.Abs(projectRoot); err == nil { projectRoot = abs } defaultInstruction := fmt.Sprintf(`You are a helpful assistant that helps users learn the Eino framework. IMPORTANT: When using filesystem tools (ls, read_file, glob, grep, etc.), you MUST use absolute paths. The project root directory is: %s - When the user asks to list files in "current directory", use path: %s - When the user asks to read a file with a relative path, convert it to absolute path by prepending %s - Example: if user says "read main.go", you should call read_file with file_path: "%s/main.go" Always use absolute paths when calling filesystem tools.`, projectRoot, projectRoot, projectRoot, projectRoot) agentInstruction := defaultInstruction if instruction != "" { agentInstruction = instruction } backend, err := localbk.NewBackend(ctx, &localbk.Config{}) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } ragTool, err := rag.BuildTool(ctx, cm) if err != nil { _, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("build rag tool: %w", err)) os.Exit(1) } var handlers []adk.ChatModelAgentMiddleware skillsDir, found := resolveSkillsDir() if found { skillBackend, sbErr := skill.NewBackendFromFilesystem(ctx, &skill.BackendFromFilesystemConfig{ Backend: backend, BaseDir: skillsDir, }) if sbErr != nil { _, _ = fmt.Fprintln(os.Stderr, sbErr) os.Exit(1) } skillMiddleware, smErr := skill.NewMiddleware(ctx, &skill.Config{ Backend: skillBackend, }) if smErr != nil { _, _ = fmt.Fprintln(os.Stderr, smErr) os.Exit(1) } handlers = append(handlers, skillMiddleware) } handlers = append(handlers, &approvalMiddleware{}, &safeToolMiddleware{}) agent, err := deep.New(ctx, &deep.Config{ Name: "Ch09RAGSkillAgent", Description: "ChatWithDoc agent with RAG tool and skill middleware.", ChatModel: cm, Instruction: agentInstruction, Backend: backend, StreamingShell: backend, MaxIteration: 50, Handlers: handlers, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: []tool.BaseTool{ragTool}, }, }, ModelRetryConfig: &adk.ModelRetryConfig{ MaxRetries: 5, IsRetryAble: func(_ context.Context, err error) bool { return strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "Too Many Requests") || strings.Contains(err.Error(), "qpm limit") }, }, }) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } runner := adk.NewRunner(ctx, adk.RunnerConfig{ Agent: agent, EnableStreaming: true, CheckPointStore: adkstore.NewInMemoryStore(), }) sessionDir := os.Getenv("SESSION_DIR") if sessionDir == "" { sessionDir = "./data/sessions" } store, err := mem.NewStore(sessionDir) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } if sessionID == "" { sessionID = uuid.New().String() fmt.Printf("Created new session: %s\n", sessionID) } else { fmt.Printf("Resuming session: %s\n", sessionID) } session, err := store.GetOrCreate(sessionID) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } fmt.Printf("Session title: %s\n", session.Title()) fmt.Printf("Project root: %s\n", projectRoot) if found { fmt.Printf("Skills dir: %s\n", skillsDir) } else { fmt.Println("Skills dir: (not configured) set EINO_EXT_SKILLS_DIR=/path/to/skills") } fmt.Println("Enter your message (empty line to exit):") reader := bufio.NewReader(os.Stdin) checkPointID := sessionID for { _, _ = fmt.Fprint(os.Stdout, "you> ") line, readErr := reader.ReadString('\n') if errors.Is(readErr, io.EOF) { break } if readErr != nil { _, _ = fmt.Fprintln(os.Stderr, readErr) os.Exit(1) } line = strings.TrimSpace(line) if line == "" { break } userMsg := schema.UserMessage(line) if err := session.Append(userMsg); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } history := session.GetMessages() events := runner.Run(ctx, history, adk.WithCheckPointID(checkPointID)) content, interruptInfo, err := printAndCollectAssistantFromEvents(events) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } if interruptInfo != nil { content, err = handleInterrupt(ctx, runner, checkPointID, interruptInfo, reader) if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } assistantMsg := schema.AssistantMessage(content, nil) if err := session.Append(assistantMsg); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } fmt.Printf("\nSession saved: %s\n", sessionID) fmt.Printf("Resume with: go run ./cmd/ch09 --session %s\n", sessionID) } func resolveSkillsDir() (string, bool) { skillsDir := strings.TrimSpace(os.Getenv("EINO_EXT_SKILLS_DIR")) if skillsDir == "" { return "", false } if absSkillsDir, absErr := filepath.Abs(skillsDir); absErr == nil { skillsDir = absSkillsDir } fi, err := os.Stat(skillsDir) if err != nil || !fi.IsDir() { return "", false } return skillsDir, true } type approvalMiddleware struct { *adk.BaseChatModelAgentMiddleware } func (m *approvalMiddleware) WrapInvokableToolCall( _ context.Context, endpoint adk.InvokableToolCallEndpoint, tCtx *adk.ToolContext, ) (adk.InvokableToolCallEndpoint, error) { if tCtx.Name != "answer_from_document" { return endpoint, nil } return func(ctx context.Context, args string, opts ...tool.Option) (string, error) { wasInterrupted, _, storedArgs := tool.GetInterruptState[string](ctx) if !wasInterrupted { return "", tool.StatefulInterrupt(ctx, &commontool.ApprovalInfo{ ToolName: tCtx.Name, ArgumentsInJSON: args, }, args) } isTarget, hasData, data := tool.GetResumeContext[*commontool.ApprovalResult](ctx) if isTarget && hasData { if data.Approved { return endpoint(ctx, storedArgs, opts...) } if data.DisapproveReason != nil { return fmt.Sprintf("tool '%s' disapproved: %s", tCtx.Name, *data.DisapproveReason), nil } return fmt.Sprintf("tool '%s' disapproved", tCtx.Name), nil } isTarget2, _, _ := tool.GetResumeContext[any](ctx) if !isTarget2 { return "", tool.StatefulInterrupt(ctx, &commontool.ApprovalInfo{ ToolName: tCtx.Name, ArgumentsInJSON: storedArgs, }, storedArgs) } return endpoint(ctx, storedArgs, opts...) }, nil } func (m *approvalMiddleware) WrapStreamableToolCall( _ context.Context, endpoint adk.StreamableToolCallEndpoint, tCtx *adk.ToolContext, ) (adk.StreamableToolCallEndpoint, error) { if tCtx.Name != "answer_from_document" { return endpoint, nil } return func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { wasInterrupted, _, storedArgs := tool.GetInterruptState[string](ctx) if !wasInterrupted { return nil, tool.StatefulInterrupt(ctx, &commontool.ApprovalInfo{ ToolName: tCtx.Name, ArgumentsInJSON: args, }, args) } isTarget, hasData, data := tool.GetResumeContext[*commontool.ApprovalResult](ctx) if isTarget && hasData { if data.Approved { return endpoint(ctx, storedArgs, opts...) } if data.DisapproveReason != nil { return singleChunkReader(fmt.Sprintf("tool '%s' disapproved: %s", tCtx.Name, *data.DisapproveReason)), nil } return singleChunkReader(fmt.Sprintf("tool '%s' disapproved", tCtx.Name)), nil } isTarget2, _, _ := tool.GetResumeContext[any](ctx) if !isTarget2 { return nil, tool.StatefulInterrupt(ctx, &commontool.ApprovalInfo{ ToolName: tCtx.Name, ArgumentsInJSON: storedArgs, }, storedArgs) } return endpoint(ctx, storedArgs, opts...) }, nil } type safeToolMiddleware struct { *adk.BaseChatModelAgentMiddleware } func (m *safeToolMiddleware) WrapInvokableToolCall( _ context.Context, endpoint adk.InvokableToolCallEndpoint, _ *adk.ToolContext, ) (adk.InvokableToolCallEndpoint, error) { return func(ctx context.Context, args string, opts ...tool.Option) (string, error) { result, err := endpoint(ctx, args, opts...) if err != nil { if _, ok := compose.IsInterruptRerunError(err); ok { return "", err } return fmt.Sprintf("[tool error] %v", err), nil } return result, nil }, nil } func (m *safeToolMiddleware) WrapStreamableToolCall( _ context.Context, endpoint adk.StreamableToolCallEndpoint, _ *adk.ToolContext, ) (adk.StreamableToolCallEndpoint, error) { return func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { sr, err := endpoint(ctx, args, opts...) if err != nil { if _, ok := compose.IsInterruptRerunError(err); ok { return nil, err } return singleChunkReader(fmt.Sprintf("[tool error] %v", err)), nil } return safeWrapReader(sr), nil }, nil } func singleChunkReader(msg string) *schema.StreamReader[string] { r, w := schema.Pipe[string](1) _ = w.Send(msg, nil) w.Close() return r } func safeWrapReader(sr *schema.StreamReader[string]) *schema.StreamReader[string] { r, w := schema.Pipe[string](64) go func() { defer w.Close() for { chunk, err := sr.Recv() if errors.Is(err, io.EOF) { return } if err != nil { _ = w.Send(fmt.Sprintf("\n[tool error] %v", err), nil) return } _ = w.Send(chunk, nil) } }() return r } func printAndCollectAssistantFromEvents(events *adk.AsyncIterator[*adk.AgentEvent]) (string, *adk.InterruptInfo, error) { var sb strings.Builder var interruptInfo *adk.InterruptInfo for { event, ok := events.Next() if !ok { break } if event.Err != nil { return "", nil, event.Err } if event.Action != nil && event.Action.Interrupted != nil { interruptInfo = event.Action.Interrupted continue } if event.Output != nil && event.Output.MessageOutput != nil { mv := event.Output.MessageOutput if mv.Role == schema.Tool { content := drainToolResult(mv) fmt.Printf("[tool result] %s\n", truncate(content, 200)) continue } if mv.Role != schema.Assistant && mv.Role != "" { continue } if mv.IsStreaming { mv.MessageStream.SetAutomaticClose() var accumulatedToolCalls []schema.ToolCall for { frame, err := mv.MessageStream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { return "", nil, err } if frame != nil { if frame.Content != "" { sb.WriteString(frame.Content) _, _ = fmt.Fprint(os.Stdout, frame.Content) } if len(frame.ToolCalls) > 0 { accumulatedToolCalls = append(accumulatedToolCalls, frame.ToolCalls...) } } } for _, tc := range accumulatedToolCalls { if tc.Function.Name != "" && tc.Function.Arguments != "" { fmt.Printf("\n[tool call] %s(%s)\n", tc.Function.Name, tc.Function.Arguments) } } _, _ = fmt.Fprintln(os.Stdout) continue } if mv.Message != nil { sb.WriteString(mv.Message.Content) _, _ = fmt.Fprintln(os.Stdout, mv.Message.Content) for _, tc := range mv.Message.ToolCalls { fmt.Printf("[tool call] %s(%s)\n", tc.Function.Name, tc.Function.Arguments) } } } } return sb.String(), interruptInfo, nil } func drainToolResult(mo *adk.MessageVariant) string { if mo.IsStreaming && mo.MessageStream != nil { var sb strings.Builder for { chunk, err := mo.MessageStream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { break } if chunk != nil && chunk.Content != "" { sb.WriteString(chunk.Content) } } return sb.String() } if mo.Message != nil { return mo.Message.Content } return "" } func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } var result bytes.Buffer if err := json.Compact(&result, []byte(s)); err == nil { s = result.String() } if len(s) <= maxLen { return s } return s[:maxLen] + "..." } func handleInterrupt(ctx context.Context, runner *adk.Runner, checkPointID string, interruptInfo *adk.InterruptInfo, reader *bufio.Reader) (string, error) { for _, ic := range interruptInfo.InterruptContexts { if !ic.IsRootCause { continue } info, ok := ic.Info.(*commontool.ApprovalInfo) if !ok { continue } fmt.Printf("\n⚠️ Approval Required ⚠️\n") fmt.Printf("Tool: %s\n", info.ToolName) fmt.Printf("Arguments: %s\n", info.ArgumentsInJSON) fmt.Print("\nApprove this action? (y/n): ") response, err := reader.ReadString('\n') if err != nil { return "", fmt.Errorf("failed to read user input: %w", err) } response = strings.TrimSpace(strings.ToLower(response)) var resumeData *commontool.ApprovalResult if response == "y" || response == "yes" { resumeData = &commontool.ApprovalResult{Approved: true} fmt.Println("✓ Approved, executing...") } else { resumeData = &commontool.ApprovalResult{Approved: false} fmt.Println("✗ Rejected") } events, err := runner.ResumeWithParams(ctx, checkPointID, &adk.ResumeParams{ Targets: map[string]any{ ic.ID: resumeData, }, }) if err != nil { return "", fmt.Errorf("failed to resume: %w", err) } content, newInterruptInfo, err := printAndCollectAssistantFromEvents(events) if err != nil { return "", err } if newInterruptInfo != nil { return handleInterrupt(ctx, runner, checkPointID, newInterruptInfo, reader) } return content, nil } return "", fmt.Errorf("no root cause interrupt context found") }