You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
2.4 KiB
Go

package main
import (
"bufio"
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"os"
"strings"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/schema"
)
// sourceEnv a crude .env file reader
func sourceEnv() {
file, err := os.Open(".env")
if err != nil {
fmt.Println(".env not found or cannot be read")
return
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
split_env := strings.Split(scanner.Text(), "=")
if len(split_env) != 2 {
log.Fatal(".env file expexted '=' to delimit key and value. Check .env file for proper format.")
}
fmt.Printf("Setting from .env: %s\n", split_env[0])
os.Setenv(split_env[0], split_env[1])
}
if err := scanner.Err(); err != nil {
log.Fatal(err)
}
}
type AgentConfig struct {
BaseUrl string
ModelId string
ApiKey string
}
func NewAgentConfig() *AgentConfig {
url, ok := os.LookupEnv("OPENAI_BASE_URL")
if !ok {
log.Fatal("An OPENAI_BASE_URL must be specified as an environment variable.")
}
modelId, ok := os.LookupEnv("OPENAI_API_MODEL")
if !ok {
log.Fatal("A model id must be specified with OPENAI_API_MODEL as an environment variable.")
}
apiKey, ok := os.LookupEnv("OPENAI_API_KEY")
if !ok {
fmt.Println("No API found as OPENAI_API_KEY. Using dummy value")
apiKey = "dummyvalue"
}
return &AgentConfig{url, modelId, apiKey}
}
func main() {
sourceEnv()
var instruction string
flag.StringVar(&instruction, "instruction", "You are a helpful assistant.", "Set a system prompt for your agent.")
flag.Parse()
query := strings.TrimSpace(strings.Join(flag.Args(), " "))
if query == "" {
fmt.Fprintln(os.Stderr, "usage: go run main -- \"your question\"")
os.Exit(2)
}
ctx := context.Background()
agentCfg := NewAgentConfig()
cm, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
Model: agentCfg.ModelId,
BaseURL: agentCfg.BaseUrl,
APIKey: agentCfg.ApiKey,
})
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
messages := []*schema.Message{
schema.SystemMessage(instruction),
schema.UserMessage(query),
}
fmt.Fprint(os.Stdout, "[assistant] ")
stream, err := cm.Stream(ctx, messages)
if err != nil {
log.Fatal(err)
}
defer stream.Close()
for {
frame, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
log.Fatal(err)
}
if frame != nil {
fmt.Fprint(os.Stdout, frame.Content)
}
}
fmt.Fprintln(os.Stdout)
}