You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

404 lines
11 KiB
Go

/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package summarization
import (
"context"
"fmt"
"strings"
"github.com/pkoukk/tiktoken-go"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
type TokenCounter func(ctx context.Context, msgs []adk.Message) (tokenNum []int64, err error)
// Config defines parameters for the conversation summarization middleware.
// It controls when summarization is triggered and how much recent context is retained.
// Required: Model. Optional: SystemPrompt, Counter, and token budgets.
type Config struct {
// MaxTokensBeforeSummary is the max token threshold to trigger summarization based on total context
// (system prompt + history). Uses DefaultMaxTokensBeforeSummary when <= 0.
MaxTokensBeforeSummary int
// MaxTokensForRecentMessages is the max token budget reserved for recent messages after summarization.
// Uses DefaultMaxTokensForRecentMessages when <= 0.
MaxTokensForRecentMessages int
// Counter custom token counter.
// Optional
Counter TokenCounter
// Model used to generate the summary. Must be provided.
// Required.
Model model.BaseChatModel
// SystemPrompt is the system prompt for the summarizer.
// Optional. If empty, PromptOfSummary is used.
SystemPrompt string
}
// New creates an AgentMiddleware that compacts long conversation history
// into a single summary message when the token threshold is exceeded.
// The summarizer chain is: ChatTemplate(SystemPrompt) -> ChatModel(Model).
// It applies defaults for token budgets and allows a custom Counter.
func New(ctx context.Context, cfg *Config) (adk.AgentMiddleware, error) {
if cfg == nil {
return adk.AgentMiddleware{}, fmt.Errorf("config is nil")
}
systemPrompt := cfg.SystemPrompt
if systemPrompt == "" {
systemPrompt = PromptOfSummary
}
maxBefore := DefaultMaxTokensBeforeSummary
if cfg.MaxTokensBeforeSummary > 0 {
maxBefore = cfg.MaxTokensBeforeSummary
}
maxRecent := DefaultMaxTokensForRecentMessages
if cfg.MaxTokensForRecentMessages > 0 {
maxRecent = cfg.MaxTokensForRecentMessages
}
tpl := prompt.FromMessages(schema.FString,
schema.SystemMessage(systemPrompt),
schema.UserMessage("summarize 'older_messages': "))
summarizer, err := compose.NewChain[map[string]any, *schema.Message]().
AppendChatTemplate(tpl).
AppendChatModel(cfg.Model).
Compile(ctx, compose.WithGraphName("Summarizer"))
if err != nil {
return adk.AgentMiddleware{}, fmt.Errorf("compile summarizer failed, err=%w", err)
}
sm := &summaryMiddleware{
counter: defaultCounterToken,
maxBefore: maxBefore,
maxRecent: maxRecent,
summarizer: summarizer,
}
if cfg.Counter != nil {
sm.counter = cfg.Counter
}
return adk.AgentMiddleware{BeforeChatModel: sm.BeforeModel}, nil
}
const summaryMessageFlag = "_agent_middleware_summary_message"
type summaryMiddleware struct {
counter TokenCounter
maxBefore int
maxRecent int
summarizer compose.Runnable[map[string]any, *schema.Message]
}
func (s *summaryMiddleware) BeforeModel(ctx context.Context, state *adk.ChatModelAgentState) (err error) {
if state == nil || len(state.Messages) == 0 {
return nil
}
messages := state.Messages
msgsToken, err := s.counter(ctx, messages)
if err != nil {
return fmt.Errorf("count token failed, err=%w", err)
}
if len(messages) != len(msgsToken) {
return fmt.Errorf("token count mismatch, msgNum=%d, tokenCountNum=%d", len(messages), len(msgsToken))
}
var total int64
for _, t := range msgsToken {
total += t
}
// Trigger summarization only when exceeding threshold
if total <= int64(s.maxBefore) {
return nil
}
// Build blocks with user-messages, summary-message, tool-call pairings
type block struct {
msgs []*schema.Message
tokens int64
}
idx := 0
systemBlock := block{}
if idx < len(messages) {
m := messages[idx]
if m != nil && m.Role == schema.System {
systemBlock.msgs = append(systemBlock.msgs, m)
systemBlock.tokens += msgsToken[idx]
idx++
}
}
userBlock := block{}
for idx < len(messages) {
m := messages[idx]
if m == nil {
idx++
continue
}
if m.Role != schema.User {
break
}
userBlock.msgs = append(userBlock.msgs, m)
userBlock.tokens += msgsToken[idx]
idx++
}
summaryBlock := block{}
if idx < len(messages) {
m := messages[idx]
if m != nil && m.Role == schema.Assistant {
if _, ok := m.Extra[summaryMessageFlag]; ok {
summaryBlock.msgs = append(summaryBlock.msgs, m)
summaryBlock.tokens += msgsToken[idx]
idx++
}
}
}
toolBlocks := make([]block, 0)
for i := idx; i < len(messages); i++ {
m := messages[i]
if m == nil {
continue
}
if m.Role == schema.Assistant && len(m.ToolCalls) > 0 {
b := block{msgs: []*schema.Message{m}, tokens: msgsToken[i]}
// Collect subsequent tool messages matching any tool call id
callIDs := make(map[string]struct{}, len(m.ToolCalls))
for _, tc := range m.ToolCalls {
callIDs[tc.ID] = struct{}{}
}
j := i + 1
for j < len(messages) {
nm := messages[j]
if nm == nil || nm.Role != schema.Tool {
break
}
// Match by ToolCallID when available; if empty, include but keep boundary
if nm.ToolCallID == "" {
b.msgs = append(b.msgs, nm)
b.tokens += msgsToken[j]
} else {
if _, ok := callIDs[nm.ToolCallID]; !ok {
// Tool message not belonging to this assistant call -> end pairing
break
}
b.msgs = append(b.msgs, nm)
b.tokens += msgsToken[j]
}
j++
}
toolBlocks = append(toolBlocks, b)
i = j - 1
continue
}
toolBlocks = append(toolBlocks, block{msgs: []*schema.Message{m}, tokens: msgsToken[i]})
}
// Split into recent and older within token budget, from newest to oldest
var recentBlocks []block
var olderBlocks []block
var recentTokens int64
for i := len(toolBlocks) - 1; i >= 0; i-- {
b := toolBlocks[i]
if recentTokens+b.tokens > int64(s.maxRecent) {
olderBlocks = append([]block{b}, olderBlocks...)
continue
}
recentBlocks = append([]block{b}, recentBlocks...)
recentTokens += b.tokens
}
joinBlocks := func(bs []block) string {
var sb strings.Builder
for _, b := range bs {
for _, m := range b.msgs {
sb.WriteString(renderMsg(m))
sb.WriteString("\n")
}
}
return sb.String()
}
olderText := joinBlocks(olderBlocks)
recentText := joinBlocks(recentBlocks)
msg, err := s.summarizer.Invoke(ctx, map[string]any{
"system_prompt": joinBlocks([]block{systemBlock}),
"user_messages": joinBlocks([]block{userBlock}),
"previous_summary": joinBlocks([]block{summaryBlock}),
"older_messages": olderText,
"recent_messages": recentText,
})
if err != nil {
return fmt.Errorf("summarize failed, err=%w", err)
}
summaryMsg := schema.AssistantMessage(msg.Content, nil)
msg.Name = "summary"
summaryMsg.Extra = map[string]any{
summaryMessageFlag: true,
}
// Build new state: prepend summary message, keep recent messages
newMessages := make([]*schema.Message, 0, len(messages))
newMessages = append(newMessages, systemBlock.msgs...)
newMessages = append(newMessages, userBlock.msgs...)
newMessages = append(newMessages, summaryMsg)
for _, b := range recentBlocks {
newMessages = append(newMessages, b.msgs...)
}
state.Messages = newMessages
return nil
}
// Render messages into strings
func renderMsg(m *schema.Message) string {
if m == nil {
return ""
}
var sb strings.Builder
if m.Role == schema.Tool {
if m.ToolName != "" {
sb.WriteString("[tool:")
sb.WriteString(m.ToolName)
sb.WriteString("]\n")
} else {
sb.WriteString("[tool]\n")
}
} else {
sb.WriteString("[")
sb.WriteString(string(m.Role))
sb.WriteString("]\n")
}
if m.Content != "" {
sb.WriteString(m.Content)
sb.WriteString("\n")
}
if m.Role == schema.Assistant && len(m.ToolCalls) > 0 {
for _, tc := range m.ToolCalls {
if tc.Function.Name != "" {
sb.WriteString("tool_call: ")
sb.WriteString(tc.Function.Name)
sb.WriteString("\n")
}
if tc.Function.Arguments != "" {
sb.WriteString("args: ")
sb.WriteString(tc.Function.Arguments)
sb.WriteString("\n")
}
}
}
for _, part := range m.UserInputMultiContent {
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
sb.WriteString(part.Text)
sb.WriteString("\n")
}
}
for _, part := range m.AssistantGenMultiContent {
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
sb.WriteString(part.Text)
sb.WriteString("\n")
}
}
return sb.String()
}
func defaultCounterToken(ctx context.Context, msgs []adk.Message) (tokenNum []int64, err error) {
encoding := "cl100k_base"
tkt, err := tiktoken.GetEncoding(encoding)
if err != nil {
return nil, fmt.Errorf("get encoding failed, encoding=%v, err=%w", encoding, err)
}
tokenNum = make([]int64, len(msgs))
for i, m := range msgs {
if m == nil {
tokenNum[i] = 0
continue
}
var sb strings.Builder
// Message role contributes to chat tokenization overhead; include it as text.
if m.Role != "" {
sb.WriteString(string(m.Role))
sb.WriteString("\n")
}
// Core text content
if m.Content != "" {
sb.WriteString(m.Content)
sb.WriteString("\n")
}
// Reasoning content if present
// Reasoning Content is not used by model
// if m.ReasoningContent != "" {
// sb.WriteString(m.ReasoningContent)
// sb.WriteString("\n")
// }
// Multi modal input/output text parts
for _, part := range m.UserInputMultiContent {
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
sb.WriteString(part.Text)
sb.WriteString("\n")
}
}
for _, part := range m.AssistantGenMultiContent {
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
sb.WriteString(part.Text)
sb.WriteString("\n")
}
}
// Tool call textual context (name + arguments)
for _, tc := range m.ToolCalls {
if tc.Function.Name != "" {
sb.WriteString(tc.Function.Name)
sb.WriteString("\n")
}
if tc.Function.Arguments != "" {
sb.WriteString(tc.Function.Arguments)
sb.WriteString("\n")
}
}
text := sb.String()
if text == "" {
tokenNum[i] = 0
continue
}
tokens := tkt.Encode(text, nil, nil)
tokenNum[i] = int64(len(tokens))
}
return tokenNum, nil
}