feat(tool): json fix tool middleware
Change-Id: I00d90a0d23b68faccad77337aa1b511e22862b28drew/english
parent
589b192304
commit
a18a2c0538
@ -0,0 +1,75 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// This example shows how to configure the jsonfix middleware on a ToolsNode
|
||||
// to repair invalid JSON arguments before invoking a local tool.
|
||||
// Run: go run ./components/tool/middlewares/jsonfix/example
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/components/tool/utils"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
jsonfix "github.com/cloudwego/eino-examples/components/tool/middlewares/jsonfix"
|
||||
)
|
||||
|
||||
type greetReq struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type greetResp struct {
|
||||
Greeting string `json:"greeting"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Define a simple local tool with JSON input/output using InferTool.
|
||||
greeter, _ := utils.InferTool("greeter", "greet by name", func(ctx context.Context, in *greetReq) (*greetResp, error) {
|
||||
return &greetResp{Greeting: "Hello, " + in.Name}, nil
|
||||
})
|
||||
|
||||
// Create ToolsNode and register the jsonfix middleware.
|
||||
tn, _ := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{
|
||||
Tools: []tool.BaseTool{greeter},
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{jsonfix.Middleware()},
|
||||
})
|
||||
|
||||
// Craft an Assistant message with an invalid JSON argument to simulate LLM output.
|
||||
msg := schema.AssistantMessage("", nil)
|
||||
msg.ToolCalls = []schema.ToolCall{{
|
||||
ID: "1",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "greeter",
|
||||
Arguments: "noise <|FunctionCallBegin|>{\"name\":\"Alice\"1\"\"}<|FunctionCallEnd|>",
|
||||
},
|
||||
}}
|
||||
|
||||
// ToolsNode invokes the tool. Middleware repairs the argument first.
|
||||
outs, err := tn.Invoke(ctx, msg)
|
||||
if err != nil {
|
||||
fmt.Println("error:", err)
|
||||
return
|
||||
}
|
||||
for _, o := range outs {
|
||||
fmt.Println("tool:", o.ToolName, "id:", o.ToolCallID, "content:", o.Content)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,113 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Package jsonfix provides a ToolMiddleware for Eino's ToolsNode that
|
||||
// repairs malformed JSON arguments produced by LLMs before tool execution.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// conf := &compose.ToolsNodeConfig{
|
||||
// Tools: []tool.BaseTool{yourTool},
|
||||
// ToolCallMiddlewares: []compose.ToolMiddleware{jsonfix.Middleware()},
|
||||
// }
|
||||
//
|
||||
// Behavior:
|
||||
// - Fast-path returns original arguments when already valid JSON.
|
||||
// - Strips common LLM artifacts and isolates the first {...} region.
|
||||
// - Applies robust fix using jsonrepair only when input is invalid.
|
||||
// - Safe for both invokable and streamable tools.
|
||||
package jsonfix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/kaptinlin/jsonrepair"
|
||||
)
|
||||
|
||||
// FixJSON is a helper endpoint that returns repaired JSON for diagnostics
|
||||
// or standalone use. ToolsNode should prefer Invokable/Streamable middleware.
|
||||
func FixJSON(ctx context.Context, in *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
fixed := repair(in.Arguments)
|
||||
return &compose.ToolOutput{Result: fixed}, nil
|
||||
}
|
||||
|
||||
// Invokable wraps a non-stream tool endpoint to sanitize JSON arguments.
|
||||
// Register via ToolCallMiddlewares to apply automatically to invokable tools.
|
||||
func Invokable(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
|
||||
return func(ctx context.Context, in *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
in.Arguments = repair(in.Arguments)
|
||||
return next(ctx, in)
|
||||
}
|
||||
}
|
||||
|
||||
// Streamable wraps a stream tool endpoint to sanitize JSON arguments.
|
||||
// Register via ToolCallMiddlewares to apply automatically to streamable tools.
|
||||
func Streamable(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint {
|
||||
return func(ctx context.Context, in *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
in.Arguments = repair(in.Arguments)
|
||||
return next(ctx, in)
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware bundles both invokable and streamable wrappers for convenience.
|
||||
func Middleware() compose.ToolMiddleware {
|
||||
return compose.ToolMiddleware{Invokable: Invokable, Streamable: Streamable}
|
||||
}
|
||||
|
||||
// repair attempts minimal work first (validity check, region isolation) and
|
||||
// only uses jsonrepair when necessary. It trims common LLM artifacts.
|
||||
func repair(input string) string {
|
||||
s := strings.TrimSpace(input)
|
||||
// Fast-path: valid JSON as-is
|
||||
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") && json.Valid([]byte(s)) {
|
||||
return s
|
||||
}
|
||||
|
||||
// Isolate JSON object region if present; strip noise if object-only is valid
|
||||
i := strings.IndexByte(s, '{')
|
||||
j := strings.LastIndexByte(s, '}')
|
||||
if i >= 0 && j >= i {
|
||||
sub := s[i : j+1]
|
||||
if json.Valid([]byte(sub)) {
|
||||
return sub
|
||||
}
|
||||
s = sub
|
||||
}
|
||||
|
||||
// Remove common LLM artifacts
|
||||
s = strings.TrimPrefix(s, "<|FunctionCallBegin|>")
|
||||
s = strings.TrimSuffix(s, "<|FunctionCallEnd|>")
|
||||
s = strings.TrimPrefix(s, "<think>")
|
||||
|
||||
// Attempt robust repair only when invalid
|
||||
if json.Valid([]byte(s)) {
|
||||
return s
|
||||
}
|
||||
// Heuristic: add missing leading/trailing brace if one side exists
|
||||
if !strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
|
||||
s = "{" + s
|
||||
} else if strings.HasPrefix(s, "{") && !strings.HasSuffix(s, "}") {
|
||||
s = s + "}"
|
||||
}
|
||||
out, err := jsonrepair.JSONRepair(s)
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -0,0 +1,185 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package jsonfix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func echoEndpoint(_ context.Context, in *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
return &compose.ToolOutput{Result: in.Arguments}, nil
|
||||
}
|
||||
|
||||
func TestInvokableMiddleware_RepairsJSON(t *testing.T) {
|
||||
mw := Invokable
|
||||
chained := mw(echoEndpoint)
|
||||
|
||||
input := &compose.ToolInput{
|
||||
Name: "test_tool",
|
||||
Arguments: "noise <|FunctionCallBegin|>{\"a\":1}<|FunctionCallEnd|> more",
|
||||
CallID: "id1",
|
||||
}
|
||||
|
||||
out, err := chained(context.Background(), input)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
var m map[string]int
|
||||
if e := json.Unmarshal([]byte(out.Result), &m); e != nil || m["a"] != 1 {
|
||||
t.Fatalf("repair failed: %v %v", out.Result, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamableMiddleware_RepairsJSON(t *testing.T) {
|
||||
mw := Streamable
|
||||
var captured string
|
||||
next := func(_ context.Context, in *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
captured = in.Arguments
|
||||
return &compose.StreamToolOutput{Result: schema.StreamReaderFromArray([]string{"ok"})}, nil
|
||||
}
|
||||
chained := mw(next)
|
||||
_, err := chained(context.Background(), &compose.ToolInput{Name: "t", Arguments: "{\"text\": \"He said \"hello\" to me\"}"})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
var m map[string]string
|
||||
if e := json.Unmarshal([]byte(captured), &m); e != nil || m["text"] != "He said \"hello\" to me" {
|
||||
t.Fatalf("repair failed: %v %v", captured, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvokableMiddleware_NoChangeForValidJSON(t *testing.T) {
|
||||
mw := Invokable
|
||||
chained := mw(echoEndpoint)
|
||||
|
||||
original := "{\"a\":1}"
|
||||
out, err := chained(context.Background(), &compose.ToolInput{Name: "t", Arguments: original})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if out.Result != original {
|
||||
t.Fatalf("should not change valid json: got %s", out.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvokableMiddleware_MissingBracesAndUnicode(t *testing.T) {
|
||||
mw := Invokable
|
||||
chained := mw(echoEndpoint)
|
||||
|
||||
inputs := []string{
|
||||
"{\"key\":\"value\"", // missing tail
|
||||
"\"key\":\"value\"}", // missing head
|
||||
"{\"emoji\": \"😀😎\"}",
|
||||
"{\"text\": \"line1\nline2\tTabbed\"}",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
out, err := chained(context.Background(), &compose.ToolInput{Name: "t", Arguments: in})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
var any map[string]string
|
||||
if e := json.Unmarshal([]byte(out.Result), &any); e != nil {
|
||||
t.Fatalf("unmarshal failed: %v for %v", e, out.Result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type spyInvokable struct{}
|
||||
|
||||
func (s *spyInvokable) Info(_ context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{Name: "spy", Desc: ""}, nil
|
||||
}
|
||||
func (s *spyInvokable) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) {
|
||||
return argumentsInJSON, nil
|
||||
}
|
||||
|
||||
func TestToolsNodeMiddleware_RepairsInvokable(t *testing.T) {
|
||||
tn, err := compose.NewToolNode(context.Background(), &compose.ToolsNodeConfig{
|
||||
Tools: []tool.BaseTool{&spyInvokable{}},
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{Middleware()},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("new tools node err: %v", err)
|
||||
}
|
||||
msg := schema.AssistantMessage("", nil)
|
||||
msg.ToolCalls = []schema.ToolCall{{
|
||||
ID: "1",
|
||||
Function: schema.FunctionCall{Name: "spy", Arguments: "garbage {\"x\": \"a \"quote\" b\"} tail"},
|
||||
}}
|
||||
outs, err := tn.Invoke(context.Background(), msg)
|
||||
if err != nil {
|
||||
t.Fatalf("invoke err: %v", err)
|
||||
}
|
||||
if len(outs) != 1 {
|
||||
t.Fatalf("unexpected outs: %d", len(outs))
|
||||
}
|
||||
var m map[string]string
|
||||
if e := json.Unmarshal([]byte(outs[0].Content), &m); e != nil || m["x"] != "a \"quote\" b" {
|
||||
t.Fatalf("repair failed: %v %v", outs[0].Content, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvokableMiddleware_UnescapedBackslashesAndSingleQuotesAndTrailingComma(t *testing.T) {
|
||||
mw := Invokable
|
||||
chained := mw(echoEndpoint)
|
||||
|
||||
cases := []string{
|
||||
// unescaped backslashes in Windows path
|
||||
"{\"path\": \"C:\\Users\\name\\file.txt\"}",
|
||||
// single-quoted JSON
|
||||
"{'a': 'b'}",
|
||||
// trailing comma
|
||||
"{\"a\": 1, \"b\": 2, }",
|
||||
}
|
||||
|
||||
for _, in := range cases {
|
||||
out, err := chained(context.Background(), &compose.ToolInput{Name: "t", Arguments: in})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
var any map[string]any
|
||||
if e := json.Unmarshal([]byte(out.Result), &any); e != nil {
|
||||
t.Fatalf("unmarshal failed: %v for %v", e, out.Result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvokableMiddleware_RawQuotesInsideStringValue(t *testing.T) {
|
||||
mw := Invokable
|
||||
chained := mw(echoEndpoint)
|
||||
|
||||
// invalid JSON: raw quotes inside value
|
||||
in := "{\"a\":\"b\"1\"\"}"
|
||||
out, err := chained(context.Background(), &compose.ToolInput{Name: "t", Arguments: in})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
var m map[string]string
|
||||
if e := json.Unmarshal([]byte(out.Result), &m); e != nil {
|
||||
t.Fatalf("unmarshal failed: %v for %v", e, out.Result)
|
||||
}
|
||||
v := m["a"]
|
||||
if v == "" || v[0] != 'b' {
|
||||
t.Fatalf("unexpected repaired value: %q", v)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue