You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
307 lines
7.3 KiB
Go
307 lines
7.3 KiB
Go
/*
|
|
* Copyright 2025 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
|
|
"github.com/cloudwego/hertz/pkg/app"
|
|
"github.com/cloudwego/hertz/pkg/app/server"
|
|
"github.com/cloudwego/hertz/pkg/protocol/consts"
|
|
"github.com/hertz-contrib/sse"
|
|
|
|
"github.com/cloudwego/eino/adk"
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
"github.com/cloudwego/eino-examples/adk/common/model"
|
|
)
|
|
|
|
type SSEEvent struct {
|
|
Type string `json:"type"`
|
|
AgentName string `json:"agent_name,omitempty"`
|
|
RunPath string `json:"run_path,omitempty"`
|
|
Content string `json:"content,omitempty"`
|
|
ToolCalls []schema.ToolCall `json:"tool_calls,omitempty"`
|
|
ActionType string `json:"action_type,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
func main() {
|
|
ctx := context.Background()
|
|
|
|
agent, err := createAgent(ctx)
|
|
if err != nil {
|
|
log.Fatalf("Failed to create agent: %v", err)
|
|
}
|
|
|
|
runner := adk.NewRunner(ctx, adk.RunnerConfig{
|
|
EnableStreaming: true,
|
|
Agent: agent,
|
|
})
|
|
|
|
h := server.Default(server.WithHostPorts(":8080"))
|
|
|
|
h.GET("/chat", func(ctx context.Context, c *app.RequestContext) {
|
|
handleChat(ctx, c, runner)
|
|
})
|
|
|
|
log.Println("Server starting on http://localhost:8080")
|
|
log.Println("Try: curl -N 'http://localhost:8080/chat?query=tell me a short story'")
|
|
h.Spin()
|
|
}
|
|
|
|
func createAgent(ctx context.Context) (adk.Agent, error) {
|
|
// add sub-agents if you want to.
|
|
// for demonstration purpose we use a simple ChatModelAgent
|
|
return adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
|
Name: "SSEAgent",
|
|
Description: "An agent that responds via Server-Sent Events",
|
|
Instruction: `You are a helpful assistant. Provide clear and concise responses to user queries.`,
|
|
Model: model.NewChatModel(),
|
|
// add tools if you want to
|
|
})
|
|
}
|
|
|
|
func formatRunPath(runPath []adk.RunStep) string {
|
|
return fmt.Sprintf("%v", runPath)
|
|
}
|
|
|
|
func handleChat(ctx context.Context, c *app.RequestContext, runner *adk.Runner) {
|
|
query := c.Query("query")
|
|
if query == "" {
|
|
c.JSON(consts.StatusBadRequest, map[string]string{
|
|
"error": "query parameter is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
log.Printf("Received query: %s", query)
|
|
|
|
iter := runner.Query(ctx, query)
|
|
|
|
s := sse.NewStream(c)
|
|
defer func(c *app.RequestContext) {
|
|
_ = c.Flush()
|
|
}(c)
|
|
|
|
for {
|
|
event, ok := iter.Next()
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
if err := processAgentEvent(ctx, s, event); err != nil {
|
|
log.Printf("Error processing event: %v", err)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func processAgentEvent(ctx context.Context, s *sse.Stream, event *adk.AgentEvent) error {
|
|
if event.Err != nil {
|
|
return sendSSEEvent(s, SSEEvent{
|
|
Type: "error",
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
Error: event.Err.Error(),
|
|
})
|
|
}
|
|
|
|
if event.Output != nil && event.Output.MessageOutput != nil {
|
|
if err := handleMessageOutput(ctx, s, event); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if event.Action != nil {
|
|
if err := handleAction(s, event); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleMessageOutput(ctx context.Context, s *sse.Stream, event *adk.AgentEvent) error {
|
|
msgOutput := event.Output.MessageOutput
|
|
|
|
if msg := msgOutput.Message; msg != nil {
|
|
return handleRegularMessage(s, event, msg)
|
|
}
|
|
|
|
if stream := msgOutput.MessageStream; stream != nil {
|
|
return handleStreamingMessage(ctx, s, event, stream)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleRegularMessage(s *sse.Stream, event *adk.AgentEvent, msg *schema.Message) error {
|
|
eventType := "message"
|
|
if msg.Role == schema.Tool {
|
|
eventType = "tool_result"
|
|
}
|
|
|
|
sseEvent := SSEEvent{
|
|
Type: eventType,
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
Content: msg.Content,
|
|
}
|
|
|
|
if len(msg.ToolCalls) > 0 {
|
|
sseEvent.ToolCalls = msg.ToolCalls
|
|
}
|
|
|
|
return sendSSEEvent(s, sseEvent)
|
|
}
|
|
|
|
func handleStreamingMessage(ctx context.Context, s *sse.Stream, event *adk.AgentEvent, stream *schema.StreamReader[*schema.Message]) error {
|
|
toolCallsMap := make(map[int][]*schema.Message)
|
|
|
|
for {
|
|
chunk, err := stream.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return sendSSEEvent(s, SSEEvent{
|
|
Type: "error",
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
Error: fmt.Sprintf("stream error: %v", err),
|
|
})
|
|
}
|
|
|
|
if chunk.Content != "" {
|
|
eventType := "stream_chunk"
|
|
if chunk.Role == schema.Tool {
|
|
eventType = "tool_result_chunk"
|
|
}
|
|
|
|
if err := sendSSEEvent(s, SSEEvent{
|
|
Type: eventType,
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
Content: chunk.Content,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if len(chunk.ToolCalls) > 0 {
|
|
for _, tc := range chunk.ToolCalls {
|
|
if tc.Index != nil {
|
|
toolCallsMap[*tc.Index] = append(toolCallsMap[*tc.Index], &schema.Message{
|
|
Role: chunk.Role,
|
|
ToolCalls: []schema.ToolCall{
|
|
{
|
|
ID: tc.ID,
|
|
Type: tc.Type,
|
|
Index: tc.Index,
|
|
Function: schema.FunctionCall{
|
|
Name: tc.Function.Name,
|
|
Arguments: tc.Function.Arguments,
|
|
},
|
|
},
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, msgs := range toolCallsMap {
|
|
concatenatedMsg, err := schema.ConcatMessages(msgs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := sendSSEEvent(s, SSEEvent{
|
|
Type: "tool_calls",
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
ToolCalls: concatenatedMsg.ToolCalls,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleAction(s *sse.Stream, event *adk.AgentEvent) error {
|
|
action := event.Action
|
|
|
|
if action.TransferToAgent != nil {
|
|
return sendSSEEvent(s, SSEEvent{
|
|
Type: "action",
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
ActionType: "transfer",
|
|
Content: fmt.Sprintf("Transfer to agent: %s", action.TransferToAgent.DestAgentName),
|
|
})
|
|
}
|
|
|
|
if action.Interrupted != nil {
|
|
for _, ic := range action.Interrupted.InterruptContexts {
|
|
content := fmt.Sprintf("%v", ic.Info)
|
|
if stringer, ok := ic.Info.(fmt.Stringer); ok {
|
|
content = stringer.String()
|
|
}
|
|
|
|
if err := sendSSEEvent(s, SSEEvent{
|
|
Type: "action",
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
ActionType: "interrupted",
|
|
Content: content,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if action.Exit {
|
|
return sendSSEEvent(s, SSEEvent{
|
|
Type: "action",
|
|
AgentName: event.AgentName,
|
|
RunPath: formatRunPath(event.RunPath),
|
|
ActionType: "exit",
|
|
Content: "Agent execution completed",
|
|
})
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func sendSSEEvent(s *sse.Stream, event SSEEvent) error {
|
|
data, err := json.Marshal(event)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal SSE event: %w", err)
|
|
}
|
|
|
|
return s.Publish(&sse.Event{
|
|
Data: data,
|
|
})
|
|
}
|