You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
4.1 KiB
Go

/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package tools
import (
"context"
"encoding/json"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
)
func GetRestaurantTool() tool.InvokableTool {
return &ToolQueryRestaurants{
backService: restService,
}
}
func GetDishTool() tool.InvokableTool {
return &ToolQueryDishes{
backService: restService,
}
}
type ToolQueryRestaurants struct {
backService *fakeService // fake service
}
func (t *ToolQueryRestaurants) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: "query_restaurants",
Desc: "Query restaurants",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"location": {
Type: "string",
Desc: "The location of the restaurant",
Required: true,
},
"topn": {
Type: "number",
Desc: "top n restaurant in some location sorted by score",
},
}),
}, nil
}
// InvokableRun
// tool 接收的参数和返回都是 string, 就如大模型的 tool call 的返回一样, 因此需要自行处理参数和结果的序列化.
// 返回的 content 会作为 schema.Message 的 content, 一般来说是作为大模型的输入, 因此处理成大模型能更好理解的结构最好.
// 因此,如果是 json 格式,就需要注意 key 和 value 的表意, 不要用 int Enum 代表一个业务含义,比如 `不要用 1 代表 male, 2 代表 female` 这类.
func (t *ToolQueryRestaurants) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
// 解析参数
p := &QueryRestaurantsParam{}
err := json.Unmarshal([]byte(argumentsInJSON), p)
if err != nil {
return "", err
}
if p.Topn == 0 {
p.Topn = 3
}
// 请求后端服务
rests, err := t.backService.QueryRestaurants(ctx, p)
if err != nil {
return "", err
}
// 序列化结果
res, err := json.Marshal(rests)
if err != nil {
return "", err
}
return string(res), nil
}
type QueryRestaurantsParam struct {
Location string `json:"location"`
Topn int `json:"topn"`
}
type Restaurant struct {
ID string `json:"id"`
Name string `json:"name"`
Place string `json:"place"`
Desc string `json:"desc"`
Score int `json:"score"`
}
// ToolQueryDishes.
type ToolQueryDishes struct {
backService *fakeService // fake service
}
func (t *ToolQueryDishes) Info(ctx context.Context) (*schema.ToolInfo, error) {
return &schema.ToolInfo{
Name: "query_dishes",
Desc: "查询一家餐厅有哪些菜品",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"restaurant_id": {
Type: "string",
Desc: "The id of one restaurant",
Required: true,
},
"topn": {
Type: "number",
Desc: "top n dishes in one restaurant sorted by score",
},
}),
}, nil
}
func (t *ToolQueryDishes) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
// 解析参数
p := &QueryDishesParam{}
err := json.Unmarshal([]byte(argumentsInJSON), p)
if err != nil {
return "", err
}
if p.Topn == 0 {
p.Topn = 5
}
// 请求后端服务
rests, err := t.backService.QueryDishes(ctx, p)
if err != nil {
return "", err
}
// 序列化结果
res, err := json.Marshal(rests)
if err != nil {
return "", err
}
return string(res), nil
}
type QueryDishesParam struct {
RestaurantID string `json:"restaurant_id"`
Topn int `json:"topn"`
}
type Dish struct {
Name string `json:"name"`
Desc string `json:"desc"`
Price int `json:"price"`
Score int `json:"score"`
}