You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

438 lines
13 KiB
Go

/*
* Copyright 2026 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/cloudwego/hertz/pkg/app"
hserver "github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/google/uuid"
"github.com/hertz-contrib/sse"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino-examples/adk/common/tool"
"github.com/cloudwego/eino-examples/quickstart/chatwitheino/a2ui"
"github.com/cloudwego/eino-examples/quickstart/chatwitheino/mem"
)
// Config holds all dependencies for the HTTP server.
type Config struct {
Runner *adk.Runner
Store *mem.Store
WorkspaceDir string
ProjectRoot string // root of the codebase the agent can explore
ExamplesDir string // root of the eino-examples repo (for example searches)
Port string
}
// Server wraps a Hertz HTTP server with the chat-with-doc routes.
type Server struct {
cfg Config
}
// New creates a Server from the given config.
func New(cfg Config) *Server {
return &Server{cfg: cfg}
}
// Spin starts the HTTP server (blocking).
func (s *Server) Spin() {
h := hserver.Default(hserver.WithHostPorts(":" + s.cfg.Port))
h.GET("/", func(ctx context.Context, c *app.RequestContext) {
data, err := os.ReadFile("static/index.html")
if err != nil {
c.JSON(consts.StatusNotFound, map[string]string{"error": "index.html not found"})
return
}
c.Data(consts.StatusOK, "text/html; charset=utf-8", data)
})
h.POST("/sessions", func(ctx context.Context, c *app.RequestContext) {
id := uuid.New().String()
if _, err := s.cfg.Store.GetOrCreate(id); err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
c.JSON(consts.StatusOK, map[string]string{"id": id})
})
h.GET("/sessions", func(ctx context.Context, c *app.RequestContext) {
metas, err := s.cfg.Store.List()
if err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if metas == nil {
metas = []mem.SessionMeta{}
}
c.JSON(consts.StatusOK, metas)
})
h.DELETE("/sessions/:id", func(ctx context.Context, c *app.RequestContext) {
id := c.Param("id")
if err := s.cfg.Store.Delete(id); err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
c.Status(consts.StatusNoContent)
})
h.POST("/sessions/:id/chat", func(ctx context.Context, c *app.RequestContext) {
s.handleChat(ctx, c)
})
h.GET("/sessions/:id/render", func(ctx context.Context, c *app.RequestContext) {
s.handleRender(ctx, c)
})
h.POST("/sessions/:id/approve", func(ctx context.Context, c *app.RequestContext) {
s.handleApprove(ctx, c)
})
h.POST("/sessions/:id/docs", func(ctx context.Context, c *app.RequestContext) {
s.handleUpload(ctx, c)
})
h.Spin()
}
type chatRequest struct {
Message string `json:"message"`
}
type approveRequest struct {
Approved bool `json:"approved"`
Reason string `json:"reason,omitempty"`
}
func (s *Server) handleRender(_ context.Context, c *app.RequestContext) {
id := c.Param("id")
sess, err := s.cfg.Store.GetOrCreate(id)
if err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
var buf bytes.Buffer
if err := a2ui.RenderHistory(&buf, id, sess.GetMessages()); err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
c.Data(consts.StatusOK, "application/x-ndjson", buf.Bytes())
}
func (s *Server) handleChat(ctx context.Context, c *app.RequestContext) {
id := c.Param("id")
body, _ := c.Body()
var req chatRequest
if err := json.Unmarshal(body, &req); err != nil || req.Message == "" {
c.JSON(consts.StatusBadRequest, map[string]string{"error": "message is required"})
return
}
log.Printf("[chat] session=%s msg=%q", id, req.Message)
sess, err := s.cfg.Store.GetOrCreate(id)
if err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
userMsg := schema.UserMessage(req.Message)
if appendErr := sess.Append(userMsg); appendErr != nil {
log.Printf("warn: failed to persist user message: %v", appendErr)
}
// history is rendered in the UI; runMessages adds workspace context for the agent.
history := sess.GetMessages()
runMessages := s.buildRunMessages(id, history)
log.Printf("[chat] session=%s running agent with %d messages (%d history + %d context)",
id, len(runMessages), len(history), len(runMessages)-len(history))
iter := s.cfg.Runner.Run(ctx, runMessages, adk.WithCheckPointID(id))
stream := sse.NewStream(c)
defer func() { _ = c.Flush() }()
// Send a keep-alive ping every 5 s so the SSE connection isn't dropped
// by Hertz or browser timeouts while the agent is processing tool results.
kaStop := make(chan struct{})
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-kaStop:
return
case <-ticker.C:
_ = stream.Publish(&sse.Event{Data: []byte{}})
log.Printf("[chat] session=%s keep-alive ping sent", id)
}
}
}()
lastContent, interruptID, finalMsgIdx, streamErr := a2ui.StreamToWriter(&sseLineWriter{stream: stream}, id, history, iter)
close(kaStop)
if streamErr != nil {
log.Printf("[chat] session=%s stream error: %v", id, streamErr)
} else if interruptID != "" {
log.Printf("[chat] session=%s interrupted: id=%s", id, interruptID)
sess.SetPendingInterruptID(interruptID)
sess.SetMsgIdx(finalMsgIdx)
} else {
log.Printf("[chat] session=%s done, response=%d chars", id, len(lastContent))
}
if lastContent != "" {
assistantMsg := schema.AssistantMessage(lastContent, nil)
if appendErr := sess.Append(assistantMsg); appendErr != nil {
log.Printf("warn: failed to persist assistant message: %v", appendErr)
}
}
}
// buildRunMessages prepends a context message so the agent knows about the
// project root and the session workspace. This message is never stored in history.
func (s *Server) buildRunMessages(sessionID string, history []*schema.Message) []*schema.Message {
var lines []string
lines = append(lines, "[Context]")
lines = append(lines,
"IMPORTANT RULES:",
" 1. Always use filesystem tools to look up real code before answering. Do not guess or make up information.",
" 2. After using tools (even if they return no results), you MUST write a text response to the user summarizing what you found.",
" 3. Never end your turn without a text response — tool calls alone are not sufficient.",
" 4. When asked to build or test code, use the execute tool to run the command.",
" Each Go example has its own go.mod. To build an example, run:",
" cd <example-dir> && go build ./...",
" NEVER assume a build succeeded without actually running it.",
" 5. When writing or editing a file and then claiming it compiles, you MUST run the build tool to verify.",
)
if s.cfg.ProjectRoot != "" {
lines = append(lines,
fmt.Sprintf("Project root: %s", s.cfg.ProjectRoot),
" IMPORTANT: Always pass the project root as the path argument when using filesystem tools.",
fmt.Sprintf(" - grep(pattern=\"...\", path=\"%s\")", s.cfg.ProjectRoot),
fmt.Sprintf(" - glob(pattern=\"%s/**/*.go\")", s.cfg.ProjectRoot),
fmt.Sprintf(" - read_file(file_path=\"%s/some/file.go\")", s.cfg.ProjectRoot),
" grep and glob recurse into ALL subdirectories under the given path.",
" Top-level subdirectories of the project root:",
)
if entries, err := os.ReadDir(s.cfg.ProjectRoot); err == nil {
for _, e := range entries {
if e.IsDir() {
lines = append(lines, " - "+filepath.Join(s.cfg.ProjectRoot, e.Name())+"/")
}
}
}
lines = append(lines, " Use these tools to read actual source code before answering questions about the codebase.")
}
if s.cfg.ExamplesDir != "" && s.cfg.ExamplesDir != s.cfg.ProjectRoot {
lines = append(lines,
fmt.Sprintf("eino-examples directory: %s", s.cfg.ExamplesDir),
" When the user asks about examples or sample code, search here specifically:",
fmt.Sprintf(" - grep(pattern=\"...\", path=\"%s\")", s.cfg.ExamplesDir),
fmt.Sprintf(" - glob(pattern=\"%s/**/*.go\")", s.cfg.ExamplesDir),
)
}
absWorkDir, err := filepath.Abs(filepath.Join(s.cfg.WorkspaceDir, sessionID))
if err == nil {
entries, _ := os.ReadDir(absWorkDir)
var uploadedFiles []string
for _, e := range entries {
if !e.IsDir() {
uploadedFiles = append(uploadedFiles, filepath.Join(absWorkDir, e.Name()))
}
}
if len(uploadedFiles) > 0 {
lines = append(lines,
fmt.Sprintf("Session workspace: %s", absWorkDir),
" Uploaded files:",
)
for _, f := range uploadedFiles {
lines = append(lines, " - "+f)
}
}
}
ctx := strings.Join(lines, "\n")
runMessages := make([]*schema.Message, 0, len(history)+1)
runMessages = append(runMessages, schema.UserMessage(ctx))
runMessages = append(runMessages, history...)
return runMessages
}
func (s *Server) handleUpload(ctx context.Context, c *app.RequestContext) {
id := c.Param("id")
absWorkDir, err := filepath.Abs(filepath.Join(s.cfg.WorkspaceDir, id))
if err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if err := os.MkdirAll(absWorkDir, 0o755); err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
fileHeader, err := c.FormFile("file")
if err != nil {
c.JSON(consts.StatusBadRequest, map[string]string{"error": "file field is required"})
return
}
dst := filepath.Join(absWorkDir, filepath.Base(fileHeader.Filename))
if err := c.SaveUploadedFile(fileHeader, dst); err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
c.JSON(consts.StatusOK, map[string]string{
"name": fileHeader.Filename,
"path": dst,
})
}
// handleApprove resumes an interrupted agent run with the user's approval decision.
// The agent must have been interrupted earlier in this session (via the approval
// middleware). The session ID is used as the checkpoint ID.
func (s *Server) handleApprove(ctx context.Context, c *app.RequestContext) {
id := c.Param("id")
sess, err := s.cfg.Store.GetOrCreate(id)
if err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
interruptID := sess.GetPendingInterruptID()
if interruptID == "" {
c.JSON(consts.StatusBadRequest, map[string]string{"error": "no pending interrupt for this session"})
return
}
body, _ := c.Body()
var req approveRequest
if err := json.Unmarshal(body, &req); err != nil {
c.JSON(consts.StatusBadRequest, map[string]string{"error": "invalid request body"})
return
}
var reason *string
if req.Reason != "" {
reason = &req.Reason
}
result := &tool.ApprovalResult{Approved: req.Approved, DisapproveReason: reason}
iter, err := s.cfg.Runner.ResumeWithParams(ctx, id, &adk.ResumeParams{
Targets: map[string]any{interruptID: result},
})
if err != nil {
c.JSON(consts.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
// Clear the pending interrupt immediately so a double-approve returns 400.
sess.SetPendingInterruptID("")
log.Printf("[approve] session=%s interruptID=%s approved=%v", id, interruptID, req.Approved)
stream := sse.NewStream(c)
defer func() { _ = c.Flush() }()
kaStop := make(chan struct{})
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-kaStop:
return
case <-ticker.C:
_ = stream.Publish(&sse.Event{Data: []byte{}})
}
}
}()
lastContent, newInterruptID, finalMsgIdx, streamErr := a2ui.StreamContinue(&sseLineWriter{stream: stream}, id, sess.GetMsgIdx(), iter)
close(kaStop)
if streamErr != nil {
log.Printf("[approve] session=%s stream error: %v", id, streamErr)
} else if newInterruptID != "" {
log.Printf("[approve] session=%s re-interrupted: id=%s", id, newInterruptID)
sess.SetPendingInterruptID(newInterruptID)
sess.SetMsgIdx(finalMsgIdx)
} else {
log.Printf("[approve] session=%s done, response=%d chars", id, len(lastContent))
}
if lastContent != "" {
assistantMsg := schema.AssistantMessage(lastContent, nil)
if appendErr := sess.Append(assistantMsg); appendErr != nil {
log.Printf("warn: failed to persist assistant message: %v", appendErr)
}
}
}
// sseLineWriter implements io.Writer, buffering until a newline is found,
// then publishing each complete line as an SSE event (without the trailing newline).
type sseLineWriter struct {
stream *sse.Stream
buf []byte
}
func (w *sseLineWriter) Write(p []byte) (int, error) {
w.buf = append(w.buf, p...)
for {
idx := -1
for i, b := range w.buf {
if b == '\n' {
idx = i
break
}
}
if idx < 0 {
break
}
line := w.buf[:idx]
w.buf = w.buf[idx+1:]
if len(line) == 0 {
continue
}
if err := w.stream.Publish(&sse.Event{Data: line}); err != nil {
return 0, err
}
}
return len(p), nil
}