feat(react): unknown tool handler example
Change-Id: I4e17983373ec7a07f43601aaeccfafcc8e3d739adrew/english
parent
77bd95ba71
commit
d099dce571
@ -0,0 +1,20 @@
|
|||||||
|
# Unknown Tools Handler for ReAct Agent
|
||||||
|
|
||||||
|
- Demonstrates `UnknownToolsHandler` in `compose.ToolsNodeConfig` when the model emits an unknown tool call.
|
||||||
|
- Mock ChatModel produces three turns: unknown tool call → correct tool call → final answer.
|
||||||
|
- Builds on the ReAct agent from the flow package.
|
||||||
|
|
||||||
|
## Rationale
|
||||||
|
- ReAct agents often rely on the model to select tool names from a provided list. In practice, models may hallucinate a tool name not registered with the `ToolsNode`.
|
||||||
|
- Instead of aborting the agent on such an error, the `UnknownToolsHandler` produces a clear, structured message that is fed back to the ChatModel as the tool result.
|
||||||
|
- This feedback informs the model that the tool name is invalid and encourages it to pick a valid tool in the next turn, improving robustness and convergence.
|
||||||
|
- The example shows: first turn emits an unknown tool call; the handler returns guidance; the second turn uses the correct tool; the final turn produces the answer.
|
||||||
|
|
||||||
|
## Run
|
||||||
|
- `cd flow/agent/react/unknown_tool_handler_example`
|
||||||
|
- `go run main.go`
|
||||||
|
|
||||||
|
## Expected
|
||||||
|
- Prints a handler message for the unknown tool name.
|
||||||
|
- Executes the `sum` tool on the second turn and returns `{"sum":3}`.
|
||||||
|
- Outputs the final assistant answer `3`.
|
||||||
@ -0,0 +1,106 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/callbacks"
|
||||||
|
"github.com/cloudwego/eino/components/model"
|
||||||
|
"github.com/cloudwego/eino/components/tool"
|
||||||
|
"github.com/cloudwego/eino/compose"
|
||||||
|
"github.com/cloudwego/eino/flow/agent"
|
||||||
|
"github.com/cloudwego/eino/flow/agent/react"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
|
||||||
|
extools "github.com/cloudwego/eino-examples/flow/agent/react/unknown_tool_handler_example/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
unknown := func(ctx context.Context, name, input string) (string, error) {
|
||||||
|
return fmt.Sprintf("unknown tool: %s; you made it up, try again with the correct tool name", name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rAgent, err := react.NewAgent(ctx, &react.AgentConfig{
|
||||||
|
ToolCallingModel: &mockToolCallingModel{},
|
||||||
|
ToolsConfig: compose.ToolsNodeConfig{
|
||||||
|
Tools: []tool.BaseTool{extools.SumToolFn()},
|
||||||
|
UnknownToolsHandler: unknown,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := rAgent.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "Add 1 and 2"}}, agent.WithComposeOptions(compose.WithCallbacks(&simpleLogger{})))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Println(msg.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockToolCallingModel struct{ step int }
|
||||||
|
|
||||||
|
func (m *mockToolCallingModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
||||||
|
switch m.step {
|
||||||
|
case 0:
|
||||||
|
m.step++
|
||||||
|
return &schema.Message{Role: schema.Assistant, ToolCalls: []schema.ToolCall{{ID: "1", Function: schema.FunctionCall{Name: "sumx", Arguments: "{\"a\":1,\"b\":2}"}}}}, nil
|
||||||
|
case 1:
|
||||||
|
m.step++
|
||||||
|
return &schema.Message{Role: schema.Assistant, ToolCalls: []schema.ToolCall{{ID: "2", Function: schema.FunctionCall{Name: "sum", Arguments: "{\"a\":1,\"b\":2}"}}}}, nil
|
||||||
|
default:
|
||||||
|
return &schema.Message{Role: schema.Assistant, Content: "3"}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockToolCallingModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
||||||
|
return nil, fmt.Errorf("not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockToolCallingModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type simpleLogger struct{ callbacks.HandlerBuilder }
|
||||||
|
|
||||||
|
func (l *simpleLogger) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *simpleLogger) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||||
|
fmt.Println(output)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *simpleLogger) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||||
|
output.Close()
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *simpleLogger) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
||||||
|
input.Close()
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *simpleLogger) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||||
|
fmt.Println(err)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
@ -0,0 +1,56 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/components/tool"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SumTool struct{}
|
||||||
|
|
||||||
|
func SumToolFn() tool.InvokableTool { return &SumTool{} }
|
||||||
|
|
||||||
|
func (t *SumTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||||
|
return &schema.ToolInfo{
|
||||||
|
Name: "sum",
|
||||||
|
Desc: "Add two integers",
|
||||||
|
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||||
|
"a": {Type: "number", Desc: "first operand", Required: true},
|
||||||
|
"b": {Type: "number", Desc: "second operand", Required: true},
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *SumTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||||
|
var p struct {
|
||||||
|
A int `json:"a"`
|
||||||
|
B int `json:"b"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(argumentsInJSON), &p); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
res := map[string]int{"sum": p.A + p.B}
|
||||||
|
b, err := json.Marshal(res)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue