You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.1 KiB
Go

/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"slices"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
func NewCachedARKChatModel(ctx context.Context, cfg *ark.ChatModelConfig) (*CachedARKChatModel, error) {
cm, err := ark.NewChatModel(ctx, cfg)
if err != nil {
return nil, err
}
return &CachedARKChatModel{
ChatModel: cm,
}, nil
}
type CachedARKChatModel struct {
ChatModel model.ToolCallingChatModel
}
const (
cacheOptionCtxKey = "ark-cache-option"
)
func WithCacheCtx(ctx context.Context, cache *ark.CacheOption) context.Context {
return context.WithValue(ctx, cacheOptionCtxKey, cache)
}
func GetCacheCtx(ctx context.Context) *ark.CacheOption {
cache, ok := ctx.Value(cacheOptionCtxKey).(*ark.CacheOption)
if !ok {
return nil
}
return cache
}
func (cm *CachedARKChatModel) Generate(ctx context.Context, in []*schema.Message,
opts ...model.Option) (outMsg *schema.Message, err error) {
opts_ := opts
in_ := in
cacheOption := GetCacheCtx(ctx)
if cacheOption != nil && cacheOption.ContextID != nil {
opts_ = slices.Clone(opts)
opts_ = append(opts_, ark.WithCache(cacheOption), model.WithTools([]*schema.ToolInfo{}))
in_ = in[len(in)-1:]
}
outMsg, err = cm.ChatModel.Generate(ctx, in_, opts_...)
if err != nil {
return nil, err
}
contextID, ok := ark.GetContextID(outMsg)
if ok {
cacheOption.ContextID = &contextID
}
return outMsg, nil
}
func (cm *CachedARKChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (
outStream *schema.StreamReader[*schema.Message], err error) {
opts_ := opts
in_ := in
cacheOption, ok := ctx.Value(cacheOptionCtxKey).(*ark.CacheOption)
if ok && cacheOption != nil && cacheOption.ContextID != nil {
opts_ = slices.Clone(opts)
opts_ = append(opts_, ark.WithCache(cacheOption), model.WithTools([]*schema.ToolInfo{}))
in_ = in[len(in)-1:]
}
outStream, err = cm.ChatModel.Stream(ctx, in_, opts_...)
if err != nil {
return nil, err
}
outStream_ := schema.StreamReaderWithConvert(outStream, func(msg *schema.Message) (*schema.Message, error) {
contextID, ok := ark.GetContextID(msg)
if ok {
cacheOption.ContextID = &contextID
}
return msg, nil
})
return outStream_, nil
}
func (cm *CachedARKChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
m, err := cm.ChatModel.WithTools(tools)
if err != nil {
return nil, err
}
ncm := *cm
ncm.ChatModel = m
return &ncm, nil
}