mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-05-08 10:17:30 +08:00
feat: expand mcp transports and skill execution
This commit is contained in:
531
pkg/tools/mcp.go
531
pkg/tools/mcp.go
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -39,6 +40,12 @@ type MCPRemoteTool struct {
|
||||
parameters map[string]interface{}
|
||||
}
|
||||
|
||||
type mcpRPCClient interface {
|
||||
listAll(ctx context.Context, method, field string) (map[string]interface{}, error)
|
||||
request(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
func NewMCPTool(workspace string, cfg config.MCPToolsConfig) *MCPTool {
|
||||
if cfg.RequestTimeoutSec <= 0 {
|
||||
cfg.RequestTimeoutSec = 20
|
||||
@@ -54,7 +61,7 @@ func (t *MCPTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *MCPTool) Description() string {
|
||||
return "Call configured MCP servers over stdio. Supports listing servers, tools, resources, prompts, and invoking remote MCP tools."
|
||||
return "Call configured MCP servers over stdio or HTTP transports. Supports listing servers, tools, resources, prompts, and invoking remote MCP tools."
|
||||
}
|
||||
|
||||
func (t *MCPTool) Parameters() map[string]interface{} {
|
||||
@@ -119,7 +126,7 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) (str
|
||||
callCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
client, err := newMCPStdioClient(callCtx, t.workspace, serverName, serverCfg)
|
||||
client, err := newMCPClient(callCtx, t.workspace, serverName, serverCfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -201,7 +208,7 @@ func (t *MCPTool) DiscoverTools(ctx context.Context) []Tool {
|
||||
seen := map[string]int{}
|
||||
for _, serverName := range names {
|
||||
serverCfg := t.cfg.Servers[serverName]
|
||||
client, err := newMCPStdioClient(ctx, t.workspace, serverName, serverCfg)
|
||||
client, err := newMCPClient(ctx, t.workspace, serverName, serverCfg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -249,7 +256,7 @@ func (t *MCPTool) callServerTool(ctx context.Context, serverName, remoteToolName
|
||||
}
|
||||
callCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
client, err := newMCPStdioClient(callCtx, t.workspace, serverName, serverCfg)
|
||||
client, err := newMCPClient(callCtx, t.workspace, serverName, serverCfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -268,6 +275,8 @@ func (t *MCPTool) listServers() string {
|
||||
type item struct {
|
||||
Name string `json:"name"`
|
||||
Transport string `json:"transport"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Permission string `json:"permission,omitempty"`
|
||||
Command string `json:"command"`
|
||||
WorkingDir string `json:"working_dir,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
@@ -286,9 +295,15 @@ func (t *MCPTool) listServers() string {
|
||||
if transport == "" {
|
||||
transport = "stdio"
|
||||
}
|
||||
permission := strings.TrimSpace(server.Permission)
|
||||
if permission == "" {
|
||||
permission = "workspace"
|
||||
}
|
||||
items = append(items, item{
|
||||
Name: name,
|
||||
Transport: transport,
|
||||
URL: strings.TrimSpace(server.URL),
|
||||
Permission: permission,
|
||||
Command: server.Command,
|
||||
WorkingDir: server.WorkingDir,
|
||||
Description: server.Description,
|
||||
@@ -326,6 +341,7 @@ func (t *MCPRemoteTool) CatalogEntry() map[string]interface{} {
|
||||
|
||||
type mcpClient struct {
|
||||
workspace string
|
||||
workingDir string
|
||||
serverName string
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
@@ -355,14 +371,35 @@ type mcpResponseWaiter struct {
|
||||
ch chan mcpInbound
|
||||
}
|
||||
|
||||
func newMCPClient(ctx context.Context, workspace, serverName string, cfg config.MCPServerConfig) (mcpRPCClient, error) {
|
||||
transport := strings.ToLower(strings.TrimSpace(cfg.Transport))
|
||||
if transport == "" {
|
||||
transport = "stdio"
|
||||
}
|
||||
switch transport {
|
||||
case "stdio":
|
||||
return newMCPStdioClient(ctx, workspace, serverName, cfg)
|
||||
case "sse":
|
||||
return newMCPSSEClient(ctx, workspace, serverName, cfg)
|
||||
case "http", "streamable_http":
|
||||
return newMCPHTTPClient(ctx, serverName, cfg)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported mcp transport %q", transport)
|
||||
}
|
||||
}
|
||||
|
||||
func newMCPStdioClient(ctx context.Context, workspace, serverName string, cfg config.MCPServerConfig) (*mcpClient, error) {
|
||||
command := strings.TrimSpace(cfg.Command)
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("mcp server %q command is empty", serverName)
|
||||
}
|
||||
workingDir, err := resolveMCPWorkingDir(workspace, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, command, cfg.Args...)
|
||||
cmd.Env = buildMCPEnv(cfg.Env)
|
||||
cmd.Dir = resolveMCPWorkingDir(workspace, cfg.WorkingDir)
|
||||
cmd.Dir = workingDir
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
@@ -374,6 +411,7 @@ func newMCPStdioClient(ctx context.Context, workspace, serverName string, cfg co
|
||||
}
|
||||
client := &mcpClient{
|
||||
workspace: workspace,
|
||||
workingDir: workingDir,
|
||||
serverName: serverName,
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
@@ -391,6 +429,450 @@ func newMCPStdioClient(ctx context.Context, workspace, serverName string, cfg co
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type mcpHTTPClient struct {
|
||||
serverName string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
nextID atomic.Int64
|
||||
}
|
||||
|
||||
type mcpSSEClient struct {
|
||||
workspace string
|
||||
serverName string
|
||||
baseURL string
|
||||
endpointURL string
|
||||
client *http.Client
|
||||
cancel context.CancelFunc
|
||||
respBody io.ReadCloser
|
||||
|
||||
writeMu sync.Mutex
|
||||
waiters sync.Map
|
||||
nextID atomic.Int64
|
||||
|
||||
endpointOnce sync.Once
|
||||
endpointCh chan string
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
func newMCPHTTPClient(ctx context.Context, serverName string, cfg config.MCPServerConfig) (*mcpHTTPClient, error) {
|
||||
baseURL := strings.TrimSpace(cfg.URL)
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("mcp server %q url is empty", serverName)
|
||||
}
|
||||
client := &mcpHTTPClient{
|
||||
serverName: serverName,
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
if err := client.initialize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func newMCPSSEClient(ctx context.Context, workspace, serverName string, cfg config.MCPServerConfig) (*mcpSSEClient, error) {
|
||||
baseURL := strings.TrimSpace(cfg.URL)
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("mcp server %q url is empty", serverName)
|
||||
}
|
||||
streamCtx, cancel := context.WithCancel(context.Background())
|
||||
client := &mcpSSEClient{
|
||||
workspace: workspace,
|
||||
serverName: serverName,
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{Timeout: 0},
|
||||
cancel: cancel,
|
||||
endpointCh: make(chan string, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
req, err := http.NewRequestWithContext(streamCtx, http.MethodGet, baseURL, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
resp, err := client.client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("connect sse for mcp server %q: %w", serverName, err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
resp.Body.Close()
|
||||
cancel()
|
||||
return nil, fmt.Errorf("connect sse for mcp server %q failed: http %d %s", serverName, resp.StatusCode, strings.TrimSpace(string(data)))
|
||||
}
|
||||
client.respBody = resp.Body
|
||||
go client.readLoop()
|
||||
select {
|
||||
case endpoint := <-client.endpointCh:
|
||||
client.endpointURL = endpoint
|
||||
case err := <-client.errCh:
|
||||
client.Close()
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
client.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if err := client.initialize(ctx); err != nil {
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) Close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
if c.respBody != nil {
|
||||
_ = c.respBody.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) initialize(ctx context.Context) error {
|
||||
result, err := c.request(ctx, "initialize", map[string]interface{}{
|
||||
"protocolVersion": mcpProtocolVersion,
|
||||
"capabilities": map[string]interface{}{
|
||||
"roots": map[string]interface{}{
|
||||
"listChanged": false,
|
||||
},
|
||||
},
|
||||
"clientInfo": map[string]interface{}{
|
||||
"name": "clawgo",
|
||||
"version": "dev",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := result["protocolVersion"]; !ok {
|
||||
return fmt.Errorf("mcp server %q initialize missing protocolVersion", c.serverName)
|
||||
}
|
||||
return c.notify("notifications/initialized", map[string]interface{}{})
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) listAll(ctx context.Context, method, field string) (map[string]interface{}, error) {
|
||||
items := make([]interface{}, 0)
|
||||
cursor := ""
|
||||
for {
|
||||
params := map[string]interface{}{}
|
||||
if strings.TrimSpace(cursor) != "" {
|
||||
params["cursor"] = cursor
|
||||
}
|
||||
result, err := c.request(ctx, method, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
batch, _ := result[field].([]interface{})
|
||||
items = append(items, batch...)
|
||||
next, _ := result["nextCursor"].(string)
|
||||
if strings.TrimSpace(next) == "" {
|
||||
return map[string]interface{}{field: items}, nil
|
||||
}
|
||||
cursor = next
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) request(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
|
||||
id := strconv.FormatInt(c.nextID.Add(1), 10)
|
||||
waiter := &mcpResponseWaiter{ch: make(chan mcpInbound, 1)}
|
||||
c.waiters.Store(id, waiter)
|
||||
defer c.waiters.Delete(id)
|
||||
if err := c.postMessage(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
select {
|
||||
case resp := <-waiter.ch:
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("mcp %s %s failed: %s", c.serverName, method, resp.Error.Message)
|
||||
}
|
||||
var out map[string]interface{}
|
||||
if len(resp.Result) == 0 {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
if err := json.Unmarshal(resp.Result, &out); err != nil {
|
||||
return nil, fmt.Errorf("decode mcp %s %s result: %w", c.serverName, method, err)
|
||||
}
|
||||
return out, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) notify(method string, params map[string]interface{}) error {
|
||||
return c.postMessage(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) postMessage(payload map[string]interface{}) error {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
req, err := http.NewRequest(http.MethodPost, c.endpointURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("mcp %s post failed: http %d %s", c.serverName, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) readLoop() {
|
||||
reader := bufio.NewReader(c.respBody)
|
||||
var eventName string
|
||||
var dataLines []string
|
||||
emit := func() bool {
|
||||
if len(dataLines) == 0 && strings.TrimSpace(eventName) == "" {
|
||||
eventName = ""
|
||||
dataLines = nil
|
||||
return true
|
||||
}
|
||||
event := strings.TrimSpace(eventName)
|
||||
if event == "" {
|
||||
event = "message"
|
||||
}
|
||||
data := strings.Join(dataLines, "\n")
|
||||
if err := c.handleSSEEvent(event, data); err != nil {
|
||||
c.signalErr(err)
|
||||
return false
|
||||
}
|
||||
eventName = ""
|
||||
dataLines = nil
|
||||
return true
|
||||
}
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
c.signalErr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
if !emit() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, ":") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) handleSSEEvent(eventName, data string) error {
|
||||
switch eventName {
|
||||
case "endpoint":
|
||||
endpoint := strings.TrimSpace(data)
|
||||
if endpoint == "" {
|
||||
return fmt.Errorf("mcp server %q sent empty endpoint event", c.serverName)
|
||||
}
|
||||
resolved, err := resolveRelativeURL(c.baseURL, endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.endpointOnce.Do(func() {
|
||||
c.endpointURL = resolved
|
||||
c.endpointCh <- resolved
|
||||
})
|
||||
return nil
|
||||
case "message":
|
||||
var msg mcpInbound
|
||||
if err := json.Unmarshal([]byte(data), &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Method != "" && msg.ID != nil {
|
||||
return c.handleServerRequest(msg)
|
||||
}
|
||||
if msg.Method != "" {
|
||||
return nil
|
||||
}
|
||||
if key, ok := normalizeMCPID(msg.ID); ok {
|
||||
if raw, ok := c.waiters.Load(key); ok {
|
||||
raw.(*mcpResponseWaiter).ch <- msg
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) handleServerRequest(msg mcpInbound) error {
|
||||
method := strings.TrimSpace(msg.Method)
|
||||
switch method {
|
||||
case "roots/list":
|
||||
root := resolveMCPDefaultRoot(c.workspace)
|
||||
return c.postMessage(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg.ID,
|
||||
"result": map[string]interface{}{
|
||||
"roots": []map[string]interface{}{
|
||||
{"uri": fileURI(root), "name": filepath.Base(root)},
|
||||
},
|
||||
},
|
||||
})
|
||||
case "ping":
|
||||
return c.postMessage(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg.ID,
|
||||
"result": map[string]interface{}{},
|
||||
})
|
||||
default:
|
||||
return c.postMessage(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg.ID,
|
||||
"error": map[string]interface{}{
|
||||
"code": -32601,
|
||||
"message": "method not supported by clawgo mcp client",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mcpSSEClient) signalErr(err error) {
|
||||
select {
|
||||
case c.errCh <- err:
|
||||
default:
|
||||
}
|
||||
c.waiters.Range(func(_, value interface{}) bool {
|
||||
value.(*mcpResponseWaiter).ch <- mcpInbound{
|
||||
Error: &mcpResponseError{Message: err.Error()},
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func resolveRelativeURL(baseURL, ref string) (string, error) {
|
||||
base, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
target, err := url.Parse(ref)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base.ResolveReference(target).String(), nil
|
||||
}
|
||||
|
||||
func (c *mcpHTTPClient) Close() error { return nil }
|
||||
|
||||
func (c *mcpHTTPClient) initialize(ctx context.Context) error {
|
||||
result, err := c.request(ctx, "initialize", map[string]interface{}{
|
||||
"protocolVersion": mcpProtocolVersion,
|
||||
"capabilities": map[string]interface{}{},
|
||||
"clientInfo": map[string]interface{}{
|
||||
"name": "clawgo",
|
||||
"version": "dev",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := result["protocolVersion"]; !ok {
|
||||
return fmt.Errorf("mcp server %q initialize missing protocolVersion", c.serverName)
|
||||
}
|
||||
_, _ = c.request(ctx, "notifications/initialized", map[string]interface{}{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mcpHTTPClient) listAll(ctx context.Context, method, field string) (map[string]interface{}, error) {
|
||||
items := make([]interface{}, 0)
|
||||
cursor := ""
|
||||
for {
|
||||
params := map[string]interface{}{}
|
||||
if strings.TrimSpace(cursor) != "" {
|
||||
params["cursor"] = cursor
|
||||
}
|
||||
result, err := c.request(ctx, method, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
batch, _ := result[field].([]interface{})
|
||||
items = append(items, batch...)
|
||||
next, _ := result["nextCursor"].(string)
|
||||
if strings.TrimSpace(next) == "" {
|
||||
return map[string]interface{}{field: items}, nil
|
||||
}
|
||||
cursor = next
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mcpHTTPClient) request(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
|
||||
id := strconv.FormatInt(c.nextID.Add(1), 10)
|
||||
payload := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mcp %s %s failed: %w", c.serverName, method, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, fmt.Errorf("mcp %s %s failed: http %d %s", c.serverName, method, resp.StatusCode, strings.TrimSpace(string(data)))
|
||||
}
|
||||
var msg mcpInbound
|
||||
if err := json.NewDecoder(resp.Body).Decode(&msg); err != nil {
|
||||
return nil, fmt.Errorf("decode mcp %s %s result: %w", c.serverName, method, err)
|
||||
}
|
||||
if msg.Error != nil {
|
||||
return nil, fmt.Errorf("mcp %s %s failed: %s", c.serverName, method, msg.Error.Message)
|
||||
}
|
||||
if len(msg.Result) == 0 {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
var out map[string]interface{}
|
||||
if err := json.Unmarshal(msg.Result, &out); err != nil {
|
||||
return nil, fmt.Errorf("decode mcp %s %s result: %w", c.serverName, method, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *mcpClient) Close() error {
|
||||
if c == nil || c.cmd == nil {
|
||||
return nil
|
||||
@@ -536,11 +1018,15 @@ func (c *mcpClient) handleServerRequest(msg mcpInbound) error {
|
||||
method := strings.TrimSpace(msg.Method)
|
||||
switch method {
|
||||
case "roots/list":
|
||||
rootDir := c.workingDir
|
||||
if strings.TrimSpace(rootDir) == "" {
|
||||
rootDir = resolveMCPDefaultRoot(c.workspace)
|
||||
}
|
||||
return c.reply(msg.ID, map[string]interface{}{
|
||||
"roots": []map[string]interface{}{
|
||||
{
|
||||
"uri": fileURI(resolveMCPWorkingDir(c.workspace, "")),
|
||||
"name": filepath.Base(resolveMCPWorkingDir(c.workspace, "")),
|
||||
"uri": fileURI(rootDir),
|
||||
"name": filepath.Base(rootDir),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -631,11 +1117,34 @@ func buildMCPEnv(overrides map[string]string) []string {
|
||||
return env
|
||||
}
|
||||
|
||||
func resolveMCPWorkingDir(workspace, wd string) string {
|
||||
wd = strings.TrimSpace(wd)
|
||||
if wd != "" {
|
||||
return wd
|
||||
func resolveMCPWorkingDir(workspace string, cfg config.MCPServerConfig) (string, error) {
|
||||
root := resolveMCPDefaultRoot(workspace)
|
||||
permission := strings.ToLower(strings.TrimSpace(cfg.Permission))
|
||||
if permission == "" {
|
||||
permission = "workspace"
|
||||
}
|
||||
wd := strings.TrimSpace(cfg.WorkingDir)
|
||||
if wd == "" {
|
||||
return root, nil
|
||||
}
|
||||
if permission == "full" {
|
||||
if !filepath.IsAbs(wd) {
|
||||
return "", fmt.Errorf("mcp server %q working_dir must be absolute when permission=full", strings.TrimSpace(cfg.Command))
|
||||
}
|
||||
return filepath.Clean(wd), nil
|
||||
}
|
||||
if filepath.IsAbs(wd) {
|
||||
clean := filepath.Clean(wd)
|
||||
rel, err := filepath.Rel(root, clean)
|
||||
if err != nil || strings.HasPrefix(rel, "..") {
|
||||
return "", fmt.Errorf("mcp working_dir %q must stay within workspace root %q unless permission=full", clean, root)
|
||||
}
|
||||
return clean, nil
|
||||
}
|
||||
return filepath.Clean(filepath.Join(root, wd)), nil
|
||||
}
|
||||
|
||||
func resolveMCPDefaultRoot(workspace string) string {
|
||||
if abs, err := filepath.Abs(workspace); err == nil {
|
||||
return abs
|
||||
}
|
||||
|
||||
@@ -6,7 +6,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -290,9 +293,9 @@ func TestValidateMCPTools(t *testing.T) {
|
||||
cfg.Tools.MCP.Servers = map[string]config.MCPServerConfig{
|
||||
"bad": {
|
||||
Enabled: true,
|
||||
Transport: "http",
|
||||
Transport: "ws",
|
||||
Command: "",
|
||||
WorkingDir: "relative",
|
||||
WorkingDir: "/outside-workspace",
|
||||
},
|
||||
}
|
||||
errs := config.Validate(cfg)
|
||||
@@ -306,9 +309,7 @@ func TestValidateMCPTools(t *testing.T) {
|
||||
joined := strings.Join(got, "\n")
|
||||
for _, want := range []string{
|
||||
"tools.mcp.request_timeout_sec must be > 0 when tools.mcp.enabled=true",
|
||||
"tools.mcp.servers.bad.transport must be 'stdio'",
|
||||
"tools.mcp.servers.bad.command is required when enabled=true",
|
||||
"tools.mcp.servers.bad.working_dir must be an absolute path",
|
||||
"tools.mcp.servers.bad.transport must be one of: stdio, http, streamable_http, sse",
|
||||
} {
|
||||
if !strings.Contains(joined, want) {
|
||||
t.Fatalf("expected validation error %q in:\n%s", want, joined)
|
||||
@@ -316,6 +317,31 @@ func TestValidateMCPTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMCPToolsFullPermissionRequiresAbsolutePath(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.MCP.Enabled = true
|
||||
cfg.Tools.MCP.Servers = map[string]config.MCPServerConfig{
|
||||
"full": {
|
||||
Enabled: true,
|
||||
Transport: "stdio",
|
||||
Command: "demo",
|
||||
Permission: "full",
|
||||
WorkingDir: "relative",
|
||||
},
|
||||
}
|
||||
errs := config.Validate(cfg)
|
||||
if len(errs) == 0 {
|
||||
t.Fatal("expected validation errors")
|
||||
}
|
||||
joined := ""
|
||||
for _, err := range errs {
|
||||
joined += err.Error() + "\n"
|
||||
}
|
||||
if !strings.Contains(joined, "tools.mcp.servers.full.working_dir must be an absolute path when permission=full") {
|
||||
t.Fatalf("unexpected validation errors:\n%s", joined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPToolListTools(t *testing.T) {
|
||||
tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{
|
||||
Enabled: true,
|
||||
@@ -346,9 +372,219 @@ func TestMCPToolListTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPToolHTTPTransport(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
var req map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
method, _ := req["method"].(string)
|
||||
id := req["id"]
|
||||
var resp map[string]interface{}
|
||||
switch method {
|
||||
case "initialize":
|
||||
resp = map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]interface{}{
|
||||
"protocolVersion": mcpProtocolVersion,
|
||||
},
|
||||
}
|
||||
case "notifications/initialized":
|
||||
resp = map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]interface{}{},
|
||||
}
|
||||
case "tools/list":
|
||||
resp = map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]interface{}{
|
||||
"tools": []map[string]interface{}{
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echo text",
|
||||
"inputSchema": map[string]interface{}{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case "tools/call":
|
||||
resp = map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "echo:http"},
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
resp = map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": map[string]interface{}{"code": -32601, "message": "unsupported"},
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{
|
||||
Enabled: true,
|
||||
RequestTimeoutSec: 5,
|
||||
Servers: map[string]config.MCPServerConfig{
|
||||
"httpdemo": {
|
||||
Enabled: true,
|
||||
Transport: "http",
|
||||
URL: server.URL,
|
||||
},
|
||||
},
|
||||
})
|
||||
out, err := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"action": "list_tools",
|
||||
"server": "httpdemo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("http list_tools returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"name": "echo"`) {
|
||||
t.Fatalf("expected http tool listing, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPToolSSETransport(t *testing.T) {
|
||||
messageCh := make(chan string, 8)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/sse":
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "flush unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "event: endpoint\n")
|
||||
_, _ = fmt.Fprintf(w, "data: /messages\n\n")
|
||||
flusher.Flush()
|
||||
notify := r.Context().Done()
|
||||
for {
|
||||
select {
|
||||
case payload := <-messageCh:
|
||||
_, _ = fmt.Fprintf(w, "event: message\n")
|
||||
_, _ = fmt.Fprintf(w, "data: %s\n\n", payload)
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/messages":
|
||||
defer r.Body.Close()
|
||||
var req map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
method, _ := req["method"].(string)
|
||||
id := req["id"]
|
||||
switch method {
|
||||
case "initialize":
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]interface{}{
|
||||
"protocolVersion": mcpProtocolVersion,
|
||||
},
|
||||
})
|
||||
messageCh <- string(payload)
|
||||
case "tools/list":
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]interface{}{
|
||||
"tools": []map[string]interface{}{
|
||||
{"name": "echo", "description": "Echo text", "inputSchema": map[string]interface{}{"type": "object"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
messageCh <- string(payload)
|
||||
case "tools/call":
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]interface{}{
|
||||
"content": []map[string]interface{}{{"type": "text", "text": "echo:sse"}},
|
||||
},
|
||||
})
|
||||
messageCh <- string(payload)
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{
|
||||
Enabled: true,
|
||||
RequestTimeoutSec: 5,
|
||||
Servers: map[string]config.MCPServerConfig{
|
||||
"ssedemo": {
|
||||
Enabled: true,
|
||||
Transport: "sse",
|
||||
URL: server.URL + "/sse",
|
||||
},
|
||||
},
|
||||
})
|
||||
out, err := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"action": "list_tools",
|
||||
"server": "ssedemo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("sse list_tools returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"name": "echo"`) {
|
||||
t.Fatalf("expected sse tool listing, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMCPDynamicToolName(t *testing.T) {
|
||||
got := buildMCPDynamicToolName("Context7 Server", "resolve-library.id")
|
||||
if got != "mcp__context7_server__resolve_library_id" {
|
||||
t.Fatalf("unexpected tool name: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMCPWorkingDirWorkspaceScoped(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
dir, err := resolveMCPWorkingDir(workspace, config.MCPServerConfig{WorkingDir: "tools/context7"})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMCPWorkingDir returned error: %v", err)
|
||||
}
|
||||
want := filepath.Join(workspace, "tools", "context7")
|
||||
if dir != want {
|
||||
t.Fatalf("unexpected working dir: got %q want %q", dir, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMCPWorkingDirRejectsOutsideWorkspaceWithoutFullPermission(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
_, err := resolveMCPWorkingDir(workspace, config.MCPServerConfig{WorkingDir: "/"})
|
||||
if err == nil {
|
||||
t.Fatal("expected outside-workspace path to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMCPWorkingDirAllowsAbsolutePathWithFullPermission(t *testing.T) {
|
||||
dir, err := resolveMCPWorkingDir(t.TempDir(), config.MCPServerConfig{Permission: "full", WorkingDir: "/"})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMCPWorkingDir returned error: %v", err)
|
||||
}
|
||||
if dir != "/" {
|
||||
t.Fatalf("unexpected working dir: %q", dir)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +62,23 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
skill, _ := args["skill"].(string)
|
||||
script, _ := args["script"].(string)
|
||||
reason, _ := args["reason"].(string)
|
||||
callerAgent, _ := args["caller_agent"].(string)
|
||||
callerScope, _ := args["caller_scope"].(string)
|
||||
reason = strings.TrimSpace(reason)
|
||||
if reason == "" {
|
||||
reason = "unspecified"
|
||||
}
|
||||
callerAgent = strings.TrimSpace(callerAgent)
|
||||
if callerAgent == "" {
|
||||
callerAgent = "main"
|
||||
}
|
||||
callerScope = strings.TrimSpace(callerScope)
|
||||
if callerScope == "" {
|
||||
callerScope = "main_agent"
|
||||
}
|
||||
if strings.TrimSpace(skill) == "" || strings.TrimSpace(script) == "" {
|
||||
err := fmt.Errorf("skill and script are required")
|
||||
t.writeAudit(skill, script, reason, false, err.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(t.workspace) != "" {
|
||||
@@ -77,31 +87,31 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
|
||||
skillDir, err := t.resolveSkillDir(skill)
|
||||
if err != nil {
|
||||
t.writeAudit(skill, script, reason, false, err.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
|
||||
return "", err
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(skillDir, "SKILL.md")); err != nil {
|
||||
err = fmt.Errorf("SKILL.md missing for skill: %s", skill)
|
||||
t.writeAudit(skill, script, reason, false, err.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
relScript := filepath.Clean(script)
|
||||
if strings.Contains(relScript, "..") || filepath.IsAbs(relScript) {
|
||||
err := fmt.Errorf("script must be relative path inside skill directory")
|
||||
t.writeAudit(skill, script, reason, false, err.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
|
||||
return "", err
|
||||
}
|
||||
if !strings.HasPrefix(relScript, "scripts"+string(os.PathSeparator)) {
|
||||
err := fmt.Errorf("script must be under scripts/ directory")
|
||||
t.writeAudit(skill, script, reason, false, err.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
scriptPath := filepath.Join(skillDir, relScript)
|
||||
if _, err := os.Stat(scriptPath); err != nil {
|
||||
err = fmt.Errorf("script not found: %s", scriptPath)
|
||||
t.writeAudit(skill, script, reason, false, err.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -124,7 +134,7 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
for attempt := 0; attempt <= policy.MaxRestarts; attempt++ {
|
||||
cmd, err := buildSkillCommand(ctx, scriptPath, cmdArgs)
|
||||
if err != nil {
|
||||
t.writeAudit(skill, script, reason, false, err.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
|
||||
return "", err
|
||||
}
|
||||
cmd.Dir = skillDir
|
||||
@@ -158,7 +168,7 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
}
|
||||
output := merged.String()
|
||||
if runErr != nil {
|
||||
t.writeAudit(skill, script, reason, false, runErr.Error())
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, runErr.Error())
|
||||
return "", fmt.Errorf("skill execution failed: %w\n%s", runErr, output)
|
||||
}
|
||||
|
||||
@@ -166,21 +176,23 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
if out == "" {
|
||||
out = "(no output)"
|
||||
}
|
||||
t.writeAudit(skill, script, reason, true, "")
|
||||
t.writeAudit(skill, script, reason, callerAgent, callerScope, true, "")
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t *SkillExecTool) writeAudit(skill, script, reason string, ok bool, errText string) {
|
||||
func (t *SkillExecTool) writeAudit(skill, script, reason, callerAgent, callerScope string, ok bool, errText string) {
|
||||
if strings.TrimSpace(t.workspace) == "" {
|
||||
return
|
||||
}
|
||||
memDir := filepath.Join(t.workspace, "memory")
|
||||
_ = os.MkdirAll(memDir, 0755)
|
||||
row := fmt.Sprintf("{\"time\":%q,\"skill\":%q,\"script\":%q,\"reason\":%q,\"ok\":%t,\"error\":%q}\n",
|
||||
row := fmt.Sprintf("{\"time\":%q,\"skill\":%q,\"script\":%q,\"reason\":%q,\"caller_agent\":%q,\"caller_scope\":%q,\"ok\":%t,\"error\":%q}\n",
|
||||
time.Now().UTC().Format(time.RFC3339),
|
||||
strings.TrimSpace(skill),
|
||||
strings.TrimSpace(script),
|
||||
strings.TrimSpace(reason),
|
||||
strings.TrimSpace(callerAgent),
|
||||
strings.TrimSpace(callerScope),
|
||||
ok,
|
||||
strings.TrimSpace(errText),
|
||||
)
|
||||
|
||||
27
pkg/tools/skill_exec_test.go
Normal file
27
pkg/tools/skill_exec_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSkillExecWriteAuditIncludesCallerIdentity(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
tool := NewSkillExecTool(workspace)
|
||||
|
||||
tool.writeAudit("demo", "scripts/run.sh", "test", "coder", "subagent", true, "")
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, "memory", "skill-audit.jsonl"))
|
||||
if err != nil {
|
||||
t.Fatalf("read audit file: %v", err)
|
||||
}
|
||||
text := string(data)
|
||||
if !strings.Contains(text, `"caller_agent":"coder"`) {
|
||||
t.Fatalf("expected caller_agent in audit row, got: %s", text)
|
||||
}
|
||||
if !strings.Contains(text, `"caller_scope":"subagent"`) {
|
||||
t.Fatalf("expected caller_scope in audit row, got: %s", text)
|
||||
}
|
||||
}
|
||||
@@ -49,6 +49,12 @@ var defaultToolAllowlistGroups = []ToolAllowlistGroup{
|
||||
Aliases: []string{"subagent", "agent_runtime"},
|
||||
Tools: []string{"spawn", "subagents", "subagent_profile"},
|
||||
},
|
||||
{
|
||||
Name: "skills",
|
||||
Description: "Skill script execution tools",
|
||||
Aliases: []string{"skill", "skill_scripts"},
|
||||
Tools: []string{"skill_exec"},
|
||||
},
|
||||
}
|
||||
|
||||
func ToolAllowlistGroups() []ToolAllowlistGroup {
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestExpandToolAllowlistEntries_GroupPrefix(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExpandToolAllowlistEntries_BareGroupAndAlias(t *testing.T) {
|
||||
got := ExpandToolAllowlistEntries([]string{"memory_all", "@subagents"})
|
||||
got := ExpandToolAllowlistEntries([]string{"memory_all", "@subagents", "skill"})
|
||||
contains := map[string]bool{}
|
||||
for _, item := range got {
|
||||
contains[item] = true
|
||||
@@ -28,4 +28,7 @@ func TestExpandToolAllowlistEntries_BareGroupAndAlias(t *testing.T) {
|
||||
if !contains["spawn"] || !contains["subagents"] || !contains["subagent_profile"] {
|
||||
t.Fatalf("subagents alias expansion missing subagent tools: %v", got)
|
||||
}
|
||||
if !contains["skill_exec"] {
|
||||
t.Fatalf("skills alias expansion missing skill_exec: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user