You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
6.8 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"os"
clc "github.com/cloudwego/eino-ext/callbacks/cozeloop"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino-ext/components/tool/duckduckgo/v2"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/components/tool/utils"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/cozeloop-go"
"github.com/cloudwego/eino-examples/internal/gptr"
"github.com/cloudwego/eino-examples/internal/logs"
)
func main() {
openAIAPIKey := os.Getenv("OPENAI_API_KEY")
openAIModelName := os.Getenv("OPENAI_MODEL_NAME")
openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
cozeloopApiToken := os.Getenv("COZELOOP_API_TOKEN")
cozeloopWorkspaceID := os.Getenv("COZELOOP_WORKSPACE_ID") // use cozeloop trace, from https://loop.coze.cn/open/docs/cozeloop/go-sdk#4a8c980e
ctx := context.Background()
var handlers []callbacks.Handler
if cozeloopApiToken != "" && cozeloopWorkspaceID != "" {
client, err := cozeloop.NewClient(
cozeloop.WithAPIToken(cozeloopApiToken),
cozeloop.WithWorkspaceID(cozeloopWorkspaceID),
)
if err != nil {
panic(err)
}
defer client.Close(ctx)
handlers = append(handlers, clc.NewLoopHandler(client))
}
callbacks.AppendGlobalHandlers(handlers...)
updateTool, err := utils.InferTool("update_todo", "Update a todo item, eg: content,deadline...", UpdateTodoFunc)
if err != nil {
logs.Errorf("InferTool failed, err=%v", err)
return
}
// 创建 DuckDuckGo 工具
searchTool, err := duckduckgo.NewTextSearchTool(ctx, &duckduckgo.Config{})
if err != nil {
logs.Errorf("NewTextSearchTool failed, err=%v", err)
return
}
// 初始化 tools
todoTools := []tool.BaseTool{
getAddTodoTool(), // 使用 NewTool 方式
updateTool, // 使用 InferTool 方式
&ListTodoTool{}, // 使用结构体实现方式, 此处未实现底层逻辑
searchTool,
}
// 创建并配置 ChatModel
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
BaseURL: openAIBaseURL,
Model: openAIModelName,
APIKey: openAIAPIKey,
Temperature: gptr.Of(float32(0.7)),
})
if err != nil {
logs.Errorf("NewChatModel failed, err=%v", err)
return
}
// 获取工具信息, 用于绑定到 ChatModel
toolInfos := make([]*schema.ToolInfo, 0, len(todoTools))
var info *schema.ToolInfo
for _, todoTool := range todoTools {
info, err = todoTool.Info(ctx)
if err != nil {
logs.Infof("get ToolInfo failed, err=%v", err)
return
}
toolInfos = append(toolInfos, info)
}
// 将 tools 绑定到 ChatModel
err = chatModel.BindTools(toolInfos)
if err != nil {
logs.Errorf("BindTools failed, err=%v", err)
return
}
// 创建 tools 节点
todoToolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{
Tools: todoTools,
})
if err != nil {
logs.Errorf("NewToolNode failed, err=%v", err)
return
}
// 构建完整的处理链
chain := compose.NewChain[[]*schema.Message, []*schema.Message]()
chain.
AppendChatModel(chatModel, compose.WithNodeName("chat_model")).
AppendToolsNode(todoToolsNode, compose.WithNodeName("tools"))
// 编译并运行 chain
agent, err := chain.Compile(ctx)
if err != nil {
logs.Errorf("chain.Compile failed, err=%v", err)
return
}
// 运行示例
resp, err := agent.Invoke(ctx, []*schema.Message{
{
Role: schema.User,
Content: "添加一个学习 Eino 的 TODO同时搜索一下 cloudwego/eino 的仓库地址",
},
})
if err != nil {
logs.Errorf("agent.Invoke failed, err=%v", err)
return
}
// 输出结果
for idx, msg := range resp {
logs.Infof("\n")
logs.Infof("message %d: %s: %s", idx, msg.Role, msg.Content)
}
}
// 获取添加 todo 工具
// 使用 utils.NewTool 创建工具
func getAddTodoTool() tool.InvokableTool {
info := &schema.ToolInfo{
Name: "add_todo",
Desc: "Add a todo item",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"content": {
Desc: "The content of the todo item",
Type: schema.String,
Required: true,
},
"started_at": {
Desc: "The started time of the todo item, in unix timestamp",
Type: schema.Integer,
},
"deadline": {
Desc: "The deadline of the todo item, in unix timestamp",
Type: schema.Integer,
},
}),
}
return utils.NewTool(info, AddTodoFunc)
}
// ListTodoTool
// 获取列出 todo 工具
// 自行实现 InvokableTool 接口
type ListTodoTool struct{}
func (lt *ListTodoTool) Info(_ context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: "list_todo",
Desc: "List all todo items",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"finished": {
Desc: "filter todo items if finished",
Type: schema.Boolean,
Required: false,
},
}),
}, nil
}
type TodoUpdateParams struct {
ID string `json:"id" jsonschema_description:"id of the todo"`
Content *string `json:"content,omitempty" jsonschema_description:"content of the todo"`
StartedAt *int64 `json:"started_at,omitempty" jsonschema_description:"start time in unix timestamp"`
Deadline *int64 `json:"deadline,omitempty" jsonschema_description:"deadline of the todo in unix timestamp"`
Done *bool `json:"done,omitempty" jsonschema_description:"done status"`
}
type TodoAddParams struct {
Content string `json:"content"`
StartAt *int64 `json:"started_at,omitempty"` // 开始时间
Deadline *int64 `json:"deadline,omitempty"`
}
func (lt *ListTodoTool) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) {
logs.Infof("invoke tool list_todo: %s", argumentsInJSON)
// Tool处理代码
// ...
return `{"todos": [{"id": "1", "content": "在2024年12月10日之前完成Eino项目演示文稿的准备工作", "started_at": 1717401600, "deadline": 1717488000, "done": false}]}`, nil
}
func AddTodoFunc(_ context.Context, params *TodoAddParams) (string, error) {
logs.Infof("invoke tool add_todo: %+v", params)
// Tool处理代码
// ...
return `{"msg": "add todo success"}`, nil
}
func UpdateTodoFunc(_ context.Context, params *TodoUpdateParams) (string, error) {
logs.Infof("invoke tool update_todo: %+v", params)
// Tool处理代码
// ...
return `{"msg": "update todo success"}`, nil
}