diff --git a/README.md b/README.md index 4a7f971..0fcd349 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,16 @@ user -> main -> worker -> main -> user 完整示例见 [config.example.json](/Users/lpf/Desktop/project/clawgo/config.example.json)。 +## MCP 服务支持 + +ClawGo 现在支持通过 `tools.mcp` 接入 `stdio` 型 MCP server。 + +- 先在 `config.json -> tools.mcp.servers` 里声明 server +- 当前支持 `list_servers`、`list_tools`、`call_tool`、`list_resources`、`read_resource`、`list_prompts`、`get_prompt` +- 启动时会自动发现远端 MCP tools,并注册为本地工具,命名格式为 `mcp____` + +示例配置可直接参考 [config.example.json](/Users/lpf/Desktop/project/clawgo/config.example.json) 中的 `tools.mcp` 段落。 + ## Prompt 文件约定 推荐把 agent prompt 独立为文件: diff --git a/README_EN.md b/README_EN.md index 52fd77c..d084f65 100644 --- a/README_EN.md +++ b/README_EN.md @@ -200,6 +200,16 @@ Notes: See [config.example.json](/Users/lpf/Desktop/project/clawgo/config.example.json) for a full example. +## MCP Server Support + +ClawGo now supports `stdio` MCP servers through `tools.mcp`. + +- declare each server under `config.json -> tools.mcp.servers` +- the bridge supports `list_servers`, `list_tools`, `call_tool`, `list_resources`, `read_resource`, `list_prompts`, and `get_prompt` +- on startup, ClawGo discovers remote MCP tools and registers them as local tools using the `mcp____` naming pattern + +See the `tools.mcp` section in [config.example.json](/Users/lpf/Desktop/project/clawgo/config.example.json). + ## Prompt File Convention Keep agent prompts in dedicated files: diff --git a/cmd/clawgo/cmd_gateway.go b/cmd/clawgo/cmd_gateway.go index 41ca30d..2000da7 100644 --- a/cmd/clawgo/cmd_gateway.go +++ b/cmd/clawgo/cmd_gateway.go @@ -166,6 +166,9 @@ func gatewayCmd() { registryServer.SetSubagentHandler(func(cctx context.Context, action string, args map[string]interface{}) (interface{}, error) { return agentLoop.HandleSubagentRuntime(cctx, action, args) }) + registryServer.SetToolsCatalogHandler(func() interface{} { + return agentLoop.GetToolCatalog() + }) registryServer.SetCronHandler(func(action string, args map[string]interface{}) (interface{}, error) { getStr := func(k string) string { v, _ := args[k].(string) diff --git a/config.example.json b/config.example.json index 0b4d71c..b76f61c 100644 --- a/config.example.json +++ b/config.example.json @@ -253,6 +253,22 @@ "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 } + }, + "mcp": { + "enabled": false, + "request_timeout_sec": 20, + "servers": { + "context7": { + "enabled": false, + "transport": "stdio", + "command": "npx", + "args": ["-y", "@upstash/context7-mcp"], + "env": {}, + "working_dir": "/absolute/path/to/project", + "description": "Example MCP server", + "package": "@upstash/context7-mcp" + } + } } }, "gateway": { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 1f34736..29a2cdb 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -17,6 +17,7 @@ import ( "path/filepath" "regexp" "runtime" + "sort" "strings" "sync" "time" @@ -165,6 +166,15 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers webFetchTool := tools.NewWebFetchTool(50000) toolsRegistry.Register(webFetchTool) toolsRegistry.Register(tools.NewParallelFetchTool(webFetchTool, maxParallelCalls, parallelSafe)) + if cfg.Tools.MCP.Enabled { + mcpTool := tools.NewMCPTool(workspace, cfg.Tools.MCP) + toolsRegistry.Register(mcpTool) + discoveryCtx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.Tools.MCP.RequestTimeoutSec)*time.Second) + for _, remoteTool := range mcpTool.DiscoverTools(discoveryCtx) { + toolsRegistry.Register(remoteTool) + } + cancel() + } // Register message tool messageTool := tools.NewMessageTool() @@ -1680,6 +1690,22 @@ func (al *AgentLoop) GetStartupInfo() map[string]interface{} { return info } +func (al *AgentLoop) GetToolCatalog() []map[string]interface{} { + if al == nil || al.tools == nil { + return nil + } + items := al.tools.Catalog() + sort.Slice(items, func(i, j int) bool { + return fmt.Sprint(items[i]["name"]) < fmt.Sprint(items[j]["name"]) + }) + for _, item := range items { + if fmt.Sprint(item["source"]) != "mcp" { + item["source"] = "local" + } + } + return items +} + // formatMessagesForLog formats messages for logging func formatMessagesForLog(messages []providers.Message) string { if len(messages) == 0 { diff --git a/pkg/api/server.go b/pkg/api/server.go index b1fac5f..5790051 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -46,6 +46,7 @@ type Server struct { onConfigAfter func() onCron func(action string, args map[string]interface{}) (interface{}, error) onSubagents func(ctx context.Context, action string, args map[string]interface{}) (interface{}, error) + onToolsCatalog func() interface{} webUIDir string ekgCacheMu sync.Mutex ekgCachePath string @@ -81,9 +82,10 @@ func (s *Server) SetCronHandler(fn func(action string, args map[string]interface func (s *Server) SetSubagentHandler(fn func(ctx context.Context, action string, args map[string]interface{}) (interface{}, error)) { s.onSubagents = fn } -func (s *Server) SetWebUIDir(dir string) { s.webUIDir = strings.TrimSpace(dir) } -func (s *Server) SetGatewayVersion(v string) { s.gatewayVersion = strings.TrimSpace(v) } -func (s *Server) SetWebUIVersion(v string) { s.webuiVersion = strings.TrimSpace(v) } +func (s *Server) SetToolsCatalogHandler(fn func() interface{}) { s.onToolsCatalog = fn } +func (s *Server) SetWebUIDir(dir string) { s.webUIDir = strings.TrimSpace(dir) } +func (s *Server) SetGatewayVersion(v string) { s.gatewayVersion = strings.TrimSpace(v) } +func (s *Server) SetWebUIVersion(v string) { s.webuiVersion = strings.TrimSpace(v) } func (s *Server) Start(ctx context.Context) error { if s.mgr == nil { @@ -112,6 +114,8 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/webui/api/subagent_profiles", s.handleWebUISubagentProfiles) mux.HandleFunc("/webui/api/subagents_runtime", s.handleWebUISubagentsRuntime) mux.HandleFunc("/webui/api/tool_allowlist_groups", s.handleWebUIToolAllowlistGroups) + mux.HandleFunc("/webui/api/tools", s.handleWebUITools) + mux.HandleFunc("/webui/api/mcp/install", s.handleWebUIMCPInstall) mux.HandleFunc("/webui/api/task_audit", s.handleWebUITaskAudit) mux.HandleFunc("/webui/api/task_queue", s.handleWebUITaskQueue) mux.HandleFunc("/webui/api/ekg_stats", s.handleWebUIEKGStats) @@ -592,6 +596,73 @@ func (s *Server) handleWebUIVersion(w http.ResponseWriter, r *http.Request) { }) } +func (s *Server) handleWebUITools(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + toolsList := []map[string]interface{}{} + if s.onToolsCatalog != nil { + if items, ok := s.onToolsCatalog().([]map[string]interface{}); ok && items != nil { + toolsList = items + } + } + mcpItems := make([]map[string]interface{}, 0) + for _, item := range toolsList { + if strings.TrimSpace(fmt.Sprint(item["source"])) == "mcp" { + mcpItems = append(mcpItems, item) + } + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "tools": toolsList, + "mcp_tools": mcpItems, + }) +} + +func (s *Server) handleWebUIMCPInstall(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Package string `json:"package"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + pkgName := strings.TrimSpace(body.Package) + if pkgName == "" { + http.Error(w, "package required", http.StatusBadRequest) + return + } + out, binName, binPath, err := ensureMCPPackageInstalled(r.Context(), pkgName) + if err != nil { + msg := err.Error() + if strings.TrimSpace(out) != "" { + msg = strings.TrimSpace(out) + "\n" + msg + } + http.Error(w, strings.TrimSpace(msg), http.StatusInternalServerError) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "ok": true, + "package": pkgName, + "output": out, + "bin_name": binName, + "bin_path": binPath, + }) +} + func (s *Server) handleWebUINodes(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -1694,6 +1765,102 @@ func ensureClawHubReady(ctx context.Context) (string, error) { return strings.Join(outs, "\n"), fmt.Errorf("installed clawhub but executable still not found in PATH") } +func ensureMCPPackageInstalled(ctx context.Context, pkgName string) (output string, binName string, binPath string, err error) { + pkgName = strings.TrimSpace(pkgName) + if pkgName == "" { + return "", "", "", fmt.Errorf("package empty") + } + outs := make([]string, 0, 4) + nodeOut, err := ensureNodeRuntime(ctx) + if nodeOut != "" { + outs = append(outs, nodeOut) + } + if err != nil { + return strings.Join(outs, "\n"), "", "", err + } + installOut, err := runInstallCommand(ctx, "npm i -g "+shellEscapeArg(pkgName)) + if installOut != "" { + outs = append(outs, installOut) + } + if err != nil { + return strings.Join(outs, "\n"), "", "", err + } + binName, err = resolveNpmPackageBin(ctx, pkgName) + if err != nil { + return strings.Join(outs, "\n"), "", "", err + } + binPath = resolveInstalledBinary(ctx, binName) + if strings.TrimSpace(binPath) == "" { + return strings.Join(outs, "\n"), binName, "", fmt.Errorf("installed %s but binary %q not found in PATH", pkgName, binName) + } + outs = append(outs, fmt.Sprintf("installed %s", pkgName)) + outs = append(outs, fmt.Sprintf("resolved binary: %s", binPath)) + return strings.Join(outs, "\n"), binName, binPath, nil +} + +func resolveNpmPackageBin(ctx context.Context, pkgName string) (string, error) { + cctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + cmd := exec.CommandContext(cctx, "npm", "view", pkgName, "bin", "--json") + out, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("failed to query npm bin for %s: %w", pkgName, err) + } + trimmed := strings.TrimSpace(string(out)) + if trimmed == "" || trimmed == "null" { + return "", fmt.Errorf("npm package %s does not expose a bin", pkgName) + } + var obj map[string]interface{} + if err := json.Unmarshal(out, &obj); err == nil && len(obj) > 0 { + keys := make([]string, 0, len(obj)) + for key := range obj { + keys = append(keys, key) + } + sort.Strings(keys) + return keys[0], nil + } + var text string + if err := json.Unmarshal(out, &text); err == nil && strings.TrimSpace(text) != "" { + return strings.TrimSpace(text), nil + } + return "", fmt.Errorf("unable to resolve bin for npm package %s", pkgName) +} + +func resolveInstalledBinary(ctx context.Context, binName string) string { + binName = strings.TrimSpace(binName) + if binName == "" { + return "" + } + if p, err := exec.LookPath(binName); err == nil { + return p + } + prefix := strings.TrimSpace(npmGlobalPrefix(ctx)) + if prefix != "" { + cand := filepath.Join(prefix, "bin", binName) + if st, err := os.Stat(cand); err == nil && !st.IsDir() { + return cand + } + } + cands := []string{ + filepath.Join("/usr/local/bin", binName), + filepath.Join("/opt/homebrew/bin", binName), + filepath.Join(os.Getenv("HOME"), ".npm-global", "bin", binName), + } + for _, cand := range cands { + if st, err := os.Stat(cand); err == nil && !st.IsDir() { + return cand + } + } + return "" +} + +func shellEscapeArg(in string) string { + if strings.TrimSpace(in) == "" { + return "''" + } + return "'" + strings.ReplaceAll(in, "'", `'\''`) + "'" +} + func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]string, error) { if err := r.ParseMultipartForm(128 << 20); err != nil { return nil, err diff --git a/pkg/config/config.go b/pkg/config/config.go index a608317..721e2fd 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -326,10 +326,28 @@ type SandboxConfig struct { type FilesystemConfig struct{} +type MCPServerConfig struct { + Enabled bool `json:"enabled"` + Transport string `json:"transport"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` + WorkingDir string `json:"working_dir,omitempty"` + Description string `json:"description,omitempty"` + Package string `json:"package,omitempty"` +} + +type MCPToolsConfig struct { + Enabled bool `json:"enabled"` + RequestTimeoutSec int `json:"request_timeout_sec"` + Servers map[string]MCPServerConfig `json:"servers,omitempty"` +} + type ToolsConfig struct { Web WebToolsConfig `json:"web"` Shell ShellConfig `json:"shell"` Filesystem FilesystemConfig `json:"filesystem"` + MCP MCPToolsConfig `json:"mcp"` } type LoggingConfig struct { @@ -540,6 +558,11 @@ func DefaultConfig() *Config { }, }, Filesystem: FilesystemConfig{}, + MCP: MCPToolsConfig{ + Enabled: false, + RequestTimeoutSec: 20, + Servers: map[string]MCPServerConfig{}, + }, }, Logging: LoggingConfig{ Enabled: true, diff --git a/pkg/config/validate.go b/pkg/config/validate.go index b00fba7..bd9efe0 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -165,6 +165,7 @@ func Validate(cfg *Config) []error { if cfg.Memory.RecentDays <= 0 { errs = append(errs, fmt.Errorf("memory.recent_days must be > 0")) } + errs = append(errs, validateMCPTools(cfg)...) if cfg.Channels.InboundMessageIDDedupeTTLSeconds <= 0 { errs = append(errs, fmt.Errorf("channels.inbound_message_id_dedupe_ttl_seconds must be > 0")) @@ -212,6 +213,40 @@ func Validate(cfg *Config) []error { return errs } +func validateMCPTools(cfg *Config) []error { + var errs []error + mcp := cfg.Tools.MCP + if !mcp.Enabled { + return errs + } + if mcp.RequestTimeoutSec <= 0 { + errs = append(errs, fmt.Errorf("tools.mcp.request_timeout_sec must be > 0 when tools.mcp.enabled=true")) + } + for name, server := range mcp.Servers { + if strings.TrimSpace(name) == "" { + errs = append(errs, fmt.Errorf("tools.mcp.servers contains an empty server name")) + continue + } + if !server.Enabled { + continue + } + transport := strings.ToLower(strings.TrimSpace(server.Transport)) + if transport == "" { + transport = "stdio" + } + if transport != "stdio" { + errs = append(errs, fmt.Errorf("tools.mcp.servers.%s.transport must be 'stdio'", name)) + } + if strings.TrimSpace(server.Command) == "" { + errs = append(errs, fmt.Errorf("tools.mcp.servers.%s.command is required when enabled=true", name)) + } + if wd := strings.TrimSpace(server.WorkingDir); wd != "" && !filepath.IsAbs(wd) { + errs = append(errs, fmt.Errorf("tools.mcp.servers.%s.working_dir must be an absolute path", name)) + } + } + return errs +} + func validateAgentRouter(cfg *Config) []error { router := cfg.Agents.Router var errs []error diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 8fea049..2875af8 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -23,6 +23,10 @@ type ResourceScopedTool interface { ResourceKeys(args map[string]interface{}) []string } +type CatalogTool interface { + CatalogEntry() map[string]interface{} +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function", diff --git a/pkg/tools/mcp.go b/pkg/tools/mcp.go new file mode 100644 index 0000000..9813526 --- /dev/null +++ b/pkg/tools/mcp.go @@ -0,0 +1,800 @@ +package tools + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "hash/fnv" + "io" + "net/url" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "clawgo/pkg/config" +) + +const mcpProtocolVersion = "2025-06-18" + +type MCPTool struct { + workspace string + cfg config.MCPToolsConfig +} + +type MCPRemoteTool struct { + bridge *MCPTool + serverName string + remoteName string + localName string + description string + parameters map[string]interface{} +} + +func NewMCPTool(workspace string, cfg config.MCPToolsConfig) *MCPTool { + if cfg.RequestTimeoutSec <= 0 { + cfg.RequestTimeoutSec = 20 + } + if cfg.Servers == nil { + cfg.Servers = map[string]config.MCPServerConfig{} + } + return &MCPTool{workspace: workspace, cfg: cfg} +} + +func (t *MCPTool) Name() string { + return "mcp" +} + +func (t *MCPTool) Description() string { + return "Call configured MCP servers over stdio. Supports listing servers, tools, resources, prompts, and invoking remote MCP tools." +} + +func (t *MCPTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{ + "type": "string", + "description": "Operation to perform", + "enum": []string{"list_servers", "list_tools", "call_tool", "list_resources", "read_resource", "list_prompts", "get_prompt"}, + }, + "server": map[string]interface{}{ + "type": "string", + "description": "Configured MCP server name", + }, + "tool": map[string]interface{}{ + "type": "string", + "description": "MCP tool name for action=call_tool", + }, + "arguments": map[string]interface{}{ + "type": "object", + "description": "Arguments for call_tool or get_prompt", + }, + "uri": map[string]interface{}{ + "type": "string", + "description": "Resource URI for action=read_resource", + }, + "prompt": map[string]interface{}{ + "type": "string", + "description": "Prompt name for action=get_prompt", + }, + }, + "required": []string{"action"}, + } +} + +func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + action := strings.TrimSpace(mcpStringArg(args, "action")) + if action == "" { + return "", fmt.Errorf("action is required") + } + if action == "list_servers" { + return t.listServers(), nil + } + + serverName := strings.TrimSpace(mcpStringArg(args, "server")) + if serverName == "" { + return "", fmt.Errorf("server is required for action %q", action) + } + serverCfg, ok := t.cfg.Servers[serverName] + if !ok || !serverCfg.Enabled { + return "", fmt.Errorf("mcp server %q is not configured or not enabled", serverName) + } + + timeout := time.Duration(t.cfg.RequestTimeoutSec) * time.Second + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 && remaining < timeout { + timeout = remaining + } + } + callCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + client, err := newMCPStdioClient(callCtx, t.workspace, serverName, serverCfg) + if err != nil { + return "", err + } + defer client.Close() + + switch action { + case "list_tools": + out, err := client.listAll(callCtx, "tools/list", "tools") + if err != nil { + return "", err + } + return prettyJSON(out) + case "call_tool": + toolName := strings.TrimSpace(mcpStringArg(args, "tool")) + if toolName == "" { + return "", fmt.Errorf("tool is required for action=call_tool") + } + params := map[string]interface{}{ + "name": toolName, + "arguments": mcpObjectArg(args, "arguments"), + } + out, err := client.request(callCtx, "tools/call", params) + if err != nil { + return "", err + } + return prettyJSON(out) + case "list_resources": + out, err := client.listAll(callCtx, "resources/list", "resources") + if err != nil { + return "", err + } + return prettyJSON(out) + case "read_resource": + resourceURI := strings.TrimSpace(mcpStringArg(args, "uri")) + if resourceURI == "" { + return "", fmt.Errorf("uri is required for action=read_resource") + } + out, err := client.request(callCtx, "resources/read", map[string]interface{}{"uri": resourceURI}) + if err != nil { + return "", err + } + return prettyJSON(out) + case "list_prompts": + out, err := client.listAll(callCtx, "prompts/list", "prompts") + if err != nil { + return "", err + } + return prettyJSON(out) + case "get_prompt": + promptName := strings.TrimSpace(mcpStringArg(args, "prompt")) + if promptName == "" { + return "", fmt.Errorf("prompt is required for action=get_prompt") + } + out, err := client.request(callCtx, "prompts/get", map[string]interface{}{ + "name": promptName, + "arguments": mcpObjectArg(args, "arguments"), + }) + if err != nil { + return "", err + } + return prettyJSON(out) + default: + return "", fmt.Errorf("unsupported action %q", action) + } +} + +func (t *MCPTool) DiscoverTools(ctx context.Context) []Tool { + if t == nil || !t.cfg.Enabled { + return nil + } + names := make([]string, 0, len(t.cfg.Servers)) + for name, server := range t.cfg.Servers { + if server.Enabled { + names = append(names, name) + } + } + sort.Strings(names) + tools := make([]Tool, 0) + seen := map[string]int{} + for _, serverName := range names { + serverCfg := t.cfg.Servers[serverName] + client, err := newMCPStdioClient(ctx, t.workspace, serverName, serverCfg) + if err != nil { + continue + } + result, err := client.listAll(ctx, "tools/list", "tools") + _ = client.Close() + if err != nil { + continue + } + items, _ := result["tools"].([]interface{}) + for _, item := range items { + toolMap, _ := item.(map[string]interface{}) + remoteName := strings.TrimSpace(mcpStringArg(toolMap, "name")) + if remoteName == "" { + continue + } + localName := buildMCPDynamicToolName(serverName, remoteName) + if count := seen[localName]; count > 0 { + localName = fmt.Sprintf("%s_%d", localName, count+1) + } + seen[localName]++ + tools = append(tools, &MCPRemoteTool{ + bridge: t, + serverName: serverName, + remoteName: remoteName, + localName: localName, + description: buildMCPDynamicToolDescription(serverName, toolMap), + parameters: normalizeMCPSchema(toolMap["inputSchema"]), + }) + } + } + return tools +} + +func (t *MCPTool) callServerTool(ctx context.Context, serverName, remoteToolName string, arguments map[string]interface{}) (string, error) { + serverCfg, ok := t.cfg.Servers[serverName] + if !ok || !serverCfg.Enabled { + return "", fmt.Errorf("mcp server %q is not configured or not enabled", serverName) + } + timeout := time.Duration(t.cfg.RequestTimeoutSec) * time.Second + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 && remaining < timeout { + timeout = remaining + } + } + callCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + client, err := newMCPStdioClient(callCtx, t.workspace, serverName, serverCfg) + if err != nil { + return "", err + } + defer client.Close() + out, err := client.request(callCtx, "tools/call", map[string]interface{}{ + "name": remoteToolName, + "arguments": arguments, + }) + if err != nil { + return "", err + } + return renderMCPToolCallResult(out) +} + +func (t *MCPTool) listServers() string { + type item struct { + Name string `json:"name"` + Transport string `json:"transport"` + Command string `json:"command"` + WorkingDir string `json:"working_dir,omitempty"` + Description string `json:"description,omitempty"` + } + names := make([]string, 0, len(t.cfg.Servers)) + for name, server := range t.cfg.Servers { + if server.Enabled { + names = append(names, name) + } + } + sort.Strings(names) + items := make([]item, 0, len(names)) + for _, name := range names { + server := t.cfg.Servers[name] + transport := strings.TrimSpace(server.Transport) + if transport == "" { + transport = "stdio" + } + items = append(items, item{ + Name: name, + Transport: transport, + Command: server.Command, + WorkingDir: server.WorkingDir, + Description: server.Description, + }) + } + out, _ := json.MarshalIndent(map[string]interface{}{"servers": items}, "", " ") + return string(out) +} + +func (t *MCPRemoteTool) Name() string { + return t.localName +} + +func (t *MCPRemoteTool) Description() string { + return t.description +} + +func (t *MCPRemoteTool) Parameters() map[string]interface{} { + return t.parameters +} + +func (t *MCPRemoteTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + return t.bridge.callServerTool(ctx, t.serverName, t.remoteName, args) +} + +func (t *MCPRemoteTool) CatalogEntry() map[string]interface{} { + return map[string]interface{}{ + "source": "mcp", + "mcp": map[string]interface{}{ + "server": t.serverName, + "remote_tool": t.remoteName, + }, + } +} + +type mcpClient struct { + workspace string + serverName string + cmd *exec.Cmd + stdin io.WriteCloser + reader *bufio.Reader + stderr bytes.Buffer + + writeMu sync.Mutex + waiters sync.Map + nextID atomic.Int64 +} + +type mcpInbound struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params map[string]interface{} `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *mcpResponseError `json:"error,omitempty"` +} + +type mcpResponseError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type mcpResponseWaiter struct { + ch chan mcpInbound +} + +func newMCPStdioClient(ctx context.Context, workspace, serverName string, cfg config.MCPServerConfig) (*mcpClient, error) { + command := strings.TrimSpace(cfg.Command) + if command == "" { + return nil, fmt.Errorf("mcp server %q command is empty", serverName) + } + cmd := exec.CommandContext(ctx, command, cfg.Args...) + cmd.Env = buildMCPEnv(cfg.Env) + cmd.Dir = resolveMCPWorkingDir(workspace, cfg.WorkingDir) + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("open stdin for mcp server %q: %w", serverName, err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("open stdout for mcp server %q: %w", serverName, err) + } + client := &mcpClient{ + workspace: workspace, + serverName: serverName, + cmd: cmd, + stdin: stdin, + reader: bufio.NewReader(stdout), + } + cmd.Stderr = &client.stderr + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start mcp server %q: %w", serverName, err) + } + go client.readLoop() + if err := client.initialize(ctx); err != nil { + client.Close() + return nil, err + } + return client, nil +} + +func (c *mcpClient) Close() error { + if c == nil || c.cmd == nil { + return nil + } + _ = c.stdin.Close() + done := make(chan error, 1) + go func() { + done <- c.cmd.Wait() + }() + select { + case err := <-done: + return err + case <-time.After(500 * time.Millisecond): + if c.cmd.Process != nil { + _ = c.cmd.Process.Kill() + } + <-done + return nil + } +} + +func (c *mcpClient) initialize(ctx context.Context) error { + result, err := c.request(ctx, "initialize", map[string]interface{}{ + "protocolVersion": mcpProtocolVersion, + "capabilities": map[string]interface{}{ + "roots": map[string]interface{}{ + "listChanged": false, + }, + }, + "clientInfo": map[string]interface{}{ + "name": "clawgo", + "version": "dev", + }, + }) + if err != nil { + return err + } + if _, ok := result["protocolVersion"]; !ok { + return fmt.Errorf("mcp server %q initialize missing protocolVersion", c.serverName) + } + return c.notify("notifications/initialized", map[string]interface{}{}) +} + +func (c *mcpClient) listAll(ctx context.Context, method, field string) (map[string]interface{}, error) { + items := make([]interface{}, 0) + cursor := "" + for { + params := map[string]interface{}{} + if strings.TrimSpace(cursor) != "" { + params["cursor"] = cursor + } + result, err := c.request(ctx, method, params) + if err != nil { + return nil, err + } + batch, _ := result[field].([]interface{}) + items = append(items, batch...) + next, _ := result["nextCursor"].(string) + if strings.TrimSpace(next) == "" { + return map[string]interface{}{field: items}, nil + } + cursor = next + } +} + +func (c *mcpClient) request(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) { + id := strconv.FormatInt(c.nextID.Add(1), 10) + waiter := &mcpResponseWaiter{ch: make(chan mcpInbound, 1)} + c.waiters.Store(id, waiter) + defer c.waiters.Delete(id) + + msg := map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + } + if err := c.writeMessage(msg); err != nil { + return nil, err + } + + select { + case resp := <-waiter.ch: + if resp.Error != nil { + return nil, fmt.Errorf("mcp %s %s failed: %s", c.serverName, method, resp.Error.Message) + } + var out map[string]interface{} + if len(resp.Result) == 0 { + return map[string]interface{}{}, nil + } + if err := json.Unmarshal(resp.Result, &out); err != nil { + return nil, fmt.Errorf("decode mcp %s %s result: %w", c.serverName, method, err) + } + return out, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (c *mcpClient) notify(method string, params map[string]interface{}) error { + return c.writeMessage(map[string]interface{}{ + "jsonrpc": "2.0", + "method": method, + "params": params, + }) +} + +func (c *mcpClient) writeMessage(payload map[string]interface{}) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + frame := fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(data), data) + c.writeMu.Lock() + defer c.writeMu.Unlock() + _, err = io.WriteString(c.stdin, frame) + return err +} + +func (c *mcpClient) readLoop() { + for { + msg, err := c.readMessage() + if err != nil { + c.failAll(err) + return + } + if msg.Method != "" && msg.ID != nil { + _ = c.handleServerRequest(msg) + continue + } + if msg.Method != "" { + continue + } + if key, ok := normalizeMCPID(msg.ID); ok { + if raw, ok := c.waiters.Load(key); ok { + raw.(*mcpResponseWaiter).ch <- msg + } + } + } +} + +func (c *mcpClient) handleServerRequest(msg mcpInbound) error { + method := strings.TrimSpace(msg.Method) + switch method { + case "roots/list": + return c.reply(msg.ID, map[string]interface{}{ + "roots": []map[string]interface{}{ + { + "uri": fileURI(resolveMCPWorkingDir(c.workspace, "")), + "name": filepath.Base(resolveMCPWorkingDir(c.workspace, "")), + }, + }, + }) + case "ping": + return c.reply(msg.ID, map[string]interface{}{}) + default: + return c.replyError(msg.ID, -32601, "method not supported by clawgo mcp client") + } +} + +func (c *mcpClient) reply(id interface{}, result map[string]interface{}) error { + return c.writeMessage(map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": result, + }) +} + +func (c *mcpClient) replyError(id interface{}, code int, message string) error { + return c.writeMessage(map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "error": map[string]interface{}{ + "code": code, + "message": message, + }, + }) +} + +func (c *mcpClient) failAll(err error) { + message := err.Error() + if stderr := strings.TrimSpace(c.stderr.String()); stderr != "" { + message += ": " + stderr + } + c.waiters.Range(func(_, value interface{}) bool { + value.(*mcpResponseWaiter).ch <- mcpInbound{ + Error: &mcpResponseError{Message: message}, + } + return true + }) +} + +func (c *mcpClient) readMessage() (mcpInbound, error) { + length := 0 + for { + line, err := c.reader.ReadString('\n') + if err != nil { + return mcpInbound{}, err + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + break + } + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + continue + } + if strings.EqualFold(strings.TrimSpace(parts[0]), "Content-Length") { + length, _ = strconv.Atoi(strings.TrimSpace(parts[1])) + } + } + if length <= 0 { + return mcpInbound{}, fmt.Errorf("invalid mcp content length") + } + body := make([]byte, length) + if _, err := io.ReadFull(c.reader, body); err != nil { + return mcpInbound{}, err + } + var msg mcpInbound + if err := json.Unmarshal(body, &msg); err != nil { + return mcpInbound{}, err + } + return msg, nil +} + +func buildMCPEnv(overrides map[string]string) []string { + env := os.Environ() + path := os.Getenv("PATH") + fallback := "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/homebrew/bin:/opt/homebrew/sbin" + if strings.TrimSpace(path) == "" { + env = append(env, "PATH="+fallback) + } else { + env = append(env, "PATH="+path+":"+fallback) + } + for key, value := range overrides { + env = append(env, key+"="+value) + } + return env +} + +func resolveMCPWorkingDir(workspace, wd string) string { + wd = strings.TrimSpace(wd) + if wd != "" { + return wd + } + if abs, err := filepath.Abs(workspace); err == nil { + return abs + } + return workspace +} + +func fileURI(path string) string { + abs, err := filepath.Abs(path) + if err != nil { + abs = path + } + return (&url.URL{Scheme: "file", Path: filepath.ToSlash(abs)}).String() +} + +func normalizeMCPID(id interface{}) (string, bool) { + switch v := id.(type) { + case string: + return v, v != "" + case float64: + return strconv.FormatInt(int64(v), 10), true + case int: + return strconv.Itoa(v), true + case int64: + return strconv.FormatInt(v, 10), true + default: + return "", false + } +} + +func prettyJSON(v interface{}) (string, error) { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return "", err + } + return string(data), nil +} + +func buildMCPDynamicToolName(serverName, remoteName string) string { + base := "mcp__" + sanitizeMCPToolSegment(serverName) + "__" + sanitizeMCPToolSegment(remoteName) + if len(base) <= 64 { + return base + } + hash := fnv.New32a() + _, _ = hash.Write([]byte(serverName + "::" + remoteName)) + suffix := fmt.Sprintf("_%x", hash.Sum32()) + trimmed := base + if len(trimmed)+len(suffix) > 64 { + trimmed = trimmed[:64-len(suffix)] + } + return trimmed + suffix +} + +func ParseMCPDynamicToolName(name string) (serverName string, remoteName string, ok bool) { + const prefix = "mcp__" + if !strings.HasPrefix(strings.TrimSpace(name), prefix) { + return "", "", false + } + rest := strings.TrimPrefix(strings.TrimSpace(name), prefix) + parts := strings.SplitN(rest, "__", 2) + if len(parts) != 2 || strings.TrimSpace(parts[0]) == "" || strings.TrimSpace(parts[1]) == "" { + return "", "", false + } + return parts[0], parts[1], true +} + +var mcpToolSegmentPattern = regexp.MustCompile(`[^a-zA-Z0-9_]+`) + +func sanitizeMCPToolSegment(in string) string { + in = strings.TrimSpace(strings.ToLower(in)) + in = mcpToolSegmentPattern.ReplaceAllString(in, "_") + in = strings.Trim(in, "_") + if in == "" { + return "tool" + } + return in +} + +func buildMCPDynamicToolDescription(serverName string, toolMap map[string]interface{}) string { + desc := strings.TrimSpace(mcpStringArg(toolMap, "description")) + remoteName := strings.TrimSpace(mcpStringArg(toolMap, "name")) + if desc == "" { + desc = fmt.Sprintf("Proxy to MCP tool %q on server %q.", remoteName, serverName) + } else { + desc = fmt.Sprintf("%s (MCP server: %s, remote tool: %s)", desc, serverName, remoteName) + } + return desc +} + +func normalizeMCPSchema(raw interface{}) map[string]interface{} { + schema, _ := raw.(map[string]interface{}) + if schema == nil { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } + } + out := map[string]interface{}{} + for k, v := range schema { + out[k] = v + } + if _, ok := out["type"]; !ok { + out["type"] = "object" + } + if _, ok := out["properties"]; !ok { + out["properties"] = map[string]interface{}{} + } + return out +} + +func renderMCPToolCallResult(result map[string]interface{}) (string, error) { + if result == nil { + return "", nil + } + if content, ok := result["content"].([]interface{}); ok && len(content) > 0 { + parts := make([]string, 0, len(content)) + for _, item := range content { + m, _ := item.(map[string]interface{}) + if m == nil { + continue + } + kind := strings.TrimSpace(mcpStringArg(m, "type")) + switch kind { + case "text": + if text := mcpStringArg(m, "text"); strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + default: + if text := mcpStringArg(m, "text"); strings.TrimSpace(text) != "" { + parts = append(parts, text) + } else { + data, err := prettyJSON(m) + if err == nil { + parts = append(parts, data) + } + } + } + } + if len(parts) > 0 { + if structured, ok := result["structuredContent"]; ok { + data, err := prettyJSON(structured) + if err == nil && strings.TrimSpace(data) != "" && data != "{}" { + parts = append(parts, data) + } + } + return strings.Join(parts, "\n\n"), nil + } + } + return prettyJSON(result) +} + +func mcpStringArg(args map[string]interface{}, key string) string { + v, _ := args[key].(string) + return v +} + +func mcpObjectArg(args map[string]interface{}, key string) map[string]interface{} { + v, _ := args[key].(map[string]interface{}) + if v == nil { + return map[string]interface{}{} + } + return v +} diff --git a/pkg/tools/mcp_test.go b/pkg/tools/mcp_test.go new file mode 100644 index 0000000..d589e45 --- /dev/null +++ b/pkg/tools/mcp_test.go @@ -0,0 +1,354 @@ +package tools + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "strconv" + "strings" + "testing" + "time" + + "clawgo/pkg/config" +) + +func TestMCPToolListServers(t *testing.T) { + tool := NewMCPTool("/tmp/workspace", config.MCPToolsConfig{ + Enabled: true, + RequestTimeoutSec: 5, + Servers: map[string]config.MCPServerConfig{ + "demo": { + Enabled: true, + Transport: "stdio", + Command: "demo-server", + Description: "demo", + }, + "disabled": { + Enabled: false, + Transport: "stdio", + Command: "nope", + }, + }, + }) + out, err := tool.Execute(context.Background(), map[string]interface{}{"action": "list_servers"}) + if err != nil { + t.Fatalf("list_servers returned error: %v", err) + } + if !strings.Contains(out, `"name": "demo"`) { + t.Fatalf("expected enabled server in output, got: %s", out) + } + if strings.Contains(out, "disabled") { + t.Fatalf("did not expect disabled server in output, got: %s", out) + } +} + +func TestMCPToolCallTool(t *testing.T) { + tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{ + Enabled: true, + RequestTimeoutSec: 5, + Servers: map[string]config.MCPServerConfig{ + "helper": { + Enabled: true, + Transport: "stdio", + Command: os.Args[0], + Args: []string{"-test.run=TestMCPHelperProcess", "--"}, + Env: map[string]string{ + "GO_WANT_HELPER_PROCESS": "1", + }, + }, + }, + }) + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "call_tool", + "server": "helper", + "tool": "echo", + "arguments": map[string]interface{}{"text": "hello"}, + }) + if err != nil { + t.Fatalf("call_tool returned error: %v", err) + } + if !strings.Contains(out, "echo:hello") { + t.Fatalf("expected echo output, got: %s", out) + } +} + +func TestMCPToolDiscoverTools(t *testing.T) { + tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{ + Enabled: true, + RequestTimeoutSec: 5, + Servers: map[string]config.MCPServerConfig{ + "helper": { + Enabled: true, + Transport: "stdio", + Command: os.Args[0], + Args: []string{"-test.run=TestMCPHelperProcess", "--"}, + Env: map[string]string{ + "GO_WANT_HELPER_PROCESS": "1", + }, + }, + }, + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + discovered := tool.DiscoverTools(ctx) + if len(discovered) != 1 { + t.Fatalf("expected 1 discovered tool, got %d", len(discovered)) + } + if got := discovered[0].Name(); got != "mcp__helper__echo" { + t.Fatalf("unexpected discovered tool name: %s", got) + } + out, err := discovered[0].Execute(ctx, map[string]interface{}{"text": "world"}) + if err != nil { + t.Fatalf("discovered tool execute returned error: %v", err) + } + if strings.TrimSpace(out) != "echo:world" { + t.Fatalf("unexpected discovered tool output: %q", out) + } +} + +func TestMCPHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + runMCPHelper() + os.Exit(0) +} + +func runMCPHelper() { + reader := bufio.NewReader(os.Stdin) + writer := bufio.NewWriter(os.Stdout) + for { + msg, err := readHelperFrame(reader) + if err != nil { + return + } + method, _ := msg["method"].(string) + id, hasID := msg["id"] + switch method { + case "initialize": + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "protocolVersion": mcpProtocolVersion, + "capabilities": map[string]interface{}{ + "tools": map[string]interface{}{}, + "resources": map[string]interface{}{}, + "prompts": map[string]interface{}{}, + }, + "serverInfo": map[string]interface{}{ + "name": "helper", + "version": "1.0.0", + }, + }, + }) + case "notifications/initialized": + continue + case "tools/list": + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "echo", + "description": "Echo the provided text", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "text": map[string]interface{}{ + "type": "string", + }, + }, + "required": []string{"text"}, + }, + }, + }, + }, + }) + case "tools/call": + params, _ := msg["params"].(map[string]interface{}) + args, _ := params["arguments"].(map[string]interface{}) + text, _ := args["text"].(string) + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "content": []map[string]interface{}{ + { + "type": "text", + "text": "echo:" + text, + }, + }, + }, + }) + case "resources/list": + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "resources": []map[string]interface{}{ + {"uri": "file:///tmp/demo.txt", "name": "demo"}, + }, + }, + }) + case "resources/read": + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "contents": []map[string]interface{}{ + {"uri": "file:///tmp/demo.txt", "mimeType": "text/plain", "text": "demo content"}, + }, + }, + }) + case "prompts/list": + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "prompts": []map[string]interface{}{ + {"name": "greeter", "description": "Greets"}, + }, + }, + }) + case "prompts/get": + params, _ := msg["params"].(map[string]interface{}) + args, _ := params["arguments"].(map[string]interface{}) + name, _ := args["name"].(string) + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "description": "Greets", + "messages": []map[string]interface{}{ + { + "role": "user", + "content": map[string]interface{}{ + "type": "text", + "text": "hello " + name, + }, + }, + }, + }, + }) + default: + if hasID { + writeHelperFrame(writer, map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "error": map[string]interface{}{ + "code": -32601, + "message": "method not found", + }, + }) + } + } + } +} + +func readHelperFrame(r *bufio.Reader) (map[string]interface{}, error) { + length := 0 + for { + line, err := r.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + break + } + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 && strings.EqualFold(strings.TrimSpace(parts[0]), "Content-Length") { + length, _ = strconv.Atoi(strings.TrimSpace(parts[1])) + } + } + body := make([]byte, length) + if _, err := io.ReadFull(r, body); err != nil { + return nil, err + } + var msg map[string]interface{} + if err := json.Unmarshal(body, &msg); err != nil { + return nil, err + } + return msg, nil +} + +func writeHelperFrame(w *bufio.Writer, payload map[string]interface{}) { + data, _ := json.Marshal(payload) + _, _ = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n%s", len(data), data) + _ = w.Flush() +} + +func TestValidateMCPTools(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Tools.MCP.Enabled = true + cfg.Tools.MCP.RequestTimeoutSec = 0 + cfg.Tools.MCP.Servers = map[string]config.MCPServerConfig{ + "bad": { + Enabled: true, + Transport: "http", + Command: "", + WorkingDir: "relative", + }, + } + errs := config.Validate(cfg) + if len(errs) == 0 { + t.Fatal("expected validation errors") + } + got := make([]string, 0, len(errs)) + for _, err := range errs { + got = append(got, err.Error()) + } + joined := strings.Join(got, "\n") + for _, want := range []string{ + "tools.mcp.request_timeout_sec must be > 0 when tools.mcp.enabled=true", + "tools.mcp.servers.bad.transport must be 'stdio'", + "tools.mcp.servers.bad.command is required when enabled=true", + "tools.mcp.servers.bad.working_dir must be an absolute path", + } { + if !strings.Contains(joined, want) { + t.Fatalf("expected validation error %q in:\n%s", want, joined) + } + } +} + +func TestMCPToolListTools(t *testing.T) { + tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{ + Enabled: true, + RequestTimeoutSec: 5, + Servers: map[string]config.MCPServerConfig{ + "helper": { + Enabled: true, + Transport: "stdio", + Command: os.Args[0], + Args: []string{"-test.run=TestMCPHelperProcess", "--"}, + Env: map[string]string{ + "GO_WANT_HELPER_PROCESS": "1", + }, + }, + }, + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + out, err := tool.Execute(ctx, map[string]interface{}{ + "action": "list_tools", + "server": "helper", + }) + if err != nil { + t.Fatalf("list_tools returned error: %v", err) + } + if !strings.Contains(out, `"name": "echo"`) { + t.Fatalf("expected tool listing, got: %s", out) + } +} + +func TestBuildMCPDynamicToolName(t *testing.T) { + got := buildMCPDynamicToolName("Context7 Server", "resolve-library.id") + if got != "mcp__context7_server__resolve_library_id" { + t.Fatalf("unexpected tool name: %s", got) + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 3790653..174aabe 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -87,6 +87,25 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { return definitions } +func (r *ToolRegistry) Catalog() []map[string]interface{} { + cur, _ := r.snapshot.Load().(map[string]Tool) + items := make([]map[string]interface{}, 0, len(cur)) + for _, tool := range cur { + item := map[string]interface{}{ + "name": tool.Name(), + "description": tool.Description(), + "parameters": tool.Parameters(), + } + if ct, ok := tool.(CatalogTool); ok { + for k, v := range ct.CatalogEntry() { + item[k] = v + } + } + items = append(items, item) + } + return items +} + // List returns a list of all registered tool names. func (r *ToolRegistry) List() []string { cur, _ := r.snapshot.Load().(map[string]Tool) diff --git a/webui/src/App.tsx b/webui/src/App.tsx index 25a974d..4848fe4 100644 --- a/webui/src/App.tsx +++ b/webui/src/App.tsx @@ -9,6 +9,7 @@ import Config from './pages/Config'; import Cron from './pages/Cron'; import Logs from './pages/Logs'; import Skills from './pages/Skills'; +import MCP from './pages/MCP'; import Memory from './pages/Memory'; import TaskAudit from './pages/TaskAudit'; import EKG from './pages/EKG'; @@ -28,6 +29,7 @@ export default function App() { } /> } /> } /> + } /> } /> } /> } /> diff --git a/webui/src/components/Sidebar.tsx b/webui/src/components/Sidebar.tsx index b5aea96..9aa5e00 100644 --- a/webui/src/components/Sidebar.tsx +++ b/webui/src/components/Sidebar.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { LayoutDashboard, MessageSquare, Settings, Clock, Terminal, Zap, FolderOpen, ClipboardList, BrainCircuit, Hash, Bot, Boxes, PanelLeftClose, PanelLeftOpen } from 'lucide-react'; +import { LayoutDashboard, MessageSquare, Settings, Clock, Terminal, Zap, FolderOpen, ClipboardList, BrainCircuit, Hash, Bot, Boxes, PanelLeftClose, PanelLeftOpen, Plug } from 'lucide-react'; import { useTranslation } from 'react-i18next'; import { useAppContext } from '../context/AppContext'; import NavItem from './NavItem'; @@ -29,6 +29,7 @@ const Sidebar: React.FC = () => { title: t('sidebarConfig'), items: [ { icon: , label: t('config'), to: '/config' }, + { icon: , label: t('mcpServices'), to: '/mcp' }, { icon: , label: t('subagentProfiles'), to: '/subagent-profiles' }, { icon: , label: t('cronJobs'), to: '/cron' }, ], diff --git a/webui/src/i18n/index.ts b/webui/src/i18n/index.ts index 64b075d..981ce1e 100644 --- a/webui/src/i18n/index.ts +++ b/webui/src/i18n/index.ts @@ -8,6 +8,8 @@ const resources = { dashboard: 'Dashboard', chat: 'Chat', config: 'Config', + mcpServices: 'MCP', + mcpServicesHint: 'Manage MCP servers, install packages, and inspect discovered remote tools.', cronJobs: 'Cron Jobs', nodes: 'Nodes', agentTree: 'Agent Tree', @@ -298,6 +300,23 @@ const resources = { configProxies: 'Proxies', configNewProviderName: 'new provider name', configNoCustomProviders: 'No custom providers yet.', + configMCPServers: 'MCP Servers', + configNewMCPServerName: 'new MCP server name', + configNoMCPServers: 'No MCP servers configured yet.', + configMCPInstallTitle: 'Install MCP Server Package', + configMCPInstallMessage: 'Install an npm package for MCP server "{{name}}"?', + configMCPInstallPlaceholder: '@scope/package', + configMCPInstalling: 'Installing MCP package...', + configMCPInstallFailedTitle: 'MCP install failed', + configMCPInstallFailedMessage: 'Failed to install MCP package', + configMCPInstallDoneTitle: 'MCP package installed', + configMCPInstallDoneMessage: 'Installed {{package}} and resolved binary {{bin}}.', + configMCPInstallDoneFallback: 'MCP package installed.', + configMCPDiscoveredTools: 'Discovered MCP Tools', + configMCPDiscoveredToolsCount: '{{count}} discovered', + configNoMCPDiscoveredTools: 'No MCP tools discovered yet.', + configDeleteMCPServerConfirmTitle: 'Delete MCP Server', + configDeleteMCPServerConfirmMessage: 'Delete MCP server "{{name}}" from current config?', configNoGroups: 'No config groups found.', configDiffPreviewCount: 'Diff Preview ({{count}} items)', saveConfigFailed: 'Failed to save config', @@ -388,6 +407,7 @@ const resources = { api_base: 'API Base', protocol: 'Protocol', models: 'Models', + command: 'Command', responses: 'Responses', streaming: 'Streaming', web_search_enabled: 'Web Search Enabled', @@ -403,6 +423,7 @@ const resources = { version: 'Version', name: 'Name', description: 'Description', + package: 'Package', system_prompt: 'System Prompt', tools: 'Tools', auth: 'Authentication', @@ -479,8 +500,14 @@ const resources = { sandbox: 'Sandbox', image: 'Image', web: 'Web', + mcp: 'MCP', search: 'Search', max_results: 'Max Results', + request_timeout_sec: 'Request Timeout (Seconds)', + servers: 'Servers', + transport: 'Transport', + args: 'Arguments', + env: 'Environment', proxies: 'Proxies', cross_session_call_id: 'Cross-session Call ID', supports_responses_compact: 'Supports Responses Compact', @@ -502,6 +529,8 @@ const resources = { dashboard: '仪表盘', chat: '对话', config: '配置', + mcpServices: 'MCP', + mcpServicesHint: '管理 MCP 服务、安装服务包,并查看已发现的远端工具。', cronJobs: '定时任务', nodes: '节点', agentTree: '代理树', @@ -792,6 +821,23 @@ const resources = { configProxies: '代理配置', configNewProviderName: '新 provider 名称', configNoCustomProviders: '暂无自定义 provider。', + configMCPServers: 'MCP 服务', + configNewMCPServerName: '新的 MCP 服务名', + configNoMCPServers: '暂无 MCP 服务配置。', + configMCPInstallTitle: '安装 MCP 服务包', + configMCPInstallMessage: '是否为 MCP 服务 “{{name}}” 安装 npm 包?', + configMCPInstallPlaceholder: '@scope/package', + configMCPInstalling: '正在安装 MCP 包...', + configMCPInstallFailedTitle: 'MCP 安装失败', + configMCPInstallFailedMessage: '安装 MCP 包失败', + configMCPInstallDoneTitle: 'MCP 包安装完成', + configMCPInstallDoneMessage: '已安装 {{package}},并解析到可执行文件 {{bin}}。', + configMCPInstallDoneFallback: 'MCP 包已安装。', + configMCPDiscoveredTools: '已发现的 MCP 工具', + configMCPDiscoveredToolsCount: '已发现 {{count}} 个', + configNoMCPDiscoveredTools: '暂未发现 MCP 工具。', + configDeleteMCPServerConfirmTitle: '删除 MCP 服务', + configDeleteMCPServerConfirmMessage: '确认从当前配置中删除 MCP 服务 “{{name}}”吗?', configNoGroups: '未找到配置分组。', configDiffPreviewCount: '配置差异预览({{count}}项)', saveConfigFailed: '保存配置失败', @@ -882,6 +928,7 @@ const resources = { api_base: 'API 基础地址', protocol: '协议', models: '模型列表', + command: '命令', responses: 'Responses 配置', streaming: '流式输出', web_search_enabled: '启用网页搜索', @@ -897,6 +944,7 @@ const resources = { version: '版本', name: '名称', description: '描述', + package: '包名', system_prompt: '系统提示词', tools: '工具', auth: '身份验证', @@ -973,8 +1021,14 @@ const resources = { sandbox: '沙箱', image: '镜像', web: 'Web', + mcp: 'MCP', search: '搜索', max_results: '最大结果数', + request_timeout_sec: '请求超时(秒)', + servers: '服务列表', + transport: '传输方式', + args: '参数', + env: '环境变量', proxies: '代理集合', cross_session_call_id: '跨会话调用 ID', supports_responses_compact: '支持紧凑 responses', diff --git a/webui/src/pages/MCP.tsx b/webui/src/pages/MCP.tsx new file mode 100644 index 0000000..9c0f1e5 --- /dev/null +++ b/webui/src/pages/MCP.tsx @@ -0,0 +1,308 @@ +import React, { useEffect, useMemo, useState } from 'react'; +import { RefreshCw, Save } from 'lucide-react'; +import { useTranslation } from 'react-i18next'; +import { useAppContext } from '../context/AppContext'; +import { useUI } from '../context/UIContext'; + +function setPath(obj: any, path: string, value: any) { + const keys = path.split('.'); + const next = JSON.parse(JSON.stringify(obj || {})); + let cur = next; + for (let i = 0; i < keys.length - 1; i++) { + const k = keys[i]; + if (typeof cur[k] !== 'object' || cur[k] === null) cur[k] = {}; + cur = cur[k]; + } + cur[keys[keys.length - 1]] = value; + return next; +} + +const MCP: React.FC = () => { + const { t } = useTranslation(); + const { cfg, setCfg, q, loadConfig, setConfigEditing } = useAppContext(); + const ui = useUI(); + const [newMCPServerName, setNewMCPServerName] = useState(''); + const [mcpTools, setMcpTools] = useState>([]); + const [baseline, setBaseline] = useState(null); + + const currentPayload = useMemo(() => cfg || {}, [cfg]); + const isDirty = useMemo(() => { + if (baseline == null) return false; + return JSON.stringify(baseline) !== JSON.stringify(currentPayload); + }, [baseline, currentPayload]); + + useEffect(() => { + if (baseline == null && cfg && Object.keys(cfg).length > 0) { + setBaseline(JSON.parse(JSON.stringify(cfg))); + } + }, [cfg, baseline]); + + useEffect(() => { + setConfigEditing(isDirty); + return () => setConfigEditing(false); + }, [isDirty, setConfigEditing]); + + async function refreshMCPTools(cancelled = false) { + try { + const r = await fetch(`/webui/api/tools${q}`); + if (!r.ok) throw new Error('Failed to load tools'); + const data = await r.json(); + if (!cancelled) { + setMcpTools(Array.isArray(data?.mcp_tools) ? data.mcp_tools : []); + } + } catch { + if (!cancelled) setMcpTools([]); + } + } + + useEffect(() => { + let cancelled = false; + void refreshMCPTools(cancelled); + return () => { + cancelled = true; + }; + }, [q]); + + function updateMCPServerField(name: string, field: string, value: any) { + setCfg((v) => setPath(v, `tools.mcp.servers.${name}.${field}`, value)); + } + + function addMCPServer() { + const name = newMCPServerName.trim(); + if (!name) return; + setCfg((v) => { + const next = JSON.parse(JSON.stringify(v || {})); + if (!next.tools || typeof next.tools !== 'object') next.tools = {}; + if (!next.tools.mcp || typeof next.tools.mcp !== 'object') { + next.tools.mcp = { enabled: true, request_timeout_sec: 20, servers: {} }; + } + if (!next.tools.mcp.servers || typeof next.tools.mcp.servers !== 'object' || Array.isArray(next.tools.mcp.servers)) { + next.tools.mcp.servers = {}; + } + if (!next.tools.mcp.servers[name]) { + next.tools.mcp.servers[name] = { + enabled: true, + transport: 'stdio', + command: '', + args: [], + env: {}, + working_dir: '', + description: '', + package: '', + }; + } + return next; + }); + setNewMCPServerName(''); + } + + async function removeMCPServer(name: string) { + const ok = await ui.confirmDialog({ + title: t('configDeleteMCPServerConfirmTitle'), + message: t('configDeleteMCPServerConfirmMessage', { name }), + danger: true, + confirmText: t('delete'), + }); + if (!ok) return; + setCfg((v) => { + const next = JSON.parse(JSON.stringify(v || {})); + if (next?.tools?.mcp?.servers && typeof next.tools.mcp.servers === 'object') { + delete next.tools.mcp.servers[name]; + } + return next; + }); + } + + function inferMCPPackage(server: any): string { + if (typeof server?.package === 'string' && server.package.trim()) return server.package.trim(); + const command = String(server?.command || '').trim(); + const args = Array.isArray(server?.args) ? server.args.map((x: any) => String(x).trim()).filter(Boolean) : []; + if (command === 'npx' || command.endsWith('/npx')) { + const pkg = args.find((arg: string) => !arg.startsWith('-')); + return pkg || ''; + } + return ''; + } + + async function installMCPServerPackage(name: string, server: any) { + const defaultPkg = inferMCPPackage(server); + const pkg = await ui.promptDialog({ + title: t('configMCPInstallTitle'), + message: t('configMCPInstallMessage', { name }), + inputPlaceholder: defaultPkg || t('configMCPInstallPlaceholder'), + initialValue: defaultPkg, + confirmText: t('install'), + }); + const packageName = String(pkg || '').trim(); + if (!packageName) return; + + ui.showLoading(t('configMCPInstalling')); + try { + const r = await fetch(`/webui/api/mcp/install${q}`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ package: packageName }), + }); + const text = await r.text(); + if (!r.ok) { + await ui.notify({ title: t('configMCPInstallFailedTitle'), message: text || t('configMCPInstallFailedMessage') }); + return; + } + let data: any = null; + try { + data = JSON.parse(text); + } catch { + data = null; + } + if (data?.bin_path) { + updateMCPServerField(name, 'command', data.bin_path); + updateMCPServerField(name, 'args', []); + updateMCPServerField(name, 'package', packageName); + } else { + updateMCPServerField(name, 'package', packageName); + } + await ui.notify({ + title: t('configMCPInstallDoneTitle'), + message: data?.bin_path + ? t('configMCPInstallDoneMessage', { package: packageName, bin: data.bin_path }) + : (text || t('configMCPInstallDoneFallback')), + }); + } finally { + ui.hideLoading(); + } + } + + async function saveConfig() { + try { + const payload = cfg; + const submit = async (confirmRisky: boolean) => { + const body = confirmRisky ? { ...payload, confirm_risky: true } : payload; + return ui.withLoading(async () => { + const r = await fetch(`/webui/api/config${q}`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + }); + const text = await r.text(); + let data: any = null; + try { + data = text ? JSON.parse(text) : null; + } catch { + data = null; + } + return { ok: r.ok, text, data }; + }, t('saving')); + }; + + let result = await submit(false); + if (!result.ok && result.data?.requires_confirm) { + const changedFields = Array.isArray(result.data?.changed_fields) ? result.data.changed_fields.join(', ') : ''; + const ok = await ui.confirmDialog({ + title: t('configRiskyChangeConfirmTitle'), + message: t('configRiskyChangeConfirmMessage', { fields: changedFields || '-' }), + danger: true, + confirmText: t('saveChanges'), + }); + if (!ok) return; + result = await submit(true); + } + + if (!result.ok) { + throw new Error(result.data?.error || result.text || 'save failed'); + } + + await ui.notify({ title: t('saved'), message: t('configSaved') }); + setBaseline(JSON.parse(JSON.stringify(payload))); + setConfigEditing(false); + await loadConfig(true); + await refreshMCPTools(); + } catch (e) { + await ui.notify({ title: t('requestFailed'), message: `${t('saveConfigFailed')}: ${e}` }); + } + } + + return ( +
+
+
+

