You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
4.8 KiB
Go

/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package abtest
import (
"context"
"errors"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
type ModelRouter func(ctx context.Context, input []*schema.Message, opts ...model.Option) (string, model.BaseChatModel, error)
// ABRouterChatModel is a dynamic router over chat models that implements ToolCallingChatModel.
//
// Behavior:
// - Routing: delegates the choice to a user-provided ModelRouter which returns (modelName, BaseChatModel).
// - RunInfo naming: uses the returned modelName when calling callbacks.EnsureRunInfo so callbacks can log the chosen model.
// - Tools: stores tool infos via WithTools and applies them lazily if the chosen model supports ToolCallingChatModel.
// - Callbacks: if the chosen model exposes components.Checker and IsCallbacksEnabled()==true, delegates directly;
// otherwise injects OnStart/OnEnd/OnError around Generate/Stream.
// - IsCallbacksEnabled: returns true to indicate this wrapper already coordinates callback triggering.
//
// Typical usage:
//
// router := NewABRouterChatModel(func(ctx, msgs, opts...) (string, model.BaseChatModel, error) {
// return "openai", openaiModel, nil
// })
// router = router.WithTools(toolInfos) // optional
// msg, _ := router.Generate(ctx, input)
type ABRouterChatModel struct {
router ModelRouter
tools []*schema.ToolInfo
}
func NewABRouterChatModel(router ModelRouter) *ABRouterChatModel {
return &ABRouterChatModel{router: router}
}
func (a *ABRouterChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
return &ABRouterChatModel{router: a.router, tools: tools}, nil
}
func (a *ABRouterChatModel) pickModel(ctx context.Context, input []*schema.Message, opts ...model.Option) (string, model.BaseChatModel, error) {
if a.router == nil {
return "", nil, errors.New("no router")
}
name, base, err := a.router(ctx, input, opts...)
if err != nil || base == nil {
return "", nil, err
}
if tcm, ok := base.(model.ToolCallingChatModel); ok && len(a.tools) > 0 {
nTcm, wErr := tcm.WithTools(a.tools)
if wErr != nil {
return "", nil, wErr
}
base = nTcm
}
return name, base, nil
}
func (a *ABRouterChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
name, base, err := a.pickModel(ctx, input, opts...)
if err != nil || base == nil {
if err == nil {
err = errors.New("router returned nil model")
}
callbacks.OnError(ctx, err)
return nil, err
}
ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Name: name, Component: components.ComponentOfChatModel})
if ch, ok := base.(components.Checker); ok && ch.IsCallbacksEnabled() {
return base.Generate(ctx, input, opts...)
}
nCtx := callbacks.OnStart(ctx, &model.CallbackInput{Messages: input})
out, err := base.Generate(nCtx, input, opts...)
if err != nil {
callbacks.OnError(nCtx, err)
return nil, err
}
callbacks.OnEnd(nCtx, &model.CallbackOutput{Message: out})
return out, nil
}
func (a *ABRouterChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
name, base, err := a.pickModel(ctx, input, opts...)
if err != nil || base == nil {
if err == nil {
err = errors.New("router returned nil model")
}
callbacks.OnError(ctx, err)
return nil, err
}
ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Name: name, Component: components.ComponentOfChatModel})
if ch, ok := base.(components.Checker); ok && ch.IsCallbacksEnabled() {
return base.Stream(ctx, input, opts...)
}
nCtx := callbacks.OnStart(ctx, &model.CallbackInput{Messages: input})
sr, err := base.Stream(nCtx, input, opts...)
if err != nil {
callbacks.OnError(nCtx, err)
return nil, err
}
out := schema.StreamReaderWithConvert(sr, func(m *schema.Message) (*model.CallbackOutput, error) {
return &model.CallbackOutput{Message: m}, nil
})
_, out = callbacks.OnEndWithStreamOutput(nCtx, out)
back := schema.StreamReaderWithConvert(out, func(o *model.CallbackOutput) (*schema.Message, error) {
return o.Message, nil
})
return back, nil
}
func (a *ABRouterChatModel) IsCallbacksEnabled() bool { return true }