feat: add too_call_once and two_model_chat to graph examples (#15)

Change-Id: I4bc86918954cae88cc8f5b4fb4deea0d0895bfc7
drew/english
shentongmartin 1 year ago committed by GitHub
parent 75de62639c
commit f7a1f0f526
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -34,9 +34,9 @@ import (
func main() { func main() {
openAIBaseURL := os.Getenv("OPENAI_BASE_URL") //openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
openAIAPIKey := os.Getenv("OPENAI_API_KEY") openAIAPIKey := os.Getenv("OPENAI_API_KEY")
modelName := os.Getenv("MODEL_NAME") modelName := os.Getenv("OPENAI_MODEL_NAME")
ctx := context.Background() ctx := context.Background()
@ -51,9 +51,9 @@ func main() {
) )
modelConf := &openai.ChatModelConfig{ modelConf := &openai.ChatModelConfig{
BaseURL: openAIBaseURL, //BaseURL: openAIBaseURL,
APIKey: openAIAPIKey, APIKey: openAIAPIKey,
ByAzure: true, //ByAzure: true,
Model: modelName, Model: modelName,
Temperature: gptr.Of(float32(0.7)), Temperature: gptr.Of(float32(0.7)),
APIVersion: "2024-06-01", APIVersion: "2024-06-01",

@ -0,0 +1,170 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"errors"
"os"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/components/tool/utils"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino-examples/internal/gptr"
"github.com/cloudwego/eino-examples/internal/logs"
)
func main() {
//openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
openAIAPIKey := os.Getenv("OPENAI_API_KEY")
modelName := os.Getenv("OPENAI_MODEL_NAME")
ctx := context.Background()
systemTpl := `你是一名房产经纪人,结合用户的薪酬和工作,使用 user_info API为其提供相关的房产信息。邮箱是必须的`
chatTpl := prompt.FromMessages(schema.FString,
schema.SystemMessage(systemTpl),
schema.MessagesPlaceholder("message_histories", true),
schema.UserMessage("{query}"),
)
modelConf := &openai.ChatModelConfig{
//BaseURL: openAIBaseURL,
APIKey: openAIAPIKey,
//ByAzure: true,
Model: modelName,
Temperature: gptr.Of(float32(0.7)),
APIVersion: "2024-06-01",
}
chatModel, err := openai.NewChatModel(ctx, modelConf)
if err != nil {
logs.Fatalf("NewChatModel failed, err=%v", err)
}
userInfoTool := utils.NewTool(
&schema.ToolInfo{
Name: "user_info",
Desc: "根据用户的姓名和邮箱,查询用户的公司、职位、薪酬信息",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"name": {
Type: "string",
Desc: "用户的姓名",
},
"email": {
Type: "string",
Desc: "用户的邮箱",
},
}),
},
func(ctx context.Context, input *userInfoRequest) (output *userInfoResponse, err error) {
return &userInfoResponse{
Name: input.Name,
Email: input.Email,
Company: "Bytedance",
Position: "CEO",
Salary: "9999",
}, nil
})
info, err := userInfoTool.Info(ctx)
if err != nil {
logs.Fatalf("Get ToolInfo failed, err=%v", err)
}
err = chatModel.BindTools([]*schema.ToolInfo{info})
if err != nil {
logs.Fatalf("BindTools failed, err=%v", err)
}
toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{
Tools: []tool.BaseTool{userInfoTool},
})
if err != nil {
logs.Fatalf("NewToolNode failed, err=%v", err)
}
takeOne := compose.InvokableLambda(func(ctx context.Context, input []*schema.Message) (*schema.Message, error) {
if len(input) == 0 {
return nil, errors.New("input is empty")
}
return input[0], nil
})
const (
nodeModel = "node_model"
nodeTools = "node_tools"
nodeTemplate = "node_template"
nodeConverter = "node_converter"
)
branch := compose.NewStreamGraphBranch(func(ctx context.Context, input *schema.StreamReader[*schema.Message]) (string, error) {
defer input.Close()
msg, err := input.Recv()
if err != nil {
return "", err
}
if len(msg.ToolCalls) > 0 {
return nodeTools, nil
}
return compose.END, nil
}, map[string]bool{compose.END: true, nodeTools: true})
graph := compose.NewGraph[map[string]any, *schema.Message]()
_ = graph.AddChatTemplateNode(nodeTemplate, chatTpl)
_ = graph.AddChatModelNode(nodeModel, chatModel)
_ = graph.AddToolsNode(nodeTools, toolsNode)
_ = graph.AddLambdaNode(nodeConverter, takeOne)
_ = graph.AddEdge(compose.START, nodeTemplate)
_ = graph.AddEdge(nodeTemplate, nodeModel)
_ = graph.AddBranch(nodeModel, branch)
_ = graph.AddEdge(nodeTools, nodeConverter)
_ = graph.AddEdge(nodeConverter, compose.END)
r, err := graph.Compile(ctx)
if err != nil {
logs.Fatalf("Compile failed, err=%v", err)
}
out, err := r.Invoke(ctx, map[string]any{"query": "我叫 zhangsan, 邮箱是 zhangsan@bytedance.com, 帮我推荐一处房产"})
if err != nil {
logs.Fatalf("Invoke failed, err=%v", err)
}
logs.Infof("result content: %v", out.Content)
}
type userInfoRequest struct {
Name string `json:"name"`
Email string `json:"email"`
}
type userInfoResponse struct {
Name string `json:"name"`
Email string `json:"email"`
Company string `json:"company"`
Position string `json:"position"`
Salary string `json:"salary"`
}

@ -0,0 +1,147 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"errors"
"fmt"
"io"
"os"
"github.com/cloudwego/eino-ext/components/model/openai"
callbacks2 "github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino/utils/callbacks"
"github.com/cloudwego/eino-examples/internal/gptr"
"github.com/cloudwego/eino-examples/internal/logs"
)
func main() {
//openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
openAIAPIKey := os.Getenv("OPENAI_API_KEY")
modelName := os.Getenv("OPENAI_MODEL_NAME")
ctx := context.Background()
modelConf := &openai.ChatModelConfig{
//BaseURL: openAIBaseURL,
APIKey: openAIAPIKey,
//ByAzure: true,
Model: modelName,
Temperature: gptr.Of(float32(0.7)),
APIVersion: "2024-06-01",
}
type state struct {
currentRound int
msgs []*schema.Message
}
llm, err := openai.NewChatModel(ctx, modelConf)
if err != nil {
logs.Fatalf("new chat model failed: %v", err)
}
g := compose.NewGraph[[]*schema.Message, *schema.Message](compose.WithGenLocalState(func(ctx context.Context) *state { return &state{} }))
_ = g.AddChatModelNode("writer", llm, compose.WithStatePreHandler[[]*schema.Message, *state](func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) {
state.currentRound++
state.msgs = append(state.msgs, input...)
input = append([]*schema.Message{schema.SystemMessage("you are a writer who writes jokes and revise it according to the critic's feedback. Prepend your joke with your name which is \"writer: \"")}, state.msgs...)
return input, nil
}), compose.WithNodeName("writer"))
_ = g.AddChatModelNode("critic", llm, compose.WithStatePreHandler[[]*schema.Message, *state](func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) {
state.msgs = append(state.msgs, input...)
input = append([]*schema.Message{schema.SystemMessage("you are a critic who ONLY gives feedback about jokes, emphasizing on funniness. Prepend your feedback with your name which is \"critic: \"")}, state.msgs...)
return input, nil
}), compose.WithNodeName("critic"))
_ = g.AddLambdaNode("toList1", compose.ToList[*schema.Message]())
_ = g.AddLambdaNode("toList2", compose.ToList[*schema.Message]())
_ = g.AddEdge(compose.START, "writer")
_ = g.AddBranch("writer", compose.NewStreamGraphBranch(func(ctx context.Context, input *schema.StreamReader[*schema.Message]) (string, error) {
input.Close()
s, err := compose.GetState[*state](ctx)
if err != nil {
return "", err
}
if s.currentRound >= 3 {
return compose.END, nil
}
return "toList1", nil
}, map[string]bool{compose.END: true, "toList1": true}))
_ = g.AddEdge("toList1", "critic")
_ = g.AddEdge("critic", "toList2")
_ = g.AddEdge("toList2", "writer")
runner, err := g.Compile(ctx)
if err != nil {
logs.Fatalf("compile error: %v", err)
}
sResponse := &streamResponse{
ch: make(chan string),
}
go func() {
for m := range sResponse.ch {
fmt.Print(m)
}
}()
handler := callbacks.NewHandlerHelper().ChatModel(&callbacks.ModelCallbackHandler{
OnEndWithStreamOutput: sResponse.OnStreamStart,
}).Handler()
outStream, err := runner.Stream(ctx, []*schema.Message{schema.UserMessage("write a funny line about robot, in 20 words.")},
compose.WithCallbacks(handler))
if err != nil {
logs.Fatalf("stream error: %v", err)
}
for {
_, err := outStream.Recv()
if err == io.EOF {
close(sResponse.ch)
break
}
}
}
type streamResponse struct {
ch chan string
}
func (s *streamResponse) OnStreamStart(ctx context.Context, runInfo *callbacks2.RunInfo, input *schema.StreamReader[*model.CallbackOutput]) context.Context {
defer input.Close()
s.ch <- "\n=======\n"
for {
frame, err := input.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
logs.Fatalf("internal error: %s\n", err)
}
s.ch <- frame.Message.Content
}
return ctx
}

@ -27,8 +27,8 @@ func newHost(ctx context.Context, baseURL, apiKey, modelName string) (*host.Host
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
BaseURL: baseURL, BaseURL: baseURL,
Model: modelName, Model: modelName,
ByAzure: true, //ByAzure: true,
APIKey: apiKey, APIKey: apiKey,
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -18,6 +18,7 @@ package logs
import ( import (
"fmt" "fmt"
"os"
"time" "time"
) )
@ -47,3 +48,11 @@ func Tokenf(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...) message := fmt.Sprintf(format, args...)
fmt.Printf("%s%s%s", colorBrown, message, colorReset) fmt.Printf("%s%s%s", colorBrown, message, colorReset)
} }
func Fatalf(format string, args ...interface{}) {
timestamp := time.Now().Format("2006-01-02 15:04:05")
prefix := fmt.Sprintf("%s[FATAL] %s ", colorRed, timestamp)
message := fmt.Sprintf(format, args...)
fmt.Printf("%s%s%s\n", prefix, message, colorReset)
os.Exit(1)
}

Loading…
Cancel
Save