feat(adk): add agent_with_summarization based on agent_middleware (#135)
parent
d41b497bdc
commit
85592f893f
@ -0,0 +1,118 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/components/tool"
|
||||||
|
"github.com/cloudwego/eino/compose"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-examples/adk/common/model"
|
||||||
|
"github.com/cloudwego/eino-examples/adk/common/prints"
|
||||||
|
"github.com/cloudwego/eino-examples/adk/common/trace"
|
||||||
|
"github.com/cloudwego/eino-examples/adk/intro/agent_with_summarization/summarization"
|
||||||
|
"github.com/cloudwego/eino-examples/internal/logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
summaryMaxTokensBefore = 10 * 1024
|
||||||
|
summaryMaxTokensRecent = 2 * 1024
|
||||||
|
agentMaxIterations = 30
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
a, err := newAgent(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logs.Fatalf("create agent failed, err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
traceCloseFn, startSpanFn := trace.AppendCozeLoopCallbackIfConfigured(ctx)
|
||||||
|
defer traceCloseFn(ctx)
|
||||||
|
|
||||||
|
runner := adk.NewRunner(ctx, adk.RunnerConfig{
|
||||||
|
EnableStreaming: true, // you can disable streaming here
|
||||||
|
Agent: a,
|
||||||
|
})
|
||||||
|
|
||||||
|
query := `Write a very long report on the history of artificial intelligence.`
|
||||||
|
ctx, endSpanFn := startSpanFn(ctx, "Agent", query)
|
||||||
|
|
||||||
|
iter := runner.Query(ctx, query)
|
||||||
|
|
||||||
|
var lastMessage adk.Message
|
||||||
|
for {
|
||||||
|
event, ok := iter.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if event.Err != nil {
|
||||||
|
log.Fatal(event.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prints.Event(event)
|
||||||
|
|
||||||
|
if event.Output != nil {
|
||||||
|
lastMessage, _, err = adk.GetMessage(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
endSpanFn(ctx, lastMessage)
|
||||||
|
|
||||||
|
// wait for all span to be ended
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAgent(ctx context.Context) (adk.Agent, error) {
|
||||||
|
sumMW, err := summarization.New(ctx, &summarization.Config{
|
||||||
|
Model: model.NewChatModel(),
|
||||||
|
MaxTokensBeforeSummary: summaryMaxTokensBefore,
|
||||||
|
MaxTokensForRecentMessages: summaryMaxTokensRecent,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||||
|
Name: "main_agent",
|
||||||
|
Description: "A long-form report assistant",
|
||||||
|
Instruction: `You are a long-form report writer working in ReAct mode.
|
||||||
|
Think step by step, call tools to expand content by repeating paragraphs, then synthesize a cohesive response.
|
||||||
|
one time call one tool, do not call multiple tools in one turn.
|
||||||
|
Each tool call should indicate the call number. After 20 tool calls, produce a final summary.`,
|
||||||
|
Model: model.NewChatModel(),
|
||||||
|
Middlewares: []adk.AgentMiddleware{sumMW},
|
||||||
|
ToolsConfig: adk.ToolsConfig{
|
||||||
|
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||||
|
Tools: []tool.BaseTool{
|
||||||
|
NewRepeatSectionsTool(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MaxIterations: agentMaxIterations,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
@ -0,0 +1,403 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package summarization
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/components/model"
|
||||||
|
"github.com/cloudwego/eino/components/prompt"
|
||||||
|
"github.com/cloudwego/eino/compose"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenCounter func(ctx context.Context, msgs []adk.Message) (tokenNum []int64, err error)
|
||||||
|
|
||||||
|
// Config defines parameters for the conversation summarization middleware.
|
||||||
|
// It controls when summarization is triggered and how much recent context is retained.
|
||||||
|
// Required: Model. Optional: SystemPrompt, Counter, and token budgets.
|
||||||
|
type Config struct {
|
||||||
|
// MaxTokensBeforeSummary is the max token threshold to trigger summarization based on total context
|
||||||
|
// (system prompt + history). Uses DefaultMaxTokensBeforeSummary when <= 0.
|
||||||
|
MaxTokensBeforeSummary int
|
||||||
|
|
||||||
|
// MaxTokensForRecentMessages is the max token budget reserved for recent messages after summarization.
|
||||||
|
// Uses DefaultMaxTokensForRecentMessages when <= 0.
|
||||||
|
MaxTokensForRecentMessages int
|
||||||
|
|
||||||
|
// Counter custom token counter.
|
||||||
|
// Optional
|
||||||
|
Counter TokenCounter
|
||||||
|
|
||||||
|
// Model used to generate the summary. Must be provided.
|
||||||
|
// Required.
|
||||||
|
Model model.BaseChatModel
|
||||||
|
|
||||||
|
// SystemPrompt is the system prompt for the summarizer.
|
||||||
|
// Optional. If empty, PromptOfSummary is used.
|
||||||
|
SystemPrompt string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates an AgentMiddleware that compacts long conversation history
|
||||||
|
// into a single summary message when the token threshold is exceeded.
|
||||||
|
// The summarizer chain is: ChatTemplate(SystemPrompt) -> ChatModel(Model).
|
||||||
|
// It applies defaults for token budgets and allows a custom Counter.
|
||||||
|
func New(ctx context.Context, cfg *Config) (adk.AgentMiddleware, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return adk.AgentMiddleware{}, fmt.Errorf("config is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
systemPrompt := cfg.SystemPrompt
|
||||||
|
if systemPrompt == "" {
|
||||||
|
systemPrompt = PromptOfSummary
|
||||||
|
}
|
||||||
|
maxBefore := DefaultMaxTokensBeforeSummary
|
||||||
|
if cfg.MaxTokensBeforeSummary > 0 {
|
||||||
|
maxBefore = cfg.MaxTokensBeforeSummary
|
||||||
|
}
|
||||||
|
maxRecent := DefaultMaxTokensForRecentMessages
|
||||||
|
if cfg.MaxTokensForRecentMessages > 0 {
|
||||||
|
maxRecent = cfg.MaxTokensForRecentMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
tpl := prompt.FromMessages(schema.FString,
|
||||||
|
schema.SystemMessage(systemPrompt),
|
||||||
|
schema.UserMessage("summarize 'older_messages': "))
|
||||||
|
|
||||||
|
summarizer, err := compose.NewChain[map[string]any, *schema.Message]().
|
||||||
|
AppendChatTemplate(tpl).
|
||||||
|
AppendChatModel(cfg.Model).
|
||||||
|
Compile(ctx, compose.WithGraphName("Summarizer"))
|
||||||
|
if err != nil {
|
||||||
|
return adk.AgentMiddleware{}, fmt.Errorf("compile summarizer failed, err=%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sm := &summaryMiddleware{
|
||||||
|
counter: defaultCounterToken,
|
||||||
|
maxBefore: maxBefore,
|
||||||
|
maxRecent: maxRecent,
|
||||||
|
summarizer: summarizer,
|
||||||
|
}
|
||||||
|
if cfg.Counter != nil {
|
||||||
|
sm.counter = cfg.Counter
|
||||||
|
}
|
||||||
|
return adk.AgentMiddleware{BeforeChatModel: sm.BeforeModel}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const summaryMessageFlag = "_agent_middleware_summary_message"
|
||||||
|
|
||||||
|
type summaryMiddleware struct {
|
||||||
|
counter TokenCounter
|
||||||
|
maxBefore int
|
||||||
|
maxRecent int
|
||||||
|
|
||||||
|
summarizer compose.Runnable[map[string]any, *schema.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *summaryMiddleware) BeforeModel(ctx context.Context, state *adk.ChatModelAgentState) (err error) {
|
||||||
|
if state == nil || len(state.Messages) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := state.Messages
|
||||||
|
msgsToken, err := s.counter(ctx, messages)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("count token failed, err=%w", err)
|
||||||
|
}
|
||||||
|
if len(messages) != len(msgsToken) {
|
||||||
|
return fmt.Errorf("token count mismatch, msgNum=%d, tokenCountNum=%d", len(messages), len(msgsToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
for _, t := range msgsToken {
|
||||||
|
total += t
|
||||||
|
}
|
||||||
|
// Trigger summarization only when exceeding threshold
|
||||||
|
if total <= int64(s.maxBefore) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build blocks with user-messages, summary-message, tool-call pairings
|
||||||
|
type block struct {
|
||||||
|
msgs []*schema.Message
|
||||||
|
tokens int64
|
||||||
|
}
|
||||||
|
idx := 0
|
||||||
|
|
||||||
|
systemBlock := block{}
|
||||||
|
if idx < len(messages) {
|
||||||
|
m := messages[idx]
|
||||||
|
if m != nil && m.Role == schema.System {
|
||||||
|
systemBlock.msgs = append(systemBlock.msgs, m)
|
||||||
|
systemBlock.tokens += msgsToken[idx]
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userBlock := block{}
|
||||||
|
for idx < len(messages) {
|
||||||
|
m := messages[idx]
|
||||||
|
if m == nil {
|
||||||
|
idx++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.Role != schema.User {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
userBlock.msgs = append(userBlock.msgs, m)
|
||||||
|
userBlock.tokens += msgsToken[idx]
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
summaryBlock := block{}
|
||||||
|
if idx < len(messages) {
|
||||||
|
m := messages[idx]
|
||||||
|
if m != nil && m.Role == schema.Assistant {
|
||||||
|
if _, ok := m.Extra[summaryMessageFlag]; ok {
|
||||||
|
summaryBlock.msgs = append(summaryBlock.msgs, m)
|
||||||
|
summaryBlock.tokens += msgsToken[idx]
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolBlocks := make([]block, 0)
|
||||||
|
for i := idx; i < len(messages); i++ {
|
||||||
|
m := messages[i]
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.Role == schema.Assistant && len(m.ToolCalls) > 0 {
|
||||||
|
b := block{msgs: []*schema.Message{m}, tokens: msgsToken[i]}
|
||||||
|
// Collect subsequent tool messages matching any tool call id
|
||||||
|
callIDs := make(map[string]struct{}, len(m.ToolCalls))
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
callIDs[tc.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
j := i + 1
|
||||||
|
for j < len(messages) {
|
||||||
|
nm := messages[j]
|
||||||
|
if nm == nil || nm.Role != schema.Tool {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Match by ToolCallID when available; if empty, include but keep boundary
|
||||||
|
if nm.ToolCallID == "" {
|
||||||
|
b.msgs = append(b.msgs, nm)
|
||||||
|
b.tokens += msgsToken[j]
|
||||||
|
} else {
|
||||||
|
if _, ok := callIDs[nm.ToolCallID]; !ok {
|
||||||
|
// Tool message not belonging to this assistant call -> end pairing
|
||||||
|
break
|
||||||
|
}
|
||||||
|
b.msgs = append(b.msgs, nm)
|
||||||
|
b.tokens += msgsToken[j]
|
||||||
|
}
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
toolBlocks = append(toolBlocks, b)
|
||||||
|
i = j - 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolBlocks = append(toolBlocks, block{msgs: []*schema.Message{m}, tokens: msgsToken[i]})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split into recent and older within token budget, from newest to oldest
|
||||||
|
var recentBlocks []block
|
||||||
|
var olderBlocks []block
|
||||||
|
var recentTokens int64
|
||||||
|
for i := len(toolBlocks) - 1; i >= 0; i-- {
|
||||||
|
b := toolBlocks[i]
|
||||||
|
if recentTokens+b.tokens > int64(s.maxRecent) {
|
||||||
|
olderBlocks = append([]block{b}, olderBlocks...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
recentBlocks = append([]block{b}, recentBlocks...)
|
||||||
|
recentTokens += b.tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
joinBlocks := func(bs []block) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, b := range bs {
|
||||||
|
for _, m := range b.msgs {
|
||||||
|
sb.WriteString(renderMsg(m))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
olderText := joinBlocks(olderBlocks)
|
||||||
|
recentText := joinBlocks(recentBlocks)
|
||||||
|
|
||||||
|
msg, err := s.summarizer.Invoke(ctx, map[string]any{
|
||||||
|
"system_prompt": joinBlocks([]block{systemBlock}),
|
||||||
|
"user_messages": joinBlocks([]block{userBlock}),
|
||||||
|
"previous_summary": joinBlocks([]block{summaryBlock}),
|
||||||
|
"older_messages": olderText,
|
||||||
|
"recent_messages": recentText,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("summarize failed, err=%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
summaryMsg := schema.AssistantMessage(msg.Content, nil)
|
||||||
|
msg.Name = "summary"
|
||||||
|
summaryMsg.Extra = map[string]any{
|
||||||
|
summaryMessageFlag: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build new state: prepend summary message, keep recent messages
|
||||||
|
newMessages := make([]*schema.Message, 0, len(messages))
|
||||||
|
newMessages = append(newMessages, systemBlock.msgs...)
|
||||||
|
newMessages = append(newMessages, userBlock.msgs...)
|
||||||
|
newMessages = append(newMessages, summaryMsg)
|
||||||
|
for _, b := range recentBlocks {
|
||||||
|
newMessages = append(newMessages, b.msgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Messages = newMessages
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render messages into strings
|
||||||
|
func renderMsg(m *schema.Message) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var sb strings.Builder
|
||||||
|
if m.Role == schema.Tool {
|
||||||
|
if m.ToolName != "" {
|
||||||
|
sb.WriteString("[tool:")
|
||||||
|
sb.WriteString(m.ToolName)
|
||||||
|
sb.WriteString("]\n")
|
||||||
|
} else {
|
||||||
|
sb.WriteString("[tool]\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sb.WriteString("[")
|
||||||
|
sb.WriteString(string(m.Role))
|
||||||
|
sb.WriteString("]\n")
|
||||||
|
}
|
||||||
|
if m.Content != "" {
|
||||||
|
sb.WriteString(m.Content)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
if m.Role == schema.Assistant && len(m.ToolCalls) > 0 {
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
if tc.Function.Name != "" {
|
||||||
|
sb.WriteString("tool_call: ")
|
||||||
|
sb.WriteString(tc.Function.Name)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
if tc.Function.Arguments != "" {
|
||||||
|
sb.WriteString("args: ")
|
||||||
|
sb.WriteString(tc.Function.Arguments)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, part := range m.UserInputMultiContent {
|
||||||
|
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||||
|
sb.WriteString(part.Text)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, part := range m.AssistantGenMultiContent {
|
||||||
|
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||||
|
sb.WriteString(part.Text)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultCounterToken(ctx context.Context, msgs []adk.Message) (tokenNum []int64, err error) {
|
||||||
|
encoding := "cl100k_base"
|
||||||
|
tkt, err := tiktoken.GetEncoding(encoding)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get encoding failed, encoding=%v, err=%w", encoding, err)
|
||||||
|
}
|
||||||
|
tokenNum = make([]int64, len(msgs))
|
||||||
|
|
||||||
|
for i, m := range msgs {
|
||||||
|
if m == nil {
|
||||||
|
tokenNum[i] = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// Message role contributes to chat tokenization overhead; include it as text.
|
||||||
|
if m.Role != "" {
|
||||||
|
sb.WriteString(string(m.Role))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Core text content
|
||||||
|
if m.Content != "" {
|
||||||
|
sb.WriteString(m.Content)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning content if present
|
||||||
|
// Reasoning Content is not used by model
|
||||||
|
// if m.ReasoningContent != "" {
|
||||||
|
// sb.WriteString(m.ReasoningContent)
|
||||||
|
// sb.WriteString("\n")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Multi modal input/output text parts
|
||||||
|
for _, part := range m.UserInputMultiContent {
|
||||||
|
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||||
|
sb.WriteString(part.Text)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, part := range m.AssistantGenMultiContent {
|
||||||
|
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||||
|
sb.WriteString(part.Text)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool call textual context (name + arguments)
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
if tc.Function.Name != "" {
|
||||||
|
sb.WriteString(tc.Function.Name)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
if tc.Function.Arguments != "" {
|
||||||
|
sb.WriteString(tc.Function.Arguments)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
text := sb.String()
|
||||||
|
if text == "" {
|
||||||
|
tokenNum[i] = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := tkt.Encode(text, nil, nil)
|
||||||
|
tokenNum[i] = int64(len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenNum, nil
|
||||||
|
}
|
||||||
@ -0,0 +1,77 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/components/tool"
|
||||||
|
"github.com/cloudwego/eino/components/tool/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RepeatSectionsInput struct {
|
||||||
|
Title string `json:"title" jsonschema_description:"Section title"`
|
||||||
|
Paragraphs []string `json:"paragraphs" jsonschema_description:"Paragraphs to be repeated"`
|
||||||
|
RepeatCount int `json:"repeat_count" jsonschema_description:"Times to repeat paragraphs, default 2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRepeatSectionsTool() tool.InvokableTool {
|
||||||
|
t, err := utils.InferTool(
|
||||||
|
"repeat_sections",
|
||||||
|
"Repeat given paragraphs to quickly accumulate context",
|
||||||
|
func(ctx context.Context, in *RepeatSectionsInput) (string, error) {
|
||||||
|
if in.RepeatCount <= 0 {
|
||||||
|
in.RepeatCount = 2
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## ")
|
||||||
|
b.WriteString(in.Title)
|
||||||
|
b.WriteString("\n")
|
||||||
|
idx := 0
|
||||||
|
for i := 0; i < in.RepeatCount; i++ {
|
||||||
|
for _, sec := range in.Paragraphs {
|
||||||
|
b.WriteString(strconv.Itoa(idx + 1))
|
||||||
|
b.WriteString(". ")
|
||||||
|
b.WriteString(sec)
|
||||||
|
b.WriteString("\n")
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
callCount := 0
|
||||||
|
times, ok := adk.GetSessionValue(ctx, "_tool_call_count")
|
||||||
|
if !ok {
|
||||||
|
times = 0
|
||||||
|
} else {
|
||||||
|
callCount = times.(int)
|
||||||
|
}
|
||||||
|
callCount++
|
||||||
|
adk.AddSessionValue(ctx, "_tool_call_count", callCount)
|
||||||
|
|
||||||
|
b.WriteString(fmt.Sprintf("Tool calls so far: %d", callCount))
|
||||||
|
return b.String(), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create repeat_sections tool: %v", err)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue