feat: expand mcp transports and skill execution

This commit is contained in:
lpf
2026-03-08 11:08:41 +08:00
parent db86b3471d
commit f043de5384
21 changed files with 1447 additions and 84 deletions

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"hash/fnv"
"io"
"net/http"
"net/url"
"os"
"os/exec"
@@ -39,6 +40,12 @@ type MCPRemoteTool struct {
parameters map[string]interface{}
}
type mcpRPCClient interface {
listAll(ctx context.Context, method, field string) (map[string]interface{}, error)
request(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error)
Close() error
}
func NewMCPTool(workspace string, cfg config.MCPToolsConfig) *MCPTool {
if cfg.RequestTimeoutSec <= 0 {
cfg.RequestTimeoutSec = 20
@@ -54,7 +61,7 @@ func (t *MCPTool) Name() string {
}
func (t *MCPTool) Description() string {
return "Call configured MCP servers over stdio. Supports listing servers, tools, resources, prompts, and invoking remote MCP tools."
return "Call configured MCP servers over stdio or HTTP transports. Supports listing servers, tools, resources, prompts, and invoking remote MCP tools."
}
func (t *MCPTool) Parameters() map[string]interface{} {
@@ -119,7 +126,7 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) (str
callCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
client, err := newMCPStdioClient(callCtx, t.workspace, serverName, serverCfg)
client, err := newMCPClient(callCtx, t.workspace, serverName, serverCfg)
if err != nil {
return "", err
}
@@ -201,7 +208,7 @@ func (t *MCPTool) DiscoverTools(ctx context.Context) []Tool {
seen := map[string]int{}
for _, serverName := range names {
serverCfg := t.cfg.Servers[serverName]
client, err := newMCPStdioClient(ctx, t.workspace, serverName, serverCfg)
client, err := newMCPClient(ctx, t.workspace, serverName, serverCfg)
if err != nil {
continue
}
@@ -249,7 +256,7 @@ func (t *MCPTool) callServerTool(ctx context.Context, serverName, remoteToolName
}
callCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
client, err := newMCPStdioClient(callCtx, t.workspace, serverName, serverCfg)
client, err := newMCPClient(callCtx, t.workspace, serverName, serverCfg)
if err != nil {
return "", err
}
@@ -268,6 +275,8 @@ func (t *MCPTool) listServers() string {
type item struct {
Name string `json:"name"`
Transport string `json:"transport"`
URL string `json:"url,omitempty"`
Permission string `json:"permission,omitempty"`
Command string `json:"command"`
WorkingDir string `json:"working_dir,omitempty"`
Description string `json:"description,omitempty"`
@@ -286,9 +295,15 @@ func (t *MCPTool) listServers() string {
if transport == "" {
transport = "stdio"
}
permission := strings.TrimSpace(server.Permission)
if permission == "" {
permission = "workspace"
}
items = append(items, item{
Name: name,
Transport: transport,
URL: strings.TrimSpace(server.URL),
Permission: permission,
Command: server.Command,
WorkingDir: server.WorkingDir,
Description: server.Description,
@@ -326,6 +341,7 @@ func (t *MCPRemoteTool) CatalogEntry() map[string]interface{} {
type mcpClient struct {
workspace string
workingDir string
serverName string
cmd *exec.Cmd
stdin io.WriteCloser
@@ -355,14 +371,35 @@ type mcpResponseWaiter struct {
ch chan mcpInbound
}
func newMCPClient(ctx context.Context, workspace, serverName string, cfg config.MCPServerConfig) (mcpRPCClient, error) {
transport := strings.ToLower(strings.TrimSpace(cfg.Transport))
if transport == "" {
transport = "stdio"
}
switch transport {
case "stdio":
return newMCPStdioClient(ctx, workspace, serverName, cfg)
case "sse":
return newMCPSSEClient(ctx, workspace, serverName, cfg)
case "http", "streamable_http":
return newMCPHTTPClient(ctx, serverName, cfg)
default:
return nil, fmt.Errorf("unsupported mcp transport %q", transport)
}
}
func newMCPStdioClient(ctx context.Context, workspace, serverName string, cfg config.MCPServerConfig) (*mcpClient, error) {
command := strings.TrimSpace(cfg.Command)
if command == "" {
return nil, fmt.Errorf("mcp server %q command is empty", serverName)
}
workingDir, err := resolveMCPWorkingDir(workspace, cfg)
if err != nil {
return nil, err
}
cmd := exec.CommandContext(ctx, command, cfg.Args...)
cmd.Env = buildMCPEnv(cfg.Env)
cmd.Dir = resolveMCPWorkingDir(workspace, cfg.WorkingDir)
cmd.Dir = workingDir
stdin, err := cmd.StdinPipe()
if err != nil {
@@ -374,6 +411,7 @@ func newMCPStdioClient(ctx context.Context, workspace, serverName string, cfg co
}
client := &mcpClient{
workspace: workspace,
workingDir: workingDir,
serverName: serverName,
cmd: cmd,
stdin: stdin,
@@ -391,6 +429,450 @@ func newMCPStdioClient(ctx context.Context, workspace, serverName string, cfg co
return client, nil
}
type mcpHTTPClient struct {
serverName string
baseURL string
client *http.Client
nextID atomic.Int64
}
type mcpSSEClient struct {
workspace string
serverName string
baseURL string
endpointURL string
client *http.Client
cancel context.CancelFunc
respBody io.ReadCloser
writeMu sync.Mutex
waiters sync.Map
nextID atomic.Int64
endpointOnce sync.Once
endpointCh chan string
errCh chan error
}
func newMCPHTTPClient(ctx context.Context, serverName string, cfg config.MCPServerConfig) (*mcpHTTPClient, error) {
baseURL := strings.TrimSpace(cfg.URL)
if baseURL == "" {
return nil, fmt.Errorf("mcp server %q url is empty", serverName)
}
client := &mcpHTTPClient{
serverName: serverName,
baseURL: baseURL,
client: &http.Client{Timeout: 30 * time.Second},
}
if err := client.initialize(ctx); err != nil {
return nil, err
}
return client, nil
}
func newMCPSSEClient(ctx context.Context, workspace, serverName string, cfg config.MCPServerConfig) (*mcpSSEClient, error) {
baseURL := strings.TrimSpace(cfg.URL)
if baseURL == "" {
return nil, fmt.Errorf("mcp server %q url is empty", serverName)
}
streamCtx, cancel := context.WithCancel(context.Background())
client := &mcpSSEClient{
workspace: workspace,
serverName: serverName,
baseURL: baseURL,
client: &http.Client{Timeout: 0},
cancel: cancel,
endpointCh: make(chan string, 1),
errCh: make(chan error, 1),
}
req, err := http.NewRequestWithContext(streamCtx, http.MethodGet, baseURL, nil)
if err != nil {
cancel()
return nil, err
}
req.Header.Set("Accept", "text/event-stream")
resp, err := client.client.Do(req)
if err != nil {
cancel()
return nil, fmt.Errorf("connect sse for mcp server %q: %w", serverName, err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
resp.Body.Close()
cancel()
return nil, fmt.Errorf("connect sse for mcp server %q failed: http %d %s", serverName, resp.StatusCode, strings.TrimSpace(string(data)))
}
client.respBody = resp.Body
go client.readLoop()
select {
case endpoint := <-client.endpointCh:
client.endpointURL = endpoint
case err := <-client.errCh:
client.Close()
return nil, err
case <-ctx.Done():
client.Close()
return nil, ctx.Err()
}
if err := client.initialize(ctx); err != nil {
client.Close()
return nil, err
}
return client, nil
}
func (c *mcpSSEClient) Close() error {
if c == nil {
return nil
}
if c.cancel != nil {
c.cancel()
}
if c.respBody != nil {
_ = c.respBody.Close()
}
return nil
}
func (c *mcpSSEClient) initialize(ctx context.Context) error {
result, err := c.request(ctx, "initialize", map[string]interface{}{
"protocolVersion": mcpProtocolVersion,
"capabilities": map[string]interface{}{
"roots": map[string]interface{}{
"listChanged": false,
},
},
"clientInfo": map[string]interface{}{
"name": "clawgo",
"version": "dev",
},
})
if err != nil {
return err
}
if _, ok := result["protocolVersion"]; !ok {
return fmt.Errorf("mcp server %q initialize missing protocolVersion", c.serverName)
}
return c.notify("notifications/initialized", map[string]interface{}{})
}
func (c *mcpSSEClient) listAll(ctx context.Context, method, field string) (map[string]interface{}, error) {
items := make([]interface{}, 0)
cursor := ""
for {
params := map[string]interface{}{}
if strings.TrimSpace(cursor) != "" {
params["cursor"] = cursor
}
result, err := c.request(ctx, method, params)
if err != nil {
return nil, err
}
batch, _ := result[field].([]interface{})
items = append(items, batch...)
next, _ := result["nextCursor"].(string)
if strings.TrimSpace(next) == "" {
return map[string]interface{}{field: items}, nil
}
cursor = next
}
}
func (c *mcpSSEClient) request(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
id := strconv.FormatInt(c.nextID.Add(1), 10)
waiter := &mcpResponseWaiter{ch: make(chan mcpInbound, 1)}
c.waiters.Store(id, waiter)
defer c.waiters.Delete(id)
if err := c.postMessage(map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}); err != nil {
return nil, err
}
select {
case resp := <-waiter.ch:
if resp.Error != nil {
return nil, fmt.Errorf("mcp %s %s failed: %s", c.serverName, method, resp.Error.Message)
}
var out map[string]interface{}
if len(resp.Result) == 0 {
return map[string]interface{}{}, nil
}
if err := json.Unmarshal(resp.Result, &out); err != nil {
return nil, fmt.Errorf("decode mcp %s %s result: %w", c.serverName, method, err)
}
return out, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (c *mcpSSEClient) notify(method string, params map[string]interface{}) error {
return c.postMessage(map[string]interface{}{
"jsonrpc": "2.0",
"method": method,
"params": params,
})
}
func (c *mcpSSEClient) postMessage(payload map[string]interface{}) error {
data, err := json.Marshal(payload)
if err != nil {
return err
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
req, err := http.NewRequest(http.MethodPost, c.endpointURL, bytes.NewReader(data))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("mcp %s post failed: http %d %s", c.serverName, resp.StatusCode, strings.TrimSpace(string(body)))
}
return nil
}
func (c *mcpSSEClient) readLoop() {
reader := bufio.NewReader(c.respBody)
var eventName string
var dataLines []string
emit := func() bool {
if len(dataLines) == 0 && strings.TrimSpace(eventName) == "" {
eventName = ""
dataLines = nil
return true
}
event := strings.TrimSpace(eventName)
if event == "" {
event = "message"
}
data := strings.Join(dataLines, "\n")
if err := c.handleSSEEvent(event, data); err != nil {
c.signalErr(err)
return false
}
eventName = ""
dataLines = nil
return true
}
for {
line, err := reader.ReadString('\n')
if err != nil {
if err != io.EOF {
c.signalErr(err)
}
return
}
line = strings.TrimRight(line, "\r\n")
if line == "" {
if !emit() {
return
}
continue
}
if strings.HasPrefix(line, ":") {
continue
}
if strings.HasPrefix(line, "event:") {
eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
continue
}
if strings.HasPrefix(line, "data:") {
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
}
}
}
func (c *mcpSSEClient) handleSSEEvent(eventName, data string) error {
switch eventName {
case "endpoint":
endpoint := strings.TrimSpace(data)
if endpoint == "" {
return fmt.Errorf("mcp server %q sent empty endpoint event", c.serverName)
}
resolved, err := resolveRelativeURL(c.baseURL, endpoint)
if err != nil {
return err
}
c.endpointOnce.Do(func() {
c.endpointURL = resolved
c.endpointCh <- resolved
})
return nil
case "message":
var msg mcpInbound
if err := json.Unmarshal([]byte(data), &msg); err != nil {
return err
}
if msg.Method != "" && msg.ID != nil {
return c.handleServerRequest(msg)
}
if msg.Method != "" {
return nil
}
if key, ok := normalizeMCPID(msg.ID); ok {
if raw, ok := c.waiters.Load(key); ok {
raw.(*mcpResponseWaiter).ch <- msg
}
}
return nil
default:
return nil
}
}
func (c *mcpSSEClient) handleServerRequest(msg mcpInbound) error {
method := strings.TrimSpace(msg.Method)
switch method {
case "roots/list":
root := resolveMCPDefaultRoot(c.workspace)
return c.postMessage(map[string]interface{}{
"jsonrpc": "2.0",
"id": msg.ID,
"result": map[string]interface{}{
"roots": []map[string]interface{}{
{"uri": fileURI(root), "name": filepath.Base(root)},
},
},
})
case "ping":
return c.postMessage(map[string]interface{}{
"jsonrpc": "2.0",
"id": msg.ID,
"result": map[string]interface{}{},
})
default:
return c.postMessage(map[string]interface{}{
"jsonrpc": "2.0",
"id": msg.ID,
"error": map[string]interface{}{
"code": -32601,
"message": "method not supported by clawgo mcp client",
},
})
}
}
func (c *mcpSSEClient) signalErr(err error) {
select {
case c.errCh <- err:
default:
}
c.waiters.Range(func(_, value interface{}) bool {
value.(*mcpResponseWaiter).ch <- mcpInbound{
Error: &mcpResponseError{Message: err.Error()},
}
return true
})
}
func resolveRelativeURL(baseURL, ref string) (string, error) {
base, err := url.Parse(baseURL)
if err != nil {
return "", err
}
target, err := url.Parse(ref)
if err != nil {
return "", err
}
return base.ResolveReference(target).String(), nil
}
func (c *mcpHTTPClient) Close() error { return nil }
func (c *mcpHTTPClient) initialize(ctx context.Context) error {
result, err := c.request(ctx, "initialize", map[string]interface{}{
"protocolVersion": mcpProtocolVersion,
"capabilities": map[string]interface{}{},
"clientInfo": map[string]interface{}{
"name": "clawgo",
"version": "dev",
},
})
if err != nil {
return err
}
if _, ok := result["protocolVersion"]; !ok {
return fmt.Errorf("mcp server %q initialize missing protocolVersion", c.serverName)
}
_, _ = c.request(ctx, "notifications/initialized", map[string]interface{}{})
return nil
}
func (c *mcpHTTPClient) listAll(ctx context.Context, method, field string) (map[string]interface{}, error) {
items := make([]interface{}, 0)
cursor := ""
for {
params := map[string]interface{}{}
if strings.TrimSpace(cursor) != "" {
params["cursor"] = cursor
}
result, err := c.request(ctx, method, params)
if err != nil {
return nil, err
}
batch, _ := result[field].([]interface{})
items = append(items, batch...)
next, _ := result["nextCursor"].(string)
if strings.TrimSpace(next) == "" {
return map[string]interface{}{field: items}, nil
}
cursor = next
}
}
func (c *mcpHTTPClient) request(ctx context.Context, method string, params map[string]interface{}) (map[string]interface{}, error) {
id := strconv.FormatInt(c.nextID.Add(1), 10)
payload := map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("mcp %s %s failed: %w", c.serverName, method, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("mcp %s %s failed: http %d %s", c.serverName, method, resp.StatusCode, strings.TrimSpace(string(data)))
}
var msg mcpInbound
if err := json.NewDecoder(resp.Body).Decode(&msg); err != nil {
return nil, fmt.Errorf("decode mcp %s %s result: %w", c.serverName, method, err)
}
if msg.Error != nil {
return nil, fmt.Errorf("mcp %s %s failed: %s", c.serverName, method, msg.Error.Message)
}
if len(msg.Result) == 0 {
return map[string]interface{}{}, nil
}
var out map[string]interface{}
if err := json.Unmarshal(msg.Result, &out); err != nil {
return nil, fmt.Errorf("decode mcp %s %s result: %w", c.serverName, method, err)
}
return out, nil
}
func (c *mcpClient) Close() error {
if c == nil || c.cmd == nil {
return nil
@@ -536,11 +1018,15 @@ func (c *mcpClient) handleServerRequest(msg mcpInbound) error {
method := strings.TrimSpace(msg.Method)
switch method {
case "roots/list":
rootDir := c.workingDir
if strings.TrimSpace(rootDir) == "" {
rootDir = resolveMCPDefaultRoot(c.workspace)
}
return c.reply(msg.ID, map[string]interface{}{
"roots": []map[string]interface{}{
{
"uri": fileURI(resolveMCPWorkingDir(c.workspace, "")),
"name": filepath.Base(resolveMCPWorkingDir(c.workspace, "")),
"uri": fileURI(rootDir),
"name": filepath.Base(rootDir),
},
},
})
@@ -631,11 +1117,34 @@ func buildMCPEnv(overrides map[string]string) []string {
return env
}
func resolveMCPWorkingDir(workspace, wd string) string {
wd = strings.TrimSpace(wd)
if wd != "" {
return wd
func resolveMCPWorkingDir(workspace string, cfg config.MCPServerConfig) (string, error) {
root := resolveMCPDefaultRoot(workspace)
permission := strings.ToLower(strings.TrimSpace(cfg.Permission))
if permission == "" {
permission = "workspace"
}
wd := strings.TrimSpace(cfg.WorkingDir)
if wd == "" {
return root, nil
}
if permission == "full" {
if !filepath.IsAbs(wd) {
return "", fmt.Errorf("mcp server %q working_dir must be absolute when permission=full", strings.TrimSpace(cfg.Command))
}
return filepath.Clean(wd), nil
}
if filepath.IsAbs(wd) {
clean := filepath.Clean(wd)
rel, err := filepath.Rel(root, clean)
if err != nil || strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("mcp working_dir %q must stay within workspace root %q unless permission=full", clean, root)
}
return clean, nil
}
return filepath.Clean(filepath.Join(root, wd)), nil
}
func resolveMCPDefaultRoot(workspace string) string {
if abs, err := filepath.Abs(workspace); err == nil {
return abs
}

View File

@@ -6,7 +6,10 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
@@ -290,9 +293,9 @@ func TestValidateMCPTools(t *testing.T) {
cfg.Tools.MCP.Servers = map[string]config.MCPServerConfig{
"bad": {
Enabled: true,
Transport: "http",
Transport: "ws",
Command: "",
WorkingDir: "relative",
WorkingDir: "/outside-workspace",
},
}
errs := config.Validate(cfg)
@@ -306,9 +309,7 @@ func TestValidateMCPTools(t *testing.T) {
joined := strings.Join(got, "\n")
for _, want := range []string{
"tools.mcp.request_timeout_sec must be > 0 when tools.mcp.enabled=true",
"tools.mcp.servers.bad.transport must be 'stdio'",
"tools.mcp.servers.bad.command is required when enabled=true",
"tools.mcp.servers.bad.working_dir must be an absolute path",
"tools.mcp.servers.bad.transport must be one of: stdio, http, streamable_http, sse",
} {
if !strings.Contains(joined, want) {
t.Fatalf("expected validation error %q in:\n%s", want, joined)
@@ -316,6 +317,31 @@ func TestValidateMCPTools(t *testing.T) {
}
}
func TestValidateMCPToolsFullPermissionRequiresAbsolutePath(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.MCP.Enabled = true
cfg.Tools.MCP.Servers = map[string]config.MCPServerConfig{
"full": {
Enabled: true,
Transport: "stdio",
Command: "demo",
Permission: "full",
WorkingDir: "relative",
},
}
errs := config.Validate(cfg)
if len(errs) == 0 {
t.Fatal("expected validation errors")
}
joined := ""
for _, err := range errs {
joined += err.Error() + "\n"
}
if !strings.Contains(joined, "tools.mcp.servers.full.working_dir must be an absolute path when permission=full") {
t.Fatalf("unexpected validation errors:\n%s", joined)
}
}
func TestMCPToolListTools(t *testing.T) {
tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{
Enabled: true,
@@ -346,9 +372,219 @@ func TestMCPToolListTools(t *testing.T) {
}
}
func TestMCPToolHTTPTransport(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
var req map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
method, _ := req["method"].(string)
id := req["id"]
var resp map[string]interface{}
switch method {
case "initialize":
resp = map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"result": map[string]interface{}{
"protocolVersion": mcpProtocolVersion,
},
}
case "notifications/initialized":
resp = map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"result": map[string]interface{}{},
}
case "tools/list":
resp = map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"result": map[string]interface{}{
"tools": []map[string]interface{}{
{
"name": "echo",
"description": "Echo text",
"inputSchema": map[string]interface{}{"type": "object"},
},
},
},
}
case "tools/call":
resp = map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"result": map[string]interface{}{
"content": []map[string]interface{}{
{"type": "text", "text": "echo:http"},
},
},
}
default:
resp = map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"error": map[string]interface{}{"code": -32601, "message": "unsupported"},
}
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{
Enabled: true,
RequestTimeoutSec: 5,
Servers: map[string]config.MCPServerConfig{
"httpdemo": {
Enabled: true,
Transport: "http",
URL: server.URL,
},
},
})
out, err := tool.Execute(context.Background(), map[string]interface{}{
"action": "list_tools",
"server": "httpdemo",
})
if err != nil {
t.Fatalf("http list_tools returned error: %v", err)
}
if !strings.Contains(out, `"name": "echo"`) {
t.Fatalf("expected http tool listing, got: %s", out)
}
}
func TestMCPToolSSETransport(t *testing.T) {
messageCh := make(chan string, 8)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/sse":
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "flush unsupported", http.StatusInternalServerError)
return
}
_, _ = fmt.Fprintf(w, "event: endpoint\n")
_, _ = fmt.Fprintf(w, "data: /messages\n\n")
flusher.Flush()
notify := r.Context().Done()
for {
select {
case payload := <-messageCh:
_, _ = fmt.Fprintf(w, "event: message\n")
_, _ = fmt.Fprintf(w, "data: %s\n\n", payload)
flusher.Flush()
case <-notify:
return
}
}
case r.Method == http.MethodPost && r.URL.Path == "/messages":
defer r.Body.Close()
var req map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
method, _ := req["method"].(string)
id := req["id"]
switch method {
case "initialize":
payload, _ := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"result": map[string]interface{}{
"protocolVersion": mcpProtocolVersion,
},
})
messageCh <- string(payload)
case "tools/list":
payload, _ := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"result": map[string]interface{}{
"tools": []map[string]interface{}{
{"name": "echo", "description": "Echo text", "inputSchema": map[string]interface{}{"type": "object"}},
},
},
})
messageCh <- string(payload)
case "tools/call":
payload, _ := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
"result": map[string]interface{}{
"content": []map[string]interface{}{{"type": "text", "text": "echo:sse"}},
},
})
messageCh <- string(payload)
}
w.WriteHeader(http.StatusAccepted)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
tool := NewMCPTool(t.TempDir(), config.MCPToolsConfig{
Enabled: true,
RequestTimeoutSec: 5,
Servers: map[string]config.MCPServerConfig{
"ssedemo": {
Enabled: true,
Transport: "sse",
URL: server.URL + "/sse",
},
},
})
out, err := tool.Execute(context.Background(), map[string]interface{}{
"action": "list_tools",
"server": "ssedemo",
})
if err != nil {
t.Fatalf("sse list_tools returned error: %v", err)
}
if !strings.Contains(out, `"name": "echo"`) {
t.Fatalf("expected sse tool listing, got: %s", out)
}
}
func TestBuildMCPDynamicToolName(t *testing.T) {
got := buildMCPDynamicToolName("Context7 Server", "resolve-library.id")
if got != "mcp__context7_server__resolve_library_id" {
t.Fatalf("unexpected tool name: %s", got)
}
}
func TestResolveMCPWorkingDirWorkspaceScoped(t *testing.T) {
workspace := t.TempDir()
dir, err := resolveMCPWorkingDir(workspace, config.MCPServerConfig{WorkingDir: "tools/context7"})
if err != nil {
t.Fatalf("resolveMCPWorkingDir returned error: %v", err)
}
want := filepath.Join(workspace, "tools", "context7")
if dir != want {
t.Fatalf("unexpected working dir: got %q want %q", dir, want)
}
}
func TestResolveMCPWorkingDirRejectsOutsideWorkspaceWithoutFullPermission(t *testing.T) {
workspace := t.TempDir()
_, err := resolveMCPWorkingDir(workspace, config.MCPServerConfig{WorkingDir: "/"})
if err == nil {
t.Fatal("expected outside-workspace path to be rejected")
}
}
func TestResolveMCPWorkingDirAllowsAbsolutePathWithFullPermission(t *testing.T) {
dir, err := resolveMCPWorkingDir(t.TempDir(), config.MCPServerConfig{Permission: "full", WorkingDir: "/"})
if err != nil {
t.Fatalf("resolveMCPWorkingDir returned error: %v", err)
}
if dir != "/" {
t.Fatalf("unexpected working dir: %q", dir)
}
}

View File

@@ -62,13 +62,23 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
skill, _ := args["skill"].(string)
script, _ := args["script"].(string)
reason, _ := args["reason"].(string)
callerAgent, _ := args["caller_agent"].(string)
callerScope, _ := args["caller_scope"].(string)
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "unspecified"
}
callerAgent = strings.TrimSpace(callerAgent)
if callerAgent == "" {
callerAgent = "main"
}
callerScope = strings.TrimSpace(callerScope)
if callerScope == "" {
callerScope = "main_agent"
}
if strings.TrimSpace(skill) == "" || strings.TrimSpace(script) == "" {
err := fmt.Errorf("skill and script are required")
t.writeAudit(skill, script, reason, false, err.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
return "", err
}
if strings.TrimSpace(t.workspace) != "" {
@@ -77,31 +87,31 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
skillDir, err := t.resolveSkillDir(skill)
if err != nil {
t.writeAudit(skill, script, reason, false, err.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
return "", err
}
if _, err := os.Stat(filepath.Join(skillDir, "SKILL.md")); err != nil {
err = fmt.Errorf("SKILL.md missing for skill: %s", skill)
t.writeAudit(skill, script, reason, false, err.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
return "", err
}
relScript := filepath.Clean(script)
if strings.Contains(relScript, "..") || filepath.IsAbs(relScript) {
err := fmt.Errorf("script must be relative path inside skill directory")
t.writeAudit(skill, script, reason, false, err.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
return "", err
}
if !strings.HasPrefix(relScript, "scripts"+string(os.PathSeparator)) {
err := fmt.Errorf("script must be under scripts/ directory")
t.writeAudit(skill, script, reason, false, err.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
return "", err
}
scriptPath := filepath.Join(skillDir, relScript)
if _, err := os.Stat(scriptPath); err != nil {
err = fmt.Errorf("script not found: %s", scriptPath)
t.writeAudit(skill, script, reason, false, err.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
return "", err
}
@@ -124,7 +134,7 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
for attempt := 0; attempt <= policy.MaxRestarts; attempt++ {
cmd, err := buildSkillCommand(ctx, scriptPath, cmdArgs)
if err != nil {
t.writeAudit(skill, script, reason, false, err.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, err.Error())
return "", err
}
cmd.Dir = skillDir
@@ -158,7 +168,7 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
}
output := merged.String()
if runErr != nil {
t.writeAudit(skill, script, reason, false, runErr.Error())
t.writeAudit(skill, script, reason, callerAgent, callerScope, false, runErr.Error())
return "", fmt.Errorf("skill execution failed: %w\n%s", runErr, output)
}
@@ -166,21 +176,23 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}
if out == "" {
out = "(no output)"
}
t.writeAudit(skill, script, reason, true, "")
t.writeAudit(skill, script, reason, callerAgent, callerScope, true, "")
return out, nil
}
func (t *SkillExecTool) writeAudit(skill, script, reason string, ok bool, errText string) {
func (t *SkillExecTool) writeAudit(skill, script, reason, callerAgent, callerScope string, ok bool, errText string) {
if strings.TrimSpace(t.workspace) == "" {
return
}
memDir := filepath.Join(t.workspace, "memory")
_ = os.MkdirAll(memDir, 0755)
row := fmt.Sprintf("{\"time\":%q,\"skill\":%q,\"script\":%q,\"reason\":%q,\"ok\":%t,\"error\":%q}\n",
row := fmt.Sprintf("{\"time\":%q,\"skill\":%q,\"script\":%q,\"reason\":%q,\"caller_agent\":%q,\"caller_scope\":%q,\"ok\":%t,\"error\":%q}\n",
time.Now().UTC().Format(time.RFC3339),
strings.TrimSpace(skill),
strings.TrimSpace(script),
strings.TrimSpace(reason),
strings.TrimSpace(callerAgent),
strings.TrimSpace(callerScope),
ok,
strings.TrimSpace(errText),
)

View File

@@ -0,0 +1,27 @@
package tools
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestSkillExecWriteAuditIncludesCallerIdentity(t *testing.T) {
workspace := t.TempDir()
tool := NewSkillExecTool(workspace)
tool.writeAudit("demo", "scripts/run.sh", "test", "coder", "subagent", true, "")
data, err := os.ReadFile(filepath.Join(workspace, "memory", "skill-audit.jsonl"))
if err != nil {
t.Fatalf("read audit file: %v", err)
}
text := string(data)
if !strings.Contains(text, `"caller_agent":"coder"`) {
t.Fatalf("expected caller_agent in audit row, got: %s", text)
}
if !strings.Contains(text, `"caller_scope":"subagent"`) {
t.Fatalf("expected caller_scope in audit row, got: %s", text)
}
}

View File

@@ -49,6 +49,12 @@ var defaultToolAllowlistGroups = []ToolAllowlistGroup{
Aliases: []string{"subagent", "agent_runtime"},
Tools: []string{"spawn", "subagents", "subagent_profile"},
},
{
Name: "skills",
Description: "Skill script execution tools",
Aliases: []string{"skill", "skill_scripts"},
Tools: []string{"skill_exec"},
},
}
func ToolAllowlistGroups() []ToolAllowlistGroup {

View File

@@ -17,7 +17,7 @@ func TestExpandToolAllowlistEntries_GroupPrefix(t *testing.T) {
}
func TestExpandToolAllowlistEntries_BareGroupAndAlias(t *testing.T) {
got := ExpandToolAllowlistEntries([]string{"memory_all", "@subagents"})
got := ExpandToolAllowlistEntries([]string{"memory_all", "@subagents", "skill"})
contains := map[string]bool{}
for _, item := range got {
contains[item] = true
@@ -28,4 +28,7 @@ func TestExpandToolAllowlistEntries_BareGroupAndAlias(t *testing.T) {
if !contains["spawn"] || !contains["subagents"] || !contains["subagent_profile"] {
t.Fatalf("subagents alias expansion missing subagent tools: %v", got)
}
if !contains["skill_exec"] {
t.Fatalf("skills alias expansion missing skill_exec: %v", got)
}
}