diff --git a/adk/common/model/chat_model.go b/adk/common/model/chat_model.go index 1512e89..37f76fb 100644 --- a/adk/common/model/chat_model.go +++ b/adk/common/model/chat_model.go @@ -18,13 +18,17 @@ package model import ( "context" + "fmt" "log" "os" "strings" + "time" "github.com/cloudwego/eino-ext/components/model/ark" "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" + cbutils "github.com/cloudwego/eino/utils/callbacks" arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" ) @@ -62,3 +66,25 @@ func NewChatModel() model.ToolCallingChatModel { } return cm } + +func GetInputLoggerCallback() callbacks.Handler { + return cbutils.NewHandlerHelper().ChatModel(&cbutils.ModelCallbackHandler{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *model.CallbackInput) context.Context { + time.Sleep(20 * time.Second) + fmt.Printf("\n========================================\n") + fmt.Printf("[ChatModel Input] Agent: %s\n", info.Name) + fmt.Printf("========================================\n") + for i, msg := range input.Messages { + fmt.Printf(" Message %d [%s]: %s\n", i+1, msg.Role, msg.Content) + if len(msg.ToolCalls) > 0 { + fmt.Printf(" Tool Calls: %d\n", len(msg.ToolCalls)) + for j, tc := range msg.ToolCalls { + fmt.Printf(" %d. %s: %s\n", j+1, tc.Function.Name, tc.Function.Arguments) + } + } + } + fmt.Printf("========================================\n\n") + return ctx + }, + }).Handler() +} diff --git a/adk/multiagent/supervisor/supervisor.go b/adk/multiagent/supervisor/supervisor.go index a4cbb06..9bc02fe 100644 --- a/adk/multiagent/supervisor/supervisor.go +++ b/adk/multiagent/supervisor/supervisor.go @@ -23,7 +23,9 @@ import ( "time" "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino-examples/adk/common/model" "github.com/cloudwego/eino-examples/adk/common/prints" "github.com/cloudwego/eino-examples/adk/common/trace" ) @@ -31,6 +33,8 @@ import ( func main() { ctx := context.Background() + callbacks.AppendGlobalHandlers(model.GetInputLoggerCallback()) + traceCloseFn, startSpanFn := trace.AppendCozeLoopCallbackIfConfigured(ctx) defer traceCloseFn(ctx)