You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
4.8 KiB
Go
132 lines
4.8 KiB
Go
/*
|
|
* Copyright 2025 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package abtest
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"github.com/cloudwego/eino/callbacks"
|
|
"github.com/cloudwego/eino/components"
|
|
"github.com/cloudwego/eino/components/model"
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
type ModelRouter func(ctx context.Context, input []*schema.Message, opts ...model.Option) (string, model.BaseChatModel, error)
|
|
|
|
// ABRouterChatModel is a dynamic router over chat models that implements ToolCallingChatModel.
|
|
//
|
|
// Behavior:
|
|
// - Routing: delegates the choice to a user-provided ModelRouter which returns (modelName, BaseChatModel).
|
|
// - RunInfo naming: uses the returned modelName when calling callbacks.EnsureRunInfo so callbacks can log the chosen model.
|
|
// - Tools: stores tool infos via WithTools and applies them lazily if the chosen model supports ToolCallingChatModel.
|
|
// - Callbacks: if the chosen model exposes components.Checker and IsCallbacksEnabled()==true, delegates directly;
|
|
// otherwise injects OnStart/OnEnd/OnError around Generate/Stream.
|
|
// - IsCallbacksEnabled: returns true to indicate this wrapper already coordinates callback triggering.
|
|
//
|
|
// Typical usage:
|
|
//
|
|
// router := NewABRouterChatModel(func(ctx, msgs, opts...) (string, model.BaseChatModel, error) {
|
|
// return "openai", openaiModel, nil
|
|
// })
|
|
// router = router.WithTools(toolInfos) // optional
|
|
// msg, _ := router.Generate(ctx, input)
|
|
type ABRouterChatModel struct {
|
|
router ModelRouter
|
|
tools []*schema.ToolInfo
|
|
}
|
|
|
|
func NewABRouterChatModel(router ModelRouter) *ABRouterChatModel {
|
|
return &ABRouterChatModel{router: router}
|
|
}
|
|
|
|
func (a *ABRouterChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
|
return &ABRouterChatModel{router: a.router, tools: tools}, nil
|
|
}
|
|
|
|
func (a *ABRouterChatModel) pickModel(ctx context.Context, input []*schema.Message, opts ...model.Option) (string, model.BaseChatModel, error) {
|
|
if a.router == nil {
|
|
return "", nil, errors.New("no router")
|
|
}
|
|
name, base, err := a.router(ctx, input, opts...)
|
|
if err != nil || base == nil {
|
|
return "", nil, err
|
|
}
|
|
if tcm, ok := base.(model.ToolCallingChatModel); ok && len(a.tools) > 0 {
|
|
nTcm, wErr := tcm.WithTools(a.tools)
|
|
if wErr != nil {
|
|
return "", nil, wErr
|
|
}
|
|
base = nTcm
|
|
}
|
|
return name, base, nil
|
|
}
|
|
|
|
func (a *ABRouterChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
|
name, base, err := a.pickModel(ctx, input, opts...)
|
|
if err != nil || base == nil {
|
|
if err == nil {
|
|
err = errors.New("router returned nil model")
|
|
}
|
|
callbacks.OnError(ctx, err)
|
|
return nil, err
|
|
}
|
|
ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Name: name, Component: components.ComponentOfChatModel})
|
|
if ch, ok := base.(components.Checker); ok && ch.IsCallbacksEnabled() {
|
|
return base.Generate(ctx, input, opts...)
|
|
}
|
|
nCtx := callbacks.OnStart(ctx, &model.CallbackInput{Messages: input})
|
|
out, err := base.Generate(nCtx, input, opts...)
|
|
if err != nil {
|
|
callbacks.OnError(nCtx, err)
|
|
return nil, err
|
|
}
|
|
callbacks.OnEnd(nCtx, &model.CallbackOutput{Message: out})
|
|
return out, nil
|
|
}
|
|
|
|
func (a *ABRouterChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
|
name, base, err := a.pickModel(ctx, input, opts...)
|
|
if err != nil || base == nil {
|
|
if err == nil {
|
|
err = errors.New("router returned nil model")
|
|
}
|
|
callbacks.OnError(ctx, err)
|
|
return nil, err
|
|
}
|
|
ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Name: name, Component: components.ComponentOfChatModel})
|
|
if ch, ok := base.(components.Checker); ok && ch.IsCallbacksEnabled() {
|
|
return base.Stream(ctx, input, opts...)
|
|
}
|
|
nCtx := callbacks.OnStart(ctx, &model.CallbackInput{Messages: input})
|
|
sr, err := base.Stream(nCtx, input, opts...)
|
|
if err != nil {
|
|
callbacks.OnError(nCtx, err)
|
|
return nil, err
|
|
}
|
|
out := schema.StreamReaderWithConvert(sr, func(m *schema.Message) (*model.CallbackOutput, error) {
|
|
return &model.CallbackOutput{Message: m}, nil
|
|
})
|
|
_, out = callbacks.OnEndWithStreamOutput(nCtx, out)
|
|
back := schema.StreamReaderWithConvert(out, func(o *model.CallbackOutput) (*schema.Message, error) {
|
|
return o.Message, nil
|
|
})
|
|
return back, nil
|
|
}
|
|
|
|
func (a *ABRouterChatModel) IsCallbacksEnabled() bool { return true }
|