You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
348 lines
8.7 KiB
Go
348 lines
8.7 KiB
Go
/*
|
|
* Copyright 2025 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package graphtool
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/cloudwego/eino/components/tool"
|
|
"github.com/cloudwego/eino/components/tool/utils"
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
type Compilable[I, O any] interface {
|
|
Compile(ctx context.Context, opts ...compose.GraphCompileOption) (compose.Runnable[I, O], error)
|
|
}
|
|
|
|
type InvokableGraphTool[I, O any] struct {
|
|
compilable Compilable[I, O]
|
|
compileOptions []compose.GraphCompileOption
|
|
tInfo *schema.ToolInfo
|
|
}
|
|
|
|
func NewInvokableGraphTool[I, O any](compilable Compilable[I, O],
|
|
name, desc string,
|
|
opts ...compose.GraphCompileOption,
|
|
) (*InvokableGraphTool[I, O], error) {
|
|
tInfo, err := utils.GoStruct2ToolInfo[I](name, desc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &InvokableGraphTool[I, O]{
|
|
compilable: compilable,
|
|
compileOptions: opts,
|
|
tInfo: tInfo,
|
|
}, nil
|
|
}
|
|
|
|
type graphToolOptions struct {
|
|
composeOpts []compose.Option
|
|
}
|
|
|
|
func WithGraphToolOption(opts ...compose.Option) tool.Option {
|
|
return tool.WrapImplSpecificOptFn(func(opt *graphToolOptions) {
|
|
opt.composeOpts = opts
|
|
})
|
|
}
|
|
|
|
type graphToolInterruptState struct {
|
|
Data []byte
|
|
ToolInput string
|
|
}
|
|
|
|
func init() {
|
|
schema.RegisterName[*graphToolInterruptState]("_eino_graph_tool_interrupt_state")
|
|
}
|
|
|
|
func (g *InvokableGraphTool[I, O]) InvokableRun(ctx context.Context, input string,
|
|
opts ...tool.Option) (output string, err error) {
|
|
var (
|
|
checkpointStore *graphToolStore
|
|
inputParams I
|
|
originOutput O
|
|
runnable compose.Runnable[I, O]
|
|
)
|
|
|
|
callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts
|
|
callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID))
|
|
|
|
wasInterrupted, hasState, state := tool.GetInterruptState[*graphToolInterruptState](ctx)
|
|
if wasInterrupted && hasState {
|
|
input = state.ToolInput
|
|
|
|
checkpointStore = newResumeStore(state.Data)
|
|
compileOptions := make([]compose.GraphCompileOption, len(g.compileOptions)+1)
|
|
copy(compileOptions, g.compileOptions)
|
|
compileOptions[len(g.compileOptions)] = compose.WithCheckPointStore(checkpointStore)
|
|
|
|
if runnable, err = g.compilable.Compile(ctx, compileOptions...); err != nil {
|
|
return "", err
|
|
}
|
|
} else {
|
|
checkpointStore = newEmptyStore()
|
|
|
|
compileOptions := make([]compose.GraphCompileOption, len(g.compileOptions)+1)
|
|
copy(compileOptions, g.compileOptions)
|
|
compileOptions[len(g.compileOptions)] = compose.WithCheckPointStore(checkpointStore)
|
|
|
|
if runnable, err = g.compilable.Compile(ctx, compileOptions...); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
inputParams = NewInstance[I]()
|
|
if err = sonic.UnmarshalString(input, &inputParams); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
originOutput, err = runnable.Invoke(ctx, inputParams, callOpts...)
|
|
if err != nil {
|
|
_, ok := compose.ExtractInterruptInfo(err)
|
|
if !ok {
|
|
return "", err
|
|
}
|
|
interruptErr := err
|
|
data, existed, getErr := checkpointStore.Get(ctx, graphToolCheckPointID)
|
|
if getErr != nil {
|
|
return "", getErr
|
|
}
|
|
if !existed {
|
|
return "", fmt.Errorf("interrupt has happened, but checkpoint not exist in store")
|
|
}
|
|
|
|
return "", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{
|
|
Data: data,
|
|
ToolInput: input,
|
|
}, interruptErr)
|
|
}
|
|
|
|
return sonic.MarshalString(originOutput)
|
|
}
|
|
|
|
func (g *InvokableGraphTool[I, O]) Info(_ context.Context) (*schema.ToolInfo, error) {
|
|
return g.tInfo, nil
|
|
}
|
|
|
|
type StreamableGraphTool[I, O any] struct {
|
|
compilable Compilable[I, O]
|
|
compileOptions []compose.GraphCompileOption
|
|
tInfo *schema.ToolInfo
|
|
}
|
|
|
|
func NewStreamableGraphTool[I, O any](compilable Compilable[I, O],
|
|
name, desc string,
|
|
opts ...compose.GraphCompileOption,
|
|
) (*StreamableGraphTool[I, O], error) {
|
|
tInfo, err := utils.GoStruct2ToolInfo[I](name, desc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &StreamableGraphTool[I, O]{
|
|
compilable: compilable,
|
|
compileOptions: opts,
|
|
tInfo: tInfo,
|
|
}, nil
|
|
}
|
|
|
|
func (g *StreamableGraphTool[I, O]) Info(_ context.Context) (*schema.ToolInfo, error) {
|
|
return g.tInfo, nil
|
|
}
|
|
|
|
func (g *StreamableGraphTool[I, O]) StreamableRun(ctx context.Context, input string,
|
|
opts ...tool.Option) (*schema.StreamReader[string], error) {
|
|
var (
|
|
checkpointStore *graphToolStore
|
|
inputParams I
|
|
runnable compose.Runnable[I, O]
|
|
err error
|
|
)
|
|
|
|
callOpts := tool.GetImplSpecificOptions(&graphToolOptions{}, opts...).composeOpts
|
|
callOpts = append(callOpts, compose.WithCheckPointID(graphToolCheckPointID))
|
|
|
|
wasInterrupted, hasState, state := tool.GetInterruptState[*graphToolInterruptState](ctx)
|
|
if wasInterrupted && hasState {
|
|
input = state.ToolInput
|
|
|
|
checkpointStore = newResumeStore(state.Data)
|
|
compileOptions := make([]compose.GraphCompileOption, len(g.compileOptions)+1)
|
|
copy(compileOptions, g.compileOptions)
|
|
compileOptions[len(g.compileOptions)] = compose.WithCheckPointStore(checkpointStore)
|
|
|
|
if runnable, err = g.compilable.Compile(ctx, compileOptions...); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
checkpointStore = newEmptyStore()
|
|
|
|
compileOptions := make([]compose.GraphCompileOption, len(g.compileOptions)+1)
|
|
copy(compileOptions, g.compileOptions)
|
|
compileOptions[len(g.compileOptions)] = compose.WithCheckPointStore(checkpointStore)
|
|
|
|
if runnable, err = g.compilable.Compile(ctx, compileOptions...); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
inputParams = NewInstance[I]()
|
|
if err = sonic.UnmarshalString(input, &inputParams); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sr, sw := schema.Pipe[string](1)
|
|
|
|
go func() {
|
|
defer sw.Close()
|
|
|
|
outputStream, err := runnable.Stream(ctx, inputParams, callOpts...)
|
|
if err != nil {
|
|
_, ok := compose.ExtractInterruptInfo(err)
|
|
if !ok {
|
|
sw.Send("", err)
|
|
return
|
|
}
|
|
interruptErr := err
|
|
data, existed, getErr := checkpointStore.Get(ctx, graphToolCheckPointID)
|
|
if getErr != nil {
|
|
sw.Send("", getErr)
|
|
return
|
|
}
|
|
if !existed {
|
|
sw.Send("", fmt.Errorf("interrupt has happened, but checkpoint not exist in store"))
|
|
return
|
|
}
|
|
|
|
sw.Send("", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{
|
|
Data: data,
|
|
ToolInput: input,
|
|
}, interruptErr))
|
|
return
|
|
}
|
|
|
|
defer outputStream.Close()
|
|
|
|
for {
|
|
chunk, err := outputStream.Recv()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
_, ok := compose.ExtractInterruptInfo(err)
|
|
if !ok {
|
|
sw.Send("", err)
|
|
return
|
|
}
|
|
interruptErr := err
|
|
data, existed, getErr := checkpointStore.Get(ctx, graphToolCheckPointID)
|
|
if getErr != nil {
|
|
sw.Send("", getErr)
|
|
return
|
|
}
|
|
if !existed {
|
|
sw.Send("", fmt.Errorf("interrupt has happened, but checkpoint not exist in store"))
|
|
return
|
|
}
|
|
|
|
sw.Send("", tool.CompositeInterrupt(ctx, "graph tool interrupt", &graphToolInterruptState{
|
|
Data: data,
|
|
ToolInput: input,
|
|
}, interruptErr))
|
|
return
|
|
}
|
|
|
|
chunkStr, err := sonic.MarshalString(chunk)
|
|
if err != nil {
|
|
sw.Send("", err)
|
|
return
|
|
}
|
|
if closed := sw.Send(chunkStr, nil); closed {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return sr, nil
|
|
}
|
|
|
|
const graphToolCheckPointID = "graph_tool_checkpoint_id"
|
|
|
|
func newEmptyStore() *graphToolStore {
|
|
return &graphToolStore{}
|
|
}
|
|
|
|
func newResumeStore(data []byte) *graphToolStore {
|
|
return &graphToolStore{
|
|
Data: data,
|
|
Valid: true,
|
|
}
|
|
}
|
|
|
|
type graphToolStore struct {
|
|
Data []byte
|
|
Valid bool
|
|
}
|
|
|
|
func (m *graphToolStore) Get(_ context.Context, _ string) ([]byte, bool, error) {
|
|
if m.Valid {
|
|
return m.Data, true, nil
|
|
}
|
|
return nil, false, nil
|
|
}
|
|
|
|
func (m *graphToolStore) Set(_ context.Context, _ string, checkPoint []byte) error {
|
|
m.Data = checkPoint
|
|
m.Valid = true
|
|
return nil
|
|
}
|
|
|
|
func NewInstance[T any]() T {
|
|
typ := TypeOf[T]()
|
|
|
|
switch typ.Kind() {
|
|
case reflect.Map:
|
|
return reflect.MakeMap(typ).Interface().(T)
|
|
case reflect.Slice, reflect.Array:
|
|
return reflect.MakeSlice(typ, 0, 0).Interface().(T)
|
|
case reflect.Ptr:
|
|
typ = typ.Elem()
|
|
origin := reflect.New(typ)
|
|
inst := origin
|
|
|
|
for typ.Kind() == reflect.Ptr {
|
|
typ = typ.Elem()
|
|
inst = inst.Elem()
|
|
inst.Set(reflect.New(typ))
|
|
}
|
|
|
|
return origin.Interface().(T)
|
|
default:
|
|
var t T
|
|
return t
|
|
}
|
|
}
|
|
|
|
func TypeOf[T any]() reflect.Type {
|
|
return reflect.TypeOf((*T)(nil)).Elem()
|
|
}
|