{t('mcpServices')}

+

{t('mcpServicesHint')}

+
+ +
+ +
+ +
+ +
+
+
{t('configMCPServers')}
+
+ setNewMCPServerName(e.target.value)} placeholder={t('configNewMCPServerName')} className="px-2 py-1 rounded-lg bg-zinc-900/70 border border-zinc-700 text-xs" /> + +
+
+
+ {Object.entries((((cfg as any)?.tools?.mcp?.servers) || {}) as Record).map(([name, server]) => ( +
+
{name}
+ + updateMCPServerField(name, 'command', e.target.value)} placeholder={t('configLabels.command')} className="md:col-span-2 px-2 py-1 rounded-lg bg-zinc-950/70 border border-zinc-800" /> + updateMCPServerField(name, 'working_dir', e.target.value)} placeholder={t('configLabels.working_dir')} className="md:col-span-2 px-2 py-1 rounded-lg bg-zinc-950/70 border border-zinc-800" /> + updateMCPServerField(name, 'args', e.target.value.split(',').map(s=>s.trim()).filter(Boolean))} placeholder={`${t('configLabels.args')}${t('configCommaSeparatedHint')}`} className="md:col-span-2 px-2 py-1 rounded-lg bg-zinc-950/70 border border-zinc-800" /> + updateMCPServerField(name, 'package', e.target.value)} placeholder={t('configLabels.package')} className="md:col-span-1 px-2 py-1 rounded-lg bg-zinc-950/70 border border-zinc-800" /> + updateMCPServerField(name, 'description', e.target.value)} placeholder={t('configLabels.description')} className="md:col-span-1 px-2 py-1 rounded-lg bg-zinc-950/70 border border-zinc-800" /> + + +
+ ))} + {Object.keys((((cfg as any)?.tools?.mcp?.servers) || {}) as Record).length === 0 && ( +
{t('configNoMCPServers')}
+ )} +
+
+ +
+
+
{t('configMCPDiscoveredTools')}
+
{t('configMCPDiscoveredToolsCount', { count: mcpTools.length })}
+
+
+ {mcpTools.map((tool) => ( +
+
+
{tool.name}
+
+ {(tool.mcp?.server || '-')}{' · '}{(tool.mcp?.remote_tool || '-')} +
+
+ {tool.description && ( +
{tool.description}
+ )} +
+ ))} + {mcpTools.length === 0 && ( +
{t('configNoMCPDiscoveredTools')}
+ )} +
+
+
+ ); +}; + +export default MCP;