|
|
|
@ -24,7 +24,7 @@ import (
|
|
|
|
"io"
|
|
|
|
"io"
|
|
|
|
"os"
|
|
|
|
"os"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
|
|
|
"github.com/cloudwego/eino-ext/components/model/ark"
|
|
|
|
"github.com/cloudwego/eino/callbacks"
|
|
|
|
"github.com/cloudwego/eino/callbacks"
|
|
|
|
"github.com/cloudwego/eino/components/tool"
|
|
|
|
"github.com/cloudwego/eino/components/tool"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
|
|
@ -37,17 +37,13 @@ import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
func main() {
|
|
|
|
openAIAPIKey := os.Getenv("OPENAI_API_KEY")
|
|
|
|
arkAPIKey := os.Getenv("ARK_API_KEY")
|
|
|
|
openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
|
|
|
|
arkModelName := os.Getenv("ARK_MODEL_NAME")
|
|
|
|
openAIModelName := os.Getenv("OPENAI_MODEL_NAME")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx := context.Background()
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
|
|
arkModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
|
|
|
// prepare ChatModel
|
|
|
|
APIKey: arkAPIKey,
|
|
|
|
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
|
|
|
|
Model: arkModelName,
|
|
|
|
BaseURL: openAIBaseURL,
|
|
|
|
|
|
|
|
APIKey: openAIAPIKey,
|
|
|
|
|
|
|
|
Model: openAIModelName,
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
logs.Errorf("failed to create chat model: %v", err)
|
|
|
|
logs.Errorf("failed to create chat model: %v", err)
|
|
|
|
@ -63,13 +59,37 @@ func main() {
|
|
|
|
你是一个帮助用户推荐餐厅和菜品的助手,根据用户的需要,查询餐厅信息并推荐,查询餐厅的菜品并推荐。
|
|
|
|
你是一个帮助用户推荐餐厅和菜品的助手,根据用户的需要,查询餐厅信息并推荐,查询餐厅的菜品并推荐。
|
|
|
|
`
|
|
|
|
`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// replace tool call checker with a custom one: check all trunks until you get a tool call
|
|
|
|
|
|
|
|
// because some models(claude or doubao 1.5-pro 32k) do not return tool call in the first response
|
|
|
|
|
|
|
|
// uncomment the following code to enable it
|
|
|
|
|
|
|
|
/*toolCallChecker := func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) {
|
|
|
|
|
|
|
|
defer sr.Close()
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
|
|
|
msg, err := sr.Recv()
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
if errors.Is(err, io.EOF) {
|
|
|
|
|
|
|
|
// finish
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return false, err
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(msg.ToolCalls) > 0 {
|
|
|
|
|
|
|
|
return true, nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false, nil
|
|
|
|
|
|
|
|
}*/
|
|
|
|
|
|
|
|
|
|
|
|
ragent, err := react.NewAgent(ctx, &react.AgentConfig{
|
|
|
|
ragent, err := react.NewAgent(ctx, &react.AgentConfig{
|
|
|
|
Model: chatModel,
|
|
|
|
Model: arkModel,
|
|
|
|
ToolsConfig: compose.ToolsNodeConfig{
|
|
|
|
ToolsConfig: compose.ToolsNodeConfig{
|
|
|
|
Tools: []tool.BaseTool{restaurantTool, dishTool},
|
|
|
|
Tools: []tool.BaseTool{restaurantTool, dishTool},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
MessageModifier: react.NewPersonaModifier(persona),
|
|
|
|
MessageModifier: react.NewPersonaModifier(persona),
|
|
|
|
|
|
|
|
// StreamToolCallChecker: toolCallChecker, // uncomment it to replace the default tool call checker with custom one
|
|
|
|
})
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
logs.Errorf("failed to create agent: %v", err)
|
|
|
|
logs.Errorf("failed to create agent: %v", err)
|
|
|
|
|