You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
270 lines
7.5 KiB
Go
270 lines
7.5 KiB
Go
/*
|
|
* Copyright 2024 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"time"
|
|
|
|
clc "github.com/cloudwego/eino-ext/callbacks/cozeloop"
|
|
"github.com/cloudwego/eino-ext/components/model/ark"
|
|
"github.com/cloudwego/eino/callbacks"
|
|
"github.com/cloudwego/eino/components/tool"
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/flow/agent"
|
|
"github.com/cloudwego/eino/flow/agent/react"
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/coze-dev/cozeloop-go"
|
|
|
|
"github.com/cloudwego/eino-examples/devops/visualize"
|
|
"github.com/cloudwego/eino-examples/flow/agent/react/tools"
|
|
"github.com/cloudwego/eino-examples/internal/logs"
|
|
)
|
|
|
|
func main() {
|
|
arkApiKey := os.Getenv("ARK_API_KEY")
|
|
arkModelName := os.Getenv("ARK_MODEL_NAME")
|
|
cozeloopApiToken := os.Getenv("COZELOOP_API_TOKEN")
|
|
cozeloopWorkspaceID := os.Getenv("COZELOOP_WORKSPACE_ID") // use cozeloop trace, from https://loop.coze.cn/open/docs/cozeloop/go-sdk#4a8c980e
|
|
|
|
ctx := context.Background()
|
|
var handlers []callbacks.Handler
|
|
if cozeloopApiToken != "" && cozeloopWorkspaceID != "" {
|
|
client, err := cozeloop.NewClient(
|
|
cozeloop.WithAPIToken(cozeloopApiToken),
|
|
cozeloop.WithWorkspaceID(cozeloopWorkspaceID),
|
|
)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer client.Close(ctx)
|
|
handlers = append(handlers, clc.NewLoopHandler(client))
|
|
}
|
|
callbacks.AppendGlobalHandlers(handlers...)
|
|
|
|
// minimal: we will export graph via API when available and compile a mermaid diagram
|
|
|
|
// Create a new cached ark chat model.
|
|
// arkModel, err = NewCachedARKChatModel(ctx, config)
|
|
|
|
config := &ark.ChatModelConfig{
|
|
APIKey: arkApiKey,
|
|
Model: arkModelName,
|
|
}
|
|
arkModel, err := ark.NewChatModel(ctx, config)
|
|
if err != nil {
|
|
logs.Errorf("failed to create chat model: %v", err)
|
|
return
|
|
}
|
|
|
|
// prepare tools
|
|
restaurantTool := tools.GetRestaurantTool()
|
|
dishTool := tools.GetDishTool()
|
|
|
|
// prepare persona (system prompt) (optional)
|
|
persona := `# Character:
|
|
你是一个帮助用户推荐餐厅和菜品的助手,根据用户的需要,查询餐厅信息并推荐,查询餐厅的菜品并推荐。
|
|
`
|
|
|
|
// replace tool call checker with a custom one: check all trunks until you get a tool call
|
|
// because some models(claude or doubao 1.5-pro 32k) do not return tool call in the first response
|
|
// uncomment the following code to enable it
|
|
/*toolCallChecker := func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) {
|
|
defer sr.Close()
|
|
for {
|
|
msg, err := sr.Recv()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
// finish
|
|
break
|
|
}
|
|
|
|
return false, err
|
|
}
|
|
|
|
if len(msg.ToolCalls) > 0 {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}*/
|
|
|
|
rAgent, err := react.NewAgent(ctx, &react.AgentConfig{
|
|
ToolCallingModel: arkModel,
|
|
ToolsConfig: compose.ToolsNodeConfig{
|
|
Tools: []tool.BaseTool{restaurantTool, dishTool},
|
|
},
|
|
// StreamToolCallChecker: toolCallChecker, // uncomment it to replace the default tool call checker with custom one
|
|
})
|
|
if err != nil {
|
|
logs.Errorf("failed to create agent: %v", err)
|
|
return
|
|
}
|
|
|
|
// if you want ping/pong, use Generate
|
|
// msg, err := agent.Generate(ctx, []*schema.Message{
|
|
// {
|
|
// Role: schema.User,
|
|
// Content: "我在北京,给我推荐一些菜,需要有口味辣一点的菜,至少推荐有 2 家餐厅",
|
|
// },
|
|
// }, react.WithCallbacks(&myCallback{}))
|
|
// if err != nil {
|
|
// log.Printf("failed to generate: %v\n", err)
|
|
// return
|
|
// }
|
|
// fmt.Println(msg.String())
|
|
|
|
// If you want to use ark caching in react, call ark.WithCache()
|
|
//cacheOption := &ark.CacheOption{
|
|
// APIType: ark.ResponsesAPI,
|
|
// SessionCache: &ark.SessionCacheConfig{
|
|
// EnableCache: true,
|
|
// TTL: 3600,
|
|
// },
|
|
//}
|
|
|
|
opt := []agent.AgentOption{
|
|
agent.WithComposeOptions(compose.WithCallbacks(&LoggerCallback{})),
|
|
// react.WithChatModelOptions(ark.WithCache(cacheOption)),
|
|
}
|
|
|
|
// Export graph and compile with mermaid (non-critical path)
|
|
{
|
|
anyG, opts := rAgent.ExportGraph()
|
|
gen := visualize.NewMermaidGenerator("flow/agent/react")
|
|
g := compose.NewGraph[[]*schema.Message, *schema.Message]()
|
|
_ = g.AddGraphNode("react_agent", anyG, opts...)
|
|
_ = g.AddEdge(compose.START, "react_agent")
|
|
_ = g.AddEdge("react_agent", compose.END)
|
|
_, _ = g.Compile(context.Background(), compose.WithGraphCompileCallbacks(gen))
|
|
}
|
|
|
|
sr, err := rAgent.Stream(ctx, []*schema.Message{
|
|
{
|
|
Role: schema.System,
|
|
Content: persona,
|
|
},
|
|
{
|
|
Role: schema.User,
|
|
Content: "我在北京,给我推荐一些菜,需要有口味辣一点的菜,至少推荐有 2 家餐厅",
|
|
},
|
|
}, opt...)
|
|
if err != nil {
|
|
logs.Errorf("failed to stream: %v", err)
|
|
return
|
|
}
|
|
|
|
defer sr.Close() // remember to close the stream
|
|
|
|
logs.Infof("\n\n===== start streaming =====\n\n")
|
|
|
|
for {
|
|
msg, err := sr.Recv()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
// finish
|
|
break
|
|
}
|
|
// error
|
|
logs.Infof("failed to recv: %v", err)
|
|
return
|
|
}
|
|
|
|
// 打字机打印
|
|
logs.Tokenf("%v", msg.Content)
|
|
}
|
|
|
|
logs.Infof("\n\n===== finished =====\n")
|
|
time.Sleep(2 * time.Second)
|
|
}
|
|
|
|
type LoggerCallback struct {
|
|
callbacks.HandlerBuilder // 可以用 callbacks.HandlerBuilder 来辅助实现 callback
|
|
}
|
|
|
|
func (cb *LoggerCallback) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
|
fmt.Println("==================")
|
|
inputStr, _ := json.MarshalIndent(input, "", " ")
|
|
fmt.Printf("[OnStart] %s\n", string(inputStr))
|
|
return ctx
|
|
}
|
|
|
|
func (cb *LoggerCallback) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
|
fmt.Println("=========[OnEnd]=========")
|
|
outputStr, _ := json.MarshalIndent(output, "", " ")
|
|
fmt.Println(string(outputStr))
|
|
return ctx
|
|
}
|
|
|
|
func (cb *LoggerCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
|
fmt.Println("=========[OnError]=========")
|
|
fmt.Println(err)
|
|
return ctx
|
|
}
|
|
|
|
func (cb *LoggerCallback) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo,
|
|
output *schema.StreamReader[callbacks.CallbackOutput],
|
|
) context.Context {
|
|
graphInfoName := react.GraphName
|
|
|
|
go func() {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
fmt.Println("[OnEndStream] panic err:", err)
|
|
}
|
|
}()
|
|
|
|
defer output.Close() // remember to close the stream in defer
|
|
|
|
fmt.Println("=========[OnEndStream]=========")
|
|
for {
|
|
frame, err := output.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
// finish
|
|
break
|
|
}
|
|
if err != nil {
|
|
fmt.Printf("internal error: %s\n", err)
|
|
return
|
|
}
|
|
|
|
s, err := json.Marshal(frame)
|
|
if err != nil {
|
|
fmt.Printf("internal error: %s\n", err)
|
|
return
|
|
}
|
|
|
|
if info.Name == graphInfoName { // 仅打印 graph 的输出, 否则每个 stream 节点的输出都会打印一遍
|
|
fmt.Printf("%s: %s\n", info.Name, string(s))
|
|
}
|
|
}
|
|
}()
|
|
return ctx
|
|
}
|
|
|
|
func (cb *LoggerCallback) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo,
|
|
input *schema.StreamReader[callbacks.CallbackInput],
|
|
) context.Context {
|
|
defer input.Close()
|
|
return ctx
|
|
}
|