diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index de67170..cc069ec 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -96,6 +96,127 @@ type subagentDigestState struct { dueAt time.Time } +type localNodeActionHandler func(nodes.Request) nodes.Response + +var localNodeActionHandlers = map[string]localNodeActionHandler{ + "run": handleLocalNodeRun, + "agent_task": handleLocalNodeAgentTask, + "camera_snap": handleLocalNodeCameraSnap, + "camera_clip": handleLocalNodeCameraClip, + "screen_snapshot": handleLocalNodeScreenSnapshot, + "screen_record": handleLocalNodeScreenRecord, + "location_get": handleLocalNodeLocationGet, + "canvas_snapshot": handleLocalNodeCanvasSnapshot, + "canvas_action": handleLocalNodeCanvasAction, +} + +var fallbackProviderPriority = map[string]int{ + "claude": 10, + "codex": 20, + "gemini": 30, + "gemini-cli": 40, + "aistudio": 50, + "vertex": 60, + "antigravity": 70, + "qwen": 80, + "kimi": 90, + "iflow": 100, + "openai-compatibility": 110, +} + +func localSimulatedPayload(extra map[string]interface{}) map[string]interface{} { + payload := map[string]interface{}{ + "transport": "relay-local", + "simulated": true, + } + for k, v := range extra { + payload[k] = v + } + return payload +} + +func handleLocalNodeRun(req nodes.Request) nodes.Response { + payload := localSimulatedPayload(nil) + if cmdRaw, ok := req.Args["command"].([]interface{}); ok && len(cmdRaw) > 0 { + parts := make([]string, 0, len(cmdRaw)) + for _, x := range cmdRaw { + parts = append(parts, fmt.Sprint(x)) + } + payload["command"] = parts + } + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: payload} +} + +func handleLocalNodeAgentTask(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "model": req.Model, + "task": req.Task, + "result": "local child-model simulated execution completed", + })} +} + +func handleLocalNodeCameraSnap(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "media_type": "image", + "storage": "inline", + "facing": req.Args["facing"], + "meta": map[string]interface{}{"width": 1280, "height": 720}, + })} +} + +func handleLocalNodeCameraClip(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "media_type": "video", + "storage": "path", + "path": "/tmp/camera_clip.mp4", + "duration_ms": req.Args["duration_ms"], + "meta": map[string]interface{}{"fps": 30}, + })} +} + +func handleLocalNodeScreenSnapshot(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "media_type": "image", + "storage": "inline", + "meta": map[string]interface{}{"width": 1920, "height": 1080}, + })} +} + +func handleLocalNodeScreenRecord(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "media_type": "video", + "storage": "path", + "path": "/tmp/screen_record.mp4", + "duration_ms": req.Args["duration_ms"], + "meta": map[string]interface{}{"fps": 30}, + })} +} + +func handleLocalNodeLocationGet(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "lat": 0.0, + "lng": 0.0, + "accuracy": "simulated", + "meta": map[string]interface{}{"provider": "simulated"}, + })} +} + +func handleLocalNodeCanvasSnapshot(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "image": "data:image/png;base64,", + "media_type": "image", + "storage": "inline", + "meta": map[string]interface{}{"width": 1280, "height": 720}, + })} +} + +func handleLocalNodeCanvasAction(req nodes.Request) nodes.Response { + return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: localSimulatedPayload(map[string]interface{}{ + "applied": true, + "args": req.Args, + })} +} + func (al *AgentLoop) SetConfigPath(path string) { if al == nil { return @@ -170,36 +291,10 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers nodesManager.SetStatePath(filepath.Join(workspace, "memory", "nodes-state.json")) nodesManager.Upsert(nodes.NodeInfo{ID: "local", Name: "local", Capabilities: nodes.Capabilities{Run: true, Invoke: true, Model: true, Camera: true, Screen: true, Location: true, Canvas: true}, Models: []string{"local-sim"}, Online: true}) nodesManager.RegisterHandler("local", func(req nodes.Request) nodes.Response { - switch req.Action { - case "run": - payload := map[string]interface{}{"transport": "relay-local", "simulated": true} - if cmdRaw, ok := req.Args["command"].([]interface{}); ok && len(cmdRaw) > 0 { - parts := make([]string, 0, len(cmdRaw)) - for _, x := range cmdRaw { - parts = append(parts, fmt.Sprint(x)) - } - payload["command"] = parts - } - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: payload} - case "agent_task": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "simulated": true, "model": req.Model, "task": req.Task, "result": "local child-model simulated execution completed"}} - case "camera_snap": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "media_type": "image", "storage": "inline", "facing": req.Args["facing"], "simulated": true, "meta": map[string]interface{}{"width": 1280, "height": 720}}} - case "camera_clip": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "media_type": "video", "storage": "path", "path": "/tmp/camera_clip.mp4", "duration_ms": req.Args["duration_ms"], "simulated": true, "meta": map[string]interface{}{"fps": 30}}} - case "screen_snapshot": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "media_type": "image", "storage": "inline", "simulated": true, "meta": map[string]interface{}{"width": 1920, "height": 1080}}} - case "screen_record": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "media_type": "video", "storage": "path", "path": "/tmp/screen_record.mp4", "duration_ms": req.Args["duration_ms"], "simulated": true, "meta": map[string]interface{}{"fps": 30}}} - case "location_get": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "lat": 0.0, "lng": 0.0, "accuracy": "simulated", "meta": map[string]interface{}{"provider": "simulated"}}} - case "canvas_snapshot": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "image": "data:image/png;base64,", "media_type": "image", "storage": "inline", "simulated": true, "meta": map[string]interface{}{"width": 1280, "height": 720}}} - case "canvas_action": - return nodes.Response{OK: true, Code: "ok", Node: "local", Action: req.Action, Payload: map[string]interface{}{"transport": "relay-local", "applied": true, "simulated": true, "args": req.Args}} - default: - return nodes.Response{OK: false, Code: "unsupported_action", Node: "local", Action: req.Action, Error: "unsupported local simulated action"} + if handler := localNodeActionHandlers[req.Action]; handler != nil { + return handler(req) } + return nodes.Response{OK: false, Code: "unsupported_action", Node: "local", Action: req.Action, Error: "unsupported local simulated action"} }) nodeDispatchPolicy := nodes.DispatchPolicy{ PreferLocal: cfg.Gateway.Nodes.Dispatch.PreferLocal, @@ -733,32 +828,10 @@ func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (provi } func automaticFallbackPriority(name string) int { - switch normalizeFallbackProviderName(name) { - case "claude": - return 10 - case "codex": - return 20 - case "gemini": - return 30 - case "gemini-cli": - return 40 - case "aistudio": - return 50 - case "vertex": - return 60 - case "antigravity": - return 70 - case "qwen": - return 80 - case "kimi": - return 90 - case "iflow": - return 100 - case "openai-compatibility": - return 110 - default: - return 1000 + if priority, ok := fallbackProviderPriority[normalizeFallbackProviderName(name)]; ok { + return priority } + return 1000 } func normalizeFallbackProviderName(name string) string { @@ -871,15 +944,13 @@ func buildAuditTaskID(msg bus.InboundMessage) string { trigger = strings.ToLower(strings.TrimSpace(msg.Metadata["trigger"])) } sessionPart := shortSessionKey(msg.SessionKey) - switch trigger { - case "heartbeat": + if trigger == "heartbeat" { if sessionPart == "" { sessionPart = "default" } return "heartbeat:" + sessionPart - default: - return fmt.Sprintf("%s-%d", sessionPart, time.Now().Unix()%100000) } + return fmt.Sprintf("%s-%d", sessionPart, time.Now().Unix()%100000) } func (al *AgentLoop) appendTaskAudit(taskID string, msg bus.InboundMessage, started time.Time, runErr error, suppressed bool) { @@ -2185,9 +2256,7 @@ func withToolContextArgs(toolName string, args map[string]interface{}, channel, if channel == "" || chatID == "" { return args } - switch toolName { - case "message", "spawn", "remind": - default: + if !toolContextEligibleTool(toolName) { return args } @@ -2254,9 +2323,7 @@ func withToolMemoryNamespaceArgs(toolName string, args map[string]interface{}, n if ns == "main" { return args } - switch strings.TrimSpace(toolName) { - case "memory_search", "memory_get", "memory_write": - default: + if !toolNeedsMemoryNamespace(toolName) { return args } @@ -2356,12 +2423,34 @@ func validateParallelAllowlistArgs(allow map[string]struct{}, args map[string]in } func isImplicitlyAllowedSubagentTool(name string) bool { - switch strings.ToLower(strings.TrimSpace(name)) { - case "skill_exec": - return true - default: - return false - } + _, ok := implicitSubagentToolSet[strings.ToLower(strings.TrimSpace(name))] + return ok +} + +var toolContextEligibleSet = map[string]struct{}{ + "message": {}, + "spawn": {}, + "remind": {}, +} + +func toolContextEligibleTool(name string) bool { + _, ok := toolContextEligibleSet[strings.TrimSpace(name)] + return ok +} + +var toolMemoryNamespaceSet = map[string]struct{}{ + "memory_search": {}, + "memory_get": {}, + "memory_write": {}, +} + +func toolNeedsMemoryNamespace(name string) bool { + _, ok := toolMemoryNamespaceSet[strings.TrimSpace(name)] + return ok +} + +var implicitSubagentToolSet = map[string]struct{}{ + "skill_exec": {}, } func normalizeToolAllowlist(in []string) map[string]struct{} { diff --git a/pkg/agent/runtime_admin.go b/pkg/agent/runtime_admin.go index 248d85d..739fc6e 100644 --- a/pkg/agent/runtime_admin.go +++ b/pkg/agent/runtime_admin.go @@ -15,6 +15,12 @@ import ( "github.com/YspCoder/clawgo/pkg/tools" ) +var subagentRuntimeActionAliases = map[string]string{ + "info": "get", + "create": "spawn", + "trace": "thread", +} + func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, args map[string]interface{}) (interface{}, error) { if al == nil || al.subagentManager == nil { return nil, fmt.Errorf("subagent runtime is not configured") @@ -26,330 +32,384 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a if action == "" { action = "list" } + if canonical := subagentRuntimeActionAliases[action]; canonical != "" { + action = canonical + } + handler := al.subagentRuntimeHandlers()[action] + if handler == nil { + return nil, fmt.Errorf("unsupported action: %s", action) + } + return handler(ctx, args) +} +type runtimeAdminHandler func(context.Context, map[string]interface{}) (interface{}, error) + +func (al *AgentLoop) subagentRuntimeHandlers() map[string]runtimeAdminHandler { sm := al.subagentManager router := al.subagentRouter - switch action { - case "list": - tasks := sm.ListTasks() - items := make([]*tools.SubagentTask, 0, len(tasks)) - for _, task := range tasks { - items = append(items, cloneSubagentTask(task)) - } - sort.Slice(items, func(i, j int) bool { return items[i].Created > items[j].Created }) - return map[string]interface{}{"items": items}, nil - case "snapshot": - limit := runtimeIntArg(args, "limit", 100) - return map[string]interface{}{"snapshot": sm.RuntimeSnapshot(limit)}, nil - case "get", "info": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - task, ok := sm.GetTask(taskID) - if !ok { - return map[string]interface{}{"found": false}, nil - } - return map[string]interface{}{"found": true, "task": cloneSubagentTask(task)}, nil - case "spawn", "create": - taskInput := runtimeStringArg(args, "task") - if taskInput == "" { - return nil, fmt.Errorf("task is required") - } - msg, err := sm.Spawn(ctx, tools.SubagentSpawnOptions{ - Task: taskInput, - Label: runtimeStringArg(args, "label"), - Role: runtimeStringArg(args, "role"), - AgentID: runtimeStringArg(args, "agent_id"), - MaxRetries: runtimeIntArg(args, "max_retries", 0), - RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0), - TimeoutSec: runtimeIntArg(args, "timeout_sec", 0), - MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0), - MaxResultChars: runtimeIntArg(args, "max_result_chars", 0), - OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"), - OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"), - }) - if err != nil { - return nil, err - } - return map[string]interface{}{"message": msg}, nil - case "dispatch_and_wait": - taskInput := runtimeStringArg(args, "task") - if taskInput == "" { - return nil, fmt.Errorf("task is required") - } - task, err := router.DispatchTask(ctx, tools.RouterDispatchRequest{ - Task: taskInput, - Label: runtimeStringArg(args, "label"), - Role: runtimeStringArg(args, "role"), - AgentID: runtimeStringArg(args, "agent_id"), - NotifyMainPolicy: "internal_only", - ThreadID: runtimeStringArg(args, "thread_id"), - CorrelationID: runtimeStringArg(args, "correlation_id"), - ParentRunID: runtimeStringArg(args, "parent_run_id"), - OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"), - OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"), - MaxRetries: runtimeIntArg(args, "max_retries", 0), - RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0), - TimeoutSec: runtimeIntArg(args, "timeout_sec", 0), - MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0), - MaxResultChars: runtimeIntArg(args, "max_result_chars", 0), - }) - if err != nil { - return nil, err - } - waitTimeoutSec := runtimeIntArg(args, "wait_timeout_sec", 120) - waitCtx := ctx - var cancel context.CancelFunc - if waitTimeoutSec > 0 { - waitCtx, cancel = context.WithTimeout(ctx, time.Duration(waitTimeoutSec)*time.Second) - defer cancel() - } - reply, err := router.WaitReply(waitCtx, task.ID, 100*time.Millisecond) - if err != nil { - return nil, err - } - return map[string]interface{}{ - "task": cloneSubagentTask(task), - "reply": reply, - "merged": router.MergeResults([]*tools.RouterReply{reply}), - }, nil - case "registry": - cfg := runtimecfg.Get() - items := make([]map[string]interface{}, 0) - if cfg != nil { - items = make([]map[string]interface{}, 0, len(cfg.Agents.Subagents)) - for agentID, subcfg := range cfg.Agents.Subagents { - promptFileFound := false - if strings.TrimSpace(subcfg.SystemPromptFile) != "" { - if absPath, err := al.resolvePromptFilePath(subcfg.SystemPromptFile); err == nil { - if info, statErr := os.Stat(absPath); statErr == nil && !info.IsDir() { - promptFileFound = true + return map[string]runtimeAdminHandler{ + "list": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + tasks := sm.ListTasks() + items := make([]*tools.SubagentTask, 0, len(tasks)) + for _, task := range tasks { + items = append(items, cloneSubagentTask(task)) + } + sort.Slice(items, func(i, j int) bool { return items[i].Created > items[j].Created }) + return map[string]interface{}{"items": items}, nil + }, + "snapshot": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + limit := runtimeIntArg(args, "limit", 100) + return map[string]interface{}{"snapshot": sm.RuntimeSnapshot(limit)}, nil + }, + "get": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + task, ok := sm.GetTask(taskID) + if !ok { + return map[string]interface{}{"found": false}, nil + } + return map[string]interface{}{"found": true, "task": cloneSubagentTask(task)}, nil + }, + "spawn": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskInput := runtimeStringArg(args, "task") + if taskInput == "" { + return nil, fmt.Errorf("task is required") + } + msg, err := sm.Spawn(ctx, tools.SubagentSpawnOptions{ + Task: taskInput, + Label: runtimeStringArg(args, "label"), + Role: runtimeStringArg(args, "role"), + AgentID: runtimeStringArg(args, "agent_id"), + MaxRetries: runtimeIntArg(args, "max_retries", 0), + RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0), + TimeoutSec: runtimeIntArg(args, "timeout_sec", 0), + MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0), + MaxResultChars: runtimeIntArg(args, "max_result_chars", 0), + OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"), + OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"), + }) + if err != nil { + return nil, err + } + return map[string]interface{}{"message": msg}, nil + }, + "dispatch_and_wait": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskInput := runtimeStringArg(args, "task") + if taskInput == "" { + return nil, fmt.Errorf("task is required") + } + task, err := router.DispatchTask(ctx, tools.RouterDispatchRequest{ + Task: taskInput, + Label: runtimeStringArg(args, "label"), + Role: runtimeStringArg(args, "role"), + AgentID: runtimeStringArg(args, "agent_id"), + NotifyMainPolicy: "internal_only", + ThreadID: runtimeStringArg(args, "thread_id"), + CorrelationID: runtimeStringArg(args, "correlation_id"), + ParentRunID: runtimeStringArg(args, "parent_run_id"), + OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"), + OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"), + MaxRetries: runtimeIntArg(args, "max_retries", 0), + RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0), + TimeoutSec: runtimeIntArg(args, "timeout_sec", 0), + MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0), + MaxResultChars: runtimeIntArg(args, "max_result_chars", 0), + }) + if err != nil { + return nil, err + } + waitTimeoutSec := runtimeIntArg(args, "wait_timeout_sec", 120) + waitCtx := ctx + var cancel context.CancelFunc + if waitTimeoutSec > 0 { + waitCtx, cancel = context.WithTimeout(ctx, time.Duration(waitTimeoutSec)*time.Second) + defer cancel() + } + reply, err := router.WaitReply(waitCtx, task.ID, 100*time.Millisecond) + if err != nil { + return nil, err + } + return map[string]interface{}{ + "task": cloneSubagentTask(task), + "reply": reply, + "merged": router.MergeResults([]*tools.RouterReply{reply}), + }, nil + }, + "registry": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + cfg := runtimecfg.Get() + items := make([]map[string]interface{}, 0) + if cfg != nil { + items = make([]map[string]interface{}, 0, len(cfg.Agents.Subagents)) + for agentID, subcfg := range cfg.Agents.Subagents { + promptFileFound := false + if strings.TrimSpace(subcfg.SystemPromptFile) != "" { + if absPath, err := al.resolvePromptFilePath(subcfg.SystemPromptFile); err == nil { + if info, statErr := os.Stat(absPath); statErr == nil && !info.IsDir() { + promptFileFound = true + } } } - } - toolInfo := al.describeSubagentTools(subcfg.Tools.Allowlist) - items = append(items, map[string]interface{}{ - "agent_id": agentID, - "enabled": subcfg.Enabled, - "type": subcfg.Type, - "transport": fallbackString(strings.TrimSpace(subcfg.Transport), "local"), - "node_id": strings.TrimSpace(subcfg.NodeID), - "parent_agent_id": strings.TrimSpace(subcfg.ParentAgentID), - "notify_main_policy": fallbackString(strings.TrimSpace(subcfg.NotifyMainPolicy), "final_only"), - "display_name": subcfg.DisplayName, - "role": subcfg.Role, - "description": subcfg.Description, - "system_prompt_file": subcfg.SystemPromptFile, - "prompt_file_found": promptFileFound, - "memory_namespace": subcfg.MemoryNamespace, - "tool_allowlist": append([]string(nil), subcfg.Tools.Allowlist...), - "tool_visibility": toolInfo, - "effective_tools": toolInfo["effective_tools"], - "inherited_tools": toolInfo["inherited_tools"], - "routing_keywords": routeKeywordsForRegistry(cfg.Agents.Router.Rules, agentID), - "managed_by": "config.json", - }) - } - } - if store := sm.ProfileStore(); store != nil { - if profiles, err := store.List(); err == nil { - for _, profile := range profiles { - if strings.TrimSpace(profile.ManagedBy) != "node_registry" { - continue - } - toolInfo := al.describeSubagentTools(profile.ToolAllowlist) + toolInfo := al.describeSubagentTools(subcfg.Tools.Allowlist) items = append(items, map[string]interface{}{ - "agent_id": profile.AgentID, - "enabled": strings.EqualFold(strings.TrimSpace(profile.Status), "active"), - "type": "node_branch", - "transport": profile.Transport, - "node_id": profile.NodeID, - "parent_agent_id": profile.ParentAgentID, - "notify_main_policy": fallbackString(strings.TrimSpace(profile.NotifyMainPolicy), "final_only"), - "display_name": profile.Name, - "role": profile.Role, - "description": "Node-registered remote main agent branch", - "system_prompt_file": profile.SystemPromptFile, - "prompt_file_found": false, - "memory_namespace": profile.MemoryNamespace, - "tool_allowlist": append([]string(nil), profile.ToolAllowlist...), + "agent_id": agentID, + "enabled": subcfg.Enabled, + "type": subcfg.Type, + "transport": fallbackString(strings.TrimSpace(subcfg.Transport), "local"), + "node_id": strings.TrimSpace(subcfg.NodeID), + "parent_agent_id": strings.TrimSpace(subcfg.ParentAgentID), + "notify_main_policy": fallbackString(strings.TrimSpace(subcfg.NotifyMainPolicy), "final_only"), + "display_name": subcfg.DisplayName, + "role": subcfg.Role, + "description": subcfg.Description, + "system_prompt_file": subcfg.SystemPromptFile, + "prompt_file_found": promptFileFound, + "memory_namespace": subcfg.MemoryNamespace, + "tool_allowlist": append([]string(nil), subcfg.Tools.Allowlist...), "tool_visibility": toolInfo, "effective_tools": toolInfo["effective_tools"], "inherited_tools": toolInfo["inherited_tools"], - "routing_keywords": []string{}, - "managed_by": profile.ManagedBy, + "routing_keywords": routeKeywordsForRegistry(cfg.Agents.Router.Rules, agentID), + "managed_by": "config.json", }) } } - } - sort.Slice(items, func(i, j int) bool { - left, _ := items[i]["agent_id"].(string) - right, _ := items[j]["agent_id"].(string) - return left < right - }) - return map[string]interface{}{"items": items}, nil - case "set_config_subagent_enabled": - agentID := runtimeStringArg(args, "agent_id") - if agentID == "" { - return nil, fmt.Errorf("agent_id is required") - } - if al.isProtectedMainAgent(agentID) { - return nil, fmt.Errorf("main agent %q cannot be disabled", agentID) - } - enabled, ok := runtimeBoolArg(args, "enabled") - if !ok { - return nil, fmt.Errorf("enabled is required") - } - return tools.UpsertConfigSubagent(al.configPath, map[string]interface{}{ - "agent_id": agentID, - "enabled": enabled, - }) - case "delete_config_subagent": - agentID := runtimeStringArg(args, "agent_id") - if agentID == "" { - return nil, fmt.Errorf("agent_id is required") - } - if al.isProtectedMainAgent(agentID) { - return nil, fmt.Errorf("main agent %q cannot be deleted", agentID) - } - return tools.DeleteConfigSubagent(al.configPath, agentID) - case "upsert_config_subagent": - return tools.UpsertConfigSubagent(al.configPath, args) - case "prompt_file_get": - relPath := runtimeStringArg(args, "path") - if relPath == "" { - return nil, fmt.Errorf("path is required") - } - absPath, err := al.resolvePromptFilePath(relPath) - if err != nil { - return nil, err - } - data, err := os.ReadFile(absPath) - if err != nil { - if os.IsNotExist(err) { - return map[string]interface{}{"found": false, "path": relPath, "content": ""}, nil + if store := sm.ProfileStore(); store != nil { + if profiles, err := store.List(); err == nil { + for _, profile := range profiles { + if strings.TrimSpace(profile.ManagedBy) != "node_registry" { + continue + } + toolInfo := al.describeSubagentTools(profile.ToolAllowlist) + items = append(items, map[string]interface{}{ + "agent_id": profile.AgentID, + "enabled": strings.EqualFold(strings.TrimSpace(profile.Status), "active"), + "type": "node_branch", + "transport": profile.Transport, + "node_id": profile.NodeID, + "parent_agent_id": profile.ParentAgentID, + "notify_main_policy": fallbackString(strings.TrimSpace(profile.NotifyMainPolicy), "final_only"), + "display_name": profile.Name, + "role": profile.Role, + "description": "Node-registered remote main agent branch", + "system_prompt_file": profile.SystemPromptFile, + "prompt_file_found": false, + "memory_namespace": profile.MemoryNamespace, + "tool_allowlist": append([]string(nil), profile.ToolAllowlist...), + "tool_visibility": toolInfo, + "effective_tools": toolInfo["effective_tools"], + "inherited_tools": toolInfo["inherited_tools"], + "routing_keywords": []string{}, + "managed_by": profile.ManagedBy, + }) + } + } } - return nil, err - } - return map[string]interface{}{"found": true, "path": relPath, "content": string(data)}, nil - case "prompt_file_set": - relPath := runtimeStringArg(args, "path") - if relPath == "" { - return nil, fmt.Errorf("path is required") - } - content := runtimeRawStringArg(args, "content") - absPath, err := al.resolvePromptFilePath(relPath) - if err != nil { - return nil, err - } - if err := os.MkdirAll(filepath.Dir(absPath), 0755); err != nil { - return nil, err - } - if err := os.WriteFile(absPath, []byte(content), 0644); err != nil { - return nil, err - } - return map[string]interface{}{"ok": true, "path": relPath, "bytes": len(content)}, nil - case "prompt_file_bootstrap": - agentID := runtimeStringArg(args, "agent_id") - if agentID == "" { - return nil, fmt.Errorf("agent_id is required") - } - relPath := runtimeStringArg(args, "path") - if relPath == "" { - relPath = filepath.ToSlash(filepath.Join("agents", agentID, "AGENT.md")) - } - absPath, err := al.resolvePromptFilePath(relPath) - if err != nil { - return nil, err - } - overwrite, _ := args["overwrite"].(bool) - if _, err := os.Stat(absPath); err == nil && !overwrite { - data, readErr := os.ReadFile(absPath) - if readErr != nil { - return nil, readErr + sort.Slice(items, func(i, j int) bool { + left, _ := items[i]["agent_id"].(string) + right, _ := items[j]["agent_id"].(string) + return left < right + }) + return map[string]interface{}{"items": items}, nil + }, + "set_config_subagent_enabled": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + agentID := runtimeStringArg(args, "agent_id") + if agentID == "" { + return nil, fmt.Errorf("agent_id is required") + } + if al.isProtectedMainAgent(agentID) { + return nil, fmt.Errorf("main agent %q cannot be disabled", agentID) + } + enabled, ok := runtimeBoolArg(args, "enabled") + if !ok { + return nil, fmt.Errorf("enabled is required") + } + return tools.UpsertConfigSubagent(al.configPath, map[string]interface{}{ + "agent_id": agentID, + "enabled": enabled, + }) + }, + "delete_config_subagent": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + agentID := runtimeStringArg(args, "agent_id") + if agentID == "" { + return nil, fmt.Errorf("agent_id is required") + } + if al.isProtectedMainAgent(agentID) { + return nil, fmt.Errorf("main agent %q cannot be deleted", agentID) + } + return tools.DeleteConfigSubagent(al.configPath, agentID) + }, + "upsert_config_subagent": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + return tools.UpsertConfigSubagent(al.configPath, args) + }, + "prompt_file_get": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + relPath := runtimeStringArg(args, "path") + if relPath == "" { + return nil, fmt.Errorf("path is required") + } + absPath, err := al.resolvePromptFilePath(relPath) + if err != nil { + return nil, err + } + data, err := os.ReadFile(absPath) + if err != nil { + if os.IsNotExist(err) { + return map[string]interface{}{"found": false, "path": relPath, "content": ""}, nil + } + return nil, err + } + return map[string]interface{}{"found": true, "path": relPath, "content": string(data)}, nil + }, + "prompt_file_set": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + relPath := runtimeStringArg(args, "path") + if relPath == "" { + return nil, fmt.Errorf("path is required") + } + content := runtimeRawStringArg(args, "content") + absPath, err := al.resolvePromptFilePath(relPath) + if err != nil { + return nil, err + } + if err := os.MkdirAll(filepath.Dir(absPath), 0755); err != nil { + return nil, err + } + if err := os.WriteFile(absPath, []byte(content), 0644); err != nil { + return nil, err + } + return map[string]interface{}{"ok": true, "path": relPath, "bytes": len(content)}, nil + }, + "prompt_file_bootstrap": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + agentID := runtimeStringArg(args, "agent_id") + if agentID == "" { + return nil, fmt.Errorf("agent_id is required") + } + relPath := runtimeStringArg(args, "path") + if relPath == "" { + relPath = filepath.ToSlash(filepath.Join("agents", agentID, "AGENT.md")) + } + absPath, err := al.resolvePromptFilePath(relPath) + if err != nil { + return nil, err + } + overwrite, _ := args["overwrite"].(bool) + if _, err := os.Stat(absPath); err == nil && !overwrite { + data, readErr := os.ReadFile(absPath) + if readErr != nil { + return nil, readErr + } + return map[string]interface{}{ + "ok": true, + "created": false, + "path": relPath, + "content": string(data), + }, nil + } + if err := os.MkdirAll(filepath.Dir(absPath), 0755); err != nil { + return nil, err + } + content := buildPromptTemplate(agentID, runtimeStringArg(args, "role"), runtimeStringArg(args, "display_name")) + if err := os.WriteFile(absPath, []byte(content), 0644); err != nil { + return nil, err } return map[string]interface{}{ "ok": true, - "created": false, + "created": true, "path": relPath, - "content": string(data), + "content": content, }, nil - } - if err := os.MkdirAll(filepath.Dir(absPath), 0755); err != nil { - return nil, err - } - content := buildPromptTemplate(agentID, runtimeStringArg(args, "role"), runtimeStringArg(args, "display_name")) - if err := os.WriteFile(absPath, []byte(content), 0644); err != nil { - return nil, err - } - return map[string]interface{}{ - "ok": true, - "created": true, - "path": relPath, - "content": content, - }, nil - case "kill": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - ok := sm.KillTask(taskID) - return map[string]interface{}{"ok": ok}, nil - case "resume": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - label, ok := sm.ResumeTask(ctx, taskID) - return map[string]interface{}{"ok": ok, "label": label}, nil - case "steer": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - msg := runtimeStringArg(args, "message") - if msg == "" { - return nil, fmt.Errorf("message is required") - } - ok := sm.SteerTask(taskID, msg) - return map[string]interface{}{"ok": ok}, nil - case "send": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - msg := runtimeStringArg(args, "message") - if msg == "" { - return nil, fmt.Errorf("message is required") - } - ok := sm.SendTaskMessage(taskID, msg) - return map[string]interface{}{"ok": ok}, nil - case "reply": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - msg := runtimeStringArg(args, "message") - if msg == "" { - return nil, fmt.Errorf("message is required") - } - ok := sm.ReplyToTask(taskID, runtimeStringArg(args, "message_id"), msg) - return map[string]interface{}{"ok": ok}, nil - case "ack": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - messageID := runtimeStringArg(args, "message_id") - if messageID == "" { - return nil, fmt.Errorf("message_id is required") - } - ok := sm.AckTaskMessage(taskID, messageID) - return map[string]interface{}{"ok": ok}, nil - case "thread", "trace": - threadID := runtimeStringArg(args, "thread_id") - if threadID == "" { + }, + "kill": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + ok := sm.KillTask(taskID) + return map[string]interface{}{"ok": ok}, nil + }, + "resume": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + label, ok := sm.ResumeTask(ctx, taskID) + return map[string]interface{}{"ok": ok, "label": label}, nil + }, + "steer": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + msg := runtimeStringArg(args, "message") + if msg == "" { + return nil, fmt.Errorf("message is required") + } + ok := sm.SteerTask(taskID, msg) + return map[string]interface{}{"ok": ok}, nil + }, + "send": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + msg := runtimeStringArg(args, "message") + if msg == "" { + return nil, fmt.Errorf("message is required") + } + ok := sm.SendTaskMessage(taskID, msg) + return map[string]interface{}{"ok": ok}, nil + }, + "reply": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + msg := runtimeStringArg(args, "message") + if msg == "" { + return nil, fmt.Errorf("message is required") + } + ok := sm.ReplyToTask(taskID, runtimeStringArg(args, "message_id"), msg) + return map[string]interface{}{"ok": ok}, nil + }, + "ack": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + messageID := runtimeStringArg(args, "message_id") + if messageID == "" { + return nil, fmt.Errorf("message_id is required") + } + ok := sm.AckTaskMessage(taskID, messageID) + return map[string]interface{}{"ok": ok}, nil + }, + "thread": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + threadID := runtimeStringArg(args, "thread_id") + if threadID == "" { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + task, ok := sm.GetTask(taskID) + if !ok { + return map[string]interface{}{"found": false}, nil + } + threadID = strings.TrimSpace(task.ThreadID) + } + if threadID == "" { + return nil, fmt.Errorf("thread_id is required") + } + thread, ok := sm.Thread(threadID) + if !ok { + return map[string]interface{}{"found": false}, nil + } + items, err := sm.ThreadMessages(threadID, runtimeIntArg(args, "limit", 50)) + if err != nil { + return nil, err + } + return map[string]interface{}{"found": true, "thread": thread, "messages": items}, nil + }, + "stream": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) if err != nil { return nil, err @@ -358,74 +418,51 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a if !ok { return map[string]interface{}{"found": false}, nil } - threadID = strings.TrimSpace(task.ThreadID) - } - if threadID == "" { - return nil, fmt.Errorf("thread_id is required") - } - thread, ok := sm.Thread(threadID) - if !ok { - return map[string]interface{}{"found": false}, nil - } - items, err := sm.ThreadMessages(threadID, runtimeIntArg(args, "limit", 50)) - if err != nil { - return nil, err - } - return map[string]interface{}{"found": true, "thread": thread, "messages": items}, nil - case "stream": - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) - if err != nil { - return nil, err - } - task, ok := sm.GetTask(taskID) - if !ok { - return map[string]interface{}{"found": false}, nil - } - events, err := sm.Events(taskID, runtimeIntArg(args, "limit", 100)) - if err != nil { - return nil, err - } - var thread *tools.AgentThread - var messages []tools.AgentMessage - if strings.TrimSpace(task.ThreadID) != "" { - if th, ok := sm.Thread(task.ThreadID); ok { - thread = th - } - messages, err = sm.ThreadMessages(task.ThreadID, runtimeIntArg(args, "limit", 100)) + events, err := sm.Events(taskID, runtimeIntArg(args, "limit", 100)) if err != nil { return nil, err } - } - stream := mergeSubagentStream(events, messages) - return map[string]interface{}{ - "found": true, - "task": cloneSubagentTask(task), - "thread": thread, - "items": stream, - }, nil - case "inbox": - agentID := runtimeStringArg(args, "agent_id") - if agentID == "" { - taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + var thread *tools.AgentThread + var messages []tools.AgentMessage + if strings.TrimSpace(task.ThreadID) != "" { + if th, ok := sm.Thread(task.ThreadID); ok { + thread = th + } + messages, err = sm.ThreadMessages(task.ThreadID, runtimeIntArg(args, "limit", 100)) + if err != nil { + return nil, err + } + } + stream := mergeSubagentStream(events, messages) + return map[string]interface{}{ + "found": true, + "task": cloneSubagentTask(task), + "thread": thread, + "items": stream, + }, nil + }, + "inbox": func(ctx context.Context, args map[string]interface{}) (interface{}, error) { + agentID := runtimeStringArg(args, "agent_id") + if agentID == "" { + taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) + if err != nil { + return nil, err + } + task, ok := sm.GetTask(taskID) + if !ok { + return map[string]interface{}{"found": false}, nil + } + agentID = strings.TrimSpace(task.AgentID) + } + if agentID == "" { + return nil, fmt.Errorf("agent_id is required") + } + items, err := sm.Inbox(agentID, runtimeIntArg(args, "limit", 50)) if err != nil { return nil, err } - task, ok := sm.GetTask(taskID) - if !ok { - return map[string]interface{}{"found": false}, nil - } - agentID = strings.TrimSpace(task.AgentID) - } - if agentID == "" { - return nil, fmt.Errorf("agent_id is required") - } - items, err := sm.Inbox(agentID, runtimeIntArg(args, "limit", 50)) - if err != nil { - return nil, err - } - return map[string]interface{}{"found": true, "agent_id": agentID, "messages": items}, nil - default: - return nil, fmt.Errorf("unsupported action: %s", action) + return map[string]interface{}{"found": true, "agent_id": agentID, "messages": items}, nil + }, } } diff --git a/pkg/api/rpc_http.go b/pkg/api/rpc_http.go new file mode 100644 index 0000000..e4649a5 --- /dev/null +++ b/pkg/api/rpc_http.go @@ -0,0 +1,380 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "strings" + + rpcpkg "github.com/YspCoder/clawgo/pkg/rpc" + "github.com/YspCoder/clawgo/pkg/tools" +) + +func (s *Server) handleSubagentRPC(w http.ResponseWriter, r *http.Request) { + s.handleRPC(w, r, s.subagentRPCRegistry()) +} + +func (s *Server) handleNodeRPC(w http.ResponseWriter, r *http.Request) { + s.handleRPC(w, r, s.nodeRPCRegistry()) +} + +func (s *Server) handleProviderRPC(w http.ResponseWriter, r *http.Request) { + s.handleRPC(w, r, s.providerRPCRegistry()) +} + +func (s *Server) handleWorkspaceRPC(w http.ResponseWriter, r *http.Request) { + s.handleRPC(w, r, s.workspaceRPCRegistry()) +} + +func (s *Server) handleConfigRPC(w http.ResponseWriter, r *http.Request) { + s.handleRPC(w, r, s.configRPCRegistry()) +} + +func (s *Server) handleCronRPC(w http.ResponseWriter, r *http.Request) { + s.handleRPC(w, r, s.cronRPCRegistry()) +} + +func (s *Server) handleRPC(w http.ResponseWriter, r *http.Request, registry *rpcpkg.Registry) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var req rpcpkg.Request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeRPCError(w, http.StatusBadRequest, req.RequestID, rpcError("invalid_argument", "invalid json", nil, false)) + return + } + result, rpcErr := registry.Handle(r.Context(), req) + if rpcErr != nil { + writeRPCError(w, rpcHTTPStatus(rpcErr), req.RequestID, rpcErr) + return + } + writeJSON(w, rpcpkg.Response{ + OK: true, + Result: result, + RequestID: strings.TrimSpace(req.RequestID), + }) +} + +func (s *Server) buildSubagentRegistry() *rpcpkg.Registry { + svc := s.subagentRPCService() + reg := rpcpkg.NewRegistry() + rpcpkg.RegisterJSON(reg, "subagent.list", func(ctx context.Context, req rpcpkg.ListSubagentsRequest) (interface{}, *rpcpkg.Error) { + return svc.List(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "subagent.snapshot", func(ctx context.Context, req rpcpkg.SnapshotRequest) (interface{}, *rpcpkg.Error) { + return svc.Snapshot(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "subagent.get", func(ctx context.Context, req rpcpkg.GetSubagentRequest) (interface{}, *rpcpkg.Error) { + return svc.Get(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "subagent.spawn", func(ctx context.Context, req rpcpkg.SpawnSubagentRequest) (interface{}, *rpcpkg.Error) { + return svc.Spawn(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "subagent.dispatch_and_wait", func(ctx context.Context, req rpcpkg.DispatchAndWaitRequest) (interface{}, *rpcpkg.Error) { + return svc.DispatchAndWait(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "subagent.registry", func(ctx context.Context, req rpcpkg.RegistryRequest) (interface{}, *rpcpkg.Error) { + return svc.Registry(ctx, req) + }) + return reg +} + +func (s *Server) subagentRPCRegistry() *rpcpkg.Registry { + if s == nil { + return rpcpkg.NewRegistry() + } + s.subagentRPCOnce.Do(func() { + s.subagentRPCReg = s.buildSubagentRegistry() + }) + if s.subagentRPCReg == nil { + return rpcpkg.NewRegistry() + } + return s.subagentRPCReg +} + +func (s *Server) buildNodeRegistry() *rpcpkg.Registry { + svc := s.nodeRPCService() + reg := rpcpkg.NewRegistry() + rpcpkg.RegisterJSON(reg, "node.register", func(ctx context.Context, req rpcpkg.RegisterNodeRequest) (interface{}, *rpcpkg.Error) { + return svc.Register(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "node.heartbeat", func(ctx context.Context, req rpcpkg.HeartbeatNodeRequest) (interface{}, *rpcpkg.Error) { + return svc.Heartbeat(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "node.dispatch", func(ctx context.Context, req rpcpkg.DispatchNodeRequest) (interface{}, *rpcpkg.Error) { + return svc.Dispatch(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "node.artifact.list", func(ctx context.Context, req rpcpkg.ListNodeArtifactsRequest) (interface{}, *rpcpkg.Error) { + return svc.ListArtifacts(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "node.artifact.get", func(ctx context.Context, req rpcpkg.GetNodeArtifactRequest) (interface{}, *rpcpkg.Error) { + return svc.GetArtifact(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "node.artifact.delete", func(ctx context.Context, req rpcpkg.DeleteNodeArtifactRequest) (interface{}, *rpcpkg.Error) { + return svc.DeleteArtifact(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "node.artifact.prune", func(ctx context.Context, req rpcpkg.PruneNodeArtifactsRequest) (interface{}, *rpcpkg.Error) { + return svc.PruneArtifacts(ctx, req) + }) + return reg +} + +func (s *Server) nodeRPCRegistry() *rpcpkg.Registry { + if s == nil { + return rpcpkg.NewRegistry() + } + s.nodeRPCOnce.Do(func() { + s.nodeRPCReg = s.buildNodeRegistry() + }) + if s.nodeRPCReg == nil { + return rpcpkg.NewRegistry() + } + return s.nodeRPCReg +} + +func (s *Server) buildProviderRegistry() *rpcpkg.Registry { + svc := s.providerRPCService() + reg := rpcpkg.NewRegistry() + rpcpkg.RegisterJSON(reg, "provider.list_models", func(ctx context.Context, req rpcpkg.ListProviderModelsRequest) (interface{}, *rpcpkg.Error) { + return svc.ListModels(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "provider.models.update", func(ctx context.Context, req rpcpkg.UpdateProviderModelsRequest) (interface{}, *rpcpkg.Error) { + return svc.UpdateModels(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "provider.chat", func(ctx context.Context, req rpcpkg.ProviderChatRequest) (interface{}, *rpcpkg.Error) { + return svc.Chat(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "provider.count_tokens", func(ctx context.Context, req rpcpkg.ProviderCountTokensRequest) (interface{}, *rpcpkg.Error) { + return svc.CountTokens(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "provider.runtime.view", func(ctx context.Context, req rpcpkg.ProviderRuntimeViewRequest) (interface{}, *rpcpkg.Error) { + return svc.RuntimeView(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "provider.runtime.action", func(ctx context.Context, req rpcpkg.ProviderRuntimeActionRequest) (interface{}, *rpcpkg.Error) { + return svc.RuntimeAction(ctx, req) + }) + return reg +} + +func (s *Server) providerRPCRegistry() *rpcpkg.Registry { + if s == nil { + return rpcpkg.NewRegistry() + } + s.providerRPCOnce.Do(func() { + s.providerRPCReg = s.buildProviderRegistry() + }) + if s.providerRPCReg == nil { + return rpcpkg.NewRegistry() + } + return s.providerRPCReg +} + +func (s *Server) buildWorkspaceRegistry() *rpcpkg.Registry { + svc := s.workspaceRPCService() + reg := rpcpkg.NewRegistry() + rpcpkg.RegisterJSON(reg, "workspace.list_files", func(ctx context.Context, req rpcpkg.ListWorkspaceFilesRequest) (interface{}, *rpcpkg.Error) { + return svc.ListFiles(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "workspace.read_file", func(ctx context.Context, req rpcpkg.ReadWorkspaceFileRequest) (interface{}, *rpcpkg.Error) { + return svc.ReadFile(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "workspace.write_file", func(ctx context.Context, req rpcpkg.WriteWorkspaceFileRequest) (interface{}, *rpcpkg.Error) { + return svc.WriteFile(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "workspace.delete_file", func(ctx context.Context, req rpcpkg.DeleteWorkspaceFileRequest) (interface{}, *rpcpkg.Error) { + return svc.DeleteFile(ctx, req) + }) + return reg +} + +func (s *Server) workspaceRPCRegistry() *rpcpkg.Registry { + if s == nil { + return rpcpkg.NewRegistry() + } + s.workspaceRPCOnce.Do(func() { + s.workspaceRPCReg = s.buildWorkspaceRegistry() + }) + if s.workspaceRPCReg == nil { + return rpcpkg.NewRegistry() + } + return s.workspaceRPCReg +} + +func (s *Server) buildConfigRegistry() *rpcpkg.Registry { + svc := s.configRPCService() + reg := rpcpkg.NewRegistry() + rpcpkg.RegisterJSON(reg, "config.view", func(ctx context.Context, req rpcpkg.ConfigViewRequest) (interface{}, *rpcpkg.Error) { + return svc.View(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "config.save", func(ctx context.Context, req rpcpkg.ConfigSaveRequest) (interface{}, *rpcpkg.Error) { + return svc.Save(ctx, req) + }) + return reg +} + +func (s *Server) configRPCRegistry() *rpcpkg.Registry { + if s == nil { + return rpcpkg.NewRegistry() + } + s.configRPCOnce.Do(func() { + s.configRPCReg = s.buildConfigRegistry() + }) + if s.configRPCReg == nil { + return rpcpkg.NewRegistry() + } + return s.configRPCReg +} + +func (s *Server) buildCronRegistry() *rpcpkg.Registry { + svc := s.cronRPCService() + reg := rpcpkg.NewRegistry() + rpcpkg.RegisterJSON(reg, "cron.list", func(ctx context.Context, req rpcpkg.ListCronJobsRequest) (interface{}, *rpcpkg.Error) { + return svc.List(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "cron.get", func(ctx context.Context, req rpcpkg.GetCronJobRequest) (interface{}, *rpcpkg.Error) { + return svc.Get(ctx, req) + }) + rpcpkg.RegisterJSON(reg, "cron.mutate", func(ctx context.Context, req rpcpkg.MutateCronJobRequest) (interface{}, *rpcpkg.Error) { + return svc.Mutate(ctx, req) + }) + return reg +} + +func (s *Server) cronRPCRegistry() *rpcpkg.Registry { + if s == nil { + return rpcpkg.NewRegistry() + } + s.cronRPCOnce.Do(func() { + s.cronRPCReg = s.buildCronRegistry() + }) + if s.cronRPCReg == nil { + return rpcpkg.NewRegistry() + } + return s.cronRPCReg +} + +func writeRPCError(w http.ResponseWriter, status int, requestID string, rpcErr *rpcpkg.Error) { + if rpcErr == nil { + rpcErr = rpcError("internal", "rpc error", nil, false) + } + writeJSONStatus(w, status, rpcpkg.Response{ + OK: false, + Error: rpcErr, + RequestID: strings.TrimSpace(requestID), + }) +} + +func (s *Server) handleSubagentLegacyAction(ctx context.Context, action string, args map[string]interface{}) (interface{}, *rpcpkg.Error) { + registry := s.subagentRPCRegistry() + req := rpcpkg.Request{ + Method: legacySubagentActionMethod(action), + Params: mustJSONMarshal(mapSubagentLegacyArgs(action, args)), + } + result, rpcErr := registry.Handle(ctx, req) + if rpcErr != nil && !strings.HasPrefix(strings.TrimSpace(req.Method), "subagent.") { + if s.onSubagents == nil { + return nil, rpcError("unavailable", "subagent runtime handler not configured", nil, false) + } + fallback, err := s.onSubagents(ctx, action, args) + if err != nil { + return nil, rpcErrorFrom(err) + } + return fallback, nil + } + return result, rpcErr +} + +var legacySubagentActionMethods = map[string]string{ + "": "subagent.list", + "list": "subagent.list", + "snapshot": "subagent.snapshot", + "get": "subagent.get", + "info": "subagent.get", + "spawn": "subagent.spawn", + "create": "subagent.spawn", + "dispatch_and_wait": "subagent.dispatch_and_wait", + "registry": "subagent.registry", +} + +func legacySubagentActionMethod(action string) string { + normalized := strings.ToLower(strings.TrimSpace(action)) + if method, ok := legacySubagentActionMethods[normalized]; ok { + return method + } + return strings.TrimSpace(action) +} + +var legacySubagentArgMappers = map[string]func(map[string]interface{}) interface{}{ + "snapshot": func(args map[string]interface{}) interface{} { + return rpcpkg.SnapshotRequest{Limit: tools.MapIntArg(args, "limit", 0)} + }, + "get": func(args map[string]interface{}) interface{} { + return rpcpkg.GetSubagentRequest{ID: tools.MapStringArg(args, "id")} + }, + "info": func(args map[string]interface{}) interface{} { + return rpcpkg.GetSubagentRequest{ID: tools.MapStringArg(args, "id")} + }, + "spawn": buildLegacySpawnSubagentRequest, + "create": func(args map[string]interface{}) interface{} { + return buildLegacySpawnSubagentRequest(args) + }, + "dispatch_and_wait": func(args map[string]interface{}) interface{} { + return rpcpkg.DispatchAndWaitRequest{ + Task: tools.MapStringArg(args, "task"), + Label: tools.MapStringArg(args, "label"), + Role: tools.MapStringArg(args, "role"), + AgentID: tools.MapStringArg(args, "agent_id"), + ThreadID: tools.MapStringArg(args, "thread_id"), + CorrelationID: tools.MapStringArg(args, "correlation_id"), + ParentRunID: tools.MapStringArg(args, "parent_run_id"), + MaxRetries: tools.MapIntArg(args, "max_retries", 0), + RetryBackoffMS: tools.MapIntArg(args, "retry_backoff_ms", 0), + TimeoutSec: tools.MapIntArg(args, "timeout_sec", 0), + MaxTaskChars: tools.MapIntArg(args, "max_task_chars", 0), + MaxResultChars: tools.MapIntArg(args, "max_result_chars", 0), + WaitTimeoutSec: tools.MapIntArg(args, "wait_timeout_sec", 0), + Channel: firstNonEmptyString(tools.MapStringArg(args, "channel"), tools.MapStringArg(args, "origin_channel")), + ChatID: firstNonEmptyString(tools.MapStringArg(args, "chat_id"), tools.MapStringArg(args, "origin_chat_id")), + } + }, +} + +func mapSubagentLegacyArgs(action string, args map[string]interface{}) interface{} { + normalized := strings.ToLower(strings.TrimSpace(action)) + if mapper, ok := legacySubagentArgMappers[normalized]; ok && mapper != nil { + return mapper(args) + } + return args +} + +func buildLegacySpawnSubagentRequest(args map[string]interface{}) interface{} { + return rpcpkg.SpawnSubagentRequest{ + Task: tools.MapStringArg(args, "task"), + Label: tools.MapStringArg(args, "label"), + Role: tools.MapStringArg(args, "role"), + AgentID: tools.MapStringArg(args, "agent_id"), + MaxRetries: tools.MapIntArg(args, "max_retries", 0), + RetryBackoffMS: tools.MapIntArg(args, "retry_backoff_ms", 0), + TimeoutSec: tools.MapIntArg(args, "timeout_sec", 0), + MaxTaskChars: tools.MapIntArg(args, "max_task_chars", 0), + MaxResultChars: tools.MapIntArg(args, "max_result_chars", 0), + Channel: firstNonEmptyString(tools.MapStringArg(args, "channel"), tools.MapStringArg(args, "origin_channel")), + ChatID: firstNonEmptyString(tools.MapStringArg(args, "chat_id"), tools.MapStringArg(args, "origin_chat_id")), + } +} + +func mustJSONMarshal(value interface{}) json.RawMessage { + if value == nil { + return json.RawMessage([]byte("{}")) + } + data, err := json.Marshal(value) + if err != nil { + return json.RawMessage([]byte("{}")) + } + return data +} diff --git a/pkg/api/rpc_services.go b/pkg/api/rpc_services.go new file mode 100644 index 0000000..f5240bb --- /dev/null +++ b/pkg/api/rpc_services.go @@ -0,0 +1,1133 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + "time" + + cfgpkg "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/nodes" + "github.com/YspCoder/clawgo/pkg/providers" + rpcpkg "github.com/YspCoder/clawgo/pkg/rpc" + "github.com/YspCoder/clawgo/pkg/tools" +) + +func mustPrettyJSON(v interface{}) []byte { + out, _ := json.MarshalIndent(v, "", " ") + return out +} + +type subagentRPCAdapter struct { + server *Server +} + +func (a *subagentRPCAdapter) call(ctx context.Context, action string, args map[string]interface{}) (interface{}, *rpcpkg.Error) { + if a == nil || a.server == nil || a.server.onSubagents == nil { + return nil, rpcError("unavailable", "subagent runtime handler not configured", nil, false) + } + result, err := a.server.onSubagents(ctx, action, args) + if err != nil { + return nil, rpcErrorFrom(err) + } + return result, nil +} + +func (a *subagentRPCAdapter) List(ctx context.Context, _ rpcpkg.ListSubagentsRequest) (*rpcpkg.ListSubagentsResponse, *rpcpkg.Error) { + result, rpcErr := a.call(ctx, "list", nil) + if rpcErr != nil { + return nil, rpcErr + } + var payload struct { + Items []*map[string]interface{} `json:"items"` + } + items, err := decodeResultSliceField[resultWrapperSubagentTask](result, "items") + if err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + _ = payload + out := make([]*resultWrapperSubagentTask, 0, len(items)) + for _, item := range items { + if item != nil { + out = append(out, item) + } + } + return &rpcpkg.ListSubagentsResponse{Items: unwrapSubagentTasks(out)}, nil +} + +func (a *subagentRPCAdapter) Snapshot(ctx context.Context, req rpcpkg.SnapshotRequest) (*rpcpkg.SnapshotResponse, *rpcpkg.Error) { + result, rpcErr := a.call(ctx, "snapshot", map[string]interface{}{"limit": req.Limit}) + if rpcErr != nil { + return nil, rpcErr + } + var snapshot struct { + Snapshot json.RawMessage `json:"snapshot"` + } + if err := decodeResultObject(result, &snapshot); err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + var out rpcpkg.SnapshotResponse + if len(snapshot.Snapshot) > 0 { + if err := json.Unmarshal(snapshot.Snapshot, &out.Snapshot); err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + } + return &out, nil +} + +func (a *subagentRPCAdapter) Get(ctx context.Context, req rpcpkg.GetSubagentRequest) (*rpcpkg.GetSubagentResponse, *rpcpkg.Error) { + result, rpcErr := a.call(ctx, "get", map[string]interface{}{"id": req.ID}) + if rpcErr != nil { + return nil, rpcErr + } + var payload struct { + Found bool `json:"found"` + Task *resultWrapperSubagentTask `json:"task"` + } + if err := decodeResultObject(result, &payload); err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + return &rpcpkg.GetSubagentResponse{Found: payload.Found, Task: unwrapSubagentTask(payload.Task)}, nil +} + +func (a *subagentRPCAdapter) Spawn(ctx context.Context, req rpcpkg.SpawnSubagentRequest) (*rpcpkg.SpawnSubagentResponse, *rpcpkg.Error) { + result, rpcErr := a.call(ctx, "spawn", map[string]interface{}{ + "task": req.Task, + "label": req.Label, + "role": req.Role, + "agent_id": req.AgentID, + "max_retries": req.MaxRetries, + "retry_backoff_ms": req.RetryBackoffMS, + "timeout_sec": req.TimeoutSec, + "max_task_chars": req.MaxTaskChars, + "max_result_chars": req.MaxResultChars, + "channel": req.Channel, + "chat_id": req.ChatID, + }) + if rpcErr != nil { + return nil, rpcErr + } + var payload rpcpkg.SpawnSubagentResponse + if err := decodeResultObject(result, &payload); err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + return &payload, nil +} + +func (a *subagentRPCAdapter) DispatchAndWait(ctx context.Context, req rpcpkg.DispatchAndWaitRequest) (*rpcpkg.DispatchAndWaitResponse, *rpcpkg.Error) { + result, rpcErr := a.call(ctx, "dispatch_and_wait", map[string]interface{}{ + "task": req.Task, + "label": req.Label, + "role": req.Role, + "agent_id": req.AgentID, + "thread_id": req.ThreadID, + "correlation_id": req.CorrelationID, + "parent_run_id": req.ParentRunID, + "channel": req.Channel, + "chat_id": req.ChatID, + "max_retries": req.MaxRetries, + "retry_backoff_ms": req.RetryBackoffMS, + "timeout_sec": req.TimeoutSec, + "max_task_chars": req.MaxTaskChars, + "max_result_chars": req.MaxResultChars, + "wait_timeout_sec": req.WaitTimeoutSec, + }) + if rpcErr != nil { + return nil, rpcErr + } + var payload struct { + Task *resultWrapperSubagentTask `json:"task"` + Reply json.RawMessage `json:"reply"` + Merged string `json:"merged"` + } + if err := decodeResultObject(result, &payload); err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + out := &rpcpkg.DispatchAndWaitResponse{ + Task: unwrapSubagentTask(payload.Task), + Merged: payload.Merged, + } + if len(payload.Reply) > 0 { + var reply resultWrapperRouterReply + if err := json.Unmarshal(payload.Reply, &reply); err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + out.Reply = unwrapRouterReply(&reply) + } + return out, nil +} + +func (a *subagentRPCAdapter) Registry(ctx context.Context, _ rpcpkg.RegistryRequest) (*rpcpkg.RegistryResponse, *rpcpkg.Error) { + result, rpcErr := a.call(ctx, "registry", nil) + if rpcErr != nil { + return nil, rpcErr + } + var payload rpcpkg.RegistryResponse + if err := decodeResultObject(result, &payload); err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + return &payload, nil +} + +type nodeRPCAdapter struct { + server *Server +} + +type workspaceRPCAdapter struct { + server *Server +} + +func (a *workspaceRPCAdapter) resolveScopeRoot(scope string) (string, *rpcpkg.Error) { + if a == nil || a.server == nil { + return "", rpcError("unavailable", "server unavailable", nil, false) + } + switch strings.ToLower(strings.TrimSpace(scope)) { + case "", "workspace": + return strings.TrimSpace(a.server.workspacePath), nil + case "memory": + root := filepath.Join(strings.TrimSpace(a.server.workspacePath), "memory") + if err := os.MkdirAll(root, 0755); err != nil { + return "", rpcError("internal", err.Error(), nil, false) + } + return root, nil + default: + return "", rpcError("invalid_argument", "unsupported workspace scope", map[string]interface{}{"scope": scope}, false) + } +} + +func (a *workspaceRPCAdapter) ListFiles(_ context.Context, req rpcpkg.ListWorkspaceFilesRequest) (*rpcpkg.ListWorkspaceFilesResponse, *rpcpkg.Error) { + root, rpcErr := a.resolveScopeRoot(req.Scope) + if rpcErr != nil { + return nil, rpcErr + } + entries, err := os.ReadDir(root) + if err != nil { + return nil, rpcError("internal", err.Error(), nil, false) + } + files := make([]string, 0, len(entries)) + for _, entry := range entries { + if entry.IsDir() { + continue + } + files = append(files, entry.Name()) + } + sort.Strings(files) + return &rpcpkg.ListWorkspaceFilesResponse{Files: files}, nil +} + +func (a *workspaceRPCAdapter) ReadFile(_ context.Context, req rpcpkg.ReadWorkspaceFileRequest) (*rpcpkg.ReadWorkspaceFileResponse, *rpcpkg.Error) { + root, rpcErr := a.resolveScopeRoot(req.Scope) + if rpcErr != nil { + return nil, rpcErr + } + clean, content, found, err := readRelativeTextFile(root, req.Path) + if err != nil { + return nil, rpcError("invalid_argument", err.Error(), nil, false) + } + return &rpcpkg.ReadWorkspaceFileResponse{Path: clean, Found: found, Content: content}, nil +} + +func (a *workspaceRPCAdapter) WriteFile(_ context.Context, req rpcpkg.WriteWorkspaceFileRequest) (*rpcpkg.WriteWorkspaceFileResponse, *rpcpkg.Error) { + root, rpcErr := a.resolveScopeRoot(req.Scope) + if rpcErr != nil { + return nil, rpcErr + } + appendMissing := !strings.EqualFold(strings.TrimSpace(req.Scope), "memory") + clean, err := writeRelativeTextFile(root, req.Path, req.Content, appendMissing) + if err != nil { + return nil, rpcError("invalid_argument", err.Error(), nil, false) + } + return &rpcpkg.WriteWorkspaceFileResponse{Path: clean, Saved: true}, nil +} + +func (a *workspaceRPCAdapter) DeleteFile(_ context.Context, req rpcpkg.DeleteWorkspaceFileRequest) (*rpcpkg.DeleteWorkspaceFileResponse, *rpcpkg.Error) { + root, rpcErr := a.resolveScopeRoot(req.Scope) + if rpcErr != nil { + return nil, rpcErr + } + clean, full, err := resolveRelativeFilePath(root, req.Path) + if err != nil { + return nil, rpcError("invalid_argument", err.Error(), nil, false) + } + if err := os.Remove(full); err != nil { + if errors.Is(err, os.ErrNotExist) { + return &rpcpkg.DeleteWorkspaceFileResponse{Path: clean, Deleted: false}, nil + } + return nil, rpcError("internal", err.Error(), nil, false) + } + return &rpcpkg.DeleteWorkspaceFileResponse{Path: clean, Deleted: true}, nil +} + +func (a *nodeRPCAdapter) Register(_ context.Context, req rpcpkg.RegisterNodeRequest) (*rpcpkg.RegisterNodeResponse, *rpcpkg.Error) { + if a == nil || a.server == nil || a.server.mgr == nil { + return nil, rpcError("unavailable", "nodes manager unavailable", nil, false) + } + if strings.TrimSpace(req.Node.ID) == "" { + return nil, rpcError("invalid_argument", "id required", nil, false) + } + a.server.mgr.Upsert(req.Node) + return &rpcpkg.RegisterNodeResponse{ID: req.Node.ID}, nil +} + +func (a *nodeRPCAdapter) Heartbeat(_ context.Context, req rpcpkg.HeartbeatNodeRequest) (*rpcpkg.HeartbeatNodeResponse, *rpcpkg.Error) { + if a == nil || a.server == nil || a.server.mgr == nil { + return nil, rpcError("unavailable", "nodes manager unavailable", nil, false) + } + id := strings.TrimSpace(req.ID) + if id == "" { + return nil, rpcError("invalid_argument", "id required", nil, false) + } + n, ok := a.server.mgr.Get(id) + if !ok { + return nil, rpcError("not_found", "node not found", nil, false) + } + a.server.mgr.Upsert(n) + return &rpcpkg.HeartbeatNodeResponse{ID: id}, nil +} + +func (a *nodeRPCAdapter) Dispatch(ctx context.Context, req rpcpkg.DispatchNodeRequest) (*rpcpkg.DispatchNodeResponse, *rpcpkg.Error) { + if a == nil || a.server == nil || a.server.onNodeDispatch == nil { + return nil, rpcError("unavailable", "node dispatch handler not configured", nil, false) + } + nodeID := strings.TrimSpace(req.Node) + action := strings.TrimSpace(req.Action) + if nodeID == "" || action == "" { + return nil, rpcError("invalid_argument", "node and action are required", nil, false) + } + resp, err := a.server.onNodeDispatch(ctx, resultNodeRequest{ + Node: nodeID, + Action: action, + Task: req.Task, + Model: req.Model, + Args: req.Args, + }.unwrap(), strings.TrimSpace(req.Mode)) + if err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.DispatchNodeResponse{Result: resp}, nil +} + +func (a *nodeRPCAdapter) ListArtifacts(_ context.Context, req rpcpkg.ListNodeArtifactsRequest) (*rpcpkg.ListNodeArtifactsResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + limit := req.Limit + if limit <= 0 { + limit = 200 + } + if limit > 1000 { + limit = 1000 + } + return &rpcpkg.ListNodeArtifactsResponse{ + Items: a.server.webUINodeArtifactsPayloadFiltered(strings.TrimSpace(req.Node), strings.TrimSpace(req.Action), strings.TrimSpace(req.Kind), limit), + ArtifactRetention: a.server.applyNodeArtifactRetention(), + }, nil +} + +func (a *nodeRPCAdapter) GetArtifact(_ context.Context, req rpcpkg.GetNodeArtifactRequest) (*rpcpkg.GetNodeArtifactResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + id := strings.TrimSpace(req.ID) + if id == "" { + return nil, rpcError("invalid_argument", "id is required", nil, false) + } + item, ok := a.server.findNodeArtifactByID(id) + if !ok { + return &rpcpkg.GetNodeArtifactResponse{Found: false}, nil + } + return &rpcpkg.GetNodeArtifactResponse{Found: true, Artifact: item}, nil +} + +func (a *nodeRPCAdapter) DeleteArtifact(_ context.Context, req rpcpkg.DeleteNodeArtifactRequest) (*rpcpkg.DeleteNodeArtifactResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + id := strings.TrimSpace(req.ID) + if id == "" { + return nil, rpcError("invalid_argument", "id is required", nil, false) + } + deletedFile, deletedAudit, err := a.server.deleteNodeArtifact(id) + if err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.DeleteNodeArtifactResponse{ArtifactDeleteResult: rpcpkg.ArtifactDeleteResult{ + ID: id, DeletedFile: deletedFile, DeletedAudit: deletedAudit, + }}, nil +} + +func (a *nodeRPCAdapter) PruneArtifacts(_ context.Context, req rpcpkg.PruneNodeArtifactsRequest) (*rpcpkg.PruneNodeArtifactsResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + limit := req.Limit + if limit <= 0 || limit > 5000 { + limit = 5000 + } + keepLatest := req.KeepLatest + if keepLatest < 0 { + keepLatest = 0 + } + items := a.server.webUINodeArtifactsPayloadFiltered(strings.TrimSpace(req.Node), strings.TrimSpace(req.Action), strings.TrimSpace(req.Kind), limit) + pruned := 0 + deletedFiles := 0 + for index, item := range items { + if index < keepLatest { + continue + } + deletedFile, deletedAudit, err := a.server.deleteNodeArtifact(strings.TrimSpace(fmt.Sprint(item["id"]))) + if err != nil || !deletedAudit { + continue + } + pruned++ + if deletedFile { + deletedFiles++ + } + } + return &rpcpkg.PruneNodeArtifactsResponse{ArtifactPruneResult: rpcpkg.ArtifactPruneResult{ + Pruned: pruned, DeletedFiles: deletedFiles, Kept: keepLatest, + }}, nil +} + +func (s *Server) subagentRPCService() rpcpkg.SubagentService { + return &subagentRPCAdapter{server: s} +} + +func (s *Server) nodeRPCService() rpcpkg.NodeService { + return &nodeRPCAdapter{server: s} +} + +type providerRPCAdapter struct { + server *Server +} + +func (a *providerRPCAdapter) ListModels(_ context.Context, req rpcpkg.ListProviderModelsRequest) (*rpcpkg.ListProviderModelsResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + cfg, err := a.server.loadConfig() + if err != nil { + return nil, rpcErrorFrom(err) + } + providerName := strings.TrimSpace(req.Provider) + if providerName == "" { + return nil, rpcError("invalid_argument", "provider is required", nil, false) + } + models := providers.GetProviderModels(cfg, providerName) + provider, err := providers.CreateProviderByName(cfg, providerName) + if err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.ListProviderModelsResponse{ + Provider: providerName, + Models: models, + Default: strings.TrimSpace(provider.GetDefaultModel()), + }, nil +} + +func (a *providerRPCAdapter) UpdateModels(_ context.Context, req rpcpkg.UpdateProviderModelsRequest) (*rpcpkg.UpdateProviderModelsResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + cfg, pc, err := a.server.loadProviderConfig(strings.TrimSpace(req.Provider)) + if err != nil { + return nil, rpcErrorFrom(err) + } + models := make([]string, 0, len(req.Models)+1) + for _, model := range req.Models { + models = appendUniqueStrings(models, model) + } + models = appendUniqueStrings(models, req.Model) + if len(models) == 0 { + return nil, rpcError("invalid_argument", "model required", nil, false) + } + pc.Models = models + if err := a.server.saveProviderConfig(cfg, req.Provider, pc); err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.UpdateProviderModelsResponse{Provider: strings.TrimSpace(req.Provider), Models: pc.Models}, nil +} + +func (a *providerRPCAdapter) Chat(ctx context.Context, req rpcpkg.ProviderChatRequest) (*rpcpkg.ProviderChatResponse, *rpcpkg.Error) { + provider, model, messages, toolDefs, err := a.resolveProviderRequest(req.Provider, req.Model, req.Messages, req.Tools) + if err != nil { + return nil, rpcErrorFrom(err) + } + resp, err := provider.Chat(ctx, messages, toolDefs, model, req.Options) + if err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.ProviderChatResponse{ + Content: strings.TrimSpace(resp.Content), + ToolCalls: marshalToolCalls(resp.ToolCalls), + FinishReason: strings.TrimSpace(resp.FinishReason), + Usage: marshalUsage(resp.Usage), + }, nil +} + +func (a *providerRPCAdapter) CountTokens(ctx context.Context, req rpcpkg.ProviderCountTokensRequest) (*rpcpkg.ProviderCountTokensResponse, *rpcpkg.Error) { + provider, model, messages, toolDefs, err := a.resolveProviderRequest(req.Provider, req.Model, req.Messages, req.Tools) + if err != nil { + return nil, rpcErrorFrom(err) + } + counter, ok := provider.(providers.TokenCounter) + if !ok { + return nil, rpcError("unavailable", "provider does not support count_tokens", nil, false) + } + usage, err := counter.CountTokens(ctx, messages, toolDefs, model, req.Options) + if err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.ProviderCountTokensResponse{Usage: marshalUsage(usage)}, nil +} + +func (a *providerRPCAdapter) RuntimeView(_ context.Context, req rpcpkg.ProviderRuntimeViewRequest) (*rpcpkg.ProviderRuntimeViewResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + cfg, err := a.server.loadConfig() + if err != nil { + return nil, rpcErrorFrom(err) + } + query := providers.ProviderRuntimeQuery{ + Provider: strings.TrimSpace(req.Provider), + EventKind: strings.TrimSpace(req.Kind), + Reason: strings.TrimSpace(req.Reason), + Target: strings.TrimSpace(req.Target), + Sort: strings.TrimSpace(req.Sort), + ChangesOnly: req.ChangesOnly, + Limit: req.Limit, + Cursor: req.Cursor, + HealthBelow: req.HealthBelow, + } + if req.WindowSec > 0 { + query.Window = time.Duration(req.WindowSec) * time.Second + } + if req.CooldownUntilBeforeSec > 0 { + query.CooldownBefore = time.Now().Add(time.Duration(req.CooldownUntilBeforeSec) * time.Second) + } + return &rpcpkg.ProviderRuntimeViewResponse{View: providers.GetProviderRuntimeView(cfg, query)}, nil +} + +func (a *providerRPCAdapter) RuntimeAction(_ context.Context, req rpcpkg.ProviderRuntimeActionRequest) (*rpcpkg.ProviderRuntimeActionResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + cfg, providerName, err := a.server.loadRuntimeProviderName(strings.TrimSpace(req.Provider)) + if err != nil { + return nil, rpcErrorFrom(err) + } + action := strings.ToLower(strings.TrimSpace(req.Action)) + result := map[string]interface{}{"provider": providerName} + handler := providerRuntimeActionHandlers[action] + if handler == nil { + return nil, rpcError("invalid_argument", "unsupported action", map[string]interface{}{"action": action}, false) + } + if err := handler(cfg, providerName, req, result); err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.ProviderRuntimeActionResponse{Result: result}, nil +} + +type providerRuntimeActionHandler func(*cfgpkg.Config, string, rpcpkg.ProviderRuntimeActionRequest, map[string]interface{}) error + +var providerRuntimeActionHandlers = map[string]providerRuntimeActionHandler{ + "clear_api_cooldown": func(_ *cfgpkg.Config, providerName string, _ rpcpkg.ProviderRuntimeActionRequest, result map[string]interface{}) error { + providers.ClearProviderAPICooldown(providerName) + result["cleared"] = true + return nil + }, + "clear_history": func(_ *cfgpkg.Config, providerName string, _ rpcpkg.ProviderRuntimeActionRequest, result map[string]interface{}) error { + providers.ClearProviderRuntimeHistory(providerName) + result["cleared"] = true + return nil + }, + "refresh_now": func(cfg *cfgpkg.Config, providerName string, req rpcpkg.ProviderRuntimeActionRequest, result map[string]interface{}) error { + refreshResult, err := providers.RefreshProviderRuntimeNow(cfg, providerName, req.OnlyExpiring) + if err != nil { + return err + } + order, _ := providers.RerankProviderRuntime(cfg, providerName) + summary := providers.GetProviderRuntimeSummary(cfg, providers.ProviderRuntimeQuery{Provider: providerName, HealthBelow: 50}) + result["refreshed"] = true + result["result"] = refreshResult + result["candidate_order"] = order + result["summary"] = summary + return nil + }, + "rerank": func(cfg *cfgpkg.Config, providerName string, _ rpcpkg.ProviderRuntimeActionRequest, result map[string]interface{}) error { + order, err := providers.RerankProviderRuntime(cfg, providerName) + if err != nil { + return err + } + result["reranked"] = true + result["candidate_order"] = order + return nil + }, +} + +func (a *providerRPCAdapter) resolveProviderRequest(providerName, model string, rawMessages []map[string]interface{}, rawTools []map[string]interface{}) (providers.LLMProvider, string, []providers.Message, []providers.ToolDefinition, error) { + if a == nil || a.server == nil { + return nil, "", nil, nil, fmt.Errorf("server unavailable") + } + cfg, err := a.server.loadConfig() + if err != nil { + return nil, "", nil, nil, err + } + providerName = strings.TrimSpace(providerName) + if providerName == "" { + return nil, "", nil, nil, fmt.Errorf("provider is required") + } + provider, err := providers.CreateProviderByName(cfg, providerName) + if err != nil { + return nil, "", nil, nil, err + } + if strings.TrimSpace(model) == "" { + model = provider.GetDefaultModel() + } + messages, err := decodeProviderMessages(rawMessages) + if err != nil { + return nil, "", nil, nil, err + } + tools, err := decodeProviderTools(rawTools) + if err != nil { + return nil, "", nil, nil, err + } + return provider, strings.TrimSpace(model), messages, tools, nil +} + +func (s *Server) providerRPCService() rpcpkg.ProviderService { + return &providerRPCAdapter{server: s} +} + +func (s *Server) workspaceRPCService() rpcpkg.WorkspaceService { + return &workspaceRPCAdapter{server: s} +} + +func (s *Server) configRPCService() rpcpkg.ConfigService { + return &configRPCAdapter{server: s} +} + +func (s *Server) cronRPCService() rpcpkg.CronService { + return &cronRPCAdapter{server: s} +} + +type configRPCAdapter struct { + server *Server +} + +type cronRPCAdapter struct { + server *Server +} + +func (a *configRPCAdapter) View(_ context.Context, req rpcpkg.ConfigViewRequest) (*rpcpkg.ConfigViewResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + if strings.TrimSpace(a.server.configPath) == "" { + return nil, rpcError("unavailable", "config path not set", nil, false) + } + if strings.EqualFold(strings.TrimSpace(req.Mode), "normalized") { + cfg, err := cfgpkg.LoadConfig(a.server.configPath) + if err != nil { + return nil, rpcErrorFrom(err) + } + resp := &rpcpkg.ConfigViewResponse{ + Config: cfg.NormalizedView(), + RawConfig: cfg, + } + if req.IncludeHotReloadInfo { + info := hotReloadFieldInfo() + paths := make([]string, 0, len(info)) + for _, it := range info { + if p := stringFromMap(it, "path"); p != "" { + paths = append(paths, p) + } + } + resp.HotReloadFields = paths + resp.HotReloadFieldDetails = info + } + return resp, nil + } + b, err := os.ReadFile(a.server.configPath) + if err != nil { + return nil, rpcErrorFrom(err) + } + cfgDefault := cfgpkg.DefaultConfig() + defBytes, _ := json.Marshal(cfgDefault) + var merged map[string]interface{} + _ = json.Unmarshal(defBytes, &merged) + var loaded map[string]interface{} + if err := json.Unmarshal(b, &loaded); err != nil { + return nil, rpcErrorFrom(err) + } + merged = mergeJSONMap(merged, loaded) + resp := &rpcpkg.ConfigViewResponse{Config: merged, PrettyText: string(mustPrettyJSON(merged))} + if req.IncludeHotReloadInfo { + info := hotReloadFieldInfo() + paths := make([]string, 0, len(info)) + for _, it := range info { + if p := stringFromMap(it, "path"); p != "" { + paths = append(paths, p) + } + } + resp.HotReloadFields = paths + resp.HotReloadFieldDetails = info + } + return resp, nil +} + +func (a *configRPCAdapter) Save(_ context.Context, req rpcpkg.ConfigSaveRequest) (*rpcpkg.ConfigSaveResponse, *rpcpkg.Error) { + if a == nil || a.server == nil { + return nil, rpcError("unavailable", "server unavailable", nil, false) + } + if strings.TrimSpace(a.server.configPath) == "" { + return nil, rpcError("unavailable", "config path not set", nil, false) + } + body := req.Config + oldCfgRaw, _ := os.ReadFile(a.server.configPath) + var oldMap map[string]interface{} + _ = json.Unmarshal(oldCfgRaw, &oldMap) + riskyOldMap := oldMap + riskyNewMap := body + if strings.EqualFold(strings.TrimSpace(req.Mode), "normalized") { + if loaded, err := cfgpkg.LoadConfig(a.server.configPath); err == nil && loaded != nil { + if raw, err := json.Marshal(loaded.NormalizedView()); err == nil { + _ = json.Unmarshal(raw, &riskyOldMap) + } + } + } + riskyPaths := collectRiskyConfigPaths(riskyOldMap, riskyNewMap) + changedRisky := make([]string, 0) + for _, p := range riskyPaths { + if fmt.Sprintf("%v", getPathValue(riskyOldMap, p)) != fmt.Sprintf("%v", getPathValue(riskyNewMap, p)) { + changedRisky = append(changedRisky, p) + } + } + if len(changedRisky) > 0 && !req.ConfirmRisky { + return &rpcpkg.ConfigSaveResponse{ + Saved: false, + RequiresConfirm: true, + ChangedFields: changedRisky, + }, rpcError("invalid_argument", "risky fields changed; confirmation required", map[string]interface{}{"changed_fields": changedRisky}, false) + } + + cfg := cfgpkg.DefaultConfig() + if strings.EqualFold(strings.TrimSpace(req.Mode), "normalized") { + loaded, err := cfgpkg.LoadConfig(a.server.configPath) + if err != nil { + return nil, rpcErrorFrom(err) + } + cfg = loaded + candidate, err := json.Marshal(body) + if err != nil { + return nil, rpcErrorFrom(err) + } + var normalized cfgpkg.NormalizedConfig + dec := json.NewDecoder(bytes.NewReader(candidate)) + dec.DisallowUnknownFields() + if err := dec.Decode(&normalized); err != nil { + return nil, rpcError("invalid_argument", "normalized config validation failed: "+err.Error(), nil, false) + } + cfg.ApplyNormalizedView(normalized) + } else { + candidate, err := json.Marshal(body) + if err != nil { + return nil, rpcErrorFrom(err) + } + dec := json.NewDecoder(bytes.NewReader(candidate)) + dec.DisallowUnknownFields() + if err := dec.Decode(cfg); err != nil { + return nil, rpcError("invalid_argument", "config schema validation failed: "+err.Error(), nil, false) + } + } + if errs := cfgpkg.Validate(cfg); len(errs) > 0 { + list := make([]string, 0, len(errs)) + for _, e := range errs { + list = append(list, e.Error()) + } + return nil, rpcError("invalid_argument", "config validation failed", list, false) + } + if err := cfgpkg.SaveConfig(a.server.configPath, cfg); err != nil { + return nil, rpcErrorFrom(err) + } + if a.server.onConfigAfter != nil { + if err := a.server.onConfigAfter(); err != nil { + return nil, rpcErrorFrom(err) + } + } else { + if err := requestSelfReloadSignal(); err != nil { + return nil, rpcErrorFrom(err) + } + } + return &rpcpkg.ConfigSaveResponse{Saved: true}, nil +} + +func (a *cronRPCAdapter) List(ctx context.Context, _ rpcpkg.ListCronJobsRequest) (*rpcpkg.ListCronJobsResponse, *rpcpkg.Error) { + if a == nil || a.server == nil || a.server.onCron == nil { + return nil, rpcError("unavailable", "cron handler not configured", nil, false) + } + res, err := a.server.onCron("list", map[string]interface{}{}) + if err != nil { + return nil, rpcErrorFrom(err) + } + jobs := normalizeCronJobs(res) + out := make([]interface{}, 0, len(jobs)) + for _, job := range jobs { + out = append(out, job) + } + return &rpcpkg.ListCronJobsResponse{Jobs: out}, nil +} + +func (a *cronRPCAdapter) Get(ctx context.Context, req rpcpkg.GetCronJobRequest) (*rpcpkg.GetCronJobResponse, *rpcpkg.Error) { + if a == nil || a.server == nil || a.server.onCron == nil { + return nil, rpcError("unavailable", "cron handler not configured", nil, false) + } + res, err := a.server.onCron("get", map[string]interface{}{"id": strings.TrimSpace(req.ID)}) + if err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.GetCronJobResponse{Job: normalizeCronJob(res)}, nil +} + +func (a *cronRPCAdapter) Mutate(ctx context.Context, req rpcpkg.MutateCronJobRequest) (*rpcpkg.MutateCronJobResponse, *rpcpkg.Error) { + if a == nil || a.server == nil || a.server.onCron == nil { + return nil, rpcError("unavailable", "cron handler not configured", nil, false) + } + args := req.Args + if args == nil { + args = map[string]interface{}{} + } + action := strings.ToLower(strings.TrimSpace(req.Action)) + if action == "" { + action = "create" + } + res, err := a.server.onCron(action, args) + if err != nil { + return nil, rpcErrorFrom(err) + } + return &rpcpkg.MutateCronJobResponse{Result: normalizeCronJob(res)}, nil +} + +func rpcError(code, message string, details interface{}, retryable bool) *rpcpkg.Error { + return &rpcpkg.Error{ + Code: strings.TrimSpace(code), + Message: strings.TrimSpace(message), + Details: details, + Retryable: retryable, + } +} + +func rpcErrorFrom(err error) *rpcpkg.Error { + if err == nil { + return nil + } + message := strings.TrimSpace(err.Error()) + code := "internal" + switch { + case errors.Is(err, context.DeadlineExceeded): + code = "timeout" + case strings.Contains(strings.ToLower(message), "not found"): + code = "not_found" + case strings.Contains(strings.ToLower(message), "required"): + code = "invalid_argument" + case strings.Contains(strings.ToLower(message), "not configured"), strings.Contains(strings.ToLower(message), "unavailable"): + code = "unavailable" + } + return rpcError(code, message, nil, false) +} + +func rpcHTTPStatus(err *rpcpkg.Error) int { + if err == nil { + return http.StatusOK + } + switch strings.TrimSpace(err.Code) { + case "invalid_argument": + return http.StatusBadRequest + case "permission_denied": + return http.StatusForbidden + case "not_found": + return http.StatusNotFound + case "timeout": + return http.StatusGatewayTimeout + case "unavailable": + return http.StatusServiceUnavailable + default: + return http.StatusInternalServerError + } +} + +func decodeResultObject(result interface{}, target interface{}) error { + data, err := json.Marshal(result) + if err != nil { + return err + } + return json.Unmarshal(data, target) +} + +func decodeProviderMessages(raw []map[string]interface{}) ([]providers.Message, error) { + if len(raw) == 0 { + return []providers.Message{}, nil + } + data, err := json.Marshal(raw) + if err != nil { + return nil, err + } + var out []providers.Message + if err := json.Unmarshal(data, &out); err != nil { + return nil, err + } + return out, nil +} + +func decodeProviderTools(raw []map[string]interface{}) ([]providers.ToolDefinition, error) { + if len(raw) == 0 { + return []providers.ToolDefinition{}, nil + } + data, err := json.Marshal(raw) + if err != nil { + return nil, err + } + var out []providers.ToolDefinition + if err := json.Unmarshal(data, &out); err != nil { + return nil, err + } + return out, nil +} + +func marshalToolCalls(in []providers.ToolCall) []map[string]interface{} { + if len(in) == 0 { + return []map[string]interface{}{} + } + data, err := json.Marshal(in) + if err != nil { + return []map[string]interface{}{} + } + var out []map[string]interface{} + if err := json.Unmarshal(data, &out); err != nil { + return []map[string]interface{}{} + } + return out +} + +func marshalUsage(in *providers.UsageInfo) map[string]interface{} { + if in == nil { + return nil + } + return map[string]interface{}{ + "prompt_tokens": in.PromptTokens, + "completion_tokens": in.CompletionTokens, + "total_tokens": in.TotalTokens, + } +} + +func decodeResultSliceField[T any](result interface{}, field string) ([]*T, error) { + if strings.TrimSpace(field) == "" { + return nil, fmt.Errorf("field is required") + } + var payload map[string]json.RawMessage + if err := decodeResultObject(result, &payload); err != nil { + return nil, err + } + raw := payload[field] + if len(raw) == 0 { + return []*T{}, nil + } + var items []*T + if err := json.Unmarshal(raw, &items); err != nil { + return nil, err + } + return items, nil +} + +type resultWrapperSubagentTask struct { + ID string `json:"id"` + Task string `json:"task"` + Label string `json:"label"` + Role string `json:"role"` + AgentID string `json:"agent_id"` + Transport string `json:"transport,omitempty"` + NodeID string `json:"node_id,omitempty"` + ParentAgentID string `json:"parent_agent_id,omitempty"` + NotifyMainPolicy string `json:"notify_main_policy,omitempty"` + SessionKey string `json:"session_key"` + MemoryNS string `json:"memory_ns"` + SystemPromptFile string `json:"system_prompt_file,omitempty"` + ToolAllowlist []string `json:"tool_allowlist,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoff int `json:"retry_backoff,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + ParentRunID string `json:"parent_run_id,omitempty"` + LastMessageID string `json:"last_message_id,omitempty"` + WaitingReply bool `json:"waiting_for_reply,omitempty"` + SharedState map[string]interface{} `json:"shared_state,omitempty"` + OriginChannel string `json:"origin_channel,omitempty"` + OriginChatID string `json:"origin_chat_id,omitempty"` + Status string `json:"status"` + Result string `json:"result,omitempty"` + Steering []string `json:"steering,omitempty"` + Created int64 `json:"created"` + Updated int64 `json:"updated"` +} + +func unwrapSubagentTask(in *resultWrapperSubagentTask) *tools.SubagentTask { + if in == nil { + return nil + } + return &tools.SubagentTask{ + ID: in.ID, + Task: in.Task, + Label: in.Label, + Role: in.Role, + AgentID: in.AgentID, + Transport: in.Transport, + NodeID: in.NodeID, + ParentAgentID: in.ParentAgentID, + NotifyMainPolicy: in.NotifyMainPolicy, + SessionKey: in.SessionKey, + MemoryNS: in.MemoryNS, + SystemPromptFile: in.SystemPromptFile, + ToolAllowlist: append([]string(nil), in.ToolAllowlist...), + MaxRetries: in.MaxRetries, + RetryBackoff: in.RetryBackoff, + TimeoutSec: in.TimeoutSec, + MaxTaskChars: in.MaxTaskChars, + MaxResultChars: in.MaxResultChars, + RetryCount: in.RetryCount, + ThreadID: in.ThreadID, + CorrelationID: in.CorrelationID, + ParentRunID: in.ParentRunID, + LastMessageID: in.LastMessageID, + WaitingReply: in.WaitingReply, + SharedState: in.SharedState, + OriginChannel: in.OriginChannel, + OriginChatID: in.OriginChatID, + Status: in.Status, + Result: in.Result, + Steering: append([]string(nil), in.Steering...), + Created: in.Created, + Updated: in.Updated, + } +} + +func unwrapSubagentTasks(in []*resultWrapperSubagentTask) []*tools.SubagentTask { + out := make([]*tools.SubagentTask, 0, len(in)) + for _, item := range in { + if task := unwrapSubagentTask(item); task != nil { + out = append(out, task) + } + } + return out +} + +type resultWrapperRuntimeError struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Stage string `json:"stage,omitempty"` + Retryable bool `json:"retryable,omitempty"` + Source string `json:"source,omitempty"` +} + +type resultWrapperRunRecord struct { + ID string `json:"id"` + TaskID string `json:"task_id,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + ParentRunID string `json:"parent_run_id,omitempty"` + Kind string `json:"kind,omitempty"` + Status string `json:"status"` + Input string `json:"input,omitempty"` + Output string `json:"output,omitempty"` + Error *resultWrapperRuntimeError `json:"error,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type resultWrapperRouterReply struct { + TaskID string `json:"task_id"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + Status string `json:"status"` + Result string `json:"result,omitempty"` + Run resultWrapperRunRecord `json:"run"` + Error *resultWrapperRuntimeError `json:"error,omitempty"` +} + +func unwrapRuntimeError(in *resultWrapperRuntimeError) *tools.RuntimeError { + if in == nil { + return nil + } + return &tools.RuntimeError{ + Code: in.Code, + Message: in.Message, + Stage: in.Stage, + Retryable: in.Retryable, + Source: in.Source, + } +} + +func unwrapRunRecord(in resultWrapperRunRecord) tools.RunRecord { + return tools.RunRecord{ + ID: in.ID, + TaskID: in.TaskID, + ThreadID: in.ThreadID, + CorrelationID: in.CorrelationID, + AgentID: in.AgentID, + ParentRunID: in.ParentRunID, + Kind: in.Kind, + Status: in.Status, + Input: in.Input, + Output: in.Output, + Error: unwrapRuntimeError(in.Error), + CreatedAt: in.CreatedAt, + UpdatedAt: in.UpdatedAt, + } +} + +func unwrapRouterReply(in *resultWrapperRouterReply) *tools.RouterReply { + if in == nil { + return nil + } + return &tools.RouterReply{ + TaskID: in.TaskID, + ThreadID: in.ThreadID, + CorrelationID: in.CorrelationID, + AgentID: in.AgentID, + Status: in.Status, + Result: in.Result, + Run: unwrapRunRecord(in.Run), + Error: unwrapRuntimeError(in.Error), + } +} + +type resultNodeRequest struct { + Node string + Action string + Task string + Model string + Args map[string]interface{} +} + +func (r resultNodeRequest) unwrap() nodes.Request { + return nodes.Request{ + Node: r.Node, + Action: r.Action, + Task: r.Task, + Model: r.Model, + Args: r.Args, + } +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 151a427..23a7336 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -1,16 +1,10 @@ package api import ( - "archive/tar" - "archive/zip" "bufio" "bytes" - "compress/gzip" "context" - "crypto/sha1" - "encoding/base64" "encoding/json" - "errors" "fmt" "io" "net" @@ -19,7 +13,6 @@ import ( "os" "os/exec" "path/filepath" - "regexp" "runtime" "runtime/debug" "sort" @@ -32,50 +25,63 @@ import ( cfgpkg "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/nodes" "github.com/YspCoder/clawgo/pkg/providers" + rpcpkg "github.com/YspCoder/clawgo/pkg/rpc" "github.com/YspCoder/clawgo/pkg/tools" "github.com/gorilla/websocket" "rsc.io/qr" ) type Server struct { - addr string - token string - mgr *nodes.Manager - server *http.Server - nodeConnMu sync.Mutex - nodeConnIDs map[string]string - nodeSockets map[string]*nodeSocketConn - nodeWebRTC *nodes.WebRTCTransport - nodeP2PStatus func() map[string]interface{} - artifactStatsMu sync.Mutex - artifactStats map[string]interface{} - gatewayVersion string - webuiVersion string - configPath string - workspacePath string - logFilePath string - onChat func(ctx context.Context, sessionKey, content string) (string, error) - onChatHistory func(sessionKey string) []map[string]interface{} - onConfigAfter func() error - onCron func(action string, args map[string]interface{}) (interface{}, error) - onSubagents func(ctx context.Context, action string, args map[string]interface{}) (interface{}, error) - onNodeDispatch func(ctx context.Context, req nodes.Request, mode string) (nodes.Response, error) - onToolsCatalog func() interface{} - webUIDir string - ekgCacheMu sync.Mutex - ekgCachePath string - ekgCacheStamp time.Time - ekgCacheSize int64 - ekgCacheRows []map[string]interface{} - liveRuntimeMu sync.Mutex - liveRuntimeSubs map[chan []byte]struct{} - liveRuntimeOn bool - whatsAppBridge *channels.WhatsAppBridgeService - whatsAppBase string - oauthFlowMu sync.Mutex - oauthFlows map[string]*providers.OAuthPendingFlow - extraRoutesMu sync.RWMutex - extraRoutes map[string]http.Handler + addr string + token string + mgr *nodes.Manager + server *http.Server + nodeConnMu sync.Mutex + nodeConnIDs map[string]string + nodeSockets map[string]*nodeSocketConn + nodeWebRTC *nodes.WebRTCTransport + nodeP2PStatus func() map[string]interface{} + artifactStatsMu sync.Mutex + artifactStats map[string]interface{} + gatewayVersion string + webuiVersion string + configPath string + workspacePath string + logFilePath string + onChat func(ctx context.Context, sessionKey, content string) (string, error) + onChatHistory func(sessionKey string) []map[string]interface{} + onConfigAfter func() error + onCron func(action string, args map[string]interface{}) (interface{}, error) + onSubagents func(ctx context.Context, action string, args map[string]interface{}) (interface{}, error) + onNodeDispatch func(ctx context.Context, req nodes.Request, mode string) (nodes.Response, error) + onToolsCatalog func() interface{} + webUIDir string + ekgCacheMu sync.Mutex + ekgCachePath string + ekgCacheStamp time.Time + ekgCacheSize int64 + ekgCacheRows []map[string]interface{} + liveRuntimeMu sync.Mutex + liveRuntimeSubs map[chan []byte]struct{} + liveRuntimeOn bool + whatsAppBridge *channels.WhatsAppBridgeService + whatsAppBase string + oauthFlowMu sync.Mutex + oauthFlows map[string]*providers.OAuthPendingFlow + extraRoutesMu sync.RWMutex + extraRoutes map[string]http.Handler + subagentRPCOnce sync.Once + subagentRPCReg *rpcpkg.Registry + nodeRPCOnce sync.Once + nodeRPCReg *rpcpkg.Registry + providerRPCOnce sync.Once + providerRPCReg *rpcpkg.Registry + workspaceRPCOnce sync.Once + workspaceRPCReg *rpcpkg.Registry + configRPCOnce sync.Once + configRPCReg *rpcpkg.Registry + cronRPCOnce sync.Once + cronRPCReg *rpcpkg.Registry } var nodesWebsocketUpgrader = websocket.Upgrader{ @@ -109,14 +115,18 @@ type nodeSocketConn struct { mu sync.Mutex } -func (c *nodeSocketConn) Send(msg nodes.WireMessage) error { +func (c *nodeSocketConn) writeJSON(payload interface{}) error { if c == nil || c.conn == nil { return fmt.Errorf("node websocket unavailable") } c.mu.Lock() defer c.mu.Unlock() _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - return c.conn.WriteJSON(msg) + return c.conn.WriteJSON(payload) +} + +func (c *nodeSocketConn) Send(msg nodes.WireMessage) error { + return c.writeJSON(msg) } func publishLiveSnapshot(subs map[chan []byte]struct{}, payload []byte) { @@ -381,10 +391,7 @@ func (s *Server) sendNodeSocketMessage(nodeID string, msg nodes.WireMessage) err if sock == nil || sock.conn == nil { return fmt.Errorf("node %s not connected", strings.TrimSpace(nodeID)) } - sock.mu.Lock() - defer sock.mu.Unlock() - _ = sock.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - return sock.conn.WriteJSON(msg) + return sock.writeJSON(msg) } func (s *Server) Start(ctx context.Context) error { @@ -429,6 +436,12 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/sessions", s.handleWebUISessions) mux.HandleFunc("/api/memory", s.handleWebUIMemory) mux.HandleFunc("/api/workspace_file", s.handleWebUIWorkspaceFile) + mux.HandleFunc("/api/rpc/subagent", s.handleSubagentRPC) + mux.HandleFunc("/api/rpc/node", s.handleNodeRPC) + mux.HandleFunc("/api/rpc/provider", s.handleProviderRPC) + mux.HandleFunc("/api/rpc/workspace", s.handleWorkspaceRPC) + mux.HandleFunc("/api/rpc/config", s.handleConfigRPC) + mux.HandleFunc("/api/rpc/cron", s.handleCronRPC) mux.HandleFunc("/api/subagents_runtime", s.handleWebUISubagentsRuntime) mux.HandleFunc("/api/tool_allowlist_groups", s.handleWebUIToolAllowlistGroups) mux.HandleFunc("/api/tools", s.handleWebUITools) @@ -500,12 +513,12 @@ func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid json", http.StatusBadRequest) return } - if strings.TrimSpace(n.ID) == "" { - http.Error(w, "id required", http.StatusBadRequest) + result, rpcErr := s.nodeRPCService().Register(r.Context(), rpcpkg.RegisterNodeRequest{Node: n}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) return } - s.mgr.Upsert(n) - writeJSON(w, map[string]interface{}{"ok": true, "id": n.ID}) + writeJSON(w, map[string]interface{}{"ok": true, "id": result.ID}) } func (s *Server) handleHeartbeat(w http.ResponseWriter, r *http.Request) { @@ -524,15 +537,12 @@ func (s *Server) handleHeartbeat(w http.ResponseWriter, r *http.Request) { http.Error(w, "id required", http.StatusBadRequest) return } - n, ok := s.mgr.Get(body.ID) - if !ok { - http.Error(w, "node not found", http.StatusNotFound) + result, rpcErr := s.nodeRPCService().Heartbeat(r.Context(), rpcpkg.HeartbeatNodeRequest{ID: body.ID}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) return } - n.LastSeenAt = time.Now().UTC() - n.Online = true - s.mgr.Upsert(n) - writeJSON(w, map[string]interface{}{"ok": true, "id": body.ID}) + writeJSON(w, map[string]interface{}{"ok": true, "id": result.ID}) } func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { @@ -558,6 +568,11 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { }) writeAck := func(ack nodes.WireAck) error { + if strings.TrimSpace(connectedID) != "" { + if sock := s.getNodeSocket(connectedID); sock != nil && sock.connID == connID && sock.conn == conn { + return sock.writeJSON(ack) + } + } _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) return conn.WriteJSON(ack) } @@ -577,87 +592,85 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { if s.mgr != nil && s.mgr.HandleWireMessage(msg) { continue } - switch strings.ToLower(strings.TrimSpace(msg.Type)) { - case "register": - if msg.Node == nil || strings.TrimSpace(msg.Node.ID) == "" { - _ = writeAck(nodes.WireAck{OK: false, Type: "register", Error: "node.id required"}) - continue - } - s.mgr.Upsert(*msg.Node) - connectedID = strings.TrimSpace(msg.Node.ID) - s.rememberNodeConnection(connectedID, connID) - s.bindNodeSocket(connectedID, connID, conn) - if err := writeAck(nodes.WireAck{OK: true, Type: "registered", ID: connectedID}); err != nil { - return - } - case "heartbeat": - id := strings.TrimSpace(msg.ID) - if id == "" { - id = connectedID - } - if id == "" { - _ = writeAck(nodes.WireAck{OK: false, Type: "heartbeat", Error: "id required"}) - continue - } - if msg.Node != nil && strings.TrimSpace(msg.Node.ID) != "" { + type nodeSocketHandler func(nodes.WireMessage) bool + handlers := map[string]nodeSocketHandler{ + "register": func(msg nodes.WireMessage) bool { + if msg.Node == nil || strings.TrimSpace(msg.Node.ID) == "" { + _ = writeAck(nodes.WireAck{OK: false, Type: "register", Error: "node.id required"}) + return true + } s.mgr.Upsert(*msg.Node) connectedID = strings.TrimSpace(msg.Node.ID) s.rememberNodeConnection(connectedID, connID) s.bindNodeSocket(connectedID, connID, conn) - } else if n, ok := s.mgr.Get(id); ok { - s.mgr.Upsert(n) - connectedID = id - s.rememberNodeConnection(connectedID, connID) - s.bindNodeSocket(connectedID, connID, conn) - } else { - _ = writeAck(nodes.WireAck{OK: false, Type: "heartbeat", ID: id, Error: "node not found"}) - continue - } - if err := writeAck(nodes.WireAck{OK: true, Type: "heartbeat", ID: connectedID}); err != nil { - return - } - case "signal_offer", "signal_answer", "signal_candidate": - targetID := strings.TrimSpace(msg.To) - if s.nodeWebRTC != nil && (targetID == "" || strings.EqualFold(targetID, "gateway")) { - if err := s.nodeWebRTC.HandleSignal(msg); err != nil { - if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}); err != nil { - return - } - } else if err := writeAck(nodes.WireAck{OK: true, Type: "signaled", ID: msg.ID}); err != nil { - return - } - continue - } - if strings.TrimSpace(connectedID) == "" { - if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, Error: "node not registered"}); err != nil { - return - } - continue - } - if targetID == "" { - if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "target node required"}); err != nil { - return - } - continue - } - msg.From = connectedID - if err := s.sendNodeSocketMessage(targetID, msg); err != nil { - if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}); err != nil { - return - } - continue - } - if err := writeAck(nodes.WireAck{OK: true, Type: "relayed", ID: msg.ID}); err != nil { - return - } - default: - if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "unsupported message type"}); err != nil { + return writeAck(nodes.WireAck{OK: true, Type: "registered", ID: connectedID}) == nil + }, + "heartbeat": func(msg nodes.WireMessage) bool { + id := strings.TrimSpace(msg.ID) + if id == "" { + id = connectedID + } + if id == "" { + _ = writeAck(nodes.WireAck{OK: false, Type: "heartbeat", Error: "id required"}) + return true + } + if msg.Node != nil && strings.TrimSpace(msg.Node.ID) != "" { + s.mgr.Upsert(*msg.Node) + connectedID = strings.TrimSpace(msg.Node.ID) + s.rememberNodeConnection(connectedID, connID) + s.bindNodeSocket(connectedID, connID, conn) + } else if n, ok := s.mgr.Get(id); ok { + s.mgr.Upsert(n) + connectedID = id + s.rememberNodeConnection(connectedID, connID) + s.bindNodeSocket(connectedID, connID, conn) + } else { + _ = writeAck(nodes.WireAck{OK: false, Type: "heartbeat", ID: id, Error: "node not found"}) + return true + } + return writeAck(nodes.WireAck{OK: true, Type: "heartbeat", ID: connectedID}) == nil + }, + "signal_offer": func(msg nodes.WireMessage) bool { return s.handleNodeSignalMessage(msg, connectedID, writeAck) }, + "signal_answer": func(msg nodes.WireMessage) bool { return s.handleNodeSignalMessage(msg, connectedID, writeAck) }, + "signal_candidate": func(msg nodes.WireMessage) bool { return s.handleNodeSignalMessage(msg, connectedID, writeAck) }, + } + if handler := handlers[strings.ToLower(strings.TrimSpace(msg.Type))]; handler != nil { + if !handler(msg) { return } + continue + } + if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "unsupported message type"}); err != nil { + return } } } +func (s *Server) handleNodeSignalMessage(msg nodes.WireMessage, connectedID string, writeAck func(nodes.WireAck) error) bool { + targetID := strings.TrimSpace(msg.To) + if s.nodeWebRTC != nil && (targetID == "" || strings.EqualFold(targetID, "gateway")) { + if err := s.nodeWebRTC.HandleSignal(msg); err != nil { + _ = writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}) + return true + } + return writeAck(nodes.WireAck{OK: true, Type: "signaled", ID: msg.ID}) == nil + } + if strings.TrimSpace(connectedID) == "" { + _ = writeAck(nodes.WireAck{OK: false, Type: msg.Type, Error: "node not registered"}) + return true + } + if targetID == "" { + _ = writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "target node required"}) + return true + } + msg.From = connectedID + if err := s.sendNodeSocketMessage(targetID, msg); err != nil { + _ = writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}) + return true + } + return writeAck(nodes.WireAck{OK: true, Type: "relayed", ID: msg.ID}) == nil +} + func (s *Server) handleWebUI(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) @@ -733,184 +746,6 @@ func (s *Server) tryServeWebUIDist(w http.ResponseWriter, r *http.Request, reqPa return true } -func (s *Server) handleWebUIConfig(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if strings.TrimSpace(s.configPath) == "" { - http.Error(w, "config path not set", http.StatusInternalServerError) - return - } - switch r.Method { - case http.MethodGet: - if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - payload := map[string]interface{}{ - "ok": true, - "config": cfg.NormalizedView(), - "raw_config": cfg, - } - if r.URL.Query().Get("include_hot_reload_fields") == "1" { - info := hotReloadFieldInfo() - paths := make([]string, 0, len(info)) - for _, it := range info { - if p := stringFromMap(it, "path"); p != "" { - paths = append(paths, p) - } - } - payload["hot_reload_fields"] = paths - payload["hot_reload_field_details"] = info - } - writeJSON(w, payload) - return - } - b, err := os.ReadFile(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - cfgDefault := cfgpkg.DefaultConfig() - defBytes, _ := json.Marshal(cfgDefault) - var merged map[string]interface{} - _ = json.Unmarshal(defBytes, &merged) - var loaded map[string]interface{} - if err := json.Unmarshal(b, &loaded); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - merged = mergeJSONMap(merged, loaded) - - if r.URL.Query().Get("include_hot_reload_fields") == "1" || strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "hot") { - info := hotReloadFieldInfo() - paths := make([]string, 0, len(info)) - for _, it := range info { - if p := stringFromMap(it, "path"); p != "" { - paths = append(paths, p) - } - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "config": merged, - "hot_reload_fields": paths, - "hot_reload_field_details": info, - }) - return - } - w.Header().Set("Content-Type", "application/json") - out, _ := json.MarshalIndent(merged, "", " ") - _, _ = w.Write(out) - case http.MethodPost: - var body map[string]interface{} - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - confirmRisky, _ := tools.MapBoolArg(body, "confirm_risky") - delete(body, "confirm_risky") - - oldCfgRaw, _ := os.ReadFile(s.configPath) - var oldMap map[string]interface{} - _ = json.Unmarshal(oldCfgRaw, &oldMap) - riskyOldMap := oldMap - riskyNewMap := body - if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { - if loaded, err := cfgpkg.LoadConfig(s.configPath); err == nil && loaded != nil { - if raw, err := json.Marshal(loaded.NormalizedView()); err == nil { - _ = json.Unmarshal(raw, &riskyOldMap) - } - } - } - - riskyPaths := collectRiskyConfigPaths(riskyOldMap, riskyNewMap) - changedRisky := make([]string, 0) - for _, p := range riskyPaths { - if fmt.Sprintf("%v", getPathValue(riskyOldMap, p)) != fmt.Sprintf("%v", getPathValue(riskyNewMap, p)) { - changedRisky = append(changedRisky, p) - } - } - if len(changedRisky) > 0 && !confirmRisky { - writeJSONStatus(w, http.StatusBadRequest, map[string]interface{}{ - "ok": false, - "error": "risky fields changed; confirmation required", - "requires_confirm": true, - "changed_fields": changedRisky, - }) - return - } - - cfg := cfgpkg.DefaultConfig() - if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { - loaded, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - cfg = loaded - candidate, err := json.Marshal(body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - var normalized cfgpkg.NormalizedConfig - dec := json.NewDecoder(bytes.NewReader(candidate)) - dec.DisallowUnknownFields() - if err := dec.Decode(&normalized); err != nil { - http.Error(w, "normalized config validation failed: "+err.Error(), http.StatusBadRequest) - return - } - cfg.ApplyNormalizedView(normalized) - } else { - candidate, err := json.Marshal(body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - dec := json.NewDecoder(bytes.NewReader(candidate)) - dec.DisallowUnknownFields() - if err := dec.Decode(cfg); err != nil { - http.Error(w, "config schema validation failed: "+err.Error(), http.StatusBadRequest) - return - } - } - if errs := cfgpkg.Validate(cfg); len(errs) > 0 { - list := make([]string, 0, len(errs)) - for _, e := range errs { - list = append(list, e.Error()) - } - writeJSONStatus(w, http.StatusBadRequest, map[string]interface{}{"ok": false, "error": "config validation failed", "details": list}) - return - } - - if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if s.onConfigAfter != nil { - if err := s.onConfigAfter(); err != nil { - http.Error(w, "config saved but reload failed: "+err.Error(), http.StatusInternalServerError) - return - } - } else { - if err := requestSelfReloadSignal(); err != nil { - http.Error(w, "config saved but reload signal failed: "+err.Error(), http.StatusInternalServerError) - return - } - } - if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { - writeJSON(w, map[string]interface{}{"ok": true, "reloaded": true, "config": cfg.NormalizedView()}) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "reloaded": true}) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - func mergeJSONMap(base, override map[string]interface{}) map[string]interface{} { if base == nil { base = map[string]interface{}{} @@ -1041,550 +876,6 @@ func (s *Server) handleWebUIUpload(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]interface{}{"ok": true, "path": path, "name": h.Filename}) } -func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost && r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - var body struct { - Provider string `json:"provider"` - AccountLabel string `json:"account_label"` - NetworkProxy string `json:"network_proxy"` - ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"` - } - if r.Method == http.MethodPost { - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - } else { - body.Provider = strings.TrimSpace(r.URL.Query().Get("provider")) - body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label")) - body.NetworkProxy = strings.TrimSpace(r.URL.Query().Get("network_proxy")) - } - cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = cfg - timeout := pc.TimeoutSec - if timeout <= 0 { - timeout = 90 - } - loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - flow, err := loginMgr.StartManualFlowWithOptions(providers.OAuthLoginOptions{ - AccountLabel: body.AccountLabel, - NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)), - }) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - flowID := fmt.Sprintf("%d", time.Now().UnixNano()) - s.oauthFlowMu.Lock() - s.oauthFlows[flowID] = flow - s.oauthFlowMu.Unlock() - writeJSON(w, map[string]interface{}{ - "ok": true, - "flow_id": flowID, - "mode": flow.Mode, - "auth_url": flow.AuthURL, - "user_code": flow.UserCode, - "instructions": flow.Instructions, - "account_label": strings.TrimSpace(body.AccountLabel), - "network_proxy": strings.TrimSpace(body.NetworkProxy), - }) -} - -func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - var body struct { - Provider string `json:"provider"` - FlowID string `json:"flow_id"` - CallbackURL string `json:"callback_url"` - AccountLabel string `json:"account_label"` - NetworkProxy string `json:"network_proxy"` - ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - timeout := pc.TimeoutSec - if timeout <= 0 { - timeout = 90 - } - loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - s.oauthFlowMu.Lock() - flow := s.oauthFlows[strings.TrimSpace(body.FlowID)] - delete(s.oauthFlows, strings.TrimSpace(body.FlowID)) - s.oauthFlowMu.Unlock() - if flow == nil { - http.Error(w, "oauth flow not found", http.StatusBadRequest) - return - } - session, models, err := loginMgr.CompleteManualFlowWithOptions(r.Context(), pc.APIBase, flow, body.CallbackURL, providers.OAuthLoginOptions{ - AccountLabel: strings.TrimSpace(body.AccountLabel), - NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)), - }) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if session.CredentialFile != "" { - pc.OAuth.CredentialFile = session.CredentialFile - pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile) - } - if err := s.saveProviderConfig(cfg, body.Provider, pc); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "account": session.Email, - "credential_file": session.CredentialFile, - "network_proxy": session.NetworkProxy, - "models": models, - }) -} - -func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if err := r.ParseMultipartForm(16 << 20); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - providerName := strings.TrimSpace(r.FormValue("provider")) - accountLabel := strings.TrimSpace(r.FormValue("account_label")) - networkProxy := strings.TrimSpace(r.FormValue("network_proxy")) - inlineCfgRaw := strings.TrimSpace(r.FormValue("provider_config")) - var inlineCfg cfgpkg.ProviderConfig - if inlineCfgRaw != "" { - if err := json.Unmarshal([]byte(inlineCfgRaw), &inlineCfg); err != nil { - http.Error(w, "invalid provider_config", http.StatusBadRequest) - return - } - } - cfg, pc, err := s.resolveProviderConfig(providerName, inlineCfg) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - file, header, err := r.FormFile("file") - if err != nil { - http.Error(w, "file required", http.StatusBadRequest) - return - } - defer file.Close() - data, err := io.ReadAll(file) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - timeout := pc.TimeoutSec - if timeout <= 0 { - timeout = 90 - } - loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - session, models, err := loginMgr.ImportAuthJSONWithOptions(r.Context(), pc.APIBase, header.Filename, data, providers.OAuthLoginOptions{ - AccountLabel: accountLabel, - NetworkProxy: firstNonEmptyString(networkProxy, strings.TrimSpace(pc.OAuth.NetworkProxy)), - }) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if session.CredentialFile != "" { - pc.OAuth.CredentialFile = session.CredentialFile - pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile) - } - if err := s.saveProviderConfig(cfg, providerName, pc); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "account": session.Email, - "credential_file": session.CredentialFile, - "network_proxy": session.NetworkProxy, - "models": models, - }) -} - -func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - providerName := strings.TrimSpace(r.URL.Query().Get("provider")) - cfg, pc, err := s.loadProviderConfig(providerName) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = cfg - timeout := pc.TimeoutSec - if timeout <= 0 { - timeout = 90 - } - loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - switch r.Method { - case http.MethodGet: - accounts, err := loginMgr.ListAccounts() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "accounts": accounts}) - case http.MethodPost: - var body struct { - Action string `json:"action"` - CredentialFile string `json:"credential_file"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - switch strings.ToLower(strings.TrimSpace(body.Action)) { - case "refresh": - account, err := loginMgr.RefreshAccount(r.Context(), body.CredentialFile) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "account": account}) - case "delete": - if err := loginMgr.DeleteAccount(body.CredentialFile); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - pc.OAuth.CredentialFiles = removeStringItem(pc.OAuth.CredentialFiles, body.CredentialFile) - if strings.TrimSpace(pc.OAuth.CredentialFile) == strings.TrimSpace(body.CredentialFile) { - pc.OAuth.CredentialFile = "" - if len(pc.OAuth.CredentialFiles) > 0 { - pc.OAuth.CredentialFile = pc.OAuth.CredentialFiles[0] - } - } - if err := s.saveProviderConfig(cfg, providerName, pc); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "deleted": true}) - case "clear_cooldown": - if err := loginMgr.ClearCooldown(body.CredentialFile); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "cleared": true}) - default: - http.Error(w, "unsupported action", http.StatusBadRequest) - } - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - -func (s *Server) handleWebUIProviderModels(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - var body struct { - Provider string `json:"provider"` - Model string `json:"model"` - Models []string `json:"models"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - cfg, pc, err := s.loadProviderConfig(strings.TrimSpace(body.Provider)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - models := make([]string, 0, len(body.Models)+1) - for _, model := range body.Models { - models = appendUniqueStrings(models, model) - } - models = appendUniqueStrings(models, body.Model) - if len(models) == 0 { - http.Error(w, "model required", http.StatusBadRequest) - return - } - pc.Models = models - if err := s.saveProviderConfig(cfg, body.Provider, pc); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "models": pc.Models, - }) -} - -func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method == http.MethodGet { - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - query := providers.ProviderRuntimeQuery{ - Provider: strings.TrimSpace(r.URL.Query().Get("provider")), - EventKind: strings.TrimSpace(r.URL.Query().Get("kind")), - Reason: strings.TrimSpace(r.URL.Query().Get("reason")), - Target: strings.TrimSpace(r.URL.Query().Get("target")), - Sort: strings.TrimSpace(r.URL.Query().Get("sort")), - ChangesOnly: strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("changes_only")), "true"), - } - if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("window_sec"))); secs > 0 { - query.Window = time.Duration(secs) * time.Second - } - if limit, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("limit"))); limit > 0 { - query.Limit = limit - } - if cursor, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cursor"))); cursor >= 0 { - query.Cursor = cursor - } - if healthBelow, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("health_below"))); healthBelow > 0 { - query.HealthBelow = healthBelow - } - if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec"))); secs > 0 { - query.CooldownBefore = time.Now().Add(time.Duration(secs) * time.Second) - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "view": providers.GetProviderRuntimeView(cfg, query), - }) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - var body struct { - Provider string `json:"provider"` - Action string `json:"action"` - OnlyExpiring bool `json:"only_expiring"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - switch strings.ToLower(strings.TrimSpace(body.Action)) { - case "clear_api_cooldown": - cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = cfg - providers.ClearProviderAPICooldown(providerName) - writeJSON(w, map[string]interface{}{"ok": true, "cleared": true}) - case "clear_history": - cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = cfg - providers.ClearProviderRuntimeHistory(providerName) - writeJSON(w, map[string]interface{}{"ok": true, "cleared": true}) - case "refresh_now": - cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - result, err := providers.RefreshProviderRuntimeNow(cfg, providerName, body.OnlyExpiring) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - order, _ := providers.RerankProviderRuntime(cfg, providerName) - summary := providers.GetProviderRuntimeSummary(cfg, providers.ProviderRuntimeQuery{Provider: providerName, HealthBelow: 50}) - writeJSON(w, map[string]interface{}{ - "ok": true, - "provider": providerName, - "refreshed": true, - "result": result, - "candidate_order": order, - "summary": summary, - }) - case "rerank": - cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - order, err := providers.RerankProviderRuntime(cfg, providerName) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "provider": providerName, "reranked": true, "candidate_order": order}) - default: - http.Error(w, "unsupported action", http.StatusBadRequest) - } -} - -func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { - if strings.TrimSpace(s.configPath) == "" { - return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("config path not set") - } - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - return nil, cfgpkg.ProviderConfig{}, err - } - providerName := strings.TrimSpace(name) - if providerName == "" { - providerName = cfgpkg.PrimaryProviderName(cfg) - } - pc, ok := cfgpkg.ProviderConfigByName(cfg, providerName) - if !ok { - return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName) - } - return cfg, pc, nil -} - -func (s *Server) loadRuntimeProviderName(name string) (*cfgpkg.Config, string, error) { - if strings.TrimSpace(s.configPath) == "" { - return nil, "", fmt.Errorf("config path not set") - } - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - return nil, "", err - } - providerName := strings.TrimSpace(name) - if providerName == "" { - providerName = cfgpkg.PrimaryProviderName(cfg) - } - if !cfgpkg.ProviderExists(cfg, providerName) { - return nil, "", fmt.Errorf("provider %q not found", providerName) - } - return cfg, providerName, nil -} - -func (s *Server) resolveProviderConfig(name string, inline cfgpkg.ProviderConfig) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { - if hasInlineProviderConfig(inline) { - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - return nil, cfgpkg.ProviderConfig{}, err - } - return cfg, inline, nil - } - return s.loadProviderConfig(name) -} - -func hasInlineProviderConfig(pc cfgpkg.ProviderConfig) bool { - return strings.TrimSpace(pc.APIBase) != "" || - strings.TrimSpace(pc.APIKey) != "" || - len(pc.Models) > 0 || - strings.TrimSpace(pc.Auth) != "" || - strings.TrimSpace(pc.OAuth.Provider) != "" -} - -func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.ProviderConfig) error { - if cfg == nil { - return fmt.Errorf("config is nil") - } - providerName := strings.TrimSpace(name) - if cfg.Models.Providers == nil { - cfg.Models.Providers = map[string]cfgpkg.ProviderConfig{} - } - cfg.Models.Providers[providerName] = pc - if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil { - return err - } - if s.onConfigAfter != nil { - if err := s.onConfigAfter(); err != nil { - return err - } - } else { - if err := requestSelfReloadSignal(); err != nil { - return err - } - } - return nil -} - -func appendUniqueStrings(values []string, item string) []string { - item = strings.TrimSpace(item) - if item == "" { - return values - } - for _, value := range values { - if strings.TrimSpace(value) == item { - return values - } - } - return append(values, item) -} - -func removeStringItem(values []string, item string) []string { - item = strings.TrimSpace(item) - if item == "" { - return values - } - out := make([]string, 0, len(values)) - for _, value := range values { - if strings.TrimSpace(value) == item { - continue - } - out = append(out, value) - } - return out -} - func (s *Server) handleWebUIChat(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -2282,55 +1573,6 @@ func (s *Server) webUINodeArtifactsPayload(limit int) []map[string]interface{} { return s.webUINodeArtifactsPayloadFiltered("", "", "", limit) } -func (s *Server) webUINodeArtifactsPayloadFiltered(nodeFilter, actionFilter, kindFilter string, limit int) []map[string]interface{} { - nodeFilter = strings.TrimSpace(nodeFilter) - actionFilter = strings.TrimSpace(actionFilter) - kindFilter = strings.TrimSpace(kindFilter) - rows, _ := s.readNodeDispatchAuditRows() - if len(rows) == 0 { - return []map[string]interface{}{} - } - out := make([]map[string]interface{}, 0, limit) - for rowIndex := len(rows) - 1; rowIndex >= 0; rowIndex-- { - row := rows[rowIndex] - artifacts, _ := row["artifacts"].([]interface{}) - for artifactIndex, raw := range artifacts { - artifact, ok := raw.(map[string]interface{}) - if !ok { - continue - } - item := map[string]interface{}{ - "id": buildNodeArtifactID(row, artifact, artifactIndex), - "time": row["time"], - "node": row["node"], - "action": row["action"], - "used_transport": row["used_transport"], - "ok": row["ok"], - "error": row["error"], - } - for _, key := range []string{"name", "kind", "mime_type", "storage", "path", "url", "content_text", "content_base64", "source_path", "size_bytes"} { - if value, ok := artifact[key]; ok { - item[key] = value - } - } - if nodeFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["node"])), nodeFilter) { - continue - } - if actionFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["action"])), actionFilter) { - continue - } - if kindFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["kind"])), kindFilter) { - continue - } - out = append(out, item) - if limit > 0 && len(out) >= limit { - return out - } - } - } - return out -} - func (s *Server) readNodeDispatchAuditRows() ([]map[string]interface{}, string) { path := s.memoryFilePath("nodes-dispatch-audit.jsonl") if path == "" { @@ -2356,100 +1598,6 @@ func (s *Server) readNodeDispatchAuditRows() ([]map[string]interface{}, string) return rows, path } -func buildNodeArtifactID(row, artifact map[string]interface{}, artifactIndex int) string { - seed := fmt.Sprintf("%v|%v|%v|%d|%v|%v|%v", - row["time"], row["node"], row["action"], artifactIndex, - artifact["name"], artifact["source_path"], artifact["path"], - ) - sum := sha1.Sum([]byte(seed)) - return fmt.Sprintf("%x", sum[:8]) -} - -func sanitizeZipEntryName(name string) string { - name = strings.TrimSpace(name) - if name == "" { - return "artifact.bin" - } - name = strings.ReplaceAll(name, "\\", "/") - name = filepath.Base(name) - name = strings.Map(func(r rune) rune { - switch { - case r >= 'a' && r <= 'z': - return r - case r >= 'A' && r <= 'Z': - return r - case r >= '0' && r <= '9': - return r - case r == '.', r == '-', r == '_': - return r - default: - return '_' - } - }, name) - if strings.Trim(name, "._") == "" { - return "artifact.bin" - } - return name -} - -func (s *Server) findNodeArtifactByID(id string) (map[string]interface{}, bool) { - for _, item := range s.webUINodeArtifactsPayload(10000) { - if strings.TrimSpace(fmt.Sprint(item["id"])) == id { - return item, true - } - } - return nil, false -} - -func resolveArtifactPath(workspace, raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return "" - } - if filepath.IsAbs(raw) { - clean := filepath.Clean(raw) - if info, err := os.Stat(clean); err == nil && !info.IsDir() { - return clean - } - return "" - } - root := strings.TrimSpace(workspace) - if root == "" { - return "" - } - clean := filepath.Clean(filepath.Join(root, raw)) - if rel, err := filepath.Rel(root, clean); err != nil || strings.HasPrefix(rel, "..") { - return "" - } - if info, err := os.Stat(clean); err == nil && !info.IsDir() { - return clean - } - return "" -} - -func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, string, error) { - if content := strings.TrimSpace(fmt.Sprint(item["content_base64"])); content != "" { - raw, err := base64.StdEncoding.DecodeString(content) - if err != nil { - return nil, "", err - } - return raw, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil - } - for _, rawPath := range []string{fmt.Sprint(item["source_path"]), fmt.Sprint(item["path"])} { - if path := resolveArtifactPath(workspace, rawPath); path != "" { - b, err := os.ReadFile(path) - if err != nil { - return nil, "", err - } - return b, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil - } - } - if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" { - return []byte(contentText), "text/plain; charset=utf-8", nil - } - return nil, "", fmt.Errorf("artifact content unavailable") -} - func resolveRelativeFilePath(root, raw string) (string, string, error) { root = strings.TrimSpace(root) if root == "" { @@ -2519,37 +1667,6 @@ func (s *Server) memoryFilePath(name string) string { return filepath.Join(workspace, "memory", strings.TrimSpace(name)) } -func (s *Server) filteredNodeDispatches(nodeFilter, actionFilter string, limit int) []map[string]interface{} { - items := s.webUINodesDispatchPayload(limit) - if nodeFilter == "" && actionFilter == "" { - return items - } - out := make([]map[string]interface{}, 0, len(items)) - for _, item := range items { - if nodeFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["node"])), nodeFilter) { - continue - } - if actionFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["action"])), actionFilter) { - continue - } - out = append(out, item) - } - return out -} - -func filteredNodeAlerts(alerts []map[string]interface{}, nodeFilter string) []map[string]interface{} { - if nodeFilter == "" { - return alerts - } - out := make([]map[string]interface{}, 0, len(alerts)) - for _, item := range alerts { - if strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["node"])), nodeFilter) { - out = append(out, item) - } - } - return out -} - func (s *Server) setArtifactStats(summary map[string]interface{}) { s.artifactStatsMu.Lock() defer s.artifactStatsMu.Unlock() @@ -3203,309 +2320,6 @@ func (s *Server) handleWebUINodeDispatches(w http.ResponseWriter, r *http.Reques }) } -func (s *Server) handleWebUINodeDispatchReplay(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if s.onNodeDispatch == nil { - http.Error(w, "node dispatch handler not configured", http.StatusServiceUnavailable) - return - } - var body struct { - Node string `json:"node"` - Action string `json:"action"` - Mode string `json:"mode"` - Task string `json:"task"` - Model string `json:"model"` - Args map[string]interface{} `json:"args"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - req := nodes.Request{ - Node: strings.TrimSpace(body.Node), - Action: strings.TrimSpace(body.Action), - Task: body.Task, - Model: body.Model, - Args: body.Args, - } - if req.Node == "" || req.Action == "" { - http.Error(w, "node and action are required", http.StatusBadRequest) - return - } - resp, err := s.onNodeDispatch(r.Context(), req, strings.TrimSpace(body.Mode)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "result": resp, - }) -} - -func (s *Server) handleWebUINodeArtifacts(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - limit := queryBoundedPositiveInt(r, "limit", 200, 1000) - retentionSummary := s.applyNodeArtifactRetention() - nodeFilter := strings.TrimSpace(r.URL.Query().Get("node")) - actionFilter := strings.TrimSpace(r.URL.Query().Get("action")) - kindFilter := strings.TrimSpace(r.URL.Query().Get("kind")) - writeJSON(w, map[string]interface{}{ - "ok": true, - "items": s.webUINodeArtifactsPayloadFiltered(nodeFilter, actionFilter, kindFilter, limit), - "artifact_retention": retentionSummary, - }) -} - -func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - retentionSummary := s.applyNodeArtifactRetention() - limit := queryBoundedPositiveInt(r, "limit", 200, 1000) - nodeFilter := strings.TrimSpace(r.URL.Query().Get("node")) - actionFilter := strings.TrimSpace(r.URL.Query().Get("action")) - kindFilter := strings.TrimSpace(r.URL.Query().Get("kind")) - artifacts := s.webUINodeArtifactsPayloadFiltered(nodeFilter, actionFilter, kindFilter, limit) - dispatches := s.filteredNodeDispatches(nodeFilter, actionFilter, limit) - payload := s.webUINodesPayload(r.Context()) - nodeList, _ := payload["nodes"].([]nodes.NodeInfo) - p2p, _ := payload["p2p"].(map[string]interface{}) - alerts := filteredNodeAlerts(s.webUINodeAlertsPayload(nodeList, p2p, dispatches), nodeFilter) - - var archive bytes.Buffer - zw := zip.NewWriter(&archive) - writeJSON := func(name string, value interface{}) error { - entry, err := zw.Create(name) - if err != nil { - return err - } - enc := json.NewEncoder(entry) - enc.SetIndent("", " ") - return enc.Encode(value) - } - manifest := map[string]interface{}{ - "generated_at": time.Now().UTC().Format(time.RFC3339), - "filters": map[string]interface{}{ - "node": nodeFilter, - "action": actionFilter, - "kind": kindFilter, - "limit": limit, - }, - "artifact_count": len(artifacts), - "dispatch_count": len(dispatches), - "alert_count": len(alerts), - "retention": retentionSummary, - } - if err := writeJSON("manifest.json", manifest); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := writeJSON("dispatches.json", dispatches); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := writeJSON("alerts.json", alerts); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := writeJSON("artifacts.json", artifacts); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - for _, item := range artifacts { - name := sanitizeZipEntryName(firstNonEmptyString( - fmt.Sprint(item["name"]), - fmt.Sprint(item["source_path"]), - fmt.Sprint(item["path"]), - fmt.Sprintf("%s.bin", fmt.Sprint(item["id"])), - )) - raw, _, err := readArtifactBytes(s.workspacePath, item) - entryName := filepath.ToSlash(filepath.Join("files", fmt.Sprintf("%s-%s", fmt.Sprint(item["id"]), name))) - if err != nil || len(raw) == 0 { - entryName = filepath.ToSlash(filepath.Join("files", fmt.Sprintf("%s-metadata.json", fmt.Sprint(item["id"])))) - raw, err = json.MarshalIndent(item, "", " ") - if err != nil { - continue - } - } - entry, err := zw.Create(entryName) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if _, err := entry.Write(raw); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - if err := zw.Close(); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - filename := "node-artifacts-export.zip" - if nodeFilter != "" { - filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter)) - } - w.Header().Set("Content-Type", "application/zip") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename)) - w.WriteHeader(http.StatusOK) - _, _ = w.Write(archive.Bytes()) -} - -func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - id := strings.TrimSpace(r.URL.Query().Get("id")) - if id == "" { - http.Error(w, "id is required", http.StatusBadRequest) - return - } - item, ok := s.findNodeArtifactByID(id) - if !ok { - http.Error(w, "artifact not found", http.StatusNotFound) - return - } - name := strings.TrimSpace(fmt.Sprint(item["name"])) - if name == "" { - name = "artifact" - } - mimeType := strings.TrimSpace(fmt.Sprint(item["mime_type"])) - if mimeType == "" { - mimeType = "application/octet-stream" - } - if contentB64 := strings.TrimSpace(fmt.Sprint(item["content_base64"])); contentB64 != "" { - payload, err := base64.StdEncoding.DecodeString(contentB64) - if err != nil { - http.Error(w, "invalid inline artifact payload", http.StatusBadRequest) - return - } - w.Header().Set("Content-Type", mimeType) - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) - _, _ = w.Write(payload) - return - } - for _, rawPath := range []string{fmt.Sprint(item["source_path"]), fmt.Sprint(item["path"])} { - if path := resolveArtifactPath(s.workspacePath, rawPath); path != "" { - http.ServeFile(w, r, path) - return - } - } - if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" { - w.Header().Set("Content-Type", mimeType) - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) - _, _ = w.Write([]byte(contentText)) - return - } - http.Error(w, "artifact content unavailable", http.StatusNotFound) -} - -func (s *Server) handleWebUINodeArtifactDelete(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - var body struct { - ID string `json:"id"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - deletedFile, deletedAudit, err := s.deleteNodeArtifact(strings.TrimSpace(body.ID)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "id": strings.TrimSpace(body.ID), - "deleted_file": deletedFile, - "deleted_audit": deletedAudit, - }) -} - -func (s *Server) handleWebUINodeArtifactPrune(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - var body struct { - Node string `json:"node"` - Action string `json:"action"` - Kind string `json:"kind"` - KeepLatest int `json:"keep_latest"` - Limit int `json:"limit"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - limit := body.Limit - if limit <= 0 || limit > 5000 { - limit = 5000 - } - keepLatest := body.KeepLatest - if keepLatest < 0 { - keepLatest = 0 - } - items := s.webUINodeArtifactsPayloadFiltered(strings.TrimSpace(body.Node), strings.TrimSpace(body.Action), strings.TrimSpace(body.Kind), limit) - pruned := 0 - deletedFiles := 0 - for index, item := range items { - if index < keepLatest { - continue - } - deletedFile, deletedAudit, err := s.deleteNodeArtifact(strings.TrimSpace(fmt.Sprint(item["id"]))) - if err != nil || !deletedAudit { - continue - } - pruned++ - if deletedFile { - deletedFiles++ - } - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "pruned": pruned, - "deleted_files": deletedFiles, - "kept": keepLatest, - }) -} - func (s *Server) buildNodeAgentTrees(ctx context.Context, nodeList []nodes.NodeInfo) []map[string]interface{} { trees := make([]map[string]interface{}, 0, len(nodeList)) localRegistry := s.fetchRegistryItems(ctx) @@ -3796,430 +2610,6 @@ func fallbackString(value, fallback string) string { return strings.TrimSpace(fallback) } -func (s *Server) handleWebUICron(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if s.onCron == nil { - http.Error(w, "cron handler not configured", http.StatusInternalServerError) - return - } - - switch r.Method { - case http.MethodGet: - id := strings.TrimSpace(r.URL.Query().Get("id")) - action := "list" - if id != "" { - action = "get" - } - res, err := s.onCron(action, map[string]interface{}{"id": id}) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if action == "list" { - writeJSON(w, map[string]interface{}{"ok": true, "jobs": normalizeCronJobs(res)}) - } else { - writeJSON(w, map[string]interface{}{"ok": true, "job": normalizeCronJob(res)}) - } - case http.MethodPost: - args := map[string]interface{}{} - if r.Body != nil { - _ = json.NewDecoder(r.Body).Decode(&args) - } - if id := strings.TrimSpace(r.URL.Query().Get("id")); id != "" { - args["id"] = id - } - action := "create" - if a := tools.MapStringArg(args, "action"); a != "" { - action = strings.ToLower(strings.TrimSpace(a)) - } - res, err := s.onCron(action, args) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "result": normalizeCronJob(res)}) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - -func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - skillsDir := filepath.Join(s.workspacePath, "skills") - if strings.TrimSpace(skillsDir) == "" { - http.Error(w, "workspace not configured", http.StatusInternalServerError) - return - } - _ = os.MkdirAll(skillsDir, 0755) - - resolveSkillPath := func(name string) (string, error) { - name = strings.TrimSpace(name) - if name == "" { - return "", fmt.Errorf("name required") - } - cands := []string{ - filepath.Join(skillsDir, name), - filepath.Join(skillsDir, name+".disabled"), - filepath.Join("/root/clawgo/workspace/skills", name), - filepath.Join("/root/clawgo/workspace/skills", name+".disabled"), - } - for _, p := range cands { - if st, err := os.Stat(p); err == nil && st.IsDir() { - return p, nil - } - } - return "", fmt.Errorf("skill not found: %s", name) - } - - switch r.Method { - case http.MethodGet: - clawhubPath := strings.TrimSpace(resolveClawHubBinary(r.Context())) - clawhubInstalled := clawhubPath != "" - if id := strings.TrimSpace(r.URL.Query().Get("id")); id != "" { - skillPath, err := resolveSkillPath(id) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - if strings.TrimSpace(r.URL.Query().Get("files")) == "1" { - var files []string - _ = filepath.WalkDir(skillPath, func(path string, d os.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - return nil - } - rel, _ := filepath.Rel(skillPath, path) - if strings.HasPrefix(rel, "..") { - return nil - } - files = append(files, filepath.ToSlash(rel)) - return nil - }) - writeJSON(w, map[string]interface{}{"ok": true, "id": id, "files": files}) - return - } - if f := strings.TrimSpace(r.URL.Query().Get("file")); f != "" { - clean, content, found, err := readRelativeTextFile(skillPath, f) - if err != nil { - http.Error(w, err.Error(), relativeFilePathStatus(err)) - return - } - if !found { - http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "id": id, "file": filepath.ToSlash(clean), "content": content}) - return - } - } - - type skillItem struct { - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Tools []string `json:"tools"` - SystemPrompt string `json:"system_prompt,omitempty"` - Enabled bool `json:"enabled"` - UpdateChecked bool `json:"update_checked"` - RemoteFound bool `json:"remote_found,omitempty"` - RemoteVersion string `json:"remote_version,omitempty"` - CheckError string `json:"check_error,omitempty"` - Source string `json:"source,omitempty"` - } - candDirs := []string{skillsDir, filepath.Join("/root/clawgo/workspace", "skills")} - seenDirs := map[string]struct{}{} - seenSkills := map[string]struct{}{} - items := make([]skillItem, 0) - // Default off to avoid hammering clawhub search API on each UI refresh. - // Enable explicitly with ?check_updates=1 when needed. - checkUpdates := strings.TrimSpace(r.URL.Query().Get("check_updates")) == "1" - - for _, dir := range candDirs { - dir = strings.TrimSpace(dir) - if dir == "" { - continue - } - if _, ok := seenDirs[dir]; ok { - continue - } - seenDirs[dir] = struct{}{} - entries, err := os.ReadDir(dir) - if err != nil { - if os.IsNotExist(err) { - continue - } - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - for _, e := range entries { - if !e.IsDir() { - continue - } - name := e.Name() - enabled := !strings.HasSuffix(name, ".disabled") - baseName := strings.TrimSuffix(name, ".disabled") - if _, ok := seenSkills[baseName]; ok { - continue - } - seenSkills[baseName] = struct{}{} - desc, tools, sys := readSkillMeta(filepath.Join(dir, name, "SKILL.md")) - if desc == "" || len(tools) == 0 || sys == "" { - d2, t2, s2 := readSkillMeta(filepath.Join(dir, baseName, "SKILL.md")) - if desc == "" { - desc = d2 - } - if len(tools) == 0 { - tools = t2 - } - if sys == "" { - sys = s2 - } - } - if tools == nil { - tools = []string{} - } - it := skillItem{ID: baseName, Name: baseName, Description: desc, Tools: tools, SystemPrompt: sys, Enabled: enabled, UpdateChecked: checkUpdates && clawhubInstalled, Source: dir} - if checkUpdates && clawhubInstalled { - found, version, checkErr := queryClawHubSkillVersion(r.Context(), baseName) - it.RemoteFound = found - it.RemoteVersion = version - if checkErr != nil { - it.CheckError = checkErr.Error() - } - } - items = append(items, it) - } - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "skills": items, - "source": "clawhub", - "clawhub_installed": clawhubInstalled, - "clawhub_path": clawhubPath, - }) - - case http.MethodPost: - ct := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))) - if strings.Contains(ct, "multipart/form-data") { - imported, err := importSkillArchiveFromMultipart(r, skillsDir) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "imported": imported}) - return - } - - var body map[string]interface{} - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - action := strings.ToLower(stringFromMap(body, "action")) - if action == "install_clawhub" { - output, err := ensureClawHubReady(r.Context()) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{ - "ok": true, - "output": output, - "installed": true, - "clawhub_path": resolveClawHubBinary(r.Context()), - }) - return - } - id := stringFromMap(body, "id") - name := stringFromMap(body, "name") - if strings.TrimSpace(name) == "" { - name = id - } - name = strings.TrimSpace(name) - if name == "" { - http.Error(w, "name required", http.StatusBadRequest) - return - } - enabledPath := filepath.Join(skillsDir, name) - disabledPath := enabledPath + ".disabled" - - switch action { - case "install": - clawhubPath := strings.TrimSpace(resolveClawHubBinary(r.Context())) - if clawhubPath == "" { - http.Error(w, "clawhub is not installed. please install clawhub first.", http.StatusPreconditionFailed) - return - } - ignoreSuspicious, _ := tools.MapBoolArg(body, "ignore_suspicious") - args := []string{"install", name} - if ignoreSuspicious { - args = append(args, "--force") - } - cmd := exec.CommandContext(r.Context(), clawhubPath, args...) - cmd.Dir = strings.TrimSpace(s.workspacePath) - out, err := cmd.CombinedOutput() - if err != nil { - outText := string(out) - lower := strings.ToLower(outText) - if strings.Contains(lower, "rate limit exceeded") || strings.Contains(lower, "too many requests") { - http.Error(w, fmt.Sprintf("clawhub rate limit exceeded. please retry later or configure auth token.\n%s", outText), http.StatusTooManyRequests) - return - } - http.Error(w, fmt.Sprintf("install failed: %v\n%s", err, outText), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "installed": name, "output": string(out)}) - case "enable": - if _, err := os.Stat(disabledPath); err == nil { - if err := os.Rename(disabledPath, enabledPath); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - writeJSON(w, map[string]interface{}{"ok": true}) - case "disable": - if _, err := os.Stat(enabledPath); err == nil { - if err := os.Rename(enabledPath, disabledPath); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - writeJSON(w, map[string]interface{}{"ok": true}) - case "write_file": - skillPath, err := resolveSkillPath(name) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - content := rawStringFromMap(body, "content") - filePath := stringFromMap(body, "file") - clean, err := writeRelativeTextFile(skillPath, filePath, content, true) - if err != nil { - http.Error(w, err.Error(), relativeFilePathStatus(err)) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "name": name, "file": filepath.ToSlash(clean)}) - case "create", "update": - desc := rawStringFromMap(body, "description") - sys := rawStringFromMap(body, "system_prompt") - toolsList := stringListFromMap(body, "tools") - if action == "create" { - if _, err := os.Stat(enabledPath); err == nil { - http.Error(w, "skill already exists", http.StatusBadRequest) - return - } - } - if err := os.MkdirAll(filepath.Join(enabledPath, "scripts"), 0755); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - skillMD := buildSkillMarkdown(name, desc, toolsList, sys) - if err := os.WriteFile(filepath.Join(enabledPath, "SKILL.md"), []byte(skillMD), 0644); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true}) - default: - http.Error(w, "unsupported action", http.StatusBadRequest) - } - - case http.MethodDelete: - id := strings.TrimSpace(r.URL.Query().Get("id")) - if id == "" { - http.Error(w, "id required", http.StatusBadRequest) - return - } - pathA := filepath.Join(skillsDir, id) - pathB := pathA + ".disabled" - deleted := false - if err := os.RemoveAll(pathA); err == nil { - deleted = true - } - if err := os.RemoveAll(pathB); err == nil { - deleted = true - } - writeJSON(w, map[string]interface{}{"ok": true, "deleted": deleted, "id": id}) - - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - -func buildSkillMarkdown(name, desc string, tools []string, systemPrompt string) string { - if desc == "" { - desc = "No description provided." - } - if len(tools) == 0 { - tools = []string{""} - } - toolLines := make([]string, 0, len(tools)) - for _, t := range tools { - if t == "" { - continue - } - toolLines = append(toolLines, "- "+t) - } - if len(toolLines) == 0 { - toolLines = []string{"- (none)"} - } - return fmt.Sprintf(`--- -name: %s -description: %s ---- - -# %s - -%s - -## Tools -%s - -## System Prompt -%s -`, name, desc, name, desc, strings.Join(toolLines, "\n"), systemPrompt) -} - -func readSkillMeta(path string) (desc string, tools []string, systemPrompt string) { - b, err := os.ReadFile(path) - if err != nil { - return "", []string{}, "" - } - s := string(b) - reDesc := regexp.MustCompile(`(?m)^description:\s*(.+)$`) - reTools := regexp.MustCompile(`(?m)^##\s*Tools\s*$`) - rePrompt := regexp.MustCompile(`(?m)^##\s*System Prompt\s*$`) - if m := reDesc.FindStringSubmatch(s); len(m) > 1 { - desc = m[1] - } - if loc := reTools.FindStringIndex(s); loc != nil { - block := s[loc[1]:] - if p := rePrompt.FindStringIndex(block); p != nil { - block = block[:p[0]] - } - for _, line := range strings.Split(block, "\n") { - line = strings.TrimPrefix(line, "-") - if line != "" { - tools = append(tools, line) - } - } - } - if tools == nil { - tools = []string{} - } - if loc := rePrompt.FindStringIndex(s); loc != nil { - systemPrompt = s[loc[1]:] - } - return -} - func gatewayBuildVersion() string { if bi, ok := debug.ReadBuildInfo(); ok && bi != nil { ver := strings.TrimSpace(bi.Main.Version) @@ -4361,59 +2751,6 @@ func normalizeCronJobs(v interface{}) []map[string]interface{} { return out } -func queryClawHubSkillVersion(ctx context.Context, skill string) (found bool, version string, err error) { - if skill == "" { - return false, "", fmt.Errorf("skill empty") - } - clawhubPath := strings.TrimSpace(resolveClawHubBinary(ctx)) - if clawhubPath == "" { - return false, "", fmt.Errorf("clawhub not installed") - } - cctx, cancel := context.WithTimeout(ctx, 8*time.Second) - defer cancel() - cmd := exec.CommandContext(cctx, clawhubPath, "search", skill, "--json") - out, runErr := cmd.Output() - if runErr != nil { - return false, "", runErr - } - var payload interface{} - if err := json.Unmarshal(out, &payload); err != nil { - return false, "", err - } - lowerSkill := strings.ToLower(skill) - var walk func(v interface{}) (bool, string) - walk = func(v interface{}) (bool, string) { - switch t := v.(type) { - case map[string]interface{}: - name := strings.ToLower(strings.TrimSpace(anyToString(t["name"]))) - if name == "" { - name = strings.ToLower(strings.TrimSpace(anyToString(t["id"]))) - } - if name == lowerSkill || strings.Contains(name, lowerSkill) { - ver := anyToString(t["version"]) - if ver == "" { - ver = anyToString(t["latest_version"]) - } - return true, ver - } - for _, vv := range t { - if ok, ver := walk(vv); ok { - return ok, ver - } - } - case []interface{}: - for _, vv := range t { - if ok, ver := walk(vv); ok { - return ok, ver - } - } - } - return false, "" - } - ok, ver := walk(payload) - return ok, ver, nil -} - func resolveClawHubBinary(ctx context.Context) string { if p, err := exec.LookPath("clawhub"); err == nil { return p @@ -4575,32 +2912,6 @@ func detectNodeMajor(ctx context.Context, nodePath string) (int, error) { return v, nil } -func ensureClawHubReady(ctx context.Context) (string, error) { - outs := make([]string, 0, 4) - if p := resolveClawHubBinary(ctx); p != "" { - return "clawhub already installed at: " + p, nil - } - nodeOut, err := ensureNodeRuntime(ctx) - if nodeOut != "" { - outs = append(outs, nodeOut) - } - if err != nil { - return strings.Join(outs, "\n"), err - } - clawOut, err := runInstallCommand(ctx, "npm i -g clawhub") - if clawOut != "" { - outs = append(outs, clawOut) - } - if err != nil { - return strings.Join(outs, "\n"), err - } - if p := resolveClawHubBinary(ctx); p != "" { - outs = append(outs, "clawhub installed at: "+p) - return strings.Join(outs, "\n"), nil - } - return strings.Join(outs, "\n"), fmt.Errorf("installed clawhub but executable still not found in PATH") -} - func ensureMCPPackageInstalled(ctx context.Context, pkgName string) (output string, binName string, binPath string, err error) { return ensureMCPPackageInstalledWithInstaller(ctx, pkgName, "npm") } @@ -4743,304 +3054,6 @@ func shellEscapeArg(in string) string { return "'" + strings.ReplaceAll(in, "'", `'\''`) + "'" } -func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]string, error) { - if err := r.ParseMultipartForm(128 << 20); err != nil { - return nil, err - } - f, h, err := r.FormFile("file") - if err != nil { - return nil, fmt.Errorf("file required") - } - defer f.Close() - - uploadDir := filepath.Join(os.TempDir(), "clawgo_skill_uploads") - _ = os.MkdirAll(uploadDir, 0755) - archivePath := filepath.Join(uploadDir, fmt.Sprintf("%d_%s", time.Now().UnixNano(), filepath.Base(h.Filename))) - out, err := os.Create(archivePath) - if err != nil { - return nil, err - } - if _, err := io.Copy(out, f); err != nil { - _ = out.Close() - _ = os.Remove(archivePath) - return nil, err - } - _ = out.Close() - defer os.Remove(archivePath) - - extractDir, err := os.MkdirTemp("", "clawgo_skill_extract_*") - if err != nil { - return nil, err - } - defer os.RemoveAll(extractDir) - - if err := extractArchive(archivePath, extractDir); err != nil { - return nil, err - } - - type candidate struct { - name string - dir string - } - candidates := make([]candidate, 0) - seen := map[string]struct{}{} - err = filepath.WalkDir(extractDir, func(path string, d os.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - return nil - } - if strings.EqualFold(d.Name(), "SKILL.md") { - dir := filepath.Dir(path) - rel, relErr := filepath.Rel(extractDir, dir) - if relErr != nil { - return nil - } - rel = filepath.ToSlash(strings.TrimSpace(rel)) - if rel == "" { - rel = "." - } - name := filepath.Base(rel) - if rel == "." { - name = archiveBaseName(h.Filename) - } - name = sanitizeSkillName(name) - if name == "" { - return nil - } - if _, ok := seen[name]; ok { - return nil - } - seen[name] = struct{}{} - candidates = append(candidates, candidate{name: name, dir: dir}) - } - return nil - }) - if err != nil { - return nil, err - } - if len(candidates) == 0 { - return nil, fmt.Errorf("no SKILL.md found in archive") - } - - imported := make([]string, 0, len(candidates)) - for _, c := range candidates { - dst := filepath.Join(skillsDir, c.name) - if _, err := os.Stat(dst); err == nil { - return nil, fmt.Errorf("skill already exists: %s", c.name) - } - if _, err := os.Stat(dst + ".disabled"); err == nil { - return nil, fmt.Errorf("disabled skill already exists: %s", c.name) - } - if err := copyDir(c.dir, dst); err != nil { - return nil, err - } - imported = append(imported, c.name) - } - sort.Strings(imported) - return imported, nil -} - -func archiveBaseName(filename string) string { - name := filepath.Base(strings.TrimSpace(filename)) - lower := strings.ToLower(name) - switch { - case strings.HasSuffix(lower, ".tar.gz"): - return name[:len(name)-len(".tar.gz")] - case strings.HasSuffix(lower, ".tgz"): - return name[:len(name)-len(".tgz")] - case strings.HasSuffix(lower, ".zip"): - return name[:len(name)-len(".zip")] - case strings.HasSuffix(lower, ".tar"): - return name[:len(name)-len(".tar")] - default: - ext := filepath.Ext(name) - return strings.TrimSuffix(name, ext) - } -} - -func sanitizeSkillName(name string) string { - name = strings.TrimSpace(name) - if name == "" { - return "" - } - var b strings.Builder - lastDash := false - for _, ch := range strings.ToLower(name) { - if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' { - b.WriteRune(ch) - lastDash = false - continue - } - if !lastDash { - b.WriteRune('-') - lastDash = true - } - } - out := strings.Trim(b.String(), "-") - if out == "" || out == "." { - return "" - } - return out -} - -func extractArchive(archivePath, targetDir string) error { - lower := strings.ToLower(archivePath) - switch { - case strings.HasSuffix(lower, ".zip"): - return extractZip(archivePath, targetDir) - case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"): - return extractTarGz(archivePath, targetDir) - case strings.HasSuffix(lower, ".tar"): - return extractTar(archivePath, targetDir) - default: - return fmt.Errorf("unsupported archive format: %s", filepath.Base(archivePath)) - } -} - -func extractZip(archivePath, targetDir string) error { - zr, err := zip.OpenReader(archivePath) - if err != nil { - return err - } - defer zr.Close() - - for _, f := range zr.File { - if err := writeArchivedEntry(targetDir, f.Name, f.FileInfo().IsDir(), func() (io.ReadCloser, error) { - return f.Open() - }); err != nil { - return err - } - } - return nil -} - -func extractTarGz(archivePath, targetDir string) error { - f, err := os.Open(archivePath) - if err != nil { - return err - } - defer f.Close() - gz, err := gzip.NewReader(f) - if err != nil { - return err - } - defer gz.Close() - return extractTarReader(tar.NewReader(gz), targetDir) -} - -func extractTar(archivePath, targetDir string) error { - f, err := os.Open(archivePath) - if err != nil { - return err - } - defer f.Close() - return extractTarReader(tar.NewReader(f), targetDir) -} - -func extractTarReader(tr *tar.Reader, targetDir string) error { - for { - hdr, err := tr.Next() - if errors.Is(err, io.EOF) { - return nil - } - if err != nil { - return err - } - switch hdr.Typeflag { - case tar.TypeDir: - if err := writeArchivedEntry(targetDir, hdr.Name, true, nil); err != nil { - return err - } - case tar.TypeReg, tar.TypeRegA: - name := hdr.Name - if err := writeArchivedEntry(targetDir, name, false, func() (io.ReadCloser, error) { - return io.NopCloser(tr), nil - }); err != nil { - return err - } - } - } -} - -func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error)) error { - clean := filepath.Clean(strings.TrimSpace(name)) - clean = strings.TrimPrefix(clean, string(filepath.Separator)) - clean = strings.TrimPrefix(clean, "/") - for strings.HasPrefix(clean, "../") { - clean = strings.TrimPrefix(clean, "../") - } - if clean == "." || clean == "" { - return nil - } - dst := filepath.Join(targetDir, clean) - absTarget, _ := filepath.Abs(targetDir) - absDst, _ := filepath.Abs(dst) - if !strings.HasPrefix(absDst, absTarget+string(filepath.Separator)) && absDst != absTarget { - return fmt.Errorf("invalid archive entry path: %s", name) - } - if isDir { - return os.MkdirAll(dst, 0755) - } - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - rc, err := opener() - if err != nil { - return err - } - defer rc.Close() - out, err := os.Create(dst) - if err != nil { - return err - } - defer out.Close() - _, err = io.Copy(out, rc) - return err -} - -func copyDir(src, dst string) error { - entries, err := os.ReadDir(src) - if err != nil { - return err - } - if err := os.MkdirAll(dst, 0755); err != nil { - return err - } - for _, e := range entries { - srcPath := filepath.Join(src, e.Name()) - dstPath := filepath.Join(dst, e.Name()) - info, err := e.Info() - if err != nil { - return err - } - if info.IsDir() { - if err := copyDir(srcPath, dstPath); err != nil { - return err - } - continue - } - in, err := os.Open(srcPath) - if err != nil { - return err - } - out, err := os.Create(dstPath) - if err != nil { - _ = in.Close() - return err - } - if _, err := io.Copy(out, in); err != nil { - _ = out.Close() - _ = in.Close() - return err - } - _ = out.Close() - _ = in.Close() - } - return nil -} - func anyToString(v interface{}) string { switch t := v.(type) { case string: @@ -5153,160 +3166,6 @@ func (s *Server) handleWebUIToolAllowlistGroups(w http.ResponseWriter, r *http.R }) } -func (s *Server) handleWebUISubagentsRuntime(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if s.onSubagents == nil { - http.Error(w, "subagent runtime handler not configured", http.StatusServiceUnavailable) - return - } - - action := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("action"))) - args := map[string]interface{}{} - switch r.Method { - case http.MethodGet: - if action == "" { - action = "list" - } - for key, values := range r.URL.Query() { - if key == "action" || key == "token" || len(values) == 0 { - continue - } - args[key] = strings.TrimSpace(values[0]) - } - case http.MethodPost: - var body map[string]interface{} - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - if body == nil { - body = map[string]interface{}{} - } - if action == "" { - if raw := stringFromMap(body, "action"); raw != "" { - action = strings.ToLower(strings.TrimSpace(raw)) - } - } - delete(body, "action") - args = body - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - result, err := s.onSubagents(r.Context(), action, args) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "result": result}) -} - -func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - memoryDir := filepath.Join(s.workspacePath, "memory") - _ = os.MkdirAll(memoryDir, 0755) - switch r.Method { - case http.MethodGet: - path := strings.TrimSpace(r.URL.Query().Get("path")) - if path == "" { - entries, err := os.ReadDir(memoryDir) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - files := make([]string, 0, len(entries)) - for _, e := range entries { - if e.IsDir() { - continue - } - files = append(files, e.Name()) - } - writeJSON(w, map[string]interface{}{"ok": true, "files": files}) - return - } - clean, content, found, err := readRelativeTextFile(memoryDir, path) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !found { - http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "content": content}) - case http.MethodPost: - var body struct { - Path string `json:"path"` - Content string `json:"content"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - clean, err := writeRelativeTextFile(memoryDir, body.Path, body.Content, false) - if err != nil { - http.Error(w, err.Error(), relativeFilePathStatus(err)) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "path": clean}) - case http.MethodDelete: - clean, full, err := resolveRelativeFilePath(memoryDir, r.URL.Query().Get("path")) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := os.Remove(full); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "deleted": true, "path": clean}) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - -func (s *Server) handleWebUIWorkspaceFile(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - workspace := strings.TrimSpace(s.workspacePath) - switch r.Method { - case http.MethodGet: - path := strings.TrimSpace(r.URL.Query().Get("path")) - clean, content, found, err := readRelativeTextFile(workspace, path) - if err != nil { - http.Error(w, err.Error(), relativeFilePathStatus(err)) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "found": found, "content": content}) - case http.MethodPost: - var body struct { - Path string `json:"path"` - Content string `json:"content"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - clean, err := writeRelativeTextFile(workspace, body.Path, body.Content, true) - if err != nil { - http.Error(w, err.Error(), relativeFilePathStatus(err)) - return - } - writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "saved": true}) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - func (s *Server) handleWebUITaskQueue(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) diff --git a/pkg/api/server_node_artifacts.go b/pkg/api/server_node_artifacts.go new file mode 100644 index 0000000..ef86f6d --- /dev/null +++ b/pkg/api/server_node_artifacts.go @@ -0,0 +1,345 @@ +package api + +import ( + "archive/zip" + "bytes" + "crypto/sha1" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/YspCoder/clawgo/pkg/nodes" +) + +func (s *Server) webUINodeArtifactsPayloadFiltered(nodeFilter, actionFilter, kindFilter string, limit int) []map[string]interface{} { + nodeFilter = strings.TrimSpace(nodeFilter) + actionFilter = strings.TrimSpace(actionFilter) + kindFilter = strings.TrimSpace(kindFilter) + rows, _ := s.readNodeDispatchAuditRows() + if len(rows) == 0 { + return []map[string]interface{}{} + } + out := make([]map[string]interface{}, 0, limit) + for rowIndex := len(rows) - 1; rowIndex >= 0; rowIndex-- { + row := rows[rowIndex] + artifacts, _ := row["artifacts"].([]interface{}) + for artifactIndex, raw := range artifacts { + artifact, ok := raw.(map[string]interface{}) + if !ok { + continue + } + item := map[string]interface{}{ + "id": buildNodeArtifactID(row, artifact, artifactIndex), + "time": row["time"], + "node": row["node"], + "action": row["action"], + "used_transport": row["used_transport"], + "ok": row["ok"], + "error": row["error"], + } + for _, key := range []string{"name", "kind", "mime_type", "storage", "path", "url", "content_text", "content_base64", "source_path", "size_bytes"} { + if value, ok := artifact[key]; ok { + item[key] = value + } + } + if nodeFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["node"])), nodeFilter) { + continue + } + if actionFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["action"])), actionFilter) { + continue + } + if kindFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["kind"])), kindFilter) { + continue + } + out = append(out, item) + if limit > 0 && len(out) >= limit { + return out + } + } + } + return out +} + +func buildNodeArtifactID(row, artifact map[string]interface{}, artifactIndex int) string { + seed := fmt.Sprintf("%v|%v|%v|%d|%v|%v|%v", + row["time"], row["node"], row["action"], artifactIndex, + artifact["name"], artifact["source_path"], artifact["path"], + ) + sum := sha1.Sum([]byte(seed)) + return fmt.Sprintf("%x", sum[:8]) +} + +func sanitizeZipEntryName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "artifact.bin" + } + name = strings.ReplaceAll(name, "\\", "/") + name = filepath.Base(name) + name = strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= 'A' && r <= 'Z': + return r + case r >= '0' && r <= '9': + return r + case r == '.', r == '-', r == '_': + return r + default: + return '_' + } + }, name) + if strings.Trim(name, "._") == "" { + return "artifact.bin" + } + return name +} + +func (s *Server) findNodeArtifactByID(id string) (map[string]interface{}, bool) { + for _, item := range s.webUINodeArtifactsPayload(10000) { + if strings.TrimSpace(fmt.Sprint(item["id"])) == id { + return item, true + } + } + return nil, false +} + +func resolveArtifactPath(workspace, raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + if filepath.IsAbs(raw) { + clean := filepath.Clean(raw) + if info, err := os.Stat(clean); err == nil && !info.IsDir() { + return clean + } + return "" + } + root := strings.TrimSpace(workspace) + if root == "" { + return "" + } + clean := filepath.Clean(filepath.Join(root, raw)) + if rel, err := filepath.Rel(root, clean); err != nil || strings.HasPrefix(rel, "..") { + return "" + } + if info, err := os.Stat(clean); err == nil && !info.IsDir() { + return clean + } + return "" +} + +func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, string, error) { + if content := strings.TrimSpace(fmt.Sprint(item["content_base64"])); content != "" { + raw, err := base64.StdEncoding.DecodeString(content) + if err != nil { + return nil, "", err + } + return raw, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil + } + for _, rawPath := range []string{fmt.Sprint(item["source_path"]), fmt.Sprint(item["path"])} { + if path := resolveArtifactPath(workspace, rawPath); path != "" { + b, err := os.ReadFile(path) + if err != nil { + return nil, "", err + } + return b, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil + } + } + if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" { + return []byte(contentText), "text/plain; charset=utf-8", nil + } + return nil, "", fmt.Errorf("artifact content unavailable") +} + +func (s *Server) filteredNodeDispatches(nodeFilter, actionFilter string, limit int) []map[string]interface{} { + items := s.webUINodesDispatchPayload(limit) + if nodeFilter == "" && actionFilter == "" { + return items + } + out := make([]map[string]interface{}, 0, len(items)) + for _, item := range items { + if nodeFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["node"])), nodeFilter) { + continue + } + if actionFilter != "" && !strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["action"])), actionFilter) { + continue + } + out = append(out, item) + } + return out +} + +func filteredNodeAlerts(alerts []map[string]interface{}, nodeFilter string) []map[string]interface{} { + if nodeFilter == "" { + return alerts + } + out := make([]map[string]interface{}, 0, len(alerts)) + for _, item := range alerts { + if strings.EqualFold(strings.TrimSpace(fmt.Sprint(item["node"])), nodeFilter) { + out = append(out, item) + } + } + return out +} + +func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + retentionSummary := s.applyNodeArtifactRetention() + limit := queryBoundedPositiveInt(r, "limit", 200, 1000) + nodeFilter := strings.TrimSpace(r.URL.Query().Get("node")) + actionFilter := strings.TrimSpace(r.URL.Query().Get("action")) + kindFilter := strings.TrimSpace(r.URL.Query().Get("kind")) + artifacts := s.webUINodeArtifactsPayloadFiltered(nodeFilter, actionFilter, kindFilter, limit) + dispatches := s.filteredNodeDispatches(nodeFilter, actionFilter, limit) + payload := s.webUINodesPayload(r.Context()) + nodeList, _ := payload["nodes"].([]nodes.NodeInfo) + p2p, _ := payload["p2p"].(map[string]interface{}) + alerts := filteredNodeAlerts(s.webUINodeAlertsPayload(nodeList, p2p, dispatches), nodeFilter) + + var archive bytes.Buffer + zw := zip.NewWriter(&archive) + writeZipJSON := func(name string, value interface{}) error { + entry, err := zw.Create(name) + if err != nil { + return err + } + enc := json.NewEncoder(entry) + enc.SetIndent("", " ") + return enc.Encode(value) + } + manifest := map[string]interface{}{ + "generated_at": time.Now().UTC().Format(time.RFC3339), + "filters": map[string]interface{}{ + "node": nodeFilter, + "action": actionFilter, + "kind": kindFilter, + "limit": limit, + }, + "artifact_count": len(artifacts), + "dispatch_count": len(dispatches), + "alert_count": len(alerts), + "retention": retentionSummary, + } + if err := writeZipJSON("manifest.json", manifest); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := writeZipJSON("dispatches.json", dispatches); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := writeZipJSON("alerts.json", alerts); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := writeZipJSON("artifacts.json", artifacts); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + for _, item := range artifacts { + name := sanitizeZipEntryName(firstNonEmptyString( + fmt.Sprint(item["name"]), + fmt.Sprint(item["source_path"]), + fmt.Sprint(item["path"]), + fmt.Sprintf("%s.bin", fmt.Sprint(item["id"])), + )) + raw, _, err := readArtifactBytes(s.workspacePath, item) + entryName := filepath.ToSlash(filepath.Join("files", fmt.Sprintf("%s-%s", fmt.Sprint(item["id"]), name))) + if err != nil || len(raw) == 0 { + entryName = filepath.ToSlash(filepath.Join("files", fmt.Sprintf("%s-metadata.json", fmt.Sprint(item["id"])))) + raw, err = json.MarshalIndent(item, "", " ") + if err != nil { + continue + } + } + entry, err := zw.Create(entryName) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if _, err := entry.Write(raw); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + if err := zw.Close(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + filename := "node-artifacts-export.zip" + if nodeFilter != "" { + filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter)) + } + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(archive.Bytes()) +} + +func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + id := strings.TrimSpace(r.URL.Query().Get("id")) + if id == "" { + http.Error(w, "id is required", http.StatusBadRequest) + return + } + item, ok := s.findNodeArtifactByID(id) + if !ok { + http.Error(w, "artifact not found", http.StatusNotFound) + return + } + name := strings.TrimSpace(fmt.Sprint(item["name"])) + if name == "" { + name = "artifact" + } + mimeType := strings.TrimSpace(fmt.Sprint(item["mime_type"])) + if mimeType == "" { + mimeType = "application/octet-stream" + } + if contentB64 := strings.TrimSpace(fmt.Sprint(item["content_base64"])); contentB64 != "" { + payload, err := base64.StdEncoding.DecodeString(contentB64) + if err != nil { + http.Error(w, "invalid inline artifact payload", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + _, _ = w.Write(payload) + return + } + for _, rawPath := range []string{fmt.Sprint(item["source_path"]), fmt.Sprint(item["path"])} { + if path := resolveArtifactPath(s.workspacePath, rawPath); path != "" { + http.ServeFile(w, r, path) + return + } + } + if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" { + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + _, _ = w.Write([]byte(contentText)) + return + } + http.Error(w, "artifact content unavailable", http.StatusNotFound) +} diff --git a/pkg/api/server_providers.go b/pkg/api/server_providers.go new file mode 100644 index 0000000..10a201d --- /dev/null +++ b/pkg/api/server_providers.go @@ -0,0 +1,429 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + cfgpkg "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/providers" +) + +func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost && r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Provider string `json:"provider"` + AccountLabel string `json:"account_label"` + NetworkProxy string `json:"network_proxy"` + ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"` + } + if r.Method == http.MethodPost { + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + } else { + body.Provider = strings.TrimSpace(r.URL.Query().Get("provider")) + body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label")) + body.NetworkProxy = strings.TrimSpace(r.URL.Query().Get("network_proxy")) + } + cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = cfg + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + flow, err := loginMgr.StartManualFlowWithOptions(providers.OAuthLoginOptions{ + AccountLabel: body.AccountLabel, + NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)), + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + flowID := fmt.Sprintf("%d", time.Now().UnixNano()) + s.oauthFlowMu.Lock() + s.oauthFlows[flowID] = flow + s.oauthFlowMu.Unlock() + writeJSON(w, map[string]interface{}{ + "ok": true, + "flow_id": flowID, + "mode": flow.Mode, + "auth_url": flow.AuthURL, + "user_code": flow.UserCode, + "instructions": flow.Instructions, + "account_label": strings.TrimSpace(body.AccountLabel), + "network_proxy": strings.TrimSpace(body.NetworkProxy), + }) +} + +func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Provider string `json:"provider"` + FlowID string `json:"flow_id"` + CallbackURL string `json:"callback_url"` + AccountLabel string `json:"account_label"` + NetworkProxy string `json:"network_proxy"` + ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.oauthFlowMu.Lock() + flow := s.oauthFlows[strings.TrimSpace(body.FlowID)] + delete(s.oauthFlows, strings.TrimSpace(body.FlowID)) + s.oauthFlowMu.Unlock() + if flow == nil { + http.Error(w, "oauth flow not found", http.StatusBadRequest) + return + } + session, models, err := loginMgr.CompleteManualFlowWithOptions(r.Context(), pc.APIBase, flow, body.CallbackURL, providers.OAuthLoginOptions{ + AccountLabel: strings.TrimSpace(body.AccountLabel), + NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)), + }) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if session.CredentialFile != "" { + pc.OAuth.CredentialFile = session.CredentialFile + pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile) + } + if err := s.saveProviderConfig(cfg, body.Provider, pc); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "account": session.Email, + "credential_file": session.CredentialFile, + "network_proxy": session.NetworkProxy, + "models": models, + }) +} + +func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseMultipartForm(16 << 20); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + providerName := strings.TrimSpace(r.FormValue("provider")) + accountLabel := strings.TrimSpace(r.FormValue("account_label")) + networkProxy := strings.TrimSpace(r.FormValue("network_proxy")) + inlineCfgRaw := strings.TrimSpace(r.FormValue("provider_config")) + var inlineCfg cfgpkg.ProviderConfig + if inlineCfgRaw != "" { + if err := json.Unmarshal([]byte(inlineCfgRaw), &inlineCfg); err != nil { + http.Error(w, "invalid provider_config", http.StatusBadRequest) + return + } + } + cfg, pc, err := s.resolveProviderConfig(providerName, inlineCfg) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + file, header, err := r.FormFile("file") + if err != nil { + http.Error(w, "file required", http.StatusBadRequest) + return + } + defer file.Close() + data, err := io.ReadAll(file) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + session, models, err := loginMgr.ImportAuthJSONWithOptions(r.Context(), pc.APIBase, header.Filename, data, providers.OAuthLoginOptions{ + AccountLabel: accountLabel, + NetworkProxy: firstNonEmptyString(networkProxy, strings.TrimSpace(pc.OAuth.NetworkProxy)), + }) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if session.CredentialFile != "" { + pc.OAuth.CredentialFile = session.CredentialFile + pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile) + } + if err := s.saveProviderConfig(cfg, providerName, pc); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "account": session.Email, + "credential_file": session.CredentialFile, + "network_proxy": session.NetworkProxy, + "models": models, + }) +} + +func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + providerName := strings.TrimSpace(r.URL.Query().Get("provider")) + cfg, pc, err := s.loadProviderConfig(providerName) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = cfg + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + switch r.Method { + case http.MethodGet: + accounts, err := loginMgr.ListAccounts() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "accounts": accounts}) + case http.MethodPost: + var body struct { + Action string `json:"action"` + CredentialFile string `json:"credential_file"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + action := strings.ToLower(strings.TrimSpace(body.Action)) + handler := providerOAuthAccountActionHandlers[action] + if handler == nil { + http.Error(w, "unsupported action", http.StatusBadRequest) + return + } + result, err := handler(r.Context(), s, loginMgr, cfg, providerName, &pc, strings.TrimSpace(body.CredentialFile)) + if err != nil { + status := http.StatusBadRequest + if action == "delete" && strings.Contains(strings.ToLower(err.Error()), "config") { + status = http.StatusInternalServerError + } + http.Error(w, err.Error(), status) + return + } + writeJSON(w, result) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +type providerOAuthAccountActionHandler func(context.Context, *Server, *providers.OAuthLoginManager, *cfgpkg.Config, string, *cfgpkg.ProviderConfig, string) (map[string]interface{}, error) + +var providerOAuthAccountActionHandlers = map[string]providerOAuthAccountActionHandler{ + "refresh": func(ctx context.Context, _ *Server, loginMgr *providers.OAuthLoginManager, _ *cfgpkg.Config, _ string, _ *cfgpkg.ProviderConfig, credentialFile string) (map[string]interface{}, error) { + account, err := loginMgr.RefreshAccount(ctx, credentialFile) + if err != nil { + return nil, err + } + return map[string]interface{}{"ok": true, "account": account}, nil + }, + "delete": func(_ context.Context, srv *Server, loginMgr *providers.OAuthLoginManager, cfg *cfgpkg.Config, providerName string, pc *cfgpkg.ProviderConfig, credentialFile string) (map[string]interface{}, error) { + if err := loginMgr.DeleteAccount(credentialFile); err != nil { + return nil, err + } + pc.OAuth.CredentialFiles = removeStringItem(pc.OAuth.CredentialFiles, credentialFile) + if strings.TrimSpace(pc.OAuth.CredentialFile) == strings.TrimSpace(credentialFile) { + pc.OAuth.CredentialFile = "" + if len(pc.OAuth.CredentialFiles) > 0 { + pc.OAuth.CredentialFile = pc.OAuth.CredentialFiles[0] + } + } + if err := srv.saveProviderConfig(cfg, providerName, *pc); err != nil { + return nil, err + } + return map[string]interface{}{"ok": true, "deleted": true}, nil + }, + "clear_cooldown": func(_ context.Context, _ *Server, loginMgr *providers.OAuthLoginManager, _ *cfgpkg.Config, _ string, _ *cfgpkg.ProviderConfig, credentialFile string) (map[string]interface{}, error) { + if err := loginMgr.ClearCooldown(credentialFile); err != nil { + return nil, err + } + return map[string]interface{}{"ok": true, "cleared": true}, nil + }, +} + +func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { + if strings.TrimSpace(s.configPath) == "" { + return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("config path not set") + } + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + return nil, cfgpkg.ProviderConfig{}, err + } + providerName := strings.TrimSpace(name) + if providerName == "" { + providerName = cfgpkg.PrimaryProviderName(cfg) + } + pc, ok := cfgpkg.ProviderConfigByName(cfg, providerName) + if !ok { + return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName) + } + return cfg, pc, nil +} + +func (s *Server) loadRuntimeProviderName(name string) (*cfgpkg.Config, string, error) { + if strings.TrimSpace(s.configPath) == "" { + return nil, "", fmt.Errorf("config path not set") + } + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + return nil, "", err + } + providerName := strings.TrimSpace(name) + if providerName == "" { + providerName = cfgpkg.PrimaryProviderName(cfg) + } + if !cfgpkg.ProviderExists(cfg, providerName) { + return nil, "", fmt.Errorf("provider %q not found", providerName) + } + return cfg, providerName, nil +} + +func (s *Server) resolveProviderConfig(name string, inline cfgpkg.ProviderConfig) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { + if hasInlineProviderConfig(inline) { + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + return nil, cfgpkg.ProviderConfig{}, err + } + return cfg, inline, nil + } + return s.loadProviderConfig(name) +} + +func hasInlineProviderConfig(pc cfgpkg.ProviderConfig) bool { + return strings.TrimSpace(pc.APIBase) != "" || + strings.TrimSpace(pc.APIKey) != "" || + len(pc.Models) > 0 || + strings.TrimSpace(pc.Auth) != "" || + strings.TrimSpace(pc.OAuth.Provider) != "" +} + +func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.ProviderConfig) error { + if cfg == nil { + return fmt.Errorf("config is nil") + } + providerName := strings.TrimSpace(name) + if cfg.Models.Providers == nil { + cfg.Models.Providers = map[string]cfgpkg.ProviderConfig{} + } + cfg.Models.Providers[providerName] = pc + if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil { + return err + } + if s.onConfigAfter != nil { + if err := s.onConfigAfter(); err != nil { + return err + } + } else { + if err := requestSelfReloadSignal(); err != nil { + return err + } + } + return nil +} + +func appendUniqueStrings(values []string, item string) []string { + item = strings.TrimSpace(item) + if item == "" { + return values + } + for _, value := range values { + if strings.TrimSpace(value) == item { + return values + } + } + return append(values, item) +} + +func removeStringItem(values []string, item string) []string { + item = strings.TrimSpace(item) + if item == "" { + return values + } + out := make([]string, 0, len(values)) + for _, value := range values { + if strings.TrimSpace(value) == item { + continue + } + out = append(out, value) + } + return out +} + +func atoiDefault(raw string, fallback int) int { + if value, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil { + return value + } + return fallback +} diff --git a/pkg/api/server_rpc_facades.go b/pkg/api/server_rpc_facades.go new file mode 100644 index 0000000..565385f --- /dev/null +++ b/pkg/api/server_rpc_facades.go @@ -0,0 +1,526 @@ +package api + +import ( + "encoding/json" + "net/http" + "os" + "strings" + + rpcpkg "github.com/YspCoder/clawgo/pkg/rpc" + "github.com/YspCoder/clawgo/pkg/tools" +) + +func (s *Server) handleWebUIConfig(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if strings.TrimSpace(s.configPath) == "" { + http.Error(w, "config path not set", http.StatusInternalServerError) + return + } + svc := s.configRPCService() + switch r.Method { + case http.MethodGet: + mode := strings.TrimSpace(r.URL.Query().Get("mode")) + includeHot := r.URL.Query().Get("include_hot_reload_fields") == "1" || strings.EqualFold(mode, "hot") + resp, rpcErr := svc.View(r.Context(), rpcpkg.ConfigViewRequest{ + Mode: mode, + IncludeHotReloadInfo: includeHot, + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + if strings.EqualFold(mode, "normalized") || includeHot { + payload := map[string]interface{}{"ok": true, "config": resp.Config} + if resp.RawConfig != nil { + payload["raw_config"] = resp.RawConfig + } + if len(resp.HotReloadFields) > 0 { + payload["hot_reload_fields"] = resp.HotReloadFields + payload["hot_reload_field_details"] = resp.HotReloadFieldDetails + } + writeJSON(w, payload) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(resp.PrettyText)) + case http.MethodPost: + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + confirmRisky, _ := tools.MapBoolArg(body, "confirm_risky") + delete(body, "confirm_risky") + resp, rpcErr := svc.Save(r.Context(), rpcpkg.ConfigSaveRequest{ + Mode: strings.TrimSpace(r.URL.Query().Get("mode")), + ConfirmRisky: confirmRisky, + Config: body, + }) + if rpcErr != nil { + message := rpcErr.Message + status := rpcHTTPStatus(rpcErr) + if status == http.StatusInternalServerError && strings.TrimSpace(message) != "" && !strings.Contains(strings.ToLower(message), "reload failed") { + message = "config saved but reload failed: " + message + } + payload := map[string]interface{}{"ok": false, "error": message} + if resp != nil && resp.RequiresConfirm { + payload["requires_confirm"] = true + payload["changed_fields"] = resp.ChangedFields + } + if resp != nil && resp.Details != nil { + payload["details"] = resp.Details + } + writeJSONStatus(w, status, payload) + return + } + out := map[string]interface{}{"ok": true, "reloaded": true} + if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { + view, viewErr := svc.View(r.Context(), rpcpkg.ConfigViewRequest{Mode: "normalized"}) + if viewErr == nil { + out["config"] = view.Config + } + } + writeJSON(w, out) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleWebUIProviderModels(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Provider string `json:"provider"` + Model string `json:"model"` + Models []string `json:"models"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + resp, rpcErr := s.providerRPCService().UpdateModels(r.Context(), rpcpkg.UpdateProviderModelsRequest{ + Provider: strings.TrimSpace(body.Provider), + Model: body.Model, + Models: body.Models, + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "models": resp.Models, + }) +} + +func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method == http.MethodGet { + resp, rpcErr := s.providerRPCService().RuntimeView(r.Context(), rpcpkg.ProviderRuntimeViewRequest{ + Provider: strings.TrimSpace(r.URL.Query().Get("provider")), + Kind: strings.TrimSpace(r.URL.Query().Get("kind")), + Reason: strings.TrimSpace(r.URL.Query().Get("reason")), + Target: strings.TrimSpace(r.URL.Query().Get("target")), + Sort: strings.TrimSpace(r.URL.Query().Get("sort")), + ChangesOnly: strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("changes_only")), "true"), + WindowSec: atoiDefault(strings.TrimSpace(r.URL.Query().Get("window_sec")), 0), + Limit: atoiDefault(strings.TrimSpace(r.URL.Query().Get("limit")), 0), + Cursor: atoiDefault(strings.TrimSpace(r.URL.Query().Get("cursor")), 0), + HealthBelow: atoiDefault(strings.TrimSpace(r.URL.Query().Get("health_below")), 0), + CooldownUntilBeforeSec: atoiDefault(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec")), 0), + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "view": resp.View, + }) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Provider string `json:"provider"` + Action string `json:"action"` + OnlyExpiring bool `json:"only_expiring"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + resp, rpcErr := s.providerRPCService().RuntimeAction(r.Context(), rpcpkg.ProviderRuntimeActionRequest{ + Provider: strings.TrimSpace(body.Provider), + Action: strings.TrimSpace(body.Action), + OnlyExpiring: body.OnlyExpiring, + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + result := map[string]interface{}{"ok": true} + for key, value := range resp.Result { + result[key] = value + } + writeJSON(w, result) +} + +func (s *Server) handleWebUINodeDispatchReplay(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Node string `json:"node"` + Action string `json:"action"` + Mode string `json:"mode"` + Task string `json:"task"` + Model string `json:"model"` + Args map[string]interface{} `json:"args"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + resp, rpcErr := s.nodeRPCService().Dispatch(r.Context(), rpcpkg.DispatchNodeRequest{ + Node: strings.TrimSpace(body.Node), + Action: strings.TrimSpace(body.Action), + Mode: strings.TrimSpace(body.Mode), + Task: body.Task, + Model: body.Model, + Args: body.Args, + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "result": resp.Result, + }) +} + +func (s *Server) handleWebUINodeArtifacts(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + resp, rpcErr := s.nodeRPCService().ListArtifacts(r.Context(), rpcpkg.ListNodeArtifactsRequest{ + Node: strings.TrimSpace(r.URL.Query().Get("node")), + Action: strings.TrimSpace(r.URL.Query().Get("action")), + Kind: strings.TrimSpace(r.URL.Query().Get("kind")), + Limit: queryBoundedPositiveInt(r, "limit", 200, 1000), + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "items": resp.Items, + "artifact_retention": resp.ArtifactRetention, + }) +} + +func (s *Server) handleWebUINodeArtifactDelete(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + ID string `json:"id"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + resp, rpcErr := s.nodeRPCService().DeleteArtifact(r.Context(), rpcpkg.DeleteNodeArtifactRequest{ID: strings.TrimSpace(body.ID)}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "id": resp.ID, + "deleted_file": resp.DeletedFile, + "deleted_audit": resp.DeletedAudit, + }) +} + +func (s *Server) handleWebUINodeArtifactPrune(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Node string `json:"node"` + Action string `json:"action"` + Kind string `json:"kind"` + KeepLatest int `json:"keep_latest"` + Limit int `json:"limit"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + resp, rpcErr := s.nodeRPCService().PruneArtifacts(r.Context(), rpcpkg.PruneNodeArtifactsRequest{ + Node: strings.TrimSpace(body.Node), + Action: strings.TrimSpace(body.Action), + Kind: strings.TrimSpace(body.Kind), + KeepLatest: body.KeepLatest, + Limit: body.Limit, + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "pruned": resp.Pruned, + "deleted_files": resp.DeletedFiles, + "kept": resp.Kept, + }) +} + +func (s *Server) handleWebUICron(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if s.onCron == nil { + http.Error(w, "cron handler not configured", http.StatusInternalServerError) + return + } + svc := s.cronRPCService() + switch r.Method { + case http.MethodGet: + id := strings.TrimSpace(r.URL.Query().Get("id")) + if id == "" { + resp, rpcErr := svc.List(r.Context(), rpcpkg.ListCronJobsRequest{}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "jobs": resp.Jobs}) + } else { + resp, rpcErr := svc.Get(r.Context(), rpcpkg.GetCronJobRequest{ID: id}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "job": resp.Job}) + } + case http.MethodPost: + args := map[string]interface{}{} + if r.Body != nil { + _ = json.NewDecoder(r.Body).Decode(&args) + } + if id := strings.TrimSpace(r.URL.Query().Get("id")); id != "" { + args["id"] = id + } + action := "create" + if a := tools.MapStringArg(args, "action"); a != "" { + action = strings.ToLower(strings.TrimSpace(a)) + } + resp, rpcErr := svc.Mutate(r.Context(), rpcpkg.MutateCronJobRequest{ + Action: action, + Args: args, + }) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "result": resp.Result}) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleWebUISubagentsRuntime(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if s.onSubagents == nil { + http.Error(w, "subagent runtime handler not configured", http.StatusServiceUnavailable) + return + } + + action := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("action"))) + args := map[string]interface{}{} + switch r.Method { + case http.MethodGet: + if action == "" { + action = "list" + } + for key, values := range r.URL.Query() { + if key == "action" || key == "token" || len(values) == 0 { + continue + } + args[key] = strings.TrimSpace(values[0]) + } + case http.MethodPost: + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if body == nil { + body = map[string]interface{}{} + } + if action == "" { + if raw := stringFromMap(body, "action"); raw != "" { + action = strings.ToLower(strings.TrimSpace(raw)) + } + } + delete(body, "action") + args = body + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + result, rpcErr := s.handleSubagentLegacyAction(r.Context(), action, args) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "result": result}) +} + +func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + svc := s.workspaceRPCService() + switch r.Method { + case http.MethodGet: + path := strings.TrimSpace(r.URL.Query().Get("path")) + if path == "" { + resp, rpcErr := svc.ListFiles(r.Context(), rpcpkg.ListWorkspaceFilesRequest{Scope: "memory"}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "files": resp.Files}) + return + } + resp, rpcErr := svc.ReadFile(r.Context(), rpcpkg.ReadWorkspaceFileRequest{Scope: "memory", Path: path}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + if !resp.Found { + http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": resp.Path, "content": resp.Content}) + case http.MethodPost: + var body struct { + Path string `json:"path"` + Content string `json:"content"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + resp, rpcErr := svc.WriteFile(r.Context(), rpcpkg.WriteWorkspaceFileRequest{Scope: "memory", Path: body.Path, Content: body.Content}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": resp.Path}) + case http.MethodDelete: + resp, rpcErr := svc.DeleteFile(r.Context(), rpcpkg.DeleteWorkspaceFileRequest{Scope: "memory", Path: r.URL.Query().Get("path")}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "deleted": resp.Deleted, "path": resp.Path}) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleWebUIWorkspaceFile(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + svc := s.workspaceRPCService() + switch r.Method { + case http.MethodGet: + path := strings.TrimSpace(r.URL.Query().Get("path")) + if path == "" { + resp, rpcErr := svc.ListFiles(r.Context(), rpcpkg.ListWorkspaceFilesRequest{Scope: "workspace"}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "files": resp.Files}) + return + } + resp, rpcErr := svc.ReadFile(r.Context(), rpcpkg.ReadWorkspaceFileRequest{Scope: "workspace", Path: path}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": resp.Path, "found": resp.Found, "content": resp.Content}) + case http.MethodPost: + var body struct { + Path string `json:"path"` + Content string `json:"content"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + resp, rpcErr := svc.WriteFile(r.Context(), rpcpkg.WriteWorkspaceFileRequest{Scope: "workspace", Path: body.Path, Content: body.Content}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": resp.Path, "saved": resp.Saved}) + case http.MethodDelete: + resp, rpcErr := svc.DeleteFile(r.Context(), rpcpkg.DeleteWorkspaceFileRequest{Scope: "workspace", Path: r.URL.Query().Get("path")}) + if rpcErr != nil { + http.Error(w, rpcErr.Message, rpcHTTPStatus(rpcErr)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "deleted": resp.Deleted, "path": resp.Path}) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} diff --git a/pkg/api/server_skills.go b/pkg/api/server_skills.go new file mode 100644 index 0000000..8d6a623 --- /dev/null +++ b/pkg/api/server_skills.go @@ -0,0 +1,787 @@ +package api + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strings" + "time" + + "github.com/YspCoder/clawgo/pkg/tools" +) + +func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + skillsDir := filepath.Join(s.workspacePath, "skills") + if strings.TrimSpace(skillsDir) == "" { + http.Error(w, "workspace not configured", http.StatusInternalServerError) + return + } + _ = os.MkdirAll(skillsDir, 0755) + + resolveSkillPath := func(name string) (string, error) { + name = strings.TrimSpace(name) + if name == "" { + return "", fmt.Errorf("name required") + } + cands := []string{ + filepath.Join(skillsDir, name), + filepath.Join(skillsDir, name+".disabled"), + filepath.Join("/root/clawgo/workspace/skills", name), + filepath.Join("/root/clawgo/workspace/skills", name+".disabled"), + } + for _, p := range cands { + if st, err := os.Stat(p); err == nil && st.IsDir() { + return p, nil + } + } + return "", fmt.Errorf("skill not found: %s", name) + } + + switch r.Method { + case http.MethodGet: + clawhubPath := strings.TrimSpace(resolveClawHubBinary(r.Context())) + clawhubInstalled := clawhubPath != "" + if id := strings.TrimSpace(r.URL.Query().Get("id")); id != "" { + skillPath, err := resolveSkillPath(id) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + if strings.TrimSpace(r.URL.Query().Get("files")) == "1" { + var files []string + _ = filepath.WalkDir(skillPath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + return nil + } + rel, _ := filepath.Rel(skillPath, path) + if strings.HasPrefix(rel, "..") { + return nil + } + files = append(files, filepath.ToSlash(rel)) + return nil + }) + writeJSON(w, map[string]interface{}{"ok": true, "id": id, "files": files}) + return + } + if f := strings.TrimSpace(r.URL.Query().Get("file")); f != "" { + clean, content, found, err := readRelativeTextFile(skillPath, f) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return + } + if !found { + http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "id": id, "file": filepath.ToSlash(clean), "content": content}) + return + } + } + + type skillItem struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Tools []string `json:"tools"` + SystemPrompt string `json:"system_prompt,omitempty"` + Enabled bool `json:"enabled"` + UpdateChecked bool `json:"update_checked"` + RemoteFound bool `json:"remote_found,omitempty"` + RemoteVersion string `json:"remote_version,omitempty"` + CheckError string `json:"check_error,omitempty"` + Source string `json:"source,omitempty"` + } + candDirs := []string{skillsDir, filepath.Join("/root/clawgo/workspace", "skills")} + seenDirs := map[string]struct{}{} + seenSkills := map[string]struct{}{} + items := make([]skillItem, 0) + checkUpdates := strings.TrimSpace(r.URL.Query().Get("check_updates")) == "1" + + for _, dir := range candDirs { + dir = strings.TrimSpace(dir) + if dir == "" { + continue + } + if _, ok := seenDirs[dir]; ok { + continue + } + seenDirs[dir] = struct{}{} + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + continue + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + for _, e := range entries { + if !e.IsDir() { + continue + } + name := e.Name() + enabled := !strings.HasSuffix(name, ".disabled") + baseName := strings.TrimSuffix(name, ".disabled") + if _, ok := seenSkills[baseName]; ok { + continue + } + seenSkills[baseName] = struct{}{} + desc, skillTools, sys := readSkillMeta(filepath.Join(dir, name, "SKILL.md")) + if desc == "" || len(skillTools) == 0 || sys == "" { + d2, t2, s2 := readSkillMeta(filepath.Join(dir, baseName, "SKILL.md")) + if desc == "" { + desc = d2 + } + if len(skillTools) == 0 { + skillTools = t2 + } + if sys == "" { + sys = s2 + } + } + if skillTools == nil { + skillTools = []string{} + } + it := skillItem{ID: baseName, Name: baseName, Description: desc, Tools: skillTools, SystemPrompt: sys, Enabled: enabled, UpdateChecked: checkUpdates && clawhubInstalled, Source: dir} + if checkUpdates && clawhubInstalled { + found, version, checkErr := queryClawHubSkillVersion(r.Context(), baseName) + it.RemoteFound = found + it.RemoteVersion = version + if checkErr != nil { + it.CheckError = checkErr.Error() + } + } + items = append(items, it) + } + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "skills": items, + "source": "clawhub", + "clawhub_installed": clawhubInstalled, + "clawhub_path": clawhubPath, + }) + + case http.MethodPost: + ct := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))) + if strings.Contains(ct, "multipart/form-data") { + imported, err := importSkillArchiveFromMultipart(r, skillsDir) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "imported": imported}) + return + } + + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + action := strings.ToLower(stringFromMap(body, "action")) + if action == "install_clawhub" { + output, err := ensureClawHubReady(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "output": output, + "installed": true, + "clawhub_path": resolveClawHubBinary(r.Context()), + }) + return + } + id := stringFromMap(body, "id") + name := strings.TrimSpace(firstNonEmptyString(stringFromMap(body, "name"), id)) + if name == "" { + http.Error(w, "name required", http.StatusBadRequest) + return + } + enabledPath := filepath.Join(skillsDir, name) + disabledPath := enabledPath + ".disabled" + type skillActionHandler func() bool + handlers := map[string]skillActionHandler{ + "install": func() bool { + clawhubPath := strings.TrimSpace(resolveClawHubBinary(r.Context())) + if clawhubPath == "" { + http.Error(w, "clawhub is not installed. please install clawhub first.", http.StatusPreconditionFailed) + return false + } + ignoreSuspicious, _ := tools.MapBoolArg(body, "ignore_suspicious") + args := []string{"install", name} + if ignoreSuspicious { + args = append(args, "--force") + } + cmd := exec.CommandContext(r.Context(), clawhubPath, args...) + cmd.Dir = strings.TrimSpace(s.workspacePath) + out, err := cmd.CombinedOutput() + if err != nil { + outText := string(out) + lower := strings.ToLower(outText) + if strings.Contains(lower, "rate limit exceeded") || strings.Contains(lower, "too many requests") { + http.Error(w, fmt.Sprintf("clawhub rate limit exceeded. please retry later or configure auth token.\n%s", outText), http.StatusTooManyRequests) + return false + } + http.Error(w, fmt.Sprintf("install failed: %v\n%s", err, outText), http.StatusInternalServerError) + return false + } + writeJSON(w, map[string]interface{}{"ok": true, "installed": name, "output": string(out)}) + return true + }, + "enable": func() bool { + if _, err := os.Stat(disabledPath); err == nil { + if err := os.Rename(disabledPath, enabledPath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return false + } + } + writeJSON(w, map[string]interface{}{"ok": true}) + return true + }, + "disable": func() bool { + if _, err := os.Stat(enabledPath); err == nil { + if err := os.Rename(enabledPath, disabledPath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return false + } + } + writeJSON(w, map[string]interface{}{"ok": true}) + return true + }, + "write_file": func() bool { + skillPath, err := resolveSkillPath(name) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return false + } + content := rawStringFromMap(body, "content") + filePath := stringFromMap(body, "file") + clean, err := writeRelativeTextFile(skillPath, filePath, content, true) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return false + } + writeJSON(w, map[string]interface{}{"ok": true, "name": name, "file": filepath.ToSlash(clean)}) + return true + }, + "create": func() bool { + return createOrUpdateSkill(w, enabledPath, name, body, true) + }, + "update": func() bool { + return createOrUpdateSkill(w, enabledPath, name, body, false) + }, + } + if handler := handlers[action]; handler != nil { + handler() + return + } + http.Error(w, "unsupported action", http.StatusBadRequest) + + case http.MethodDelete: + id := strings.TrimSpace(r.URL.Query().Get("id")) + if id == "" { + http.Error(w, "id required", http.StatusBadRequest) + return + } + pathA := filepath.Join(skillsDir, id) + pathB := pathA + ".disabled" + deleted := false + if err := os.RemoveAll(pathA); err == nil { + deleted = true + } + if err := os.RemoveAll(pathB); err == nil { + deleted = true + } + writeJSON(w, map[string]interface{}{"ok": true, "deleted": deleted, "id": id}) + + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func buildSkillMarkdown(name, desc string, toolsList []string, systemPrompt string) string { + if desc == "" { + desc = "No description provided." + } + if len(toolsList) == 0 { + toolsList = []string{""} + } + toolLines := make([]string, 0, len(toolsList)) + for _, t := range toolsList { + if t == "" { + continue + } + toolLines = append(toolLines, "- "+t) + } + if len(toolLines) == 0 { + toolLines = []string{"- (none)"} + } + return fmt.Sprintf(`--- +name: %s +description: %s +--- + +# %s + +%s + +## Tools +%s + +## System Prompt +%s +`, name, desc, name, desc, strings.Join(toolLines, "\n"), systemPrompt) +} + +func createOrUpdateSkill(w http.ResponseWriter, enabledPath, name string, body map[string]interface{}, checkExists bool) bool { + desc := rawStringFromMap(body, "description") + sys := rawStringFromMap(body, "system_prompt") + toolsList := stringListFromMap(body, "tools") + if checkExists { + if _, err := os.Stat(enabledPath); err == nil { + http.Error(w, "skill already exists", http.StatusBadRequest) + return false + } + } + if err := os.MkdirAll(filepath.Join(enabledPath, "scripts"), 0755); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return false + } + skillMD := buildSkillMarkdown(name, desc, toolsList, sys) + if err := os.WriteFile(filepath.Join(enabledPath, "SKILL.md"), []byte(skillMD), 0644); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return false + } + writeJSON(w, map[string]interface{}{"ok": true}) + return true +} + +func readSkillMeta(path string) (desc string, toolsList []string, systemPrompt string) { + b, err := os.ReadFile(path) + if err != nil { + return "", []string{}, "" + } + s := string(b) + reDesc := regexp.MustCompile(`(?m)^description:\s*(.+)$`) + reTools := regexp.MustCompile(`(?m)^##\s*Tools\s*$`) + rePrompt := regexp.MustCompile(`(?m)^##\s*System Prompt\s*$`) + if m := reDesc.FindStringSubmatch(s); len(m) > 1 { + desc = m[1] + } + if loc := reTools.FindStringIndex(s); loc != nil { + block := s[loc[1]:] + if p := rePrompt.FindStringIndex(block); p != nil { + block = block[:p[0]] + } + for _, line := range strings.Split(block, "\n") { + line = strings.TrimPrefix(line, "-") + if line != "" { + toolsList = append(toolsList, line) + } + } + } + if toolsList == nil { + toolsList = []string{} + } + if loc := rePrompt.FindStringIndex(s); loc != nil { + systemPrompt = s[loc[1]:] + } + return +} + +func queryClawHubSkillVersion(ctx context.Context, skill string) (found bool, version string, err error) { + if skill == "" { + return false, "", fmt.Errorf("skill empty") + } + clawhubPath := strings.TrimSpace(resolveClawHubBinary(ctx)) + if clawhubPath == "" { + return false, "", fmt.Errorf("clawhub not installed") + } + cctx, cancel := context.WithTimeout(ctx, 8*time.Second) + defer cancel() + cmd := exec.CommandContext(cctx, clawhubPath, "search", skill, "--json") + out, runErr := cmd.Output() + if runErr != nil { + return false, "", runErr + } + var payload interface{} + if err := json.Unmarshal(out, &payload); err != nil { + return false, "", err + } + lowerSkill := strings.ToLower(skill) + var walk func(v interface{}) (bool, string) + walk = func(v interface{}) (bool, string) { + switch t := v.(type) { + case map[string]interface{}: + name := strings.ToLower(strings.TrimSpace(anyToString(t["name"]))) + if name == "" { + name = strings.ToLower(strings.TrimSpace(anyToString(t["id"]))) + } + if name == lowerSkill || strings.Contains(name, lowerSkill) { + ver := anyToString(t["version"]) + if ver == "" { + ver = anyToString(t["latest_version"]) + } + return true, ver + } + for _, vv := range t { + if ok, ver := walk(vv); ok { + return ok, ver + } + } + case []interface{}: + for _, vv := range t { + if ok, ver := walk(vv); ok { + return ok, ver + } + } + } + return false, "" + } + ok, ver := walk(payload) + return ok, ver, nil +} + +func ensureClawHubReady(ctx context.Context) (string, error) { + outs := make([]string, 0, 4) + if p := resolveClawHubBinary(ctx); p != "" { + return "clawhub already installed at: " + p, nil + } + nodeOut, err := ensureNodeRuntime(ctx) + if nodeOut != "" { + outs = append(outs, nodeOut) + } + if err != nil { + return strings.Join(outs, "\n"), err + } + clawOut, err := runInstallCommand(ctx, "npm i -g clawhub") + if clawOut != "" { + outs = append(outs, clawOut) + } + if err != nil { + return strings.Join(outs, "\n"), err + } + if p := resolveClawHubBinary(ctx); p != "" { + outs = append(outs, "clawhub installed at: "+p) + return strings.Join(outs, "\n"), nil + } + return strings.Join(outs, "\n"), fmt.Errorf("installed clawhub but executable still not found in PATH") +} + +func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]string, error) { + if err := r.ParseMultipartForm(128 << 20); err != nil { + return nil, err + } + f, h, err := r.FormFile("file") + if err != nil { + return nil, fmt.Errorf("file required") + } + defer f.Close() + + uploadDir := filepath.Join(os.TempDir(), "clawgo_skill_uploads") + _ = os.MkdirAll(uploadDir, 0755) + archivePath := filepath.Join(uploadDir, fmt.Sprintf("%d_%s", time.Now().UnixNano(), filepath.Base(h.Filename))) + out, err := os.Create(archivePath) + if err != nil { + return nil, err + } + if _, err := io.Copy(out, f); err != nil { + _ = out.Close() + _ = os.Remove(archivePath) + return nil, err + } + _ = out.Close() + defer os.Remove(archivePath) + + extractDir, err := os.MkdirTemp("", "clawgo_skill_extract_*") + if err != nil { + return nil, err + } + defer os.RemoveAll(extractDir) + + if err := extractArchive(archivePath, extractDir); err != nil { + return nil, err + } + + type candidate struct { + name string + dir string + } + candidates := make([]candidate, 0) + seen := map[string]struct{}{} + err = filepath.WalkDir(extractDir, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + return nil + } + if strings.EqualFold(d.Name(), "SKILL.md") { + dir := filepath.Dir(path) + rel, relErr := filepath.Rel(extractDir, dir) + if relErr != nil { + return nil + } + rel = filepath.ToSlash(strings.TrimSpace(rel)) + if rel == "" { + rel = "." + } + name := filepath.Base(rel) + if rel == "." { + name = archiveBaseName(h.Filename) + } + name = sanitizeSkillName(name) + if name == "" { + return nil + } + if _, ok := seen[name]; ok { + return nil + } + seen[name] = struct{}{} + candidates = append(candidates, candidate{name: name, dir: dir}) + } + return nil + }) + if err != nil { + return nil, err + } + if len(candidates) == 0 { + return nil, fmt.Errorf("no SKILL.md found in archive") + } + + imported := make([]string, 0, len(candidates)) + for _, c := range candidates { + dst := filepath.Join(skillsDir, c.name) + if _, err := os.Stat(dst); err == nil { + return nil, fmt.Errorf("skill already exists: %s", c.name) + } + if _, err := os.Stat(dst + ".disabled"); err == nil { + return nil, fmt.Errorf("disabled skill already exists: %s", c.name) + } + if err := copyDir(c.dir, dst); err != nil { + return nil, err + } + imported = append(imported, c.name) + } + sort.Strings(imported) + return imported, nil +} + +func archiveBaseName(filename string) string { + name := filepath.Base(strings.TrimSpace(filename)) + lower := strings.ToLower(name) + switch { + case strings.HasSuffix(lower, ".tar.gz"): + return name[:len(name)-len(".tar.gz")] + case strings.HasSuffix(lower, ".tgz"): + return name[:len(name)-len(".tgz")] + case strings.HasSuffix(lower, ".zip"): + return name[:len(name)-len(".zip")] + case strings.HasSuffix(lower, ".tar"): + return name[:len(name)-len(".tar")] + default: + ext := filepath.Ext(name) + return strings.TrimSuffix(name, ext) + } +} + +func sanitizeSkillName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + var b strings.Builder + lastDash := false + for _, ch := range strings.ToLower(name) { + if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' { + b.WriteRune(ch) + lastDash = false + continue + } + if !lastDash { + b.WriteRune('-') + lastDash = true + } + } + out := strings.Trim(b.String(), "-") + if out == "" || out == "." { + return "" + } + return out +} + +func extractArchive(archivePath, targetDir string) error { + lower := strings.ToLower(archivePath) + switch { + case strings.HasSuffix(lower, ".zip"): + return extractZip(archivePath, targetDir) + case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"): + return extractTarGz(archivePath, targetDir) + case strings.HasSuffix(lower, ".tar"): + return extractTar(archivePath, targetDir) + default: + return fmt.Errorf("unsupported archive format: %s", filepath.Base(archivePath)) + } +} + +func extractZip(archivePath, targetDir string) error { + zr, err := zip.OpenReader(archivePath) + if err != nil { + return err + } + defer zr.Close() + + for _, f := range zr.File { + if err := writeArchivedEntry(targetDir, f.Name, f.FileInfo().IsDir(), func() (io.ReadCloser, error) { + return f.Open() + }); err != nil { + return err + } + } + return nil +} + +func extractTarGz(archivePath, targetDir string) error { + f, err := os.Open(archivePath) + if err != nil { + return err + } + defer f.Close() + gz, err := gzip.NewReader(f) + if err != nil { + return err + } + defer gz.Close() + return extractTarReader(tar.NewReader(gz), targetDir) +} + +func extractTar(archivePath, targetDir string) error { + f, err := os.Open(archivePath) + if err != nil { + return err + } + defer f.Close() + return extractTarReader(tar.NewReader(f), targetDir) +} + +func extractTarReader(tr *tar.Reader, targetDir string) error { + for { + hdr, err := tr.Next() + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return err + } + switch hdr.Typeflag { + case tar.TypeDir: + if err := writeArchivedEntry(targetDir, hdr.Name, true, nil); err != nil { + return err + } + case tar.TypeReg, tar.TypeRegA: + name := hdr.Name + if err := writeArchivedEntry(targetDir, name, false, func() (io.ReadCloser, error) { + return io.NopCloser(tr), nil + }); err != nil { + return err + } + } + } +} + +func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error)) error { + clean := filepath.Clean(strings.TrimSpace(name)) + clean = strings.TrimPrefix(clean, string(filepath.Separator)) + clean = strings.TrimPrefix(clean, "/") + for strings.HasPrefix(clean, "../") { + clean = strings.TrimPrefix(clean, "../") + } + if clean == "." || clean == "" { + return nil + } + dst := filepath.Join(targetDir, clean) + absTarget, _ := filepath.Abs(targetDir) + absDst, _ := filepath.Abs(dst) + if !strings.HasPrefix(absDst, absTarget+string(filepath.Separator)) && absDst != absTarget { + return fmt.Errorf("invalid archive entry path: %s", name) + } + if isDir { + return os.MkdirAll(dst, 0755) + } + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + rc, err := opener() + if err != nil { + return err + } + defer rc.Close() + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + _, err = io.Copy(out, rc) + return err +} + +func copyDir(src, dst string) error { + entries, err := os.ReadDir(src) + if err != nil { + return err + } + if err := os.MkdirAll(dst, 0755); err != nil { + return err + } + for _, e := range entries { + srcPath := filepath.Join(src, e.Name()) + dstPath := filepath.Join(dst, e.Name()) + info, err := e.Info() + if err != nil { + return err + } + if info.IsDir() { + if err := copyDir(srcPath, dstPath); err != nil { + return err + } + continue + } + in, err := os.Open(srcPath) + if err != nil { + return err + } + out, err := os.Create(dstPath) + if err != nil { + _ = in.Close() + return err + } + if _, err := io.Copy(out, in); err != nil { + _ = out.Close() + _ = in.Close() + return err + } + _ = out.Close() + _ = in.Close() + } + return nil +} diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index e5313cc..9b6cb35 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -1106,6 +1106,196 @@ func TestHandleWebUINodeArtifactsListAndDelete(t *testing.T) { } } +func TestHandleSubagentRPCSpawn(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + srv.SetSubagentHandler(func(ctx context.Context, action string, args map[string]interface{}) (interface{}, error) { + if action != "spawn" { + t.Fatalf("unexpected action: %s", action) + } + if fmt.Sprint(args["agent_id"]) != "coder" || fmt.Sprint(args["task"]) != "ship it" { + t.Fatalf("unexpected args: %+v", args) + } + return map[string]interface{}{"message": "spawned"}, nil + }) + + body := `{"method":"subagent.spawn","request_id":"req-1","params":{"agent_id":"coder","task":"ship it","channel":"webui","chat_id":"group"}}` + req := httptest.NewRequest(http.MethodPost, "/api/rpc/subagent", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + srv.handleSubagentRPC(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"request_id":"req-1"`) || !strings.Contains(rec.Body.String(), `"message":"spawned"`) { + t.Fatalf("unexpected rpc body: %s", rec.Body.String()) + } +} + +func TestHandleNodeRPCDispatch(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + srv.SetNodeDispatchHandler(func(ctx context.Context, req nodes.Request, mode string) (nodes.Response, error) { + if req.Node != "edge-a" || req.Action != "screen_snapshot" || mode != "relay" { + t.Fatalf("unexpected request: %+v mode=%s", req, mode) + } + return nodes.Response{ + OK: true, + Node: req.Node, + Action: req.Action, + Payload: map[string]interface{}{ + "used_transport": "relay", + }, + }, nil + }) + + body := `{"method":"node.dispatch","request_id":"req-2","params":{"node":"edge-a","action":"screen_snapshot","mode":"relay","args":{"quality":"high"}}}` + req := httptest.NewRequest(http.MethodPost, "/api/rpc/node", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + srv.handleNodeRPC(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"request_id":"req-2"`) || !strings.Contains(rec.Body.String(), `"used_transport":"relay"`) { + t.Fatalf("unexpected rpc body: %s", rec.Body.String()) + } +} + +func TestHandleProviderRPCListModels(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + pc := cfg.Models.Providers["openai"] + pc.APIBase = "https://example.invalid/v1" + pc.APIKey = "test-key" + pc.Models = []string{"gpt-5.4", "gpt-5.4-mini"} + cfg.Models.Providers["openai"] = pc + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + srv.SetConfigPath(cfgPath) + + body := `{"method":"provider.list_models","request_id":"req-p1","params":{"provider":"openai"}}` + req := httptest.NewRequest(http.MethodPost, "/api/rpc/provider", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + srv.handleProviderRPC(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"gpt-5.4"`) || !strings.Contains(rec.Body.String(), `"request_id":"req-p1"`) { + t.Fatalf("unexpected provider rpc body: %s", rec.Body.String()) + } +} + +func TestHandleProviderRPCCountTokensUnavailable(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + pc := cfg.Models.Providers["openai"] + pc.APIBase = "https://example.invalid/v1" + pc.APIKey = "test-key" + pc.Models = []string{"gpt-5.4"} + cfg.Models.Providers["openai"] = pc + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + srv.SetConfigPath(cfgPath) + + body := `{"method":"provider.count_tokens","request_id":"req-p2","params":{"provider":"openai","messages":[{"role":"user","content":"hello"}]}}` + req := httptest.NewRequest(http.MethodPost, "/api/rpc/provider", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + srv.handleProviderRPC(rec, req) + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("expected 503, got %d: %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"code":"unavailable"`) { + t.Fatalf("expected unavailable rpc error, got: %s", rec.Body.String()) + } +} + +func TestHandleWebUIProviderModelsUsesRPCFacade(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + pc := cfg.Models.Providers["openai"] + pc.APIBase = "https://example.invalid/v1" + pc.APIKey = "test-key" + pc.Models = []string{"gpt-5.4-mini"} + cfg.Models.Providers["openai"] = pc + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + srv.SetConfigPath(cfgPath) + srv.SetConfigAfterHook(func() error { return nil }) + + req := httptest.NewRequest(http.MethodPost, "/api/provider/models", strings.NewReader(`{"provider":"openai","model":"gpt-5.4"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + srv.handleWebUIProviderModels(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"gpt-5.4"`) { + t.Fatalf("unexpected response: %s", rec.Body.String()) + } +} + +func TestHandleWebUIProviderRuntimeUsesRPCFacade(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + pc := cfg.Models.Providers["openai"] + pc.APIBase = "https://example.invalid/v1" + pc.APIKey = "test-key" + pc.Models = []string{"gpt-5.4-mini"} + cfg.Models.Providers["openai"] = pc + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + srv.SetConfigPath(cfgPath) + + req := httptest.NewRequest(http.MethodGet, "/api/provider/runtime?provider=openai&limit=5", nil) + rec := httptest.NewRecorder() + + srv.handleWebUIProviderRuntime(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"view"`) { + t.Fatalf("unexpected response: %s", rec.Body.String()) + } +} + func TestHandleWebUINodeArtifactsExport(t *testing.T) { t.Parallel() diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index e6729ba..d566b2b 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -1173,51 +1173,58 @@ func (c *TelegramChannel) handleAction(ctx context.Context, chatID int64, action if !ok && action != "send" && action != "stream" && action != "finalize" { return fmt.Errorf("message_id required for action=%s", action) } - switch action { - case "edit": - htmlContent := clampTelegramHTML(msg.Content, telegramSafeHTMLMaxRunes) - editCtx, cancel := withTelegramAPITimeout(ctx) - defer cancel() - _, err := c.bot.EditMessageText(editCtx, &telego.EditMessageTextParams{ChatID: telegoutil.ID(chatID), MessageID: messageID, Text: htmlContent, ParseMode: telego.ModeHTML}) - return err - case "stream": - return c.handleStreamAction(ctx, chatID, msg, false) - case "finalize": - if strings.TrimSpace(msg.Content) != "" { - // Final pass to recover rich formatting after conservative plain streaming. - if err := c.handleStreamAction(ctx, chatID, bus.OutboundMessage{ - ChatID: msg.ChatID, - ReplyToID: msg.ReplyToID, - Content: msg.Content, - Action: "stream", - }, true); err != nil { - return err + handlers := map[string]func() error{ + "edit": func() error { + htmlContent := clampTelegramHTML(msg.Content, telegramSafeHTMLMaxRunes) + editCtx, cancel := withTelegramAPITimeout(ctx) + defer cancel() + _, err := c.bot.EditMessageText(editCtx, &telego.EditMessageTextParams{ChatID: telegoutil.ID(chatID), MessageID: messageID, Text: htmlContent, ParseMode: telego.ModeHTML}) + return err + }, + "stream": func() error { + return c.handleStreamAction(ctx, chatID, msg, false) + }, + "finalize": func() error { + if strings.TrimSpace(msg.Content) != "" { + // Final pass to recover rich formatting after conservative plain streaming. + if err := c.handleStreamAction(ctx, chatID, bus.OutboundMessage{ + ChatID: msg.ChatID, + ReplyToID: msg.ReplyToID, + Content: msg.Content, + Action: "stream", + }, true); err != nil { + return err + } } - } - streamKey := telegramStreamKey(chatID, msg.ReplyToID) - c.streamMu.Lock() - delete(c.streamState, streamKey) - c.streamMu.Unlock() - return nil - case "delete": - delCtx, cancel := withTelegramAPITimeout(ctx) - defer cancel() - return c.bot.DeleteMessage(delCtx, &telego.DeleteMessageParams{ChatID: telegoutil.ID(chatID), MessageID: messageID}) - case "react": - reactCtx, cancel := withTelegramAPITimeout(ctx) - defer cancel() - emoji := strings.TrimSpace(msg.Emoji) - if emoji == "" { - return fmt.Errorf("emoji required for react action") - } - return c.bot.SetMessageReaction(reactCtx, &telego.SetMessageReactionParams{ - ChatID: telegoutil.ID(chatID), - MessageID: messageID, - Reaction: []telego.ReactionType{&telego.ReactionTypeEmoji{Emoji: emoji}}, - }) - default: - return fmt.Errorf("unsupported telegram action: %s", action) + streamKey := telegramStreamKey(chatID, msg.ReplyToID) + c.streamMu.Lock() + delete(c.streamState, streamKey) + c.streamMu.Unlock() + return nil + }, + "delete": func() error { + delCtx, cancel := withTelegramAPITimeout(ctx) + defer cancel() + return c.bot.DeleteMessage(delCtx, &telego.DeleteMessageParams{ChatID: telegoutil.ID(chatID), MessageID: messageID}) + }, + "react": func() error { + reactCtx, cancel := withTelegramAPITimeout(ctx) + defer cancel() + emoji := strings.TrimSpace(msg.Emoji) + if emoji == "" { + return fmt.Errorf("emoji required for react action") + } + return c.bot.SetMessageReaction(reactCtx, &telego.SetMessageReactionParams{ + ChatID: telegoutil.ID(chatID), + MessageID: messageID, + Reaction: []telego.ReactionType{&telego.ReactionTypeEmoji{Emoji: emoji}}, + }) + }, } + if handler := handlers[action]; handler != nil { + return handler() + } + return fmt.Errorf("unsupported telegram action: %s", action) } func parseTelegramMessageID(raw string) (int, bool) { diff --git a/pkg/channels/utils.go b/pkg/channels/utils.go index a1f38ea..b5e5d71 100644 --- a/pkg/channels/utils.go +++ b/pkg/channels/utils.go @@ -18,15 +18,6 @@ func truncateString(s string, maxLen int) string { return s[:maxLen] } -func safeCloseSignal(v interface{}) { - ch, ok := v.(chan struct{}) - if !ok || ch == nil { - return - } - defer func() { _ = recover() }() - close(ch) -} - type cancelGuard struct { mu sync.Mutex cancel context.CancelFunc diff --git a/pkg/nodes/manager.go b/pkg/nodes/manager.go index 11ff990..4bd3217 100644 --- a/pkg/nodes/manager.go +++ b/pkg/nodes/manager.go @@ -47,8 +47,53 @@ type DispatchPolicy struct { var defaultManager = NewManager() +var nodeActionCapabilityChecks = map[string]func(Capabilities) bool{ + "run": func(c Capabilities) bool { return c.Run }, + "agent_task": func(c Capabilities) bool { return c.Model }, + "camera_snap": func(c Capabilities) bool { return c.Camera }, + "camera_clip": func(c Capabilities) bool { return c.Camera }, + "screen_record": func(c Capabilities) bool { return c.Screen }, + "screen_snapshot": func(c Capabilities) bool { return c.Screen }, + "location_get": func(c Capabilities) bool { return c.Location }, + "canvas_snapshot": func(c Capabilities) bool { return c.Canvas }, + "canvas_action": func(c Capabilities) bool { return c.Canvas }, +} + +var realtimePreferredActions = map[string]struct{}{ + "camera_snap": {}, + "camera_clip": {}, + "screen_record": {}, + "screen_snapshot": {}, + "canvas_snapshot": {}, + "canvas_action": {}, +} + +var wireMessageHandlers = map[string]func(*Manager, WireMessage) bool{ + "node_response": handleWireNodeResponse, +} + func DefaultManager() *Manager { return defaultManager } +func handleWireNodeResponse(m *Manager, msg WireMessage) bool { + if strings.TrimSpace(msg.ID) == "" { + return false + } + m.mu.Lock() + ch := m.pending[msg.ID] + if ch != nil { + delete(m.pending, msg.ID) + } + m.mu.Unlock() + if ch == nil { + return false + } + select { + case ch <- msg: + default: + } + return true +} + func NewManager() *Manager { m := &Manager{ nodes: map[string]NodeInfo{}, @@ -199,28 +244,10 @@ func (m *Manager) RegisterWireSender(nodeID string, sender WireSender) { } func (m *Manager) HandleWireMessage(msg WireMessage) bool { - switch strings.ToLower(strings.TrimSpace(msg.Type)) { - case "node_response": - if strings.TrimSpace(msg.ID) == "" { - return false - } - m.mu.Lock() - ch := m.pending[msg.ID] - if ch != nil { - delete(m.pending, msg.ID) - } - m.mu.Unlock() - if ch == nil { - return false - } - select { - case ch <- msg: - default: - } - return true - default: - return false + if handler := wireMessageHandlers[strings.ToLower(strings.TrimSpace(msg.Type))]; handler != nil { + return handler(m, msg) } + return false } func (m *Manager) SendWireRequest(ctx context.Context, nodeID string, req Request) (Response, error) { @@ -314,22 +341,10 @@ func nodeSupportsRequest(n NodeInfo, req Request) bool { return false } } - switch action { - case "run": - return n.Capabilities.Run - case "agent_task": - return n.Capabilities.Model - case "camera_snap", "camera_clip": - return n.Capabilities.Camera - case "screen_record", "screen_snapshot": - return n.Capabilities.Screen - case "location_get": - return n.Capabilities.Location - case "canvas_snapshot", "canvas_action": - return n.Capabilities.Canvas - default: - return n.Capabilities.Invoke + if check := nodeActionCapabilityChecks[action]; check != nil { + return check(n.Capabilities) } + return n.Capabilities.Invoke } func (m *Manager) PickFor(action string) (NodeInfo, bool) { @@ -595,12 +610,8 @@ func nodeHasAgent(n NodeInfo, agentID string) bool { } func prefersRealtimeTransport(action string) bool { - switch strings.ToLower(strings.TrimSpace(action)) { - case "camera_snap", "camera_clip", "screen_record", "screen_snapshot", "canvas_snapshot", "canvas_action": - return true - default: - return false - } + _, ok := realtimePreferredActions[strings.ToLower(strings.TrimSpace(action))] + return ok } func (m *Manager) reaperLoop() { diff --git a/pkg/nodes/transport.go b/pkg/nodes/transport.go index e56ad80..58a580e 100644 --- a/pkg/nodes/transport.go +++ b/pkg/nodes/transport.go @@ -111,31 +111,24 @@ type HTTPRelayTransport struct { func (s *HTTPRelayTransport) Name() string { return "relay" } +var actionHTTPPaths = map[string]string{ + "run": "/run", + "invoke": "/invoke", + "agent_task": "/agent/task", + "camera_snap": "/camera/snap", + "camera_clip": "/camera/clip", + "screen_record": "/screen/record", + "screen_snapshot": "/screen/snapshot", + "location_get": "/location/get", + "canvas_snapshot": "/canvas/snapshot", + "canvas_action": "/canvas/action", +} + func actionHTTPPath(action string) string { - switch strings.ToLower(strings.TrimSpace(action)) { - case "run": - return "/run" - case "invoke": - return "/invoke" - case "agent_task": - return "/agent/task" - case "camera_snap": - return "/camera/snap" - case "camera_clip": - return "/camera/clip" - case "screen_record": - return "/screen/record" - case "screen_snapshot": - return "/screen/snapshot" - case "location_get": - return "/location/get" - case "canvas_snapshot": - return "/canvas/snapshot" - case "canvas_action": - return "/canvas/action" - default: - return "/invoke" + if path := actionHTTPPaths[strings.ToLower(strings.TrimSpace(action))]; path != "" { + return path } + return "/invoke" } func DoEndpointRequest(ctx context.Context, client *http.Client, endpoint, token string, req Request) (Response, error) { diff --git a/pkg/nodes/webrtc.go b/pkg/nodes/webrtc.go index 3222276..d17b480 100644 --- a/pkg/nodes/webrtc.go +++ b/pkg/nodes/webrtc.go @@ -186,6 +186,23 @@ func (t *WebRTCTransport) currentSignaler(nodeID string) WireSender { return t.signal[strings.TrimSpace(nodeID)] } +var webRTCSignalHandlers = map[string]func(*gatewayRTCSession, WireMessage) error{ + "signal_answer": func(session *gatewayRTCSession, msg WireMessage) error { + var desc webrtc.SessionDescription + if err := mapInto(msg.Payload, &desc); err != nil { + return err + } + return session.pc.SetRemoteDescription(desc) + }, + "signal_candidate": func(session *gatewayRTCSession, msg WireMessage) error { + var candidate webrtc.ICECandidateInit + if err := mapInto(msg.Payload, &candidate); err != nil { + return err + } + return session.pc.AddICECandidate(candidate) + }, +} + func (t *WebRTCTransport) HandleSignal(msg WireMessage) error { nodeID := strings.TrimSpace(msg.From) if nodeID == "" { @@ -195,22 +212,10 @@ func (t *WebRTCTransport) HandleSignal(msg WireMessage) error { if err != nil { return err } - switch strings.ToLower(strings.TrimSpace(msg.Type)) { - case "signal_answer": - var desc webrtc.SessionDescription - if err := mapInto(msg.Payload, &desc); err != nil { - return err - } - return session.pc.SetRemoteDescription(desc) - case "signal_candidate": - var candidate webrtc.ICECandidateInit - if err := mapInto(msg.Payload, &candidate); err != nil { - return err - } - return session.pc.AddICECandidate(candidate) - default: - return fmt.Errorf("unsupported signal type: %s", msg.Type) + if handler := webRTCSignalHandlers[strings.ToLower(strings.TrimSpace(msg.Type))]; handler != nil { + return handler(session, msg) } + return fmt.Errorf("unsupported signal type: %s", msg.Type) } func (t *WebRTCTransport) Send(ctx context.Context, req Request) (Response, error) { diff --git a/pkg/rpc/admin.go b/pkg/rpc/admin.go new file mode 100644 index 0000000..7d771a2 --- /dev/null +++ b/pkg/rpc/admin.go @@ -0,0 +1,63 @@ +package rpc + +import "context" + +type ConfigService interface { + View(context.Context, ConfigViewRequest) (*ConfigViewResponse, *Error) + Save(context.Context, ConfigSaveRequest) (*ConfigSaveResponse, *Error) +} + +type CronService interface { + List(context.Context, ListCronJobsRequest) (*ListCronJobsResponse, *Error) + Get(context.Context, GetCronJobRequest) (*GetCronJobResponse, *Error) + Mutate(context.Context, MutateCronJobRequest) (*MutateCronJobResponse, *Error) +} + +type ConfigViewRequest struct { + Mode string `json:"mode,omitempty"` + IncludeHotReloadInfo bool `json:"include_hot_reload_info,omitempty"` +} + +type ConfigViewResponse struct { + Config interface{} `json:"config,omitempty"` + RawConfig interface{} `json:"raw_config,omitempty"` + PrettyText string `json:"pretty_text,omitempty"` + HotReloadFields []string `json:"hot_reload_fields,omitempty"` + HotReloadFieldDetails []map[string]interface{} `json:"hot_reload_field_details,omitempty"` +} + +type ConfigSaveRequest struct { + Mode string `json:"mode,omitempty"` + ConfirmRisky bool `json:"confirm_risky,omitempty"` + Config map[string]interface{} `json:"config"` +} + +type ConfigSaveResponse struct { + Saved bool `json:"saved"` + RequiresConfirm bool `json:"requires_confirm,omitempty"` + ChangedFields []string `json:"changed_fields,omitempty"` + Details interface{} `json:"details,omitempty"` +} + +type ListCronJobsRequest struct{} + +type ListCronJobsResponse struct { + Jobs []interface{} `json:"jobs"` +} + +type GetCronJobRequest struct { + ID string `json:"id"` +} + +type GetCronJobResponse struct { + Job interface{} `json:"job,omitempty"` +} + +type MutateCronJobRequest struct { + Action string `json:"action"` + Args map[string]interface{} `json:"args,omitempty"` +} + +type MutateCronJobResponse struct { + Result interface{} `json:"result,omitempty"` +} diff --git a/pkg/rpc/envelope.go b/pkg/rpc/envelope.go new file mode 100644 index 0000000..070486d --- /dev/null +++ b/pkg/rpc/envelope.go @@ -0,0 +1,37 @@ +package rpc + +import ( + "encoding/json" + "strings" +) + +type Request struct { + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` + RequestID string `json:"request_id,omitempty"` +} + +type Response struct { + OK bool `json:"ok"` + Result interface{} `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` + RequestID string `json:"request_id,omitempty"` +} + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + Details interface{} `json:"details,omitempty"` + Retryable bool `json:"retryable,omitempty"` +} + +func (r Request) NormalizedMethod() string { + return strings.ToLower(strings.TrimSpace(r.Method)) +} + +func DecodeParams(raw json.RawMessage, target interface{}) error { + if len(raw) == 0 || string(raw) == "null" { + return nil + } + return json.Unmarshal(raw, target) +} diff --git a/pkg/rpc/node.go b/pkg/rpc/node.go new file mode 100644 index 0000000..1972382 --- /dev/null +++ b/pkg/rpc/node.go @@ -0,0 +1,192 @@ +package rpc + +import ( + "context" + + "github.com/YspCoder/clawgo/pkg/nodes" +) + +type NodeService interface { + Register(context.Context, RegisterNodeRequest) (*RegisterNodeResponse, *Error) + Heartbeat(context.Context, HeartbeatNodeRequest) (*HeartbeatNodeResponse, *Error) + Dispatch(context.Context, DispatchNodeRequest) (*DispatchNodeResponse, *Error) + ListArtifacts(context.Context, ListNodeArtifactsRequest) (*ListNodeArtifactsResponse, *Error) + GetArtifact(context.Context, GetNodeArtifactRequest) (*GetNodeArtifactResponse, *Error) + DeleteArtifact(context.Context, DeleteNodeArtifactRequest) (*DeleteNodeArtifactResponse, *Error) + PruneArtifacts(context.Context, PruneNodeArtifactsRequest) (*PruneNodeArtifactsResponse, *Error) +} + +type RegisterNodeRequest struct { + Node nodes.NodeInfo `json:"node"` +} + +type RegisterNodeResponse struct { + ID string `json:"id"` +} + +type HeartbeatNodeRequest struct { + ID string `json:"id"` +} + +type HeartbeatNodeResponse struct { + ID string `json:"id"` +} + +type DispatchNodeRequest struct { + Node string `json:"node"` + Action string `json:"action"` + Mode string `json:"mode,omitempty"` + Task string `json:"task,omitempty"` + Model string `json:"model,omitempty"` + Args map[string]interface{} `json:"args,omitempty"` +} + +type DispatchNodeResponse struct { + Result nodes.Response `json:"result"` +} + +type ArtifactSummary struct { + Data map[string]interface{} `json:"data"` +} + +type ArtifactContentRef struct { + Data map[string]interface{} `json:"data"` +} + +type ArtifactDeleteResult struct { + ID string `json:"id"` + DeletedFile bool `json:"deleted_file"` + DeletedAudit bool `json:"deleted_audit"` +} + +type ArtifactPruneResult struct { + Pruned int `json:"pruned"` + DeletedFiles int `json:"deleted_files"` + Kept int `json:"kept"` +} + +type ListNodeArtifactsRequest struct { + Node string `json:"node,omitempty"` + Action string `json:"action,omitempty"` + Kind string `json:"kind,omitempty"` + Limit int `json:"limit,omitempty"` +} + +type ListNodeArtifactsResponse struct { + Items []map[string]interface{} `json:"items"` + ArtifactRetention map[string]interface{} `json:"artifact_retention,omitempty"` +} + +type GetNodeArtifactRequest struct { + ID string `json:"id"` +} + +type GetNodeArtifactResponse struct { + Found bool `json:"found"` + Artifact map[string]interface{} `json:"artifact,omitempty"` +} + +type DeleteNodeArtifactRequest struct { + ID string `json:"id"` +} + +type DeleteNodeArtifactResponse struct { + ArtifactDeleteResult +} + +type PruneNodeArtifactsRequest struct { + Node string `json:"node,omitempty"` + Action string `json:"action,omitempty"` + Kind string `json:"kind,omitempty"` + KeepLatest int `json:"keep_latest,omitempty"` + Limit int `json:"limit,omitempty"` +} + +type PruneNodeArtifactsResponse struct { + ArtifactPruneResult +} + +type ProviderService interface { + ListModels(context.Context, ListProviderModelsRequest) (*ListProviderModelsResponse, *Error) + UpdateModels(context.Context, UpdateProviderModelsRequest) (*UpdateProviderModelsResponse, *Error) + Chat(context.Context, ProviderChatRequest) (*ProviderChatResponse, *Error) + CountTokens(context.Context, ProviderCountTokensRequest) (*ProviderCountTokensResponse, *Error) + RuntimeView(context.Context, ProviderRuntimeViewRequest) (*ProviderRuntimeViewResponse, *Error) + RuntimeAction(context.Context, ProviderRuntimeActionRequest) (*ProviderRuntimeActionResponse, *Error) +} + +type ListProviderModelsRequest struct { + Provider string `json:"provider"` +} + +type ListProviderModelsResponse struct { + Provider string `json:"provider"` + Models []string `json:"models,omitempty"` + Default string `json:"default_model,omitempty"` +} + +type UpdateProviderModelsRequest struct { + Provider string `json:"provider"` + Model string `json:"model,omitempty"` + Models []string `json:"models,omitempty"` +} + +type UpdateProviderModelsResponse struct { + Provider string `json:"provider"` + Models []string `json:"models,omitempty"` +} + +type ProviderChatRequest struct { + Provider string `json:"provider"` + Model string `json:"model,omitempty"` + Messages []map[string]interface{} `json:"messages"` + Tools []map[string]interface{} `json:"tools,omitempty"` + Options map[string]interface{} `json:"options,omitempty"` +} + +type ProviderChatResponse struct { + Content string `json:"content,omitempty"` + ToolCalls []map[string]interface{} `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Usage map[string]interface{} `json:"usage,omitempty"` +} + +type ProviderCountTokensRequest struct { + Provider string `json:"provider"` + Model string `json:"model,omitempty"` + Messages []map[string]interface{} `json:"messages"` + Tools []map[string]interface{} `json:"tools,omitempty"` + Options map[string]interface{} `json:"options,omitempty"` +} + +type ProviderCountTokensResponse struct { + Usage map[string]interface{} `json:"usage,omitempty"` +} + +type ProviderRuntimeViewRequest struct { + Provider string `json:"provider,omitempty"` + Kind string `json:"kind,omitempty"` + Reason string `json:"reason,omitempty"` + Target string `json:"target,omitempty"` + Sort string `json:"sort,omitempty"` + ChangesOnly bool `json:"changes_only,omitempty"` + WindowSec int `json:"window_sec,omitempty"` + Limit int `json:"limit,omitempty"` + Cursor int `json:"cursor,omitempty"` + HealthBelow int `json:"health_below,omitempty"` + CooldownUntilBeforeSec int `json:"cooldown_until_before_sec,omitempty"` +} + +type ProviderRuntimeViewResponse struct { + View map[string]interface{} `json:"view"` +} + +type ProviderRuntimeActionRequest struct { + Provider string `json:"provider,omitempty"` + Action string `json:"action"` + OnlyExpiring bool `json:"only_expiring,omitempty"` +} + +type ProviderRuntimeActionResponse struct { + Result map[string]interface{} `json:"result"` +} diff --git a/pkg/rpc/registry.go b/pkg/rpc/registry.go new file mode 100644 index 0000000..93f2d82 --- /dev/null +++ b/pkg/rpc/registry.go @@ -0,0 +1,56 @@ +package rpc + +import ( + "context" + "encoding/json" + "strings" +) + +type MethodHandler func(context.Context, json.RawMessage) (interface{}, *Error) + +type Registry struct { + methods map[string]MethodHandler +} + +func NewRegistry() *Registry { + return &Registry{methods: map[string]MethodHandler{}} +} + +func (r *Registry) Register(method string, handler MethodHandler) { + if r == nil || handler == nil { + return + } + method = strings.ToLower(strings.TrimSpace(method)) + if method == "" { + return + } + r.methods[method] = handler +} + +func (r *Registry) Handle(ctx context.Context, req Request) (interface{}, *Error) { + if r == nil { + return nil, &Error{Code: "internal", Message: "rpc registry unavailable"} + } + handler := r.methods[req.NormalizedMethod()] + if handler == nil { + return nil, &Error{ + Code: "invalid_argument", + Message: "unknown method", + Details: map[string]interface{}{"method": strings.TrimSpace(req.Method)}, + } + } + return handler(ctx, req.Params) +} + +func RegisterJSON[T any](r *Registry, method string, handler func(context.Context, T) (interface{}, *Error)) { + if r == nil || handler == nil { + return + } + r.Register(method, func(ctx context.Context, raw json.RawMessage) (interface{}, *Error) { + var params T + if err := DecodeParams(raw, ¶ms); err != nil { + return nil, &Error{Code: "invalid_argument", Message: err.Error()} + } + return handler(ctx, params) + }) +} diff --git a/pkg/rpc/subagent.go b/pkg/rpc/subagent.go new file mode 100644 index 0000000..d6234f8 --- /dev/null +++ b/pkg/rpc/subagent.go @@ -0,0 +1,87 @@ +package rpc + +import ( + "context" + + "github.com/YspCoder/clawgo/pkg/tools" +) + +type SubagentService interface { + List(context.Context, ListSubagentsRequest) (*ListSubagentsResponse, *Error) + Snapshot(context.Context, SnapshotRequest) (*SnapshotResponse, *Error) + Get(context.Context, GetSubagentRequest) (*GetSubagentResponse, *Error) + Spawn(context.Context, SpawnSubagentRequest) (*SpawnSubagentResponse, *Error) + DispatchAndWait(context.Context, DispatchAndWaitRequest) (*DispatchAndWaitResponse, *Error) + Registry(context.Context, RegistryRequest) (*RegistryResponse, *Error) +} + +type ListSubagentsRequest struct{} + +type ListSubagentsResponse struct { + Items []*tools.SubagentTask `json:"items"` +} + +type SnapshotRequest struct { + Limit int `json:"limit,omitempty"` +} + +type SnapshotResponse struct { + Snapshot tools.RuntimeSnapshot `json:"snapshot"` +} + +type GetSubagentRequest struct { + ID string `json:"id"` +} + +type GetSubagentResponse struct { + Found bool `json:"found"` + Task *tools.SubagentTask `json:"task,omitempty"` +} + +type SpawnSubagentRequest struct { + Task string `json:"task"` + Label string `json:"label,omitempty"` + Role string `json:"role,omitempty"` + AgentID string `json:"agent_id,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoffMS int `json:"retry_backoff_ms,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` +} + +type SpawnSubagentResponse struct { + Message string `json:"message"` +} + +type DispatchAndWaitRequest struct { + Task string `json:"task"` + Label string `json:"label,omitempty"` + Role string `json:"role,omitempty"` + AgentID string `json:"agent_id,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + ParentRunID string `json:"parent_run_id,omitempty"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoffMS int `json:"retry_backoff_ms,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + WaitTimeoutSec int `json:"wait_timeout_sec,omitempty"` +} + +type DispatchAndWaitResponse struct { + Task *tools.SubagentTask `json:"task,omitempty"` + Reply *tools.RouterReply `json:"reply,omitempty"` + Merged string `json:"merged,omitempty"` +} + +type RegistryRequest struct{} + +type RegistryResponse struct { + Items []map[string]interface{} `json:"items"` +} diff --git a/pkg/rpc/workspace.go b/pkg/rpc/workspace.go new file mode 100644 index 0000000..866540e --- /dev/null +++ b/pkg/rpc/workspace.go @@ -0,0 +1,50 @@ +package rpc + +import "context" + +type WorkspaceService interface { + ListFiles(context.Context, ListWorkspaceFilesRequest) (*ListWorkspaceFilesResponse, *Error) + ReadFile(context.Context, ReadWorkspaceFileRequest) (*ReadWorkspaceFileResponse, *Error) + WriteFile(context.Context, WriteWorkspaceFileRequest) (*WriteWorkspaceFileResponse, *Error) + DeleteFile(context.Context, DeleteWorkspaceFileRequest) (*DeleteWorkspaceFileResponse, *Error) +} + +type ListWorkspaceFilesRequest struct { + Scope string `json:"scope,omitempty"` +} + +type ListWorkspaceFilesResponse struct { + Files []string `json:"files"` +} + +type ReadWorkspaceFileRequest struct { + Scope string `json:"scope,omitempty"` + Path string `json:"path"` +} + +type ReadWorkspaceFileResponse struct { + Path string `json:"path,omitempty"` + Found bool `json:"found,omitempty"` + Content string `json:"content,omitempty"` +} + +type WriteWorkspaceFileRequest struct { + Scope string `json:"scope,omitempty"` + Path string `json:"path"` + Content string `json:"content"` +} + +type WriteWorkspaceFileResponse struct { + Path string `json:"path,omitempty"` + Saved bool `json:"saved,omitempty"` +} + +type DeleteWorkspaceFileRequest struct { + Scope string `json:"scope,omitempty"` + Path string `json:"path"` +} + +type DeleteWorkspaceFileResponse struct { + Path string `json:"path,omitempty"` + Deleted bool `json:"deleted"` +} diff --git a/pkg/tools/browser.go b/pkg/tools/browser.go index d3fc0bb..eed8547 100644 --- a/pkg/tools/browser.go +++ b/pkg/tools/browser.go @@ -50,15 +50,14 @@ func (t *BrowserTool) Parameters() map[string]interface{} { func (t *BrowserTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { action := MapStringArg(args, "action") url := MapStringArg(args, "url") - - switch action { - case "screenshot": - return t.takeScreenshot(ctx, url) - case "content": - return t.fetchDynamicContent(ctx, url) - default: - return "", fmt.Errorf("unknown browser action: %s", action) + handlers := map[string]func(context.Context, string) (string, error){ + "screenshot": t.takeScreenshot, + "content": t.fetchDynamicContent, } + if handler := handlers[action]; handler != nil { + return handler(ctx, url) + } + return "", fmt.Errorf("unknown browser action: %s", action) } func (t *BrowserTool) takeScreenshot(ctx context.Context, url string) (string, error) { diff --git a/pkg/tools/cron_tool.go b/pkg/tools/cron_tool.go index ef68b4d..49ba9ec 100644 --- a/pkg/tools/cron_tool.go +++ b/pkg/tools/cron_tool.go @@ -34,6 +34,7 @@ func (t *CronTool) Parameters() map[string]interface{} { } func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + _ = ctx if t.cs == nil { return "Error: cron service not available", nil } @@ -42,40 +43,41 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st action = "list" } id := MapStringArg(args, "id") - - switch action { - case "list": - jobs := t.cs.ListJobs(true) - b, _ := json.Marshal(jobs) - return string(b), nil - case "delete": - if id == "" { - return "", fmt.Errorf("%w: id for action=delete", ErrMissingField) - } - ok := t.cs.RemoveJob(id) - if !ok { - return fmt.Sprintf("job not found: %s", id), nil - } - return fmt.Sprintf("deleted job: %s", id), nil - case "enable": - if id == "" { - return "", fmt.Errorf("%w: id for action=enable", ErrMissingField) - } - job := t.cs.EnableJob(id, true) - if job == nil { - return fmt.Sprintf("job not found: %s", id), nil - } - return fmt.Sprintf("enabled job: %s", id), nil - case "disable": - if id == "" { - return "", fmt.Errorf("%w: id for action=disable", ErrMissingField) - } - job := t.cs.EnableJob(id, false) - if job == nil { - return fmt.Sprintf("job not found: %s", id), nil - } - return fmt.Sprintf("disabled job: %s", id), nil - default: - return "", fmt.Errorf("%w: %s", ErrUnsupportedAction, action) + handlers := map[string]func() (string, error){ + "list": func() (string, error) { + b, _ := json.Marshal(t.cs.ListJobs(true)) + return string(b), nil + }, + "delete": func() (string, error) { + if id == "" { + return "", fmt.Errorf("%w: id for action=delete", ErrMissingField) + } + if !t.cs.RemoveJob(id) { + return fmt.Sprintf("job not found: %s", id), nil + } + return fmt.Sprintf("deleted job: %s", id), nil + }, + "enable": func() (string, error) { + if id == "" { + return "", fmt.Errorf("%w: id for action=enable", ErrMissingField) + } + if t.cs.EnableJob(id, true) == nil { + return fmt.Sprintf("job not found: %s", id), nil + } + return fmt.Sprintf("enabled job: %s", id), nil + }, + "disable": func() (string, error) { + if id == "" { + return "", fmt.Errorf("%w: id for action=disable", ErrMissingField) + } + if t.cs.EnableJob(id, false) == nil { + return fmt.Sprintf("job not found: %s", id), nil + } + return fmt.Sprintf("disabled job: %s", id), nil + }, } + if handler := handlers[action]; handler != nil { + return handler() + } + return "", fmt.Errorf("%w: %s", ErrUnsupportedAction, action) } diff --git a/pkg/tools/mcp.go b/pkg/tools/mcp.go index c0649e8..4532080 100644 --- a/pkg/tools/mcp.go +++ b/pkg/tools/mcp.go @@ -131,66 +131,72 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) (str return "", err } defer client.Close() - - switch action { - case "list_tools": - out, err := client.listAll(callCtx, "tools/list", "tools") - if err != nil { - return "", err - } - return prettyJSON(out) - case "call_tool": - toolName := strings.TrimSpace(mcpStringArg(args, "tool")) - if toolName == "" { - return "", fmt.Errorf("tool is required for action=call_tool") - } - params := map[string]interface{}{ - "name": toolName, - "arguments": mcpObjectArg(args, "arguments"), - } - out, err := client.request(callCtx, "tools/call", params) - if err != nil { - return "", err - } - return prettyJSON(out) - case "list_resources": - out, err := client.listAll(callCtx, "resources/list", "resources") - if err != nil { - return "", err - } - return prettyJSON(out) - case "read_resource": - resourceURI := strings.TrimSpace(mcpStringArg(args, "uri")) - if resourceURI == "" { - return "", fmt.Errorf("uri is required for action=read_resource") - } - out, err := client.request(callCtx, "resources/read", map[string]interface{}{"uri": resourceURI}) - if err != nil { - return "", err - } - return prettyJSON(out) - case "list_prompts": - out, err := client.listAll(callCtx, "prompts/list", "prompts") - if err != nil { - return "", err - } - return prettyJSON(out) - case "get_prompt": - promptName := strings.TrimSpace(mcpStringArg(args, "prompt")) - if promptName == "" { - return "", fmt.Errorf("prompt is required for action=get_prompt") - } - out, err := client.request(callCtx, "prompts/get", map[string]interface{}{ - "name": promptName, - "arguments": mcpObjectArg(args, "arguments"), - }) - if err != nil { - return "", err - } - return prettyJSON(out) - default: - return "", fmt.Errorf("unsupported action %q", action) + handlers := map[string]func() (string, error){ + "list_tools": func() (string, error) { + out, err := client.listAll(callCtx, "tools/list", "tools") + if err != nil { + return "", err + } + return prettyJSON(out) + }, + "call_tool": func() (string, error) { + toolName := strings.TrimSpace(mcpStringArg(args, "tool")) + if toolName == "" { + return "", fmt.Errorf("tool is required for action=call_tool") + } + out, err := client.request(callCtx, "tools/call", map[string]interface{}{ + "name": toolName, + "arguments": mcpObjectArg(args, "arguments"), + }) + if err != nil { + return "", err + } + return prettyJSON(out) + }, + "list_resources": func() (string, error) { + out, err := client.listAll(callCtx, "resources/list", "resources") + if err != nil { + return "", err + } + return prettyJSON(out) + }, + "read_resource": func() (string, error) { + resourceURI := strings.TrimSpace(mcpStringArg(args, "uri")) + if resourceURI == "" { + return "", fmt.Errorf("uri is required for action=read_resource") + } + out, err := client.request(callCtx, "resources/read", map[string]interface{}{"uri": resourceURI}) + if err != nil { + return "", err + } + return prettyJSON(out) + }, + "list_prompts": func() (string, error) { + out, err := client.listAll(callCtx, "prompts/list", "prompts") + if err != nil { + return "", err + } + return prettyJSON(out) + }, + "get_prompt": func() (string, error) { + promptName := strings.TrimSpace(mcpStringArg(args, "prompt")) + if promptName == "" { + return "", fmt.Errorf("prompt is required for action=get_prompt") + } + out, err := client.request(callCtx, "prompts/get", map[string]interface{}{ + "name": promptName, + "arguments": mcpObjectArg(args, "arguments"), + }) + if err != nil { + return "", err + } + return prettyJSON(out) + }, } + if handler := handlers[action]; handler != nil { + return handler() + } + return "", fmt.Errorf("unsupported action %q", action) } func (t *MCPTool) DiscoverTools(ctx context.Context) []Tool { diff --git a/pkg/tools/message.go b/pkg/tools/message.go index f293460..f8615e0 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -144,25 +144,37 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) } messageID := MapStringArg(args, "message_id") emoji := MapStringArg(args, "emoji") - - switch action { - case "send": - if content == "" && media == "" { - return "", fmt.Errorf("%w: message/content or media for action=send", ErrMissingField) + validators := map[string]func() error{ + "send": func() error { + if content == "" && media == "" { + return fmt.Errorf("%w: message/content or media for action=send", ErrMissingField) + } + return nil + }, + "edit": func() error { + if messageID == "" || content == "" { + return fmt.Errorf("%w: message_id and message/content for action=edit", ErrMissingField) + } + return nil + }, + "delete": func() error { + if messageID == "" { + return fmt.Errorf("%w: message_id for action=delete", ErrMissingField) + } + return nil + }, + "react": func() error { + if messageID == "" || emoji == "" { + return fmt.Errorf("%w: message_id and emoji for action=react", ErrMissingField) + } + return nil + }, + } + if validate := validators[action]; validate != nil { + if err := validate(); err != nil { + return "", err } - case "edit": - if messageID == "" || content == "" { - return "", fmt.Errorf("%w: message_id and message/content for action=edit", ErrMissingField) - } - case "delete": - if messageID == "" { - return "", fmt.Errorf("%w: message_id for action=delete", ErrMissingField) - } - case "react": - if messageID == "" || emoji == "" { - return "", fmt.Errorf("%w: message_id and emoji for action=react", ErrMissingField) - } - default: + } else { return "", fmt.Errorf("%w: %s", ErrUnsupportedAction, action) } diff --git a/pkg/tools/nodes_tool.go b/pkg/tools/nodes_tool.go index fc86433..4f522d0 100644 --- a/pkg/tools/nodes_tool.go +++ b/pkg/tools/nodes_tool.go @@ -54,81 +54,86 @@ func (t *NodesTool) Execute(ctx context.Context, args map[string]interface{}) (s if t.manager == nil { return "", fmt.Errorf("nodes manager not configured") } - - switch action { - case "status", "describe": - if nodeID != "" { - n, ok := t.manager.Get(nodeID) - if !ok { - return "", fmt.Errorf("node not found: %s", nodeID) + var statusHandler func() (string, error) + handlers := map[string]func() (string, error){ + "status": func() (string, error) { + if nodeID != "" { + n, ok := t.manager.Get(nodeID) + if !ok { + return "", fmt.Errorf("node not found: %s", nodeID) + } + b, _ := json.Marshal(n) + return string(b), nil } - b, _ := json.Marshal(n) + b, _ := json.Marshal(t.manager.List()) return string(b), nil - } - b, _ := json.Marshal(t.manager.List()) - return string(b), nil - default: - reqArgs := map[string]interface{}{} - if raw, ok := args["args"].(map[string]interface{}); ok { - for k, v := range raw { - reqArgs[k] = v - } - } - if rawPaths := MapStringListArg(args, "artifact_paths"); len(rawPaths) > 0 { - reqArgs["artifact_paths"] = rawPaths - } - if cmd, ok := args["command"].([]interface{}); ok && len(cmd) > 0 { - reqArgs["command"] = cmd - } - if facing := MapStringArg(args, "facing"); facing != "" { - f := strings.ToLower(strings.TrimSpace(facing)) - if f != "front" && f != "back" && f != "both" { - return "", fmt.Errorf("invalid_args: facing must be front|back|both") - } - reqArgs["facing"] = f - } - if di := MapIntArg(args, "duration_ms", 0); di > 0 { - if di <= 0 || di > 300000 { - return "", fmt.Errorf("invalid_args: duration_ms must be in 1..300000") - } - reqArgs["duration_ms"] = di - } - task := MapStringArg(args, "task") - model := MapStringArg(args, "model") - if action == "agent_task" && strings.TrimSpace(task) == "" { - return "", fmt.Errorf("invalid_args: agent_task requires task") - } - if action == "canvas_action" { - if act := MapStringArg(reqArgs, "action"); act == "" { - return "", fmt.Errorf("invalid_args: canvas_action requires args.action") - } - } - if nodeID == "" { - if picked, ok := t.manager.PickRequest(nodes.Request{Action: action, Task: task, Model: model, Args: reqArgs}, mode); ok { - nodeID = picked.ID - } - } - if nodeID == "" { - return "", fmt.Errorf("no eligible node found for action=%s", action) - } - req := nodes.Request{Action: action, Node: nodeID, Task: task, Model: model, Args: reqArgs} - if !t.manager.SupportsRequest(nodeID, req) { - return "", fmt.Errorf("node %s does not support action=%s", nodeID, action) - } - if t.router == nil { - return "", fmt.Errorf("nodes transport router not configured") - } - started := time.Now() - resp, err := t.router.Dispatch(ctx, req, mode) - durationMs := int(time.Since(started).Milliseconds()) - if err != nil { - t.writeAudit(req, nodes.Response{OK: false, Code: "transport_error", Error: err.Error(), Node: nodeID, Action: action}, mode, durationMs) - return "", err - } - t.writeAudit(req, resp, mode, durationMs) - b, _ := json.Marshal(resp) - return string(b), nil + }, } + statusHandler = handlers["status"] + handlers["describe"] = func() (string, error) { return statusHandler() } + if handler := handlers[action]; handler != nil { + return handler() + } + reqArgs := map[string]interface{}{} + if raw, ok := args["args"].(map[string]interface{}); ok { + for k, v := range raw { + reqArgs[k] = v + } + } + if rawPaths := MapStringListArg(args, "artifact_paths"); len(rawPaths) > 0 { + reqArgs["artifact_paths"] = rawPaths + } + if cmd, ok := args["command"].([]interface{}); ok && len(cmd) > 0 { + reqArgs["command"] = cmd + } + if facing := MapStringArg(args, "facing"); facing != "" { + f := strings.ToLower(strings.TrimSpace(facing)) + if f != "front" && f != "back" && f != "both" { + return "", fmt.Errorf("invalid_args: facing must be front|back|both") + } + reqArgs["facing"] = f + } + if di := MapIntArg(args, "duration_ms", 0); di > 0 { + if di <= 0 || di > 300000 { + return "", fmt.Errorf("invalid_args: duration_ms must be in 1..300000") + } + reqArgs["duration_ms"] = di + } + task := MapStringArg(args, "task") + model := MapStringArg(args, "model") + if action == "agent_task" && strings.TrimSpace(task) == "" { + return "", fmt.Errorf("invalid_args: agent_task requires task") + } + if action == "canvas_action" { + if act := MapStringArg(reqArgs, "action"); act == "" { + return "", fmt.Errorf("invalid_args: canvas_action requires args.action") + } + } + if nodeID == "" { + if picked, ok := t.manager.PickRequest(nodes.Request{Action: action, Task: task, Model: model, Args: reqArgs}, mode); ok { + nodeID = picked.ID + } + } + if nodeID == "" { + return "", fmt.Errorf("no eligible node found for action=%s", action) + } + req := nodes.Request{Action: action, Node: nodeID, Task: task, Model: model, Args: reqArgs} + if !t.manager.SupportsRequest(nodeID, req) { + return "", fmt.Errorf("node %s does not support action=%s", nodeID, action) + } + if t.router == nil { + return "", fmt.Errorf("nodes transport router not configured") + } + started := time.Now() + resp, err := t.router.Dispatch(ctx, req, mode) + durationMs := int(time.Since(started).Milliseconds()) + if err != nil { + t.writeAudit(req, nodes.Response{OK: false, Code: "transport_error", Error: err.Error(), Node: nodeID, Action: action}, mode, durationMs) + return "", err + } + t.writeAudit(req, resp, mode, durationMs) + b, _ := json.Marshal(resp) + return string(b), nil } func (t *NodesTool) writeAudit(req nodes.Request, resp nodes.Response, mode string, durationMs int) { diff --git a/pkg/tools/process_tool.go b/pkg/tools/process_tool.go index 59b62b8..4518811 100644 --- a/pkg/tools/process_tool.go +++ b/pkg/tools/process_tool.go @@ -29,54 +29,58 @@ func (t *ProcessTool) Execute(ctx context.Context, args map[string]interface{}) if sid == "" { sid = MapStringArg(args, "sessionId") } - switch action { - case "list": - b, _ := json.Marshal(t.m.List()) - return string(b), nil - case "log": - off := MapIntArg(args, "offset", 0) - lim := MapIntArg(args, "limit", 0) - return t.m.Log(sid, off, lim) - case "kill": - if err := t.m.Kill(sid); err != nil { - return "", err - } - return "killed", nil - case "poll": - timeout := MapIntArg(args, "timeout_ms", 0) - if timeout < 0 { - timeout = 0 - } - s, ok := t.m.Get(sid) - if !ok { - return "", nil - } - if timeout > 0 { - select { - case <-s.done: - case <-time.After(time.Duration(timeout) * time.Millisecond): - case <-ctx.Done(): + handlers := map[string]func() (string, error){ + "list": func() (string, error) { + b, _ := json.Marshal(t.m.List()) + return string(b), nil + }, + "log": func() (string, error) { + return t.m.Log(sid, MapIntArg(args, "offset", 0), MapIntArg(args, "limit", 0)) + }, + "kill": func() (string, error) { + if err := t.m.Kill(sid); err != nil { + return "", err } - } - off := MapIntArg(args, "offset", 0) - lim := MapIntArg(args, "limit", 0) - if lim <= 0 { - lim = 1200 - } - if off < 0 { - off = 0 - } - chunk, _ := t.m.Log(sid, off, lim) - s.mu.RLock() - defer s.mu.RUnlock() - resp := map[string]interface{}{"id": s.ID, "running": s.ExitCode == nil, "started_at": s.StartedAt.Format(time.RFC3339), "log": chunk, "next_offset": off + len(chunk)} - if s.ExitCode != nil { - resp["exit_code"] = *s.ExitCode - resp["ended_at"] = s.EndedAt.Format(time.RFC3339) - } - b, _ := json.Marshal(resp) - return string(b), nil - default: - return "", nil + return "killed", nil + }, + "poll": func() (string, error) { + timeout := MapIntArg(args, "timeout_ms", 0) + if timeout < 0 { + timeout = 0 + } + s, ok := t.m.Get(sid) + if !ok { + return "", nil + } + if timeout > 0 { + select { + case <-s.done: + case <-time.After(time.Duration(timeout) * time.Millisecond): + case <-ctx.Done(): + } + } + off := MapIntArg(args, "offset", 0) + lim := MapIntArg(args, "limit", 0) + if lim <= 0 { + lim = 1200 + } + if off < 0 { + off = 0 + } + chunk, _ := t.m.Log(sid, off, lim) + s.mu.RLock() + defer s.mu.RUnlock() + resp := map[string]interface{}{"id": s.ID, "running": s.ExitCode == nil, "started_at": s.StartedAt.Format(time.RFC3339), "log": chunk, "next_offset": off + len(chunk)} + if s.ExitCode != nil { + resp["exit_code"] = *s.ExitCode + resp["ended_at"] = s.EndedAt.Format(time.RFC3339) + } + b, _ := json.Marshal(resp) + return string(b), nil + }, } + if handler := handlers[action]; handler != nil { + return handler() + } + return "", nil } diff --git a/pkg/tools/sessions_tool.go b/pkg/tools/sessions_tool.go index 39ba5e4..82cc9c7 100644 --- a/pkg/tools/sessions_tool.go +++ b/pkg/tools/sessions_tool.go @@ -73,185 +73,189 @@ func (t *SessionsTool) Execute(ctx context.Context, args map[string]interface{}) kindFilter[s] = struct{}{} } } + type sessionActionHandler func() (string, error) + handlers := map[string]sessionActionHandler{ + "list": func() (string, error) { + if t.listFn == nil { + return "sessions list unavailable", nil + } + items := t.listFn(limit * 3) + if len(items) == 0 { + return "No sessions.", nil + } + if len(kindFilter) > 0 { + filtered := make([]SessionInfo, 0, len(items)) + for _, s := range items { + k := strings.ToLower(strings.TrimSpace(s.Kind)) + if _, ok := kindFilter[k]; ok { + filtered = append(filtered, s) + } + } + items = filtered + } + if activeMinutes > 0 { + cutoff := time.Now().Add(-time.Duration(activeMinutes) * time.Minute) + filtered := make([]SessionInfo, 0, len(items)) + for _, s := range items { + if s.UpdatedAt.After(cutoff) { + filtered = append(filtered, s) + } + } + items = filtered + } + if query != "" { + filtered := make([]SessionInfo, 0, len(items)) + for _, s := range items { + blob := strings.ToLower(s.Key + "\n" + s.Kind + "\n" + s.Summary) + if strings.Contains(blob, query) { + filtered = append(filtered, s) + } + } + items = filtered + } + if len(items) == 0 { + return "No sessions (after filters).", nil + } + sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt.After(items[j].UpdatedAt) }) + if len(items) > limit { + items = items[:limit] + } + var sb strings.Builder + sb.WriteString("Sessions:\n") + for _, s := range items { + sb.WriteString(fmt.Sprintf("- %s kind=%s compactions=%d updated=%s\n", s.Key, s.Kind, s.CompactionCount, s.UpdatedAt.Format(time.RFC3339))) + } + return sb.String(), nil + }, + "history": func() (string, error) { + if t.historyFn == nil { + return "sessions history unavailable", nil + } + key := MapStringArg(args, "key") + if key == "" { + return "key is required for history", nil + } + raw := t.historyFn(key, 0) + if len(raw) == 0 { + return "No history.", nil + } + type indexedMsg struct { + idx int + msg providers.Message + } + window := make([]indexedMsg, 0, len(raw)) + for i, m := range raw { + window = append(window, indexedMsg{idx: i + 1, msg: m}) + } - switch action { - case "list": - if t.listFn == nil { - return "sessions list unavailable", nil - } - items := t.listFn(limit * 3) - if len(items) == 0 { - return "No sessions.", nil - } - if len(kindFilter) > 0 { - filtered := make([]SessionInfo, 0, len(items)) - for _, s := range items { - k := strings.ToLower(strings.TrimSpace(s.Kind)) - if _, ok := kindFilter[k]; ok { - filtered = append(filtered, s) + // Window selectors are 1-indexed (human-friendly) + if around > 0 { + center := around - 1 + if center < 0 { + center = 0 } - } - items = filtered - } - if activeMinutes > 0 { - cutoff := time.Now().Add(-time.Duration(activeMinutes) * time.Minute) - filtered := make([]SessionInfo, 0, len(items)) - for _, s := range items { - if s.UpdatedAt.After(cutoff) { - filtered = append(filtered, s) + if center >= len(window) { + center = len(window) - 1 } - } - items = filtered - } - if query != "" { - filtered := make([]SessionInfo, 0, len(items)) - for _, s := range items { - blob := strings.ToLower(s.Key + "\n" + s.Kind + "\n" + s.Summary) - if strings.Contains(blob, query) { - filtered = append(filtered, s) + half := limit / 2 + if half < 1 { + half = 1 } - } - items = filtered - } - if len(items) == 0 { - return "No sessions (after filters).", nil - } - sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt.After(items[j].UpdatedAt) }) - if len(items) > limit { - items = items[:limit] - } - var sb strings.Builder - sb.WriteString("Sessions:\n") - for _, s := range items { - sb.WriteString(fmt.Sprintf("- %s kind=%s compactions=%d updated=%s\n", s.Key, s.Kind, s.CompactionCount, s.UpdatedAt.Format(time.RFC3339))) - } - return sb.String(), nil - case "history": - if t.historyFn == nil { - return "sessions history unavailable", nil - } - key := MapStringArg(args, "key") - if key == "" { - return "key is required for history", nil - } - raw := t.historyFn(key, 0) - if len(raw) == 0 { - return "No history.", nil - } - type indexedMsg struct { - idx int - msg providers.Message - } - window := make([]indexedMsg, 0, len(raw)) - for i, m := range raw { - window = append(window, indexedMsg{idx: i + 1, msg: m}) - } - - // Window selectors are 1-indexed (human-friendly) - if around > 0 { - center := around - 1 - if center < 0 { - center = 0 - } - if center >= len(window) { - center = len(window) - 1 - } - half := limit / 2 - if half < 1 { - half = 1 - } - start := center - half - if start < 0 { - start = 0 - } - end := center + half + 1 - if end > len(window) { - end = len(window) - } - window = window[start:end] - } else { - start := 0 - end := len(window) - if after > 0 { - start = after - if start > len(window) { - start = len(window) - } - } - if before > 0 { - end = before - 1 - if end < 0 { - end = 0 + start := center - half + if start < 0 { + start = 0 } + end := center + half + 1 if end > len(window) { end = len(window) } + window = window[start:end] + } else { + start := 0 + end := len(window) + if after > 0 { + start = after + if start > len(window) { + start = len(window) + } + } + if before > 0 { + end = before - 1 + if end < 0 { + end = 0 + } + if end > len(window) { + end = len(window) + } + } + if start > end { + start = end + } + window = window[start:end] } - if start > end { - start = end - } - window = window[start:end] - } - if !includeTools { - filtered := make([]indexedMsg, 0, len(window)) - for _, m := range window { - if strings.ToLower(m.msg.Role) == "tool" { - continue - } - filtered = append(filtered, m) - } - window = filtered - } - if roleFilter != "" { - filtered := make([]indexedMsg, 0, len(window)) - for _, m := range window { - if strings.ToLower(m.msg.Role) == roleFilter { + if !includeTools { + filtered := make([]indexedMsg, 0, len(window)) + for _, m := range window { + if strings.ToLower(m.msg.Role) == "tool" { + continue + } filtered = append(filtered, m) } + window = filtered } - window = filtered - } - if fromMeSet { - targetRole := "user" - if fromMe { - targetRole = "assistant" - } - filtered := make([]indexedMsg, 0, len(window)) - for _, m := range window { - if strings.ToLower(m.msg.Role) == targetRole { - filtered = append(filtered, m) + if roleFilter != "" { + filtered := make([]indexedMsg, 0, len(window)) + for _, m := range window { + if strings.ToLower(m.msg.Role) == roleFilter { + filtered = append(filtered, m) + } } + window = filtered } - window = filtered - } - if query != "" { - filtered := make([]indexedMsg, 0, len(window)) - for _, m := range window { - blob := strings.ToLower(m.msg.Role + "\n" + m.msg.Content) - if strings.Contains(blob, query) { - filtered = append(filtered, m) + if fromMeSet { + targetRole := "user" + if fromMe { + targetRole = "assistant" } + filtered := make([]indexedMsg, 0, len(window)) + for _, m := range window { + if strings.ToLower(m.msg.Role) == targetRole { + filtered = append(filtered, m) + } + } + window = filtered } - window = filtered - } - if len(window) == 0 { - return "No history (after filters).", nil - } - if len(window) > limit { - window = window[len(window)-limit:] - } - var sb strings.Builder - sb.WriteString(fmt.Sprintf("History for %s:\n", key)) - for _, item := range window { - content := item.msg.Content - if len(content) > 180 { - content = content[:180] + "..." + if query != "" { + filtered := make([]indexedMsg, 0, len(window)) + for _, m := range window { + blob := strings.ToLower(m.msg.Role + "\n" + m.msg.Content) + if strings.Contains(blob, query) { + filtered = append(filtered, m) + } + } + window = filtered } - sb.WriteString(fmt.Sprintf("- [#%d][%s] %s\n", item.idx, item.msg.Role, content)) - } - return sb.String(), nil - default: - return "unsupported action", nil + if len(window) == 0 { + return "No history (after filters).", nil + } + if len(window) > limit { + window = window[len(window)-limit:] + } + var sb strings.Builder + sb.WriteString(fmt.Sprintf("History for %s:\n", key)) + for _, item := range window { + content := item.msg.Content + if len(content) > 180 { + content = content[:180] + "..." + } + sb.WriteString(fmt.Sprintf("- [#%d][%s] %s\n", item.idx, item.msg.Role, content)) + } + return sb.String(), nil + }, } + if handler := handlers[action]; handler != nil { + return handler() + } + return "unsupported action", nil } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 00d0936..c0e99a4 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -358,7 +358,6 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { CreatedAt: task.Updated, }) sm.persistTaskLocked(task, "failed", task.Result) - sm.notifyTaskWaitersLocked(task.ID) } else { task.Status = RuntimeStatusCompleted task.Result = applySubagentResultQuota(result, task.MaxResultChars) @@ -376,7 +375,6 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { CreatedAt: task.Updated, }) sm.persistTaskLocked(task, "completed", task.Result) - sm.notifyTaskWaitersLocked(task.ID) } sm.mu.Unlock() @@ -405,6 +403,9 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { }, }) } + sm.mu.Lock() + sm.notifyTaskWaitersLocked(task.ID) + sm.mu.Unlock() } func (sm *SubagentManager) recordEKG(task *SubagentTask, runErr error) { diff --git a/pkg/tools/subagent_config_tool.go b/pkg/tools/subagent_config_tool.go index 8352941..77c37c0 100644 --- a/pkg/tools/subagent_config_tool.go +++ b/pkg/tools/subagent_config_tool.go @@ -69,16 +69,20 @@ func (t *SubagentConfigTool) SetConfigPath(path string) { func (t *SubagentConfigTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { _ = ctx - switch stringArgFromMap(args, "action") { - case "upsert": - result, err := UpsertConfigSubagent(t.getConfigPath(), cloneSubagentConfigArgs(args)) - if err != nil { - return "", err - } - return marshalSubagentConfigPayload(result) - default: - return "", fmt.Errorf("unsupported action") + action := stringArgFromMap(args, "action") + handlers := map[string]func() (string, error){ + "upsert": func() (string, error) { + result, err := UpsertConfigSubagent(t.getConfigPath(), cloneSubagentConfigArgs(args)) + if err != nil { + return "", err + } + return marshalSubagentConfigPayload(result) + }, } + if handler := handlers[action]; handler != nil { + return handler() + } + return "", fmt.Errorf("%w: %s", ErrUnsupportedAction, action) } func (t *SubagentConfigTool) getConfigPath() string { diff --git a/pkg/tools/subagent_profile.go b/pkg/tools/subagent_profile.go index 8dc04e4..5c88bc8 100644 --- a/pkg/tools/subagent_profile.go +++ b/pkg/tools/subagent_profile.go @@ -563,149 +563,171 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter } action := strings.ToLower(MapStringArg(args, "action")) agentID := normalizeSubagentIdentifier(MapStringArg(args, "agent_id")) - - switch action { - case "list": - items, err := t.store.List() - if err != nil { - return "", err - } - if len(items) == 0 { - return "No subagent profiles.", nil - } - var sb strings.Builder - sb.WriteString("Subagent Profiles:\n") - for i, p := range items { - sb.WriteString(fmt.Sprintf("- #%d %s [%s] role=%s memory_ns=%s\n", i+1, p.AgentID, p.Status, p.Role, p.MemoryNamespace)) - } - return strings.TrimSpace(sb.String()), nil - case "get": - if agentID == "" { - return "agent_id is required", nil - } - p, ok, err := t.store.Get(agentID) - if err != nil { - return "", err - } - if !ok { - return "subagent profile not found", nil - } - b, _ := json.MarshalIndent(p, "", " ") - return string(b), nil - case "create": - if agentID == "" { - return "agent_id is required", nil - } - if _, ok, err := t.store.Get(agentID); err != nil { - return "", err - } else if ok { - return "subagent profile already exists", nil - } - p := SubagentProfile{ - AgentID: agentID, - Name: stringArg(args, "name"), - NotifyMainPolicy: stringArg(args, "notify_main_policy"), - Role: stringArg(args, "role"), - SystemPromptFile: stringArg(args, "system_prompt_file"), - MemoryNamespace: stringArg(args, "memory_namespace"), - Status: stringArg(args, "status"), - ToolAllowlist: parseStringList(args["tool_allowlist"]), - MaxRetries: profileIntArg(args, "max_retries"), - RetryBackoff: profileIntArg(args, "retry_backoff_ms"), - TimeoutSec: profileIntArg(args, "timeout_sec"), - MaxTaskChars: profileIntArg(args, "max_task_chars"), - MaxResultChars: profileIntArg(args, "max_result_chars"), - } - saved, err := t.store.Upsert(p) - if err != nil { - return "", err - } - return fmt.Sprintf("Created subagent profile: %s (role=%s status=%s)", saved.AgentID, saved.Role, saved.Status), nil - case "update": - if agentID == "" { - return "agent_id is required", nil - } - existing, ok, err := t.store.Get(agentID) - if err != nil { - return "", err - } - if !ok { - return "subagent profile not found", nil - } - next := *existing - if _, ok := args["name"]; ok { - next.Name = stringArg(args, "name") - } - if _, ok := args["role"]; ok { - next.Role = stringArg(args, "role") - } - if _, ok := args["notify_main_policy"]; ok { - next.NotifyMainPolicy = stringArg(args, "notify_main_policy") - } - if _, ok := args["system_prompt_file"]; ok { - next.SystemPromptFile = stringArg(args, "system_prompt_file") - } - if _, ok := args["memory_namespace"]; ok { - next.MemoryNamespace = stringArg(args, "memory_namespace") - } - if _, ok := args["status"]; ok { - next.Status = stringArg(args, "status") - } - if _, ok := args["tool_allowlist"]; ok { - next.ToolAllowlist = parseStringList(args["tool_allowlist"]) - } - if _, ok := args["max_retries"]; ok { - next.MaxRetries = profileIntArg(args, "max_retries") - } - if _, ok := args["retry_backoff_ms"]; ok { - next.RetryBackoff = profileIntArg(args, "retry_backoff_ms") - } - if _, ok := args["timeout_sec"]; ok { - next.TimeoutSec = profileIntArg(args, "timeout_sec") - } - if _, ok := args["max_task_chars"]; ok { - next.MaxTaskChars = profileIntArg(args, "max_task_chars") - } - if _, ok := args["max_result_chars"]; ok { - next.MaxResultChars = profileIntArg(args, "max_result_chars") - } - saved, err := t.store.Upsert(next) - if err != nil { - return "", err - } - return fmt.Sprintf("Updated subagent profile: %s (role=%s status=%s)", saved.AgentID, saved.Role, saved.Status), nil - case "enable", "disable": - if agentID == "" { - return "agent_id is required", nil - } - existing, ok, err := t.store.Get(agentID) - if err != nil { - return "", err - } - if !ok { - return "subagent profile not found", nil - } - if action == "enable" { + type subagentProfileActionHandler func() (string, error) + handlers := map[string]subagentProfileActionHandler{ + "list": func() (string, error) { + items, err := t.store.List() + if err != nil { + return "", err + } + if len(items) == 0 { + return "No subagent profiles.", nil + } + var sb strings.Builder + sb.WriteString("Subagent Profiles:\n") + for i, p := range items { + sb.WriteString(fmt.Sprintf("- #%d %s [%s] role=%s memory_ns=%s\n", i+1, p.AgentID, p.Status, p.Role, p.MemoryNamespace)) + } + return strings.TrimSpace(sb.String()), nil + }, + "get": func() (string, error) { + if agentID == "" { + return "agent_id is required", nil + } + p, ok, err := t.store.Get(agentID) + if err != nil { + return "", err + } + if !ok { + return "subagent profile not found", nil + } + b, _ := json.MarshalIndent(p, "", " ") + return string(b), nil + }, + "create": func() (string, error) { + if agentID == "" { + return "agent_id is required", nil + } + if _, ok, err := t.store.Get(agentID); err != nil { + return "", err + } else if ok { + return "subagent profile already exists", nil + } + p := SubagentProfile{ + AgentID: agentID, + Name: stringArg(args, "name"), + NotifyMainPolicy: stringArg(args, "notify_main_policy"), + Role: stringArg(args, "role"), + SystemPromptFile: stringArg(args, "system_prompt_file"), + MemoryNamespace: stringArg(args, "memory_namespace"), + Status: stringArg(args, "status"), + ToolAllowlist: parseStringList(args["tool_allowlist"]), + MaxRetries: profileIntArg(args, "max_retries"), + RetryBackoff: profileIntArg(args, "retry_backoff_ms"), + TimeoutSec: profileIntArg(args, "timeout_sec"), + MaxTaskChars: profileIntArg(args, "max_task_chars"), + MaxResultChars: profileIntArg(args, "max_result_chars"), + } + saved, err := t.store.Upsert(p) + if err != nil { + return "", err + } + return fmt.Sprintf("Created subagent profile: %s (role=%s status=%s)", saved.AgentID, saved.Role, saved.Status), nil + }, + "update": func() (string, error) { + if agentID == "" { + return "agent_id is required", nil + } + existing, ok, err := t.store.Get(agentID) + if err != nil { + return "", err + } + if !ok { + return "subagent profile not found", nil + } + next := *existing + if _, ok := args["name"]; ok { + next.Name = stringArg(args, "name") + } + if _, ok := args["role"]; ok { + next.Role = stringArg(args, "role") + } + if _, ok := args["notify_main_policy"]; ok { + next.NotifyMainPolicy = stringArg(args, "notify_main_policy") + } + if _, ok := args["system_prompt_file"]; ok { + next.SystemPromptFile = stringArg(args, "system_prompt_file") + } + if _, ok := args["memory_namespace"]; ok { + next.MemoryNamespace = stringArg(args, "memory_namespace") + } + if _, ok := args["status"]; ok { + next.Status = stringArg(args, "status") + } + if _, ok := args["tool_allowlist"]; ok { + next.ToolAllowlist = parseStringList(args["tool_allowlist"]) + } + if _, ok := args["max_retries"]; ok { + next.MaxRetries = profileIntArg(args, "max_retries") + } + if _, ok := args["retry_backoff_ms"]; ok { + next.RetryBackoff = profileIntArg(args, "retry_backoff_ms") + } + if _, ok := args["timeout_sec"]; ok { + next.TimeoutSec = profileIntArg(args, "timeout_sec") + } + if _, ok := args["max_task_chars"]; ok { + next.MaxTaskChars = profileIntArg(args, "max_task_chars") + } + if _, ok := args["max_result_chars"]; ok { + next.MaxResultChars = profileIntArg(args, "max_result_chars") + } + saved, err := t.store.Upsert(next) + if err != nil { + return "", err + } + return fmt.Sprintf("Updated subagent profile: %s (role=%s status=%s)", saved.AgentID, saved.Role, saved.Status), nil + }, + "enable": func() (string, error) { + if agentID == "" { + return "agent_id is required", nil + } + existing, ok, err := t.store.Get(agentID) + if err != nil { + return "", err + } + if !ok { + return "subagent profile not found", nil + } existing.Status = "active" - } else { + saved, err := t.store.Upsert(*existing) + if err != nil { + return "", err + } + return fmt.Sprintf("Subagent profile %s set to %s", saved.AgentID, saved.Status), nil + }, + "disable": func() (string, error) { + if agentID == "" { + return "agent_id is required", nil + } + existing, ok, err := t.store.Get(agentID) + if err != nil { + return "", err + } + if !ok { + return "subagent profile not found", nil + } existing.Status = "disabled" - } - saved, err := t.store.Upsert(*existing) - if err != nil { - return "", err - } - return fmt.Sprintf("Subagent profile %s set to %s", saved.AgentID, saved.Status), nil - case "delete": - if agentID == "" { - return "agent_id is required", nil - } - if err := t.store.Delete(agentID); err != nil { - return "", err - } - return fmt.Sprintf("Deleted subagent profile: %s", agentID), nil - default: - return "unsupported action", nil + saved, err := t.store.Upsert(*existing) + if err != nil { + return "", err + } + return fmt.Sprintf("Subagent profile %s set to %s", saved.AgentID, saved.Status), nil + }, + "delete": func() (string, error) { + if agentID == "" { + return "agent_id is required", nil + } + if err := t.store.Delete(agentID); err != nil { + return "", err + } + return fmt.Sprintf("Deleted subagent profile: %s", agentID), nil + }, } + if handler := handlers[action]; handler != nil { + return handler() + } + return "unsupported action", nil } func stringArg(args map[string]interface{}, key string) string { diff --git a/pkg/tools/subagents_tool.go b/pkg/tools/subagents_tool.go index 70fd29f..6c41416 100644 --- a/pkg/tools/subagents_tool.go +++ b/pkg/tools/subagents_tool.go @@ -53,164 +53,202 @@ func (t *SubagentsTool) Execute(ctx context.Context, args map[string]interface{} agentID := MapStringArg(args, "agent_id") limit := MapIntArg(args, "limit", 20) recentMinutes := MapIntArg(args, "recent_minutes", 0) - - switch action { - case "list": - tasks := t.filterRecent(t.manager.ListTasks(), recentMinutes) - if len(tasks) == 0 { - return "No subagents.", nil - } - var sb strings.Builder - sb.WriteString("Subagents:\n") - sort.Slice(tasks, func(i, j int) bool { return tasks[i].Created > tasks[j].Created }) - for i, task := range tasks { - sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s session=%s allowlist=%d retry=%d timeout=%ds\n", - i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, task.SessionKey, len(task.ToolAllowlist), task.MaxRetries, task.TimeoutSec)) - } - return strings.TrimSpace(sb.String()), nil - case "info": - if strings.EqualFold(strings.TrimSpace(id), "all") { + type subagentActionHandler func() (string, error) + var threadHandler subagentActionHandler + handlers := map[string]subagentActionHandler{ + "list": func() (string, error) { tasks := t.filterRecent(t.manager.ListTasks(), recentMinutes) if len(tasks) == 0 { return "No subagents.", nil } + var sb strings.Builder + sb.WriteString("Subagents:\n") sort.Slice(tasks, func(i, j int) bool { return tasks[i].Created > tasks[j].Created }) - var sb strings.Builder - sb.WriteString("Subagents Summary:\n") for i, task := range tasks { - sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s steering=%d allowlist=%d retry=%d timeout=%ds\n", - i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, len(task.Steering), len(task.ToolAllowlist), task.MaxRetries, task.TimeoutSec)) + sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s session=%s allowlist=%d retry=%d timeout=%ds\n", + i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, task.SessionKey, len(task.ToolAllowlist), task.MaxRetries, task.TimeoutSec)) } return strings.TrimSpace(sb.String()), nil - } - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - task, ok := t.manager.GetTask(resolvedID) - if !ok { - return "subagent not found", nil - } - info := fmt.Sprintf("ID: %s\nStatus: %s\nLabel: %s\nAgent ID: %s\nRole: %s\nSession Key: %s\nThread ID: %s\nCorrelation ID: %s\nWaiting Reply: %t\nMemory Namespace: %s\nTool Allowlist: %v\nMax Retries: %d\nRetry Count: %d\nRetry Backoff(ms): %d\nTimeout(s): %d\nMax Task Chars: %d\nMax Result Chars: %d\nCreated: %d\nUpdated: %d\nSteering Count: %d\nTask: %s\nResult:\n%s", - task.ID, task.Status, task.Label, task.AgentID, task.Role, task.SessionKey, task.ThreadID, task.CorrelationID, task.WaitingReply, task.MemoryNS, - task.ToolAllowlist, task.MaxRetries, task.RetryCount, task.RetryBackoff, task.TimeoutSec, task.MaxTaskChars, task.MaxResultChars, - task.Created, task.Updated, len(task.Steering), task.Task, task.Result) - if events, err := t.manager.Events(task.ID, 6); err == nil && len(events) > 0 { + }, + "info": func() (string, error) { + if strings.EqualFold(strings.TrimSpace(id), "all") { + tasks := t.filterRecent(t.manager.ListTasks(), recentMinutes) + if len(tasks) == 0 { + return "No subagents.", nil + } + sort.Slice(tasks, func(i, j int) bool { return tasks[i].Created > tasks[j].Created }) + var sb strings.Builder + sb.WriteString("Subagents Summary:\n") + for i, task := range tasks { + sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s steering=%d allowlist=%d retry=%d timeout=%ds\n", + i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, len(task.Steering), len(task.ToolAllowlist), task.MaxRetries, task.TimeoutSec)) + } + return strings.TrimSpace(sb.String()), nil + } + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + task, ok := t.manager.GetTask(resolvedID) + if !ok { + return "subagent not found", nil + } + info := fmt.Sprintf("ID: %s\nStatus: %s\nLabel: %s\nAgent ID: %s\nRole: %s\nSession Key: %s\nThread ID: %s\nCorrelation ID: %s\nWaiting Reply: %t\nMemory Namespace: %s\nTool Allowlist: %v\nMax Retries: %d\nRetry Count: %d\nRetry Backoff(ms): %d\nTimeout(s): %d\nMax Task Chars: %d\nMax Result Chars: %d\nCreated: %d\nUpdated: %d\nSteering Count: %d\nTask: %s\nResult:\n%s", + task.ID, task.Status, task.Label, task.AgentID, task.Role, task.SessionKey, task.ThreadID, task.CorrelationID, task.WaitingReply, task.MemoryNS, + task.ToolAllowlist, task.MaxRetries, task.RetryCount, task.RetryBackoff, task.TimeoutSec, task.MaxTaskChars, task.MaxResultChars, + task.Created, task.Updated, len(task.Steering), task.Task, task.Result) + if events, err := t.manager.Events(task.ID, 6); err == nil && len(events) > 0 { + var sb strings.Builder + sb.WriteString(info) + sb.WriteString("\nEvents:\n") + for _, evt := range events { + sb.WriteString(formatSubagentEventLog(evt) + "\n") + } + return strings.TrimSpace(sb.String()), nil + } + return info, nil + }, + "kill": func() (string, error) { + if strings.EqualFold(strings.TrimSpace(id), "all") { + tasks := t.filterRecent(t.manager.ListTasks(), recentMinutes) + if len(tasks) == 0 { + return "No subagents.", nil + } + killed := 0 + for _, task := range tasks { + if t.manager.KillTask(task.ID) { + killed++ + } + } + return fmt.Sprintf("subagent kill requested for %d tasks", killed), nil + } + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + if !t.manager.KillTask(resolvedID) { + return "subagent not found", nil + } + return "subagent kill requested", nil + }, + "steer": func() (string, error) { + if message == "" { + return "message is required for steer", nil + } + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + if !t.manager.SteerTask(resolvedID, message) { + return "subagent not found", nil + } + return "steering message accepted", nil + }, + "send": func() (string, error) { + if message == "" { + return "message is required for send", nil + } + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + if !t.manager.SendTaskMessage(resolvedID, message) { + return "subagent not found", nil + } + return "message sent", nil + }, + "reply": func() (string, error) { + if message == "" { + return "message is required for reply", nil + } + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + if !t.manager.ReplyToTask(resolvedID, messageID, message) { + return "subagent not found", nil + } + return "reply sent", nil + }, + "ack": func() (string, error) { + if messageID == "" { + return "message_id is required for ack", nil + } + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + if !t.manager.AckTaskMessage(resolvedID, messageID) { + return "subagent or message not found", nil + } + return "message acked", nil + }, + "thread": func() (string, error) { + if threadID == "" { + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + task, ok := t.manager.GetTask(resolvedID) + if !ok { + return "subagent not found", nil + } + threadID = task.ThreadID + } + if threadID == "" { + return "thread_id is required", nil + } + thread, ok := t.manager.Thread(threadID) + if !ok { + return "thread not found", nil + } + msgs, err := t.manager.ThreadMessages(threadID, limit) + if err != nil { + return "", err + } var sb strings.Builder - sb.WriteString(info) - sb.WriteString("\nEvents:\n") - for _, evt := range events { - sb.WriteString(formatSubagentEventLog(evt) + "\n") - } - return strings.TrimSpace(sb.String()), nil - } - return info, nil - case "kill": - if strings.EqualFold(strings.TrimSpace(id), "all") { - tasks := t.filterRecent(t.manager.ListTasks(), recentMinutes) - if len(tasks) == 0 { - return "No subagents.", nil - } - killed := 0 - for _, task := range tasks { - if t.manager.KillTask(task.ID) { - killed++ + sb.WriteString(fmt.Sprintf("Thread: %s\nOwner: %s\nStatus: %s\nParticipants: %s\nTopic: %s\n", + thread.ThreadID, thread.Owner, thread.Status, strings.Join(thread.Participants, ","), thread.Topic)) + if len(msgs) > 0 { + sb.WriteString("Messages:\n") + for _, msg := range msgs { + sb.WriteString(fmt.Sprintf("- %s %s -> %s type=%s reply_to=%s status=%s\n %s\n", + msg.MessageID, msg.FromAgent, msg.ToAgent, msg.Type, msg.ReplyTo, msg.Status, msg.Content)) } } - return fmt.Sprintf("subagent kill requested for %d tasks", killed), nil - } - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - if !t.manager.KillTask(resolvedID) { - return "subagent not found", nil - } - return "subagent kill requested", nil - case "steer": - if message == "" { - return "message is required for steer", nil - } - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - if !t.manager.SteerTask(resolvedID, message) { - return "subagent not found", nil - } - return "steering message accepted", nil - case "send": - if message == "" { - return "message is required for send", nil - } - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - if !t.manager.SendTaskMessage(resolvedID, message) { - return "subagent not found", nil - } - return "message sent", nil - case "reply": - if message == "" { - return "message is required for reply", nil - } - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - if !t.manager.ReplyToTask(resolvedID, messageID, message) { - return "subagent not found", nil - } - return "reply sent", nil - case "ack": - if messageID == "" { - return "message_id is required for ack", nil - } - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - if !t.manager.AckTaskMessage(resolvedID, messageID) { - return "subagent or message not found", nil - } - return "message acked", nil - case "thread", "trace": - if threadID == "" { - resolvedID, err := t.resolveTaskID(id) + return strings.TrimSpace(sb.String()), nil + }, + "inbox": func() (string, error) { + if agentID == "" { + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + task, ok := t.manager.GetTask(resolvedID) + if !ok { + return "subagent not found", nil + } + agentID = task.AgentID + } + if agentID == "" { + return "agent_id is required", nil + } + msgs, err := t.manager.Inbox(agentID, limit) if err != nil { - return err.Error(), nil + return "", err } - task, ok := t.manager.GetTask(resolvedID) - if !ok { - return "subagent not found", nil + if len(msgs) == 0 { + return "No inbox messages.", nil } - threadID = task.ThreadID - } - if threadID == "" { - return "thread_id is required", nil - } - thread, ok := t.manager.Thread(threadID) - if !ok { - return "thread not found", nil - } - msgs, err := t.manager.ThreadMessages(threadID, limit) - if err != nil { - return "", err - } - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Thread: %s\nOwner: %s\nStatus: %s\nParticipants: %s\nTopic: %s\n", - thread.ThreadID, thread.Owner, thread.Status, strings.Join(thread.Participants, ","), thread.Topic)) - if len(msgs) > 0 { - sb.WriteString("Messages:\n") + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Inbox for %s:\n", agentID)) for _, msg := range msgs { - sb.WriteString(fmt.Sprintf("- %s %s -> %s type=%s reply_to=%s status=%s\n %s\n", - msg.MessageID, msg.FromAgent, msg.ToAgent, msg.Type, msg.ReplyTo, msg.Status, msg.Content)) + sb.WriteString(fmt.Sprintf("- %s thread=%s from=%s type=%s status=%s\n %s\n", + msg.MessageID, msg.ThreadID, msg.FromAgent, msg.Type, msg.Status, msg.Content)) } - } - return strings.TrimSpace(sb.String()), nil - case "inbox": - if agentID == "" { + return strings.TrimSpace(sb.String()), nil + }, + "log": func() (string, error) { resolvedID, err := t.resolveTaskID(id) if err != nil { return err.Error(), nil @@ -219,72 +257,50 @@ func (t *SubagentsTool) Execute(ctx context.Context, args map[string]interface{} if !ok { return "subagent not found", nil } - agentID = task.AgentID - } - if agentID == "" { - return "agent_id is required", nil - } - msgs, err := t.manager.Inbox(agentID, limit) - if err != nil { - return "", err - } - if len(msgs) == 0 { - return "No inbox messages.", nil - } - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Inbox for %s:\n", agentID)) - for _, msg := range msgs { - sb.WriteString(fmt.Sprintf("- %s thread=%s from=%s type=%s status=%s\n %s\n", - msg.MessageID, msg.ThreadID, msg.FromAgent, msg.Type, msg.Status, msg.Content)) - } - return strings.TrimSpace(sb.String()), nil - case "log": - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - task, ok := t.manager.GetTask(resolvedID) - if !ok { - return "subagent not found", nil - } - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Subagent %s Log\n", task.ID)) - sb.WriteString(fmt.Sprintf("Status: %s\n", task.Status)) - sb.WriteString(fmt.Sprintf("Agent ID: %s\nRole: %s\nSession Key: %s\nThread ID: %s\nCorrelation ID: %s\nWaiting Reply: %t\nTool Allowlist: %v\nMax Retries: %d\nRetry Count: %d\nRetry Backoff(ms): %d\nTimeout(s): %d\n", - task.AgentID, task.Role, task.SessionKey, task.ThreadID, task.CorrelationID, task.WaitingReply, task.ToolAllowlist, task.MaxRetries, task.RetryCount, task.RetryBackoff, task.TimeoutSec)) - if len(task.Steering) > 0 { - sb.WriteString("Steering Messages:\n") - for _, m := range task.Steering { - sb.WriteString("- " + m + "\n") + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Subagent %s Log\n", task.ID)) + sb.WriteString(fmt.Sprintf("Status: %s\n", task.Status)) + sb.WriteString(fmt.Sprintf("Agent ID: %s\nRole: %s\nSession Key: %s\nThread ID: %s\nCorrelation ID: %s\nWaiting Reply: %t\nTool Allowlist: %v\nMax Retries: %d\nRetry Count: %d\nRetry Backoff(ms): %d\nTimeout(s): %d\n", + task.AgentID, task.Role, task.SessionKey, task.ThreadID, task.CorrelationID, task.WaitingReply, task.ToolAllowlist, task.MaxRetries, task.RetryCount, task.RetryBackoff, task.TimeoutSec)) + if len(task.Steering) > 0 { + sb.WriteString("Steering Messages:\n") + for _, m := range task.Steering { + sb.WriteString("- " + m + "\n") + } } - } - if events, err := t.manager.Events(task.ID, 20); err == nil && len(events) > 0 { - sb.WriteString("Events:\n") - for _, evt := range events { - sb.WriteString(formatSubagentEventLog(evt) + "\n") + if events, err := t.manager.Events(task.ID, 20); err == nil && len(events) > 0 { + sb.WriteString("Events:\n") + for _, evt := range events { + sb.WriteString(formatSubagentEventLog(evt) + "\n") + } } - } - if strings.TrimSpace(task.Result) != "" { - result := strings.TrimSpace(task.Result) - if len(result) > 500 { - result = result[:500] + "..." + if strings.TrimSpace(task.Result) != "" { + result := strings.TrimSpace(task.Result) + if len(result) > 500 { + result = result[:500] + "..." + } + sb.WriteString("Result Preview:\n" + result) } - sb.WriteString("Result Preview:\n" + result) - } - return strings.TrimSpace(sb.String()), nil - case "resume": - resolvedID, err := t.resolveTaskID(id) - if err != nil { - return err.Error(), nil - } - label, ok := t.manager.ResumeTask(ctx, resolvedID) - if !ok { - return "subagent resume failed", nil - } - return fmt.Sprintf("subagent resumed as %s", label), nil - default: - return "unsupported action", nil + return strings.TrimSpace(sb.String()), nil + }, + "resume": func() (string, error) { + resolvedID, err := t.resolveTaskID(id) + if err != nil { + return err.Error(), nil + } + label, ok := t.manager.ResumeTask(ctx, resolvedID) + if !ok { + return "subagent resume failed", nil + } + return fmt.Sprintf("subagent resumed as %s", label), nil + }, } + threadHandler = handlers["thread"] + handlers["trace"] = func() (string, error) { return threadHandler() } + if handler := handlers[action]; handler != nil { + return handler() + } + return "unsupported action", nil } func (t *SubagentsTool) resolveTaskID(idOrIndex string) (string, error) {