You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
4.8 KiB
Go

/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package infra
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/RanFeng/ilog"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/cloudwego/eino-examples/flow/agent/deer-go/conf"
)
const (
transportStdio = "stdio"
transportSSE = "sse"
)
var (
MCPServer map[string]client.MCPClient
)
func InitMCP() {
var err error
MCPServer, err = CreateMCPClients()
if err != nil {
panic(err)
}
}
type MCPConfig struct {
MCPServers map[string]ServerConfigWrapper `json:"mcpServers"`
}
type ServerConfig interface {
GetType() string
}
type STDIOServerConfig struct {
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env,omitempty"`
}
func (s STDIOServerConfig) GetType() string {
return transportStdio
}
type SSEServerConfig struct {
Url string `json:"url"`
Headers []string `json:"headers,omitempty"`
}
func (s SSEServerConfig) GetType() string {
return transportSSE
}
type ServerConfigWrapper struct {
Config ServerConfig
}
func (w *ServerConfigWrapper) UnmarshalJSON(data []byte) error {
var typeField struct {
Url string `json:"url"`
}
if err := json.Unmarshal(data, &typeField); err != nil {
return err
}
if typeField.Url != "" {
// If the URL field is present, treat it as an SSE server
var sse SSEServerConfig
if err := json.Unmarshal(data, &sse); err != nil {
return err
}
w.Config = sse
} else {
// Otherwise, treat it as a STDIOServerConfig
var stdio STDIOServerConfig
if err := json.Unmarshal(data, &stdio); err != nil {
return err
}
w.Config = stdio
}
return nil
}
func (w ServerConfigWrapper) MarshalJSON() ([]byte, error) {
return json.Marshal(w.Config)
}
func CreateMCPClients() (map[string]client.MCPClient, error) {
// 将 DeerConfig 转换为 MCPConfig
mcpConfig := &MCPConfig{
MCPServers: make(map[string]ServerConfigWrapper),
}
for name, server := range conf.Config.MCP.Servers {
mcpConfig.MCPServers[name] = ServerConfigWrapper{
Config: STDIOServerConfig{
Command: server.Command,
Args: server.Args,
Env: server.Env,
},
}
}
clients := make(map[string]client.MCPClient)
for name, server := range mcpConfig.MCPServers {
var mcpClient client.MCPClient
var err error
ilog.EventInfo(context.Background(), "load mcp client", name, server.Config.GetType())
if server.Config.GetType() == transportSSE {
sseConfig := server.Config.(SSEServerConfig)
options := []client.ClientOption{}
if sseConfig.Headers != nil {
// Parse headers from the conf
headers := make(map[string]string)
for _, header := range sseConfig.Headers {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
headers[key] = value
}
}
options = append(options, client.WithHeaders(headers))
}
mcpClient, err = client.NewSSEMCPClient(
sseConfig.Url,
options...,
)
if err == nil {
err = mcpClient.(*client.SSEMCPClient).Start(context.Background())
}
} else {
stdioConfig := server.Config.(STDIOServerConfig)
var env []string
for k, v := range stdioConfig.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
mcpClient, err = client.NewStdioMCPClient(
stdioConfig.Command,
env,
stdioConfig.Args...)
}
if err != nil {
for _, c := range clients {
_ = c.Close()
}
return nil, fmt.Errorf(
"failed to create MCP client for %s: %w",
name,
err,
)
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
ilog.EventInfo(ctx, "Initializing server...", "name", name)
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "mcphost",
Version: "0.1.0",
}
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
_, err = mcpClient.Initialize(ctx, initRequest)
if err != nil {
_ = mcpClient.Close()
for _, c := range clients {
_ = c.Close()
}
return nil, fmt.Errorf(
"failed to initialize MCP client for %s: %w",
name,
err,
)
}
clients[name] = mcpClient
}
return clients, nil
}