You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
4.2 KiB
Go

/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by hertz generator.
package main
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"time"
"github.com/RanFeng/ilog"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/cors"
"github.com/cloudwego/eino-examples/flow/agent/deer-go/biz/consts"
"github.com/cloudwego/eino-examples/flow/agent/deer-go/biz/eino"
"github.com/cloudwego/eino-examples/flow/agent/deer-go/biz/infra"
"github.com/cloudwego/eino-examples/flow/agent/deer-go/biz/model"
"github.com/cloudwego/eino-examples/flow/agent/deer-go/conf"
hertztracing "github.com/hertz-contrib/obs-opentelemetry/tracing"
)
func CorsMw() app.HandlerFunc {
return cors.New(cors.Config{
// 允许跨源访问的 origin 列表
AllowOrigins: []string{"*"},
// 允许客户端跨源访问所使用的 HTTP 方法列表
AllowMethods: []string{"POST", "GET", "PUT", "DELETE", "OPTIONS"},
// 允许使用的头信息字段列表
AllowHeaders: []string{"Authorization, Content-Length, X-CSRF-Token, Token,session,X_Requested_With,Accept, Origin, Host, Connection, Accept-Encoding, Accept-Language,DNT, X-CustomHeader, Keep-Alive, User-Agent, X-Requested-With, If-Modified-Since, Cache-Control, Content-Type, Pragma"},
// 允许暴露给客户端的响应头列表
ExposeHeaders: []string{"Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers,Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma,FooBar"},
// 允许客户端请求携带用户凭证
AllowCredentials: true,
MaxAge: 12 * time.Hour,
})
}
func runServer() {
ilog.SetGlobalLogLevel(ilog.LevelInfo)
ctx := context.Background()
conf.LoadDeerConfig(ctx)
infra.InitModel()
infra.InitMCP()
tracer, cfg, shutdown := infra.InitAPMPlusTracing(ctx, true)
defer func() {
if shutdown != nil {
shutdown(ctx)
}
}()
infra.InitCozeLoopTracing()
h := server.Default(server.WithHostPorts(":8000"), tracer)
if tracer.F != nil && cfg != nil {
h = server.Default(server.WithHostPorts(":8000"), tracer)
h.Use(hertztracing.ServerMiddleware(cfg))
}
h.Use(CorsMw())
register(h)
h.Spin()
}
func runConsole() {
ctx := context.Background()
ctx = context.WithValue(ctx, ilog.LogLevelKey, ilog.LevelInfo)
conf.LoadDeerConfig(ctx)
_, _, shutdown := infra.InitAPMPlusTracing(ctx, false)
defer func() {
if shutdown != nil {
shutdown(ctx)
time.Sleep(5 * time.Second)
}
}()
infra.InitCozeLoopTracing()
infra.InitModel()
infra.InitMCP()
reader := bufio.NewReader(os.Stdin)
fmt.Print("请输入你的需求: ")
userPrompt, _ := reader.ReadString('\n')
userPrompt = strings.TrimSpace(userPrompt) // 去除换行符
userMessage := []*schema.Message{
schema.UserMessage(userPrompt),
}
genFunc := func(ctx context.Context) *model.State {
return &model.State{
MaxPlanIterations: conf.Config.Setting.MaxPlanIterations,
AutoAcceptedPlan: true,
MaxStepNum: conf.Config.Setting.MaxStepNum,
Messages: userMessage,
Goto: consts.Coordinator,
}
}
r := eino.Builder[string, string, *model.State](ctx, genFunc)
outChan := make(chan string)
go func() {
for out := range outChan {
fmt.Print(out)
}
}()
_, err := r.Stream(ctx,
consts.Coordinator,
compose.WithCallbacks(&infra.LoggerCallback{
Out: outChan,
}),
)
if err != nil {
ilog.EventError(ctx, err, "run failed")
}
ilog.EventInfo(ctx, "run console finish", time.Now())
}
func main() {
if len(os.Args) == 2 && os.Args[1] == "-s" {
runServer()
return
}
runConsole()
}