You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
133 lines
3.5 KiB
Go
133 lines
3.5 KiB
Go
/*
|
|
* Copyright 2025 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package utils
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
|
|
"github.com/cloudwego/eino-ext/components/model/ark"
|
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
|
"github.com/cloudwego/eino/components/model"
|
|
arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
)
|
|
|
|
type CreateChatModelOption func(o *option)
|
|
|
|
func NewChatModel(ctx context.Context, opts ...CreateChatModelOption) (cm model.ToolCallingChatModel, err error) {
|
|
o := &option{}
|
|
for _, opt := range opts {
|
|
opt(o)
|
|
}
|
|
|
|
if modelName := os.Getenv("ARK_MODEL"); modelName != "" {
|
|
conf := &ark.ChatModelConfig{
|
|
APIKey: os.Getenv("ARK_API_KEY"),
|
|
BaseURL: os.Getenv("ARK_BASE_URL"),
|
|
Region: os.Getenv("ARK_REGION"),
|
|
Model: modelName,
|
|
MaxTokens: o.MaxTokens,
|
|
Temperature: o.Temperature,
|
|
TopP: o.TopP,
|
|
}
|
|
if o.DisableThinking != nil && *o.DisableThinking {
|
|
conf.Thinking = &arkmodel.Thinking{
|
|
Type: arkmodel.ThinkingTypeDisabled,
|
|
}
|
|
}
|
|
if o.JsonSchema != nil {
|
|
conf.ResponseFormat = &ark.ResponseFormat{
|
|
Type: arkmodel.ResponseFormatJSONSchema,
|
|
JSONSchema: &arkmodel.ResponseFormatJSONSchemaJSONSchemaParam{
|
|
Name: o.JsonSchema.Name,
|
|
Description: o.JsonSchema.Description,
|
|
Schema: o.JsonSchema.JSONSchema,
|
|
Strict: o.JsonSchema.Strict,
|
|
},
|
|
}
|
|
}
|
|
cm, err = ark.NewChatModel(ctx, conf)
|
|
|
|
} else if modelName = os.Getenv("OPENAI_MODEL"); modelName != "" {
|
|
conf := &openai.ChatModelConfig{
|
|
APIKey: os.Getenv("OPENAI_API_KEY"),
|
|
ByAzure: func() bool {
|
|
return os.Getenv("OPENAI_BY_AZURE") == "true"
|
|
}(),
|
|
BaseURL: os.Getenv("OPENAI_BASE_URL"),
|
|
Model: modelName,
|
|
MaxTokens: o.MaxTokens,
|
|
Temperature: o.Temperature,
|
|
TopP: o.TopP,
|
|
}
|
|
if o.JsonSchema != nil {
|
|
conf.ResponseFormat = &openai.ChatCompletionResponseFormat{
|
|
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
|
|
JSONSchema: o.JsonSchema,
|
|
}
|
|
}
|
|
cm, err = openai.NewChatModel(ctx, conf)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cm == nil {
|
|
return nil, fmt.Errorf("no model config")
|
|
}
|
|
|
|
return cm, nil
|
|
}
|
|
|
|
type option struct {
|
|
MaxTokens *int
|
|
Temperature *float32
|
|
TopP *float32
|
|
DisableThinking *bool
|
|
JsonSchema *openai.ChatCompletionResponseFormatJSONSchema
|
|
}
|
|
|
|
func WithMaxTokens(maxTokens int) CreateChatModelOption {
|
|
return func(o *option) {
|
|
o.MaxTokens = &maxTokens
|
|
}
|
|
}
|
|
|
|
func WithTemperature(temp float32) CreateChatModelOption {
|
|
return func(o *option) {
|
|
o.Temperature = &temp
|
|
}
|
|
}
|
|
|
|
func WithTopP(topP float32) CreateChatModelOption {
|
|
return func(o *option) {
|
|
o.TopP = &topP
|
|
}
|
|
}
|
|
|
|
func WithDisableThinking(disable bool) CreateChatModelOption {
|
|
return func(o *option) {
|
|
o.DisableThinking = &disable
|
|
}
|
|
}
|
|
|
|
func WithResponseFormatJsonSchema(schema *openai.ChatCompletionResponseFormatJSONSchema) CreateChatModelOption {
|
|
return func(o *option) {
|
|
o.JsonSchema = schema
|
|
}
|
|
}
|