feat(adk): add agent_with_summarization based on agent_middleware (#135)
parent
d41b497bdc
commit
85592f893f
@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/cloudwego/eino-examples/adk/common/model"
|
||||
"github.com/cloudwego/eino-examples/adk/common/prints"
|
||||
"github.com/cloudwego/eino-examples/adk/common/trace"
|
||||
"github.com/cloudwego/eino-examples/adk/intro/agent_with_summarization/summarization"
|
||||
"github.com/cloudwego/eino-examples/internal/logs"
|
||||
)
|
||||
|
||||
const (
|
||||
summaryMaxTokensBefore = 10 * 1024
|
||||
summaryMaxTokensRecent = 2 * 1024
|
||||
agentMaxIterations = 30
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
a, err := newAgent(ctx)
|
||||
if err != nil {
|
||||
logs.Fatalf("create agent failed, err=%v", err)
|
||||
}
|
||||
|
||||
traceCloseFn, startSpanFn := trace.AppendCozeLoopCallbackIfConfigured(ctx)
|
||||
defer traceCloseFn(ctx)
|
||||
|
||||
runner := adk.NewRunner(ctx, adk.RunnerConfig{
|
||||
EnableStreaming: true, // you can disable streaming here
|
||||
Agent: a,
|
||||
})
|
||||
|
||||
query := `Write a very long report on the history of artificial intelligence.`
|
||||
ctx, endSpanFn := startSpanFn(ctx, "Agent", query)
|
||||
|
||||
iter := runner.Query(ctx, query)
|
||||
|
||||
var lastMessage adk.Message
|
||||
for {
|
||||
event, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if event.Err != nil {
|
||||
log.Fatal(event.Err)
|
||||
}
|
||||
|
||||
prints.Event(event)
|
||||
|
||||
if event.Output != nil {
|
||||
lastMessage, _, err = adk.GetMessage(event)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
endSpanFn(ctx, lastMessage)
|
||||
|
||||
// wait for all span to be ended
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
|
||||
func newAgent(ctx context.Context) (adk.Agent, error) {
|
||||
sumMW, err := summarization.New(ctx, &summarization.Config{
|
||||
Model: model.NewChatModel(),
|
||||
MaxTokensBeforeSummary: summaryMaxTokensBefore,
|
||||
MaxTokensForRecentMessages: summaryMaxTokensRecent,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||
Name: "main_agent",
|
||||
Description: "A long-form report assistant",
|
||||
Instruction: `You are a long-form report writer working in ReAct mode.
|
||||
Think step by step, call tools to expand content by repeating paragraphs, then synthesize a cohesive response.
|
||||
one time call one tool, do not call multiple tools in one turn.
|
||||
Each tool call should indicate the call number. After 20 tool calls, produce a final summary.`,
|
||||
Model: model.NewChatModel(),
|
||||
Middlewares: []adk.AgentMiddleware{sumMW},
|
||||
ToolsConfig: adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
Tools: []tool.BaseTool{
|
||||
NewRepeatSectionsTool(),
|
||||
},
|
||||
},
|
||||
},
|
||||
MaxIterations: agentMaxIterations,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
@ -0,0 +1,403 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package summarization
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type TokenCounter func(ctx context.Context, msgs []adk.Message) (tokenNum []int64, err error)
|
||||
|
||||
// Config defines parameters for the conversation summarization middleware.
|
||||
// It controls when summarization is triggered and how much recent context is retained.
|
||||
// Required: Model. Optional: SystemPrompt, Counter, and token budgets.
|
||||
type Config struct {
|
||||
// MaxTokensBeforeSummary is the max token threshold to trigger summarization based on total context
|
||||
// (system prompt + history). Uses DefaultMaxTokensBeforeSummary when <= 0.
|
||||
MaxTokensBeforeSummary int
|
||||
|
||||
// MaxTokensForRecentMessages is the max token budget reserved for recent messages after summarization.
|
||||
// Uses DefaultMaxTokensForRecentMessages when <= 0.
|
||||
MaxTokensForRecentMessages int
|
||||
|
||||
// Counter custom token counter.
|
||||
// Optional
|
||||
Counter TokenCounter
|
||||
|
||||
// Model used to generate the summary. Must be provided.
|
||||
// Required.
|
||||
Model model.BaseChatModel
|
||||
|
||||
// SystemPrompt is the system prompt for the summarizer.
|
||||
// Optional. If empty, PromptOfSummary is used.
|
||||
SystemPrompt string
|
||||
}
|
||||
|
||||
// New creates an AgentMiddleware that compacts long conversation history
|
||||
// into a single summary message when the token threshold is exceeded.
|
||||
// The summarizer chain is: ChatTemplate(SystemPrompt) -> ChatModel(Model).
|
||||
// It applies defaults for token budgets and allows a custom Counter.
|
||||
func New(ctx context.Context, cfg *Config) (adk.AgentMiddleware, error) {
|
||||
if cfg == nil {
|
||||
return adk.AgentMiddleware{}, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
systemPrompt := cfg.SystemPrompt
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = PromptOfSummary
|
||||
}
|
||||
maxBefore := DefaultMaxTokensBeforeSummary
|
||||
if cfg.MaxTokensBeforeSummary > 0 {
|
||||
maxBefore = cfg.MaxTokensBeforeSummary
|
||||
}
|
||||
maxRecent := DefaultMaxTokensForRecentMessages
|
||||
if cfg.MaxTokensForRecentMessages > 0 {
|
||||
maxRecent = cfg.MaxTokensForRecentMessages
|
||||
}
|
||||
|
||||
tpl := prompt.FromMessages(schema.FString,
|
||||
schema.SystemMessage(systemPrompt),
|
||||
schema.UserMessage("summarize 'older_messages': "))
|
||||
|
||||
summarizer, err := compose.NewChain[map[string]any, *schema.Message]().
|
||||
AppendChatTemplate(tpl).
|
||||
AppendChatModel(cfg.Model).
|
||||
Compile(ctx, compose.WithGraphName("Summarizer"))
|
||||
if err != nil {
|
||||
return adk.AgentMiddleware{}, fmt.Errorf("compile summarizer failed, err=%w", err)
|
||||
}
|
||||
|
||||
sm := &summaryMiddleware{
|
||||
counter: defaultCounterToken,
|
||||
maxBefore: maxBefore,
|
||||
maxRecent: maxRecent,
|
||||
summarizer: summarizer,
|
||||
}
|
||||
if cfg.Counter != nil {
|
||||
sm.counter = cfg.Counter
|
||||
}
|
||||
return adk.AgentMiddleware{BeforeChatModel: sm.BeforeModel}, nil
|
||||
}
|
||||
|
||||
const summaryMessageFlag = "_agent_middleware_summary_message"
|
||||
|
||||
type summaryMiddleware struct {
|
||||
counter TokenCounter
|
||||
maxBefore int
|
||||
maxRecent int
|
||||
|
||||
summarizer compose.Runnable[map[string]any, *schema.Message]
|
||||
}
|
||||
|
||||
func (s *summaryMiddleware) BeforeModel(ctx context.Context, state *adk.ChatModelAgentState) (err error) {
|
||||
if state == nil || len(state.Messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
messages := state.Messages
|
||||
msgsToken, err := s.counter(ctx, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count token failed, err=%w", err)
|
||||
}
|
||||
if len(messages) != len(msgsToken) {
|
||||
return fmt.Errorf("token count mismatch, msgNum=%d, tokenCountNum=%d", len(messages), len(msgsToken))
|
||||
}
|
||||
|
||||
var total int64
|
||||
for _, t := range msgsToken {
|
||||
total += t
|
||||
}
|
||||
// Trigger summarization only when exceeding threshold
|
||||
if total <= int64(s.maxBefore) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build blocks with user-messages, summary-message, tool-call pairings
|
||||
type block struct {
|
||||
msgs []*schema.Message
|
||||
tokens int64
|
||||
}
|
||||
idx := 0
|
||||
|
||||
systemBlock := block{}
|
||||
if idx < len(messages) {
|
||||
m := messages[idx]
|
||||
if m != nil && m.Role == schema.System {
|
||||
systemBlock.msgs = append(systemBlock.msgs, m)
|
||||
systemBlock.tokens += msgsToken[idx]
|
||||
idx++
|
||||
}
|
||||
}
|
||||
userBlock := block{}
|
||||
for idx < len(messages) {
|
||||
m := messages[idx]
|
||||
if m == nil {
|
||||
idx++
|
||||
continue
|
||||
}
|
||||
if m.Role != schema.User {
|
||||
break
|
||||
}
|
||||
userBlock.msgs = append(userBlock.msgs, m)
|
||||
userBlock.tokens += msgsToken[idx]
|
||||
idx++
|
||||
}
|
||||
summaryBlock := block{}
|
||||
if idx < len(messages) {
|
||||
m := messages[idx]
|
||||
if m != nil && m.Role == schema.Assistant {
|
||||
if _, ok := m.Extra[summaryMessageFlag]; ok {
|
||||
summaryBlock.msgs = append(summaryBlock.msgs, m)
|
||||
summaryBlock.tokens += msgsToken[idx]
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolBlocks := make([]block, 0)
|
||||
for i := idx; i < len(messages); i++ {
|
||||
m := messages[i]
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.Assistant && len(m.ToolCalls) > 0 {
|
||||
b := block{msgs: []*schema.Message{m}, tokens: msgsToken[i]}
|
||||
// Collect subsequent tool messages matching any tool call id
|
||||
callIDs := make(map[string]struct{}, len(m.ToolCalls))
|
||||
for _, tc := range m.ToolCalls {
|
||||
callIDs[tc.ID] = struct{}{}
|
||||
}
|
||||
j := i + 1
|
||||
for j < len(messages) {
|
||||
nm := messages[j]
|
||||
if nm == nil || nm.Role != schema.Tool {
|
||||
break
|
||||
}
|
||||
// Match by ToolCallID when available; if empty, include but keep boundary
|
||||
if nm.ToolCallID == "" {
|
||||
b.msgs = append(b.msgs, nm)
|
||||
b.tokens += msgsToken[j]
|
||||
} else {
|
||||
if _, ok := callIDs[nm.ToolCallID]; !ok {
|
||||
// Tool message not belonging to this assistant call -> end pairing
|
||||
break
|
||||
}
|
||||
b.msgs = append(b.msgs, nm)
|
||||
b.tokens += msgsToken[j]
|
||||
}
|
||||
j++
|
||||
}
|
||||
toolBlocks = append(toolBlocks, b)
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
toolBlocks = append(toolBlocks, block{msgs: []*schema.Message{m}, tokens: msgsToken[i]})
|
||||
}
|
||||
|
||||
// Split into recent and older within token budget, from newest to oldest
|
||||
var recentBlocks []block
|
||||
var olderBlocks []block
|
||||
var recentTokens int64
|
||||
for i := len(toolBlocks) - 1; i >= 0; i-- {
|
||||
b := toolBlocks[i]
|
||||
if recentTokens+b.tokens > int64(s.maxRecent) {
|
||||
olderBlocks = append([]block{b}, olderBlocks...)
|
||||
continue
|
||||
}
|
||||
recentBlocks = append([]block{b}, recentBlocks...)
|
||||
recentTokens += b.tokens
|
||||
}
|
||||
|
||||
joinBlocks := func(bs []block) string {
|
||||
var sb strings.Builder
|
||||
for _, b := range bs {
|
||||
for _, m := range b.msgs {
|
||||
sb.WriteString(renderMsg(m))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
olderText := joinBlocks(olderBlocks)
|
||||
recentText := joinBlocks(recentBlocks)
|
||||
|
||||
msg, err := s.summarizer.Invoke(ctx, map[string]any{
|
||||
"system_prompt": joinBlocks([]block{systemBlock}),
|
||||
"user_messages": joinBlocks([]block{userBlock}),
|
||||
"previous_summary": joinBlocks([]block{summaryBlock}),
|
||||
"older_messages": olderText,
|
||||
"recent_messages": recentText,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("summarize failed, err=%w", err)
|
||||
}
|
||||
|
||||
summaryMsg := schema.AssistantMessage(msg.Content, nil)
|
||||
msg.Name = "summary"
|
||||
summaryMsg.Extra = map[string]any{
|
||||
summaryMessageFlag: true,
|
||||
}
|
||||
|
||||
// Build new state: prepend summary message, keep recent messages
|
||||
newMessages := make([]*schema.Message, 0, len(messages))
|
||||
newMessages = append(newMessages, systemBlock.msgs...)
|
||||
newMessages = append(newMessages, userBlock.msgs...)
|
||||
newMessages = append(newMessages, summaryMsg)
|
||||
for _, b := range recentBlocks {
|
||||
newMessages = append(newMessages, b.msgs...)
|
||||
}
|
||||
|
||||
state.Messages = newMessages
|
||||
return nil
|
||||
}
|
||||
|
||||
// Render messages into strings
|
||||
func renderMsg(m *schema.Message) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
var sb strings.Builder
|
||||
if m.Role == schema.Tool {
|
||||
if m.ToolName != "" {
|
||||
sb.WriteString("[tool:")
|
||||
sb.WriteString(m.ToolName)
|
||||
sb.WriteString("]\n")
|
||||
} else {
|
||||
sb.WriteString("[tool]\n")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("[")
|
||||
sb.WriteString(string(m.Role))
|
||||
sb.WriteString("]\n")
|
||||
}
|
||||
if m.Content != "" {
|
||||
sb.WriteString(m.Content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
if m.Role == schema.Assistant && len(m.ToolCalls) > 0 {
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.Function.Name != "" {
|
||||
sb.WriteString("tool_call: ")
|
||||
sb.WriteString(tc.Function.Name)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
if tc.Function.Arguments != "" {
|
||||
sb.WriteString("args: ")
|
||||
sb.WriteString(tc.Function.Arguments)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, part := range m.UserInputMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||
sb.WriteString(part.Text)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
for _, part := range m.AssistantGenMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||
sb.WriteString(part.Text)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func defaultCounterToken(ctx context.Context, msgs []adk.Message) (tokenNum []int64, err error) {
|
||||
encoding := "cl100k_base"
|
||||
tkt, err := tiktoken.GetEncoding(encoding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get encoding failed, encoding=%v, err=%w", encoding, err)
|
||||
}
|
||||
tokenNum = make([]int64, len(msgs))
|
||||
|
||||
for i, m := range msgs {
|
||||
if m == nil {
|
||||
tokenNum[i] = 0
|
||||
continue
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Message role contributes to chat tokenization overhead; include it as text.
|
||||
if m.Role != "" {
|
||||
sb.WriteString(string(m.Role))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Core text content
|
||||
if m.Content != "" {
|
||||
sb.WriteString(m.Content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Reasoning content if present
|
||||
// Reasoning Content is not used by model
|
||||
// if m.ReasoningContent != "" {
|
||||
// sb.WriteString(m.ReasoningContent)
|
||||
// sb.WriteString("\n")
|
||||
// }
|
||||
|
||||
// Multi modal input/output text parts
|
||||
for _, part := range m.UserInputMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||
sb.WriteString(part.Text)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
for _, part := range m.AssistantGenMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||
sb.WriteString(part.Text)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call textual context (name + arguments)
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.Function.Name != "" {
|
||||
sb.WriteString(tc.Function.Name)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
if tc.Function.Arguments != "" {
|
||||
sb.WriteString(tc.Function.Arguments)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
text := sb.String()
|
||||
if text == "" {
|
||||
tokenNum[i] = 0
|
||||
continue
|
||||
}
|
||||
|
||||
tokens := tkt.Encode(text, nil, nil)
|
||||
tokenNum[i] = int64(len(tokens))
|
||||
}
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/components/tool/utils"
|
||||
)
|
||||
|
||||
type RepeatSectionsInput struct {
|
||||
Title string `json:"title" jsonschema_description:"Section title"`
|
||||
Paragraphs []string `json:"paragraphs" jsonschema_description:"Paragraphs to be repeated"`
|
||||
RepeatCount int `json:"repeat_count" jsonschema_description:"Times to repeat paragraphs, default 2"`
|
||||
}
|
||||
|
||||
func NewRepeatSectionsTool() tool.InvokableTool {
|
||||
t, err := utils.InferTool(
|
||||
"repeat_sections",
|
||||
"Repeat given paragraphs to quickly accumulate context",
|
||||
func(ctx context.Context, in *RepeatSectionsInput) (string, error) {
|
||||
if in.RepeatCount <= 0 {
|
||||
in.RepeatCount = 2
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("## ")
|
||||
b.WriteString(in.Title)
|
||||
b.WriteString("\n")
|
||||
idx := 0
|
||||
for i := 0; i < in.RepeatCount; i++ {
|
||||
for _, sec := range in.Paragraphs {
|
||||
b.WriteString(strconv.Itoa(idx + 1))
|
||||
b.WriteString(". ")
|
||||
b.WriteString(sec)
|
||||
b.WriteString("\n")
|
||||
idx++
|
||||
}
|
||||
}
|
||||
callCount := 0
|
||||
times, ok := adk.GetSessionValue(ctx, "_tool_call_count")
|
||||
if !ok {
|
||||
times = 0
|
||||
} else {
|
||||
callCount = times.(int)
|
||||
}
|
||||
callCount++
|
||||
adk.AddSessionValue(ctx, "_tool_call_count", callCount)
|
||||
|
||||
b.WriteString(fmt.Sprintf("Tool calls so far: %d", callCount))
|
||||
return b.String(), nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create repeat_sections tool: %v", err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
Loading…
Reference in New Issue