You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
232 lines
6.9 KiB
Go
232 lines
6.9 KiB
Go
/*
|
|
* Copyright 2026 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
|
"github.com/cloudwego/eino/adk"
|
|
"github.com/cloudwego/eino/adk/middlewares/skill"
|
|
"github.com/cloudwego/eino/adk/prebuilt/deep"
|
|
"github.com/cloudwego/eino/components/tool"
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
"github.com/cloudwego/eino-examples/adk/common/model"
|
|
commontool "github.com/cloudwego/eino-examples/adk/common/tool"
|
|
"github.com/cloudwego/eino-examples/quickstart/chatwitheino/rag"
|
|
)
|
|
|
|
func buildAgent(ctx context.Context) (adk.Agent, error) {
|
|
cm := model.NewChatModel()
|
|
|
|
backend, err := localbk.NewBackend(ctx, &localbk.Config{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ragTool, err := rag.BuildTool(ctx, cm)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build rag tool: %w", err)
|
|
}
|
|
|
|
var handlers []adk.ChatModelAgentMiddleware
|
|
if skillsDir, ok := resolveSkillsDir(); ok {
|
|
skillBackend, sbErr := skill.NewBackendFromFilesystem(ctx, &skill.BackendFromFilesystemConfig{
|
|
Backend: backend,
|
|
BaseDir: skillsDir,
|
|
})
|
|
if sbErr != nil {
|
|
return nil, sbErr
|
|
}
|
|
skillMiddleware, smErr := skill.NewMiddleware(ctx, &skill.Config{
|
|
Backend: skillBackend,
|
|
})
|
|
if smErr != nil {
|
|
return nil, smErr
|
|
}
|
|
handlers = append(handlers, skillMiddleware)
|
|
}
|
|
handlers = append(handlers, &approvalMiddleware{}, &safeToolMiddleware{})
|
|
|
|
return deep.New(ctx, &deep.Config{
|
|
Name: "ChatWithDocAgent",
|
|
Description: "An agent that reads and answers questions about documents.",
|
|
ChatModel: cm,
|
|
Backend: backend,
|
|
StreamingShell: backend,
|
|
MaxIteration: 50,
|
|
Handlers: handlers,
|
|
ToolsConfig: adk.ToolsConfig{
|
|
ToolsNodeConfig: compose.ToolsNodeConfig{
|
|
Tools: []tool.BaseTool{ragTool},
|
|
},
|
|
},
|
|
ModelRetryConfig: &adk.ModelRetryConfig{
|
|
MaxRetries: 5,
|
|
IsRetryAble: func(_ context.Context, err error) bool {
|
|
return strings.Contains(err.Error(), "429") ||
|
|
strings.Contains(err.Error(), "Too Many Requests") ||
|
|
strings.Contains(err.Error(), "qpm limit")
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
func resolveSkillsDir() (string, bool) {
|
|
skillsDir := strings.TrimSpace(os.Getenv("EINO_EXT_SKILLS_DIR"))
|
|
if skillsDir == "" {
|
|
return "", false
|
|
}
|
|
if absSkillsDir, absErr := filepath.Abs(skillsDir); absErr == nil {
|
|
skillsDir = absSkillsDir
|
|
}
|
|
fi, err := os.Stat(skillsDir)
|
|
if err != nil || !fi.IsDir() {
|
|
return "", false
|
|
}
|
|
return skillsDir, true
|
|
}
|
|
|
|
// safeToolMiddleware converts streaming tool errors into error-message strings
|
|
// so that a non-zero exit code or mid-stream failure is returned to the model
|
|
// as a readable tool result instead of aborting the agent pipeline.
|
|
type safeToolMiddleware struct {
|
|
*adk.BaseChatModelAgentMiddleware
|
|
}
|
|
|
|
// approvalMiddleware intercepts calls to the answer_from_document tool and
|
|
// pauses the agent with a human-approval interrupt before executing the RAG
|
|
// workflow. The runner's CheckPointStore must be configured for this to work.
|
|
type approvalMiddleware struct {
|
|
*adk.BaseChatModelAgentMiddleware
|
|
}
|
|
|
|
// WrapInvokableToolCall inserts an approval gate around the answer_from_document
|
|
// tool. All other tools pass through unchanged.
|
|
func (m *approvalMiddleware) WrapInvokableToolCall(
|
|
_ context.Context,
|
|
endpoint adk.InvokableToolCallEndpoint,
|
|
tCtx *adk.ToolContext,
|
|
) (adk.InvokableToolCallEndpoint, error) {
|
|
if tCtx.Name != "answer_from_document" {
|
|
return endpoint, nil
|
|
}
|
|
return func(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
|
wasInterrupted, _, storedArgs := tool.GetInterruptState[string](ctx)
|
|
if !wasInterrupted {
|
|
return "", tool.StatefulInterrupt(ctx, &commontool.ApprovalInfo{
|
|
ToolName: tCtx.Name,
|
|
ArgumentsInJSON: args,
|
|
}, args)
|
|
}
|
|
|
|
isTarget, hasData, data := tool.GetResumeContext[*commontool.ApprovalResult](ctx)
|
|
if isTarget && hasData {
|
|
if data.Approved {
|
|
return endpoint(ctx, storedArgs, opts...)
|
|
}
|
|
if data.DisapproveReason != nil {
|
|
return fmt.Sprintf("tool '%s' disapproved: %s", tCtx.Name, *data.DisapproveReason), nil
|
|
}
|
|
return fmt.Sprintf("tool '%s' disapproved", tCtx.Name), nil
|
|
}
|
|
|
|
// Re-interrupt if this is not the resume target (another tool was resumed instead).
|
|
isTarget, _, _ = tool.GetResumeContext[any](ctx)
|
|
if !isTarget {
|
|
return "", tool.StatefulInterrupt(ctx, &commontool.ApprovalInfo{
|
|
ToolName: tCtx.Name,
|
|
ArgumentsInJSON: storedArgs,
|
|
}, storedArgs)
|
|
}
|
|
|
|
return endpoint(ctx, storedArgs, opts...)
|
|
}, nil
|
|
}
|
|
|
|
func (m *safeToolMiddleware) WrapInvokableToolCall(
|
|
_ context.Context,
|
|
endpoint adk.InvokableToolCallEndpoint,
|
|
_ *adk.ToolContext,
|
|
) (adk.InvokableToolCallEndpoint, error) {
|
|
return func(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
|
result, err := endpoint(ctx, args, opts...)
|
|
if err != nil {
|
|
if _, ok := compose.IsInterruptRerunError(err); ok {
|
|
return "", err
|
|
}
|
|
return fmt.Sprintf("[tool error] %v", err), nil
|
|
}
|
|
return result, nil
|
|
}, nil
|
|
}
|
|
|
|
func (m *safeToolMiddleware) WrapStreamableToolCall(
|
|
_ context.Context,
|
|
endpoint adk.StreamableToolCallEndpoint,
|
|
_ *adk.ToolContext,
|
|
) (adk.StreamableToolCallEndpoint, error) {
|
|
return func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) {
|
|
sr, err := endpoint(ctx, args, opts...)
|
|
if err != nil {
|
|
if _, ok := compose.IsInterruptRerunError(err); ok {
|
|
return nil, err
|
|
}
|
|
return singleChunkReader(fmt.Sprintf("[tool error] %v", err)), nil
|
|
}
|
|
return safeWrapReader(sr), nil
|
|
}, nil
|
|
}
|
|
|
|
// singleChunkReader returns a StreamReader that emits one string then EOF.
|
|
func singleChunkReader(msg string) *schema.StreamReader[string] {
|
|
r, w := schema.Pipe[string](1)
|
|
_ = w.Send(msg, nil)
|
|
w.Close()
|
|
return r
|
|
}
|
|
|
|
// safeWrapReader proxies chunks from sr; on a stream error it emits the error
|
|
// as a final chunk instead of propagating it, so the model sees a complete
|
|
// (if error-annotated) tool result rather than a pipeline failure.
|
|
func safeWrapReader(sr *schema.StreamReader[string]) *schema.StreamReader[string] {
|
|
r, w := schema.Pipe[string](64)
|
|
go func() {
|
|
defer w.Close()
|
|
for {
|
|
chunk, err := sr.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
return
|
|
}
|
|
if err != nil {
|
|
_ = w.Send(fmt.Sprintf("\n[tool error] %v", err), nil)
|
|
return
|
|
}
|
|
_ = w.Send(chunk, nil)
|
|
}
|
|
}()
|
|
return r
|
|
}
|