You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
2.3 KiB
Go

/*
* Copyright 2026 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"errors"
"flag"
"fmt"
"io"
"os"
"strings"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
func main() {
var instruction string
flag.StringVar(&instruction, "instruction", "You are a helpful assistant.", "")
flag.Parse()
query := strings.TrimSpace(strings.Join(flag.Args(), " "))
if query == "" {
_, _ = fmt.Fprintln(os.Stderr, "usage: go run ./cmd/ch01 -- \"your question\"")
os.Exit(2)
}
ctx := context.Background()
cm, err := newChatModel(ctx)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
messages := []*schema.Message{
schema.SystemMessage(instruction),
schema.UserMessage(query),
}
_, _ = fmt.Fprint(os.Stdout, "[assistant] ")
stream, err := cm.Stream(ctx, messages)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
defer stream.Close()
for {
frame, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
if frame != nil {
_, _ = fmt.Fprint(os.Stdout, frame.Content)
}
}
_, _ = fmt.Fprintln(os.Stdout)
}
func newChatModel(ctx context.Context) (model.ToolCallingChatModel, error) {
if os.Getenv("MODEL_TYPE") == "ark" {
return ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: os.Getenv("ARK_API_KEY"),
Model: os.Getenv("ARK_MODEL"),
BaseURL: os.Getenv("ARK_BASE_URL"),
})
}
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
APIKey: os.Getenv("OPENAI_API_KEY"),
Model: os.Getenv("OPENAI_MODEL"),
BaseURL: os.Getenv("OPENAI_BASE_URL"),
ByAzure: os.Getenv("OPENAI_BY_AZURE") == "true",
})
}