You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

270 lines
7.5 KiB
Go

/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"time"
clc "github.com/cloudwego/eino-ext/callbacks/cozeloop"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent"
"github.com/cloudwego/eino/flow/agent/react"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/cozeloop-go"
"github.com/cloudwego/eino-examples/devops/visualize"
"github.com/cloudwego/eino-examples/flow/agent/react/tools"
"github.com/cloudwego/eino-examples/internal/logs"
)
func main() {
arkApiKey := os.Getenv("ARK_API_KEY")
arkModelName := os.Getenv("ARK_MODEL_NAME")
cozeloopApiToken := os.Getenv("COZELOOP_API_TOKEN")
cozeloopWorkspaceID := os.Getenv("COZELOOP_WORKSPACE_ID") // use cozeloop trace, from https://loop.coze.cn/open/docs/cozeloop/go-sdk#4a8c980e
ctx := context.Background()
var handlers []callbacks.Handler
if cozeloopApiToken != "" && cozeloopWorkspaceID != "" {
client, err := cozeloop.NewClient(
cozeloop.WithAPIToken(cozeloopApiToken),
cozeloop.WithWorkspaceID(cozeloopWorkspaceID),
)
if err != nil {
panic(err)
}
defer client.Close(ctx)
handlers = append(handlers, clc.NewLoopHandler(client))
}
callbacks.AppendGlobalHandlers(handlers...)
// minimal: we will export graph via API when available and compile a mermaid diagram
// Create a new cached ark chat model.
// arkModel, err = NewCachedARKChatModel(ctx, config)
config := &ark.ChatModelConfig{
APIKey: arkApiKey,
Model: arkModelName,
}
arkModel, err := ark.NewChatModel(ctx, config)
if err != nil {
logs.Errorf("failed to create chat model: %v", err)
return
}
// prepare tools
restaurantTool := tools.GetRestaurantTool()
dishTool := tools.GetDishTool()
// prepare persona (system prompt) (optional)
persona := `# Character:
你是一个帮助用户推荐餐厅和菜品的助手,根据用户的需要,查询餐厅信息并推荐,查询餐厅的菜品并推荐。
`
// replace tool call checker with a custom one: check all trunks until you get a tool call
// because some models(claude or doubao 1.5-pro 32k) do not return tool call in the first response
// uncomment the following code to enable it
/*toolCallChecker := func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) {
defer sr.Close()
for {
msg, err := sr.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
// finish
break
}
return false, err
}
if len(msg.ToolCalls) > 0 {
return true, nil
}
}
return false, nil
}*/
rAgent, err := react.NewAgent(ctx, &react.AgentConfig{
ToolCallingModel: arkModel,
ToolsConfig: compose.ToolsNodeConfig{
Tools: []tool.BaseTool{restaurantTool, dishTool},
},
// StreamToolCallChecker: toolCallChecker, // uncomment it to replace the default tool call checker with custom one
})
if err != nil {
logs.Errorf("failed to create agent: %v", err)
return
}
// if you want ping/pong, use Generate
// msg, err := agent.Generate(ctx, []*schema.Message{
// {
// Role: schema.User,
// Content: "我在北京,给我推荐一些菜,需要有口味辣一点的菜,至少推荐有 2 家餐厅",
// },
// }, react.WithCallbacks(&myCallback{}))
// if err != nil {
// log.Printf("failed to generate: %v\n", err)
// return
// }
// fmt.Println(msg.String())
// If you want to use ark caching in react, call ark.WithCache()
//cacheOption := &ark.CacheOption{
// APIType: ark.ResponsesAPI,
// SessionCache: &ark.SessionCacheConfig{
// EnableCache: true,
// TTL: 3600,
// },
//}
opt := []agent.AgentOption{
agent.WithComposeOptions(compose.WithCallbacks(&LoggerCallback{})),
// react.WithChatModelOptions(ark.WithCache(cacheOption)),
}
// Export graph and compile with mermaid (non-critical path)
{
anyG, opts := rAgent.ExportGraph()
gen := visualize.NewMermaidGenerator("flow/agent/react")
g := compose.NewGraph[[]*schema.Message, *schema.Message]()
_ = g.AddGraphNode("react_agent", anyG, opts...)
_ = g.AddEdge(compose.START, "react_agent")
_ = g.AddEdge("react_agent", compose.END)
_, _ = g.Compile(context.Background(), compose.WithGraphCompileCallbacks(gen))
}
sr, err := rAgent.Stream(ctx, []*schema.Message{
{
Role: schema.System,
Content: persona,
},
{
Role: schema.User,
Content: "我在北京,给我推荐一些菜,需要有口味辣一点的菜,至少推荐有 2 家餐厅",
},
}, opt...)
if err != nil {
logs.Errorf("failed to stream: %v", err)
return
}
defer sr.Close() // remember to close the stream
logs.Infof("\n\n===== start streaming =====\n\n")
for {
msg, err := sr.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
// finish
break
}
// error
logs.Infof("failed to recv: %v", err)
return
}
// 打字机打印
logs.Tokenf("%v", msg.Content)
}
logs.Infof("\n\n===== finished =====\n")
time.Sleep(2 * time.Second)
}
type LoggerCallback struct {
callbacks.HandlerBuilder // 可以用 callbacks.HandlerBuilder 来辅助实现 callback
}
func (cb *LoggerCallback) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
fmt.Println("==================")
inputStr, _ := json.MarshalIndent(input, "", " ")
fmt.Printf("[OnStart] %s\n", string(inputStr))
return ctx
}
func (cb *LoggerCallback) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
fmt.Println("=========[OnEnd]=========")
outputStr, _ := json.MarshalIndent(output, "", " ")
fmt.Println(string(outputStr))
return ctx
}
func (cb *LoggerCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
fmt.Println("=========[OnError]=========")
fmt.Println(err)
return ctx
}
func (cb *LoggerCallback) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo,
output *schema.StreamReader[callbacks.CallbackOutput],
) context.Context {
graphInfoName := react.GraphName
go func() {
defer func() {
if err := recover(); err != nil {
fmt.Println("[OnEndStream] panic err:", err)
}
}()
defer output.Close() // remember to close the stream in defer
fmt.Println("=========[OnEndStream]=========")
for {
frame, err := output.Recv()
if errors.Is(err, io.EOF) {
// finish
break
}
if err != nil {
fmt.Printf("internal error: %s\n", err)
return
}
s, err := json.Marshal(frame)
if err != nil {
fmt.Printf("internal error: %s\n", err)
return
}
if info.Name == graphInfoName { // 仅打印 graph 的输出, 否则每个 stream 节点的输出都会打印一遍
fmt.Printf("%s: %s\n", info.Name, string(s))
}
}
}()
return ctx
}
func (cb *LoggerCallback) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo,
input *schema.StreamReader[callbacks.CallbackInput],
) context.Context {
defer input.Close()
return ctx
}