diff --git a/Makefile b/Makefile index ca3951c..6e65d7f 100644 --- a/Makefile +++ b/Makefile @@ -269,7 +269,9 @@ sync-embed-workspace: ## cleanup-embed-workspace: Remove synced embed workspace artifacts cleanup-embed-workspace: - @rm -rf "$(EMBED_WORKSPACE_DIR)" + @if [ -d "$(EMBED_WORKSPACE_DIR)" ]; then \ + find "$(EMBED_WORKSPACE_DIR)" -mindepth 1 ! -name 'embedkeep.txt' -exec rm -rf {} +; \ + fi @echo "✓ Cleaned embedded workspace artifacts" ## install: Install clawgo to system and copy builtin skills diff --git a/README.md b/README.md index 33abee8..12c270c 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,14 @@ user -> main -> worker -> main -> user ## 配置结构 -当前推荐结构: +当前有两层配置视图: + +- 落盘文件仍然使用下面的原始结构 +- WebUI 与运行时接口优先使用标准化视图: + - `core` + - `runtime` + +原始配置的推荐结构: ```json { @@ -227,6 +234,13 @@ user -> main -> worker -> main -> user - `agents.defaults.execution` - `agents.defaults.summary_policy` - `agents.router.policy` +- WebUI 配置保存优先走 normalized schema: + - `core.default_provider` + - `core.main_agent_id` + - `core.subagents` + - `runtime.router` + - `runtime.providers` +- 运行态面板优先消费统一 `runtime snapshot / runtime live` - 启用中的本地 subagent 必须配置 `system_prompt_file` - 远端分支需要: - `transport: "node"` diff --git a/README_EN.md b/README_EN.md index b660614..186fbb1 100644 --- a/README_EN.md +++ b/README_EN.md @@ -180,7 +180,14 @@ ClawGo currently has four layers: ## Config Layout -Recommended structure: +There are now two configuration views: + +- the persisted file still uses the raw structure shown below +- the WebUI and runtime-facing APIs prefer a normalized view: + - `core` + - `runtime` + +Recommended raw structure: ```json { @@ -217,6 +224,13 @@ Notes: - `agents.defaults.execution` - `agents.defaults.summary_policy` - `agents.router.policy` +- the WebUI now saves through the normalized schema first: + - `core.default_provider` + - `core.main_agent_id` + - `core.subagents` + - `runtime.router` + - `runtime.providers` +- runtime panels now consume the unified `runtime snapshot / runtime live` - enabled local subagents must define `system_prompt_file` - remote branches require: - `transport: "node"` diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index fbf11d6..de67170 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -537,11 +537,11 @@ func nodeAgentTaskResult(payload map[string]interface{}) string { if len(payload) == 0 { return "" } - if result, _ := payload["result"].(string); strings.TrimSpace(result) != "" { - return strings.TrimSpace(result) + if result := tools.MapStringArg(payload, "result"); result != "" { + return result } - if content, _ := payload["content"].(string); strings.TrimSpace(content) != "" { - return strings.TrimSpace(content) + if content := tools.MapStringArg(payload, "content"); content != "" { + return content } return "" } @@ -1804,8 +1804,8 @@ func (al *AgentLoop) buildProviderToolDefs(toolDefs []map[string]interface{}) [] if !ok { continue } - name, _ := fnRaw["name"].(string) - description, _ := fnRaw["description"].(string) + name := tools.MapStringArg(fnRaw, "name") + description := tools.MapStringArg(fnRaw, "description") params, _ := fnRaw["parameters"].(map[string]interface{}) if strings.TrimSpace(name) == "" { continue @@ -1844,8 +1844,7 @@ func filterToolDefinitionsByContext(ctx context.Context, toolDefs []map[string]i if !ok { continue } - name, _ := fnRaw["name"].(string) - name = strings.ToLower(strings.TrimSpace(name)) + name := strings.ToLower(tools.MapStringArg(fnRaw, "name")) if name == "" { continue } @@ -2261,7 +2260,7 @@ func withToolMemoryNamespaceArgs(toolName string, args map[string]interface{}, n return args } - if raw, ok := args["namespace"].(string); ok && strings.TrimSpace(raw) != "" { + if raw := tools.MapStringArg(args, "namespace"); raw != "" { return args } next := make(map[string]interface{}, len(args)+1) @@ -2344,7 +2343,7 @@ func validateParallelAllowlistArgs(allow map[string]struct{}, args map[string]in if !ok { continue } - tool, _ := m["tool"].(string) + tool := tools.MapStringArg(m, "tool") name := strings.ToLower(strings.TrimSpace(tool)) if name == "" { continue @@ -2422,8 +2421,7 @@ func shouldSuppressSelfMessageSend(toolName string, args map[string]interface{}, if strings.TrimSpace(toolName) != "message" { return false } - action, _ := args["action"].(string) - action = strings.ToLower(strings.TrimSpace(action)) + action := strings.ToLower(tools.MapStringArg(args, "action")) if action == "" { action = "send" } @@ -2436,17 +2434,15 @@ func shouldSuppressSelfMessageSend(toolName string, args map[string]interface{}, } func resolveMessageToolTarget(args map[string]interface{}, fallbackChannel, fallbackChatID string) (string, string) { - channel, _ := args["channel"].(string) - channel = strings.TrimSpace(channel) + channel := tools.MapStringArg(args, "channel") if channel == "" { channel = strings.TrimSpace(fallbackChannel) } - chatID, _ := args["chat_id"].(string) - if to, _ := args["to"].(string); strings.TrimSpace(to) != "" { + chatID := tools.MapStringArg(args, "chat_id") + if to := tools.MapStringArg(args, "to"); to != "" { chatID = to } - chatID = strings.TrimSpace(chatID) if chatID == "" { chatID = strings.TrimSpace(fallbackChatID) } diff --git a/pkg/agent/loop_allowlist_test.go b/pkg/agent/loop_allowlist_test.go index 550784f..673befc 100644 --- a/pkg/agent/loop_allowlist_test.go +++ b/pkg/agent/loop_allowlist_test.go @@ -53,6 +53,15 @@ func TestEnsureToolAllowedByContextParallelNested(t *testing.T) { if err := ensureToolAllowedByContext(restricted, "parallel", skillArgs); err != nil { t.Fatalf("expected parallel with nested skill_exec to pass, got: %v", err) } + + stringToolArgs := map[string]interface{}{ + "calls": []interface{}{ + map[string]interface{}{"tool": "read_file", "arguments": map[string]interface{}{"path": "README.md"}}, + }, + } + if err := ensureToolAllowedByContext(restricted, "parallel", stringToolArgs); err != nil { + t.Fatalf("expected parallel with string tool key to pass, got: %v", err) + } } func TestEnsureToolAllowedByContext_GroupAllowlist(t *testing.T) { diff --git a/pkg/agent/loop_message_target_test.go b/pkg/agent/loop_message_target_test.go new file mode 100644 index 0000000..b6302bd --- /dev/null +++ b/pkg/agent/loop_message_target_test.go @@ -0,0 +1,30 @@ +package agent + +import "testing" + +func TestResolveMessageToolTargetUsesStringHelpers(t *testing.T) { + channel, chat := resolveMessageToolTarget(map[string]interface{}{ + "channel": "telegram", + "to": "chat-2", + }, "whatsapp", "chat-1") + if channel != "telegram" || chat != "chat-2" { + t.Fatalf("unexpected target: %s %s", channel, chat) + } + + channel, chat = resolveMessageToolTarget(map[string]interface{}{}, "whatsapp", "chat-1") + if channel != "whatsapp" || chat != "chat-1" { + t.Fatalf("unexpected fallback target: %s %s", channel, chat) + } +} + +func TestNodeAgentTaskResultUsesCompatiblePayloadFields(t *testing.T) { + if got := nodeAgentTaskResult(map[string]interface{}{"result": "done"}); got != "done" { + t.Fatalf("unexpected result field extraction: %q", got) + } + if got := nodeAgentTaskResult(map[string]interface{}{"content": "fallback"}); got != "fallback" { + t.Fatalf("unexpected content field extraction: %q", got) + } + if got := nodeAgentTaskResult(nil); got != "" { + t.Fatalf("expected empty result for nil payload, got %q", got) + } +} diff --git a/pkg/agent/router_dispatch.go b/pkg/agent/router_dispatch.go index b00d245..c6c8740 100644 --- a/pkg/agent/router_dispatch.go +++ b/pkg/agent/router_dispatch.go @@ -27,8 +27,8 @@ func (al *AgentLoop) maybeAutoRoute(ctx context.Context, msg bus.InboundMessage) if cfg == nil || !cfg.Agents.Router.Enabled { return "", false, nil } - agentID, taskText := resolveAutoRouteTarget(cfg, msg.Content) - if agentID == "" || strings.TrimSpace(taskText) == "" { + decision := resolveDispatchDecision(cfg, msg.Content) + if !decision.Valid() { return "", false, nil } waitTimeout := cfg.Agents.Router.DefaultTimeoutSec @@ -38,8 +38,9 @@ func (al *AgentLoop) maybeAutoRoute(ctx context.Context, msg bus.InboundMessage) waitCtx, cancel := context.WithTimeout(ctx, time.Duration(waitTimeout)*time.Second) defer cancel() task, err := al.subagentRouter.DispatchTask(waitCtx, tools.RouterDispatchRequest{ - Task: taskText, - AgentID: agentID, + Task: decision.TaskText, + AgentID: decision.TargetAgent, + Decision: &decision, NotifyMainPolicy: "internal_only", OriginChannel: msg.Channel, OriginChatID: msg.ChatID, @@ -55,16 +56,21 @@ func (al *AgentLoop) maybeAutoRoute(ctx context.Context, msg bus.InboundMessage) } func resolveAutoRouteTarget(cfg *config.Config, raw string) (string, string) { + decision := resolveDispatchDecision(cfg, raw) + return decision.TargetAgent, decision.TaskText +} + +func resolveDispatchDecision(cfg *config.Config, raw string) tools.DispatchDecision { if cfg == nil { - return "", "" + return tools.DispatchDecision{} } content := strings.TrimSpace(raw) if content == "" || len(cfg.Agents.Subagents) == 0 { - return "", "" + return tools.DispatchDecision{} } maxChars := cfg.Agents.Router.Policy.IntentMaxInputChars if maxChars > 0 && len([]rune(content)) > maxChars { - return "", "" + return tools.DispatchDecision{} } lower := strings.ToLower(content) for agentID, subcfg := range cfg.Agents.Subagents { @@ -73,19 +79,37 @@ func resolveAutoRouteTarget(cfg *config.Config, raw string) (string, string) { } marker := "@" + strings.ToLower(strings.TrimSpace(agentID)) if strings.HasPrefix(lower, marker+" ") || lower == marker { - return agentID, strings.TrimSpace(content[len(marker):]) + return tools.DispatchDecision{ + TargetAgent: agentID, + Reason: "explicit agent mention", + Confidence: 1, + TaskText: strings.TrimSpace(content[len(marker):]), + RouteSource: "explicit", + } } prefix := "agent:" + strings.ToLower(strings.TrimSpace(agentID)) if strings.HasPrefix(lower, prefix+" ") || lower == prefix { - return agentID, strings.TrimSpace(content[len(prefix):]) + return tools.DispatchDecision{ + TargetAgent: agentID, + Reason: "explicit agent prefix", + Confidence: 1, + TaskText: strings.TrimSpace(content[len(prefix):]), + RouteSource: "explicit", + } } } if strings.EqualFold(strings.TrimSpace(cfg.Agents.Router.Strategy), "rules_first") { if agentID := selectAgentByRules(cfg, content); agentID != "" { - return agentID, content + return tools.DispatchDecision{ + TargetAgent: agentID, + Reason: "matched router rules or role keywords", + Confidence: 0.8, + TaskText: content, + RouteSource: "rules", + } } } - return "", "" + return tools.DispatchDecision{} } func selectAgentByRules(cfg *config.Config, content string) string { diff --git a/pkg/agent/router_dispatch_test.go b/pkg/agent/router_dispatch_test.go index 0b03221..0b3b851 100644 --- a/pkg/agent/router_dispatch_test.go +++ b/pkg/agent/router_dispatch_test.go @@ -29,7 +29,7 @@ func TestResolveAutoRouteTargetRulesFirst(t *testing.T) { cfg.Agents.Subagents["tester"] = config.SubagentConfig{Enabled: true, Role: "testing", SystemPromptFile: "agents/tester/AGENT.md"} cfg.Agents.Router.Rules = []config.AgentRouteRule{{AgentID: "coder", Keywords: []string{"鐧诲綍", "bug"}}} - agentID, task := resolveAutoRouteTarget(cfg, "璇峰府鎴戜慨澶嶇櫥褰曟帴鍙g殑 bug 骞舵敼浠g爜") + agentID, task := resolveAutoRouteTarget(cfg, "please fix the login bug and update the code") if agentID != "coder" || task == "" { t.Fatalf("expected coder route, got %s / %s", agentID, task) } @@ -113,7 +113,7 @@ func TestMaybeAutoRouteDispatchesRulesFirstMatch(t *testing.T) { Channel: "cli", ChatID: "direct", SessionKey: "main", - Content: "璇峰仛涓€娆″洖褰掓祴璇曞苟楠岃瘉杩欎釜淇", + Content: "please run regression testing and verify this fix", }) if err != nil { t.Fatalf("rules-first auto route failed: %v", err) @@ -126,6 +126,21 @@ func TestMaybeAutoRouteDispatchesRulesFirstMatch(t *testing.T) { } } +func TestResolveDispatchDecisionIncludesReason(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Router.Enabled = true + cfg.Agents.Router.Strategy = "rules_first" + cfg.Agents.Subagents["tester"] = config.SubagentConfig{Enabled: true, Role: "testing", SystemPromptFile: "agents/tester/AGENT.md"} + + decision := resolveDispatchDecision(cfg, "run regression testing for this change") + if !decision.Valid() { + t.Fatalf("expected valid decision") + } + if decision.TargetAgent != "tester" || decision.RouteSource == "" || decision.Reason == "" { + t.Fatalf("unexpected decision: %+v", decision) + } +} + func TestResolveAutoRouteTargetSkipsOversizedIntent(t *testing.T) { cfg := config.DefaultConfig() cfg.Agents.Router.Enabled = true diff --git a/pkg/agent/runtime_admin.go b/pkg/agent/runtime_admin.go index 66a02ee..248d85d 100644 --- a/pkg/agent/runtime_admin.go +++ b/pkg/agent/runtime_admin.go @@ -38,6 +38,9 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a } sort.Slice(items, func(i, j int) bool { return items[i].Created > items[j].Created }) return map[string]interface{}{"items": items}, nil + case "snapshot": + limit := runtimeIntArg(args, "limit", 100) + return map[string]interface{}{"snapshot": sm.RuntimeSnapshot(limit)}, nil case "get", "info": taskID, err := resolveSubagentTaskIDForRuntime(sm, runtimeStringArg(args, "id")) if err != nil { @@ -194,7 +197,7 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a if al.isProtectedMainAgent(agentID) { return nil, fmt.Errorf("main agent %q cannot be disabled", agentID) } - enabled, ok := args["enabled"].(bool) + enabled, ok := runtimeBoolArg(args, "enabled") if !ok { return nil, fmt.Errorf("enabled is required") } @@ -400,22 +403,6 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a "thread": thread, "items": stream, }, nil - case "stream_all": - tasks := sm.ListTasks() - sort.Slice(tasks, func(i, j int) bool { - left := maxInt64(tasks[i].Updated, tasks[i].Created) - right := maxInt64(tasks[j].Updated, tasks[j].Created) - if left != right { - return left > right - } - return tasks[i].ID > tasks[j].ID - }) - taskLimit := runtimeIntArg(args, "task_limit", 16) - if taskLimit > 0 && len(tasks) > taskLimit { - tasks = tasks[:taskLimit] - } - items := mergeAllSubagentStreams(sm, tasks, runtimeIntArg(args, "limit", 200)) - return map[string]interface{}{"found": true, "items": items}, nil case "inbox": agentID := runtimeStringArg(args, "agent_id") if agentID == "" { @@ -528,90 +515,6 @@ func mergeSubagentStream(events []tools.SubagentRunEvent, messages []tools.Agent return items } -func mergeAllSubagentStreams(sm *tools.SubagentManager, tasks []*tools.SubagentTask, limit int) []map[string]interface{} { - if sm == nil || len(tasks) == 0 { - return nil - } - items := make([]map[string]interface{}, 0) - seenEvents := map[string]struct{}{} - seenMessages := map[string]struct{}{} - for _, task := range tasks { - if task == nil { - continue - } - if events, err := sm.Events(task.ID, limit); err == nil { - for _, evt := range events { - key := fmt.Sprintf("%s:%s:%d:%s", evt.RunID, evt.Type, evt.At, evt.Message) - if _, ok := seenEvents[key]; ok { - continue - } - seenEvents[key] = struct{}{} - items = append(items, map[string]interface{}{ - "kind": "event", - "at": evt.At, - "task_id": task.ID, - "label": task.Label, - "run_id": evt.RunID, - "agent_id": firstNonEmptyString(evt.AgentID, task.AgentID), - "event_type": evt.Type, - "status": evt.Status, - "message": evt.Message, - "retry_count": evt.RetryCount, - }) - } - } - if strings.TrimSpace(task.ThreadID) == "" { - continue - } - if messages, err := sm.ThreadMessages(task.ThreadID, limit); err == nil { - for _, msg := range messages { - if _, ok := seenMessages[msg.MessageID]; ok { - continue - } - seenMessages[msg.MessageID] = struct{}{} - items = append(items, map[string]interface{}{ - "kind": "message", - "at": msg.CreatedAt, - "task_id": task.ID, - "label": task.Label, - "message_id": msg.MessageID, - "thread_id": msg.ThreadID, - "from_agent": msg.FromAgent, - "to_agent": msg.ToAgent, - "reply_to": msg.ReplyTo, - "correlation_id": msg.CorrelationID, - "message_type": msg.Type, - "content": msg.Content, - "status": msg.Status, - "requires_reply": msg.RequiresReply, - }) - } - } - } - sort.Slice(items, func(i, j int) bool { - left, _ := items[i]["at"].(int64) - right, _ := items[j]["at"].(int64) - if left != right { - return left < right - } - return fmt.Sprintf("%v", items[i]["task_id"]) < fmt.Sprintf("%v", items[j]["task_id"]) - }) - if limit > 0 && len(items) > limit { - items = items[len(items)-limit:] - } - return items -} - -func maxInt64(values ...int64) int64 { - var out int64 - for _, v := range values { - if v > out { - out = v - } - } - return out -} - func firstNonEmptyString(values ...string) string { for _, v := range values { if strings.TrimSpace(v) != "" { @@ -665,40 +568,19 @@ func resolveSubagentTaskIDForRuntime(sm *tools.SubagentManager, raw string) (str } func runtimeStringArg(args map[string]interface{}, key string) string { - if args == nil { - return "" - } - v, _ := args[key].(string) - return strings.TrimSpace(v) + return tools.MapStringArg(args, key) } func runtimeRawStringArg(args map[string]interface{}, key string) string { - if args == nil { - return "" - } - v, _ := args[key].(string) - return v + return tools.MapRawStringArg(args, key) } func runtimeIntArg(args map[string]interface{}, key string, fallback int) int { - if args == nil { - return fallback - } - switch v := args[key].(type) { - case int: - if v > 0 { - return v - } - case int64: - if v > 0 { - return int(v) - } - case float64: - if v > 0 { - return int(v) - } - } - return fallback + return tools.MapIntArg(args, key, fallback) +} + +func runtimeBoolArg(args map[string]interface{}, key string) (bool, bool) { + return tools.MapBoolArg(args, key) } func fallbackString(v, fallback string) string { diff --git a/pkg/agent/runtime_admin_test.go b/pkg/agent/runtime_admin_test.go index 042eb06..47c3299 100644 --- a/pkg/agent/runtime_admin_test.go +++ b/pkg/agent/runtime_admin_test.go @@ -48,6 +48,7 @@ func TestHandleSubagentRuntimeDispatchAndWait(t *testing.T) { if merged == "" { t.Fatalf("expected merged output") } + time.Sleep(20 * time.Millisecond) } func TestHandleSubagentRuntimeUpsertConfigSubagent(t *testing.T) { @@ -247,6 +248,50 @@ func TestHandleSubagentRuntimeDeleteConfigSubagent(t *testing.T) { } } +func TestHandleSubagentRuntimeToggleEnabledParsesStringBool(t *testing.T) { + workspace := t.TempDir() + configPath := filepath.Join(workspace, "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Router.Enabled = true + cfg.Agents.Subagents["main"] = config.SubagentConfig{ + Enabled: true, + Type: "router", + Role: "orchestrator", + SystemPromptFile: "agents/main/AGENT.md", + } + cfg.Agents.Subagents["tester"] = config.SubagentConfig{ + Enabled: true, + Type: "worker", + Role: "testing", + SystemPromptFile: "agents/tester/AGENT.md", + } + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("save config failed: %v", err) + } + runtimecfg.Set(cfg) + t.Cleanup(func() { runtimecfg.Set(config.DefaultConfig()) }) + + manager := tools.NewSubagentManager(nil, workspace, nil) + loop := &AgentLoop{ + configPath: configPath, + subagentManager: manager, + subagentRouter: tools.NewSubagentRouter(manager), + } + if _, err := loop.HandleSubagentRuntime(context.Background(), "set_config_subagent_enabled", map[string]interface{}{ + "agent_id": "tester", + "enabled": "false", + }); err != nil { + t.Fatalf("toggle enabled failed: %v", err) + } + reloaded, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("reload config failed: %v", err) + } + if reloaded.Agents.Subagents["tester"].Enabled { + t.Fatalf("expected tester to be disabled") + } +} + func TestHandleSubagentRuntimePromptFileGetSetBootstrap(t *testing.T) { workspace := t.TempDir() manager := tools.NewSubagentManager(nil, workspace, nil) @@ -437,19 +482,4 @@ func TestHandleSubagentRuntimeStreamAll(t *testing.T) { time.Sleep(10 * time.Millisecond) } - out, err := loop.HandleSubagentRuntime(context.Background(), "stream_all", map[string]interface{}{ - "limit": 100, - "task_limit": 10, - }) - if err != nil { - t.Fatalf("stream_all failed: %v", err) - } - payload, ok := out.(map[string]interface{}) - if !ok || payload["found"] != true { - t.Fatalf("unexpected stream_all payload: %#v", out) - } - items, ok := payload["items"].([]map[string]interface{}) - if !ok || len(items) == 0 { - t.Fatalf("expected grouped stream items, got %#v", payload["items"]) - } } diff --git a/pkg/api/server.go b/pkg/api/server.go index c52786a..151a427 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -70,8 +70,6 @@ type Server struct { liveRuntimeMu sync.Mutex liveRuntimeSubs map[chan []byte]struct{} liveRuntimeOn bool - liveSubagentMu sync.Mutex - liveSubagents map[string]*liveSubagentGroup whatsAppBridge *channels.WhatsAppBridgeService whatsAppBase string oauthFlowMu sync.Mutex @@ -100,7 +98,6 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server nodeSockets: map[string]*nodeSocketConn{}, artifactStats: map[string]interface{}{}, liveRuntimeSubs: map[chan []byte]struct{}{}, - liveSubagents: map[string]*liveSubagentGroup{}, oauthFlows: map[string]*providers.OAuthPendingFlow{}, extraRoutes: map[string]http.Handler{}, } @@ -112,13 +109,6 @@ type nodeSocketConn struct { mu sync.Mutex } -type liveSubagentGroup struct { - taskID string - previewTaskID string - subs map[chan []byte]struct{} - stopCh chan struct{} -} - func (c *nodeSocketConn) Send(msg nodes.WireMessage) error { if c == nil || c.conn == nil { return fmt.Errorf("node websocket unavailable") @@ -210,87 +200,6 @@ func (s *Server) publishRuntimeSnapshot(ctx context.Context) bool { return true } -func buildSubagentLiveKey(taskID, previewTaskID string) string { - return strings.TrimSpace(taskID) + "\x00" + strings.TrimSpace(previewTaskID) -} - -func (s *Server) subscribeSubagentLive(ctx context.Context, taskID, previewTaskID string) chan []byte { - ch := make(chan []byte, 1) - key := buildSubagentLiveKey(taskID, previewTaskID) - s.liveSubagentMu.Lock() - group := s.liveSubagents[key] - if group == nil { - group = &liveSubagentGroup{ - taskID: strings.TrimSpace(taskID), - previewTaskID: strings.TrimSpace(previewTaskID), - subs: map[chan []byte]struct{}{}, - stopCh: make(chan struct{}), - } - s.liveSubagents[key] = group - go s.subagentLiveLoop(key, group) - } - group.subs[ch] = struct{}{} - s.liveSubagentMu.Unlock() - go func() { - <-ctx.Done() - s.unsubscribeSubagentLive(key, ch) - }() - return ch -} - -func (s *Server) unsubscribeSubagentLive(key string, ch chan []byte) { - s.liveSubagentMu.Lock() - group := s.liveSubagents[key] - if group == nil { - s.liveSubagentMu.Unlock() - return - } - delete(group.subs, ch) - if len(group.subs) == 0 { - delete(s.liveSubagents, key) - close(group.stopCh) - } - s.liveSubagentMu.Unlock() -} - -func (s *Server) subagentLiveLoop(key string, group *liveSubagentGroup) { - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - for { - if !s.publishSubagentLiveSnapshot(context.Background(), key, group.taskID, group.previewTaskID) { - return - } - select { - case <-group.stopCh: - return - case <-ticker.C: - } - } -} - -func (s *Server) publishSubagentLiveSnapshot(ctx context.Context, key, taskID, previewTaskID string) bool { - if s == nil { - return false - } - payload := map[string]interface{}{ - "ok": true, - "type": "subagents_live", - "payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID), - } - data, err := json.Marshal(payload) - if err != nil { - return false - } - s.liveSubagentMu.Lock() - defer s.liveSubagentMu.Unlock() - group := s.liveSubagents[key] - if group == nil || len(group.subs) == 0 { - return false - } - publishLiveSnapshot(group.subs, data) - return true -} - func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) } func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) } func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) } @@ -371,6 +280,35 @@ func joinServerRoute(base, endpoint string) string { return base + "/" + strings.TrimPrefix(endpoint, "/") } +func writeJSON(w http.ResponseWriter, payload interface{}) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(payload) +} + +func writeJSONStatus(w http.ResponseWriter, code int, payload interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(payload) +} + +func queryBoundedPositiveInt(r *http.Request, key string, fallback int, max int) int { + if r == nil { + return fallback + } + value := strings.TrimSpace(r.URL.Query().Get(strings.TrimSpace(key))) + if value == "" { + return fallback + } + n, err := strconv.Atoi(value) + if err != nil || n <= 0 { + return fallback + } + if max > 0 && n > max { + return max + } + return n +} + func (s *Server) rememberNodeConnection(nodeID, connID string) { nodeID = strings.TrimSpace(nodeID) connID = strings.TrimSpace(connID) @@ -465,7 +403,6 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/config", s.handleWebUIConfig) mux.HandleFunc("/api/chat", s.handleWebUIChat) mux.HandleFunc("/api/chat/history", s.handleWebUIChatHistory) - mux.HandleFunc("/api/chat/stream", s.handleWebUIChatStream) mux.HandleFunc("/api/chat/live", s.handleWebUIChatLive) mux.HandleFunc("/api/runtime", s.handleWebUIRuntime) mux.HandleFunc("/api/version", s.handleWebUIVersion) @@ -475,7 +412,6 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/provider/oauth/accounts", s.handleWebUIProviderOAuthAccounts) mux.HandleFunc("/api/provider/models", s.handleWebUIProviderModels) mux.HandleFunc("/api/provider/runtime", s.handleWebUIProviderRuntime) - mux.HandleFunc("/api/provider/runtime/summary", s.handleWebUIProviderRuntimeSummary) mux.HandleFunc("/api/whatsapp/status", s.handleWebUIWhatsAppStatus) mux.HandleFunc("/api/whatsapp/logout", s.handleWebUIWhatsAppLogout) mux.HandleFunc("/api/whatsapp/qr.svg", s.handleWebUIWhatsAppQR) @@ -492,17 +428,13 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/skills", s.handleWebUISkills) mux.HandleFunc("/api/sessions", s.handleWebUISessions) mux.HandleFunc("/api/memory", s.handleWebUIMemory) - mux.HandleFunc("/api/subagent_profiles", s.handleWebUISubagentProfiles) + mux.HandleFunc("/api/workspace_file", s.handleWebUIWorkspaceFile) mux.HandleFunc("/api/subagents_runtime", s.handleWebUISubagentsRuntime) - mux.HandleFunc("/api/subagents_runtime/live", s.handleWebUISubagentsRuntimeLive) mux.HandleFunc("/api/tool_allowlist_groups", s.handleWebUIToolAllowlistGroups) mux.HandleFunc("/api/tools", s.handleWebUITools) mux.HandleFunc("/api/mcp/install", s.handleWebUIMCPInstall) - mux.HandleFunc("/api/task_audit", s.handleWebUITaskAudit) mux.HandleFunc("/api/task_queue", s.handleWebUITaskQueue) mux.HandleFunc("/api/ekg_stats", s.handleWebUIEKGStats) - mux.HandleFunc("/api/exec_approvals", s.handleWebUIExecApprovals) - mux.HandleFunc("/api/logs/stream", s.handleWebUILogsStream) mux.HandleFunc("/api/logs/live", s.handleWebUILogsLive) mux.HandleFunc("/api/logs/recent", s.handleWebUILogsRecent) s.extraRoutesMu.RLock() @@ -573,8 +505,7 @@ func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { return } s.mgr.Upsert(n) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "id": n.ID}) + writeJSON(w, map[string]interface{}{"ok": true, "id": n.ID}) } func (s *Server) handleHeartbeat(w http.ResponseWriter, r *http.Request) { @@ -601,8 +532,7 @@ func (s *Server) handleHeartbeat(w http.ResponseWriter, r *http.Request) { n.LastSeenAt = time.Now().UTC() n.Online = true s.mgr.Upsert(n) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "id": body.ID}) + writeJSON(w, map[string]interface{}{"ok": true, "id": body.ID}) } func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { @@ -814,6 +744,31 @@ func (s *Server) handleWebUIConfig(w http.ResponseWriter, r *http.Request) { } switch r.Method { case http.MethodGet: + if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + payload := map[string]interface{}{ + "ok": true, + "config": cfg.NormalizedView(), + "raw_config": cfg, + } + if r.URL.Query().Get("include_hot_reload_fields") == "1" { + info := hotReloadFieldInfo() + paths := make([]string, 0, len(info)) + for _, it := range info { + if p := stringFromMap(it, "path"); p != "" { + paths = append(paths, p) + } + } + payload["hot_reload_fields"] = paths + payload["hot_reload_field_details"] = info + } + writeJSON(w, payload) + return + } b, err := os.ReadFile(s.configPath) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -831,15 +786,14 @@ func (s *Server) handleWebUIConfig(w http.ResponseWriter, r *http.Request) { merged = mergeJSONMap(merged, loaded) if r.URL.Query().Get("include_hot_reload_fields") == "1" || strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "hot") { - w.Header().Set("Content-Type", "application/json") info := hotReloadFieldInfo() paths := make([]string, 0, len(info)) for _, it := range info { - if p, ok := it["path"].(string); ok && strings.TrimSpace(p) != "" { + if p := stringFromMap(it, "path"); p != "" { paths = append(paths, p) } } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "config": merged, "hot_reload_fields": paths, @@ -856,23 +810,31 @@ func (s *Server) handleWebUIConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid json", http.StatusBadRequest) return } - confirmRisky, _ := body["confirm_risky"].(bool) + confirmRisky, _ := tools.MapBoolArg(body, "confirm_risky") delete(body, "confirm_risky") oldCfgRaw, _ := os.ReadFile(s.configPath) var oldMap map[string]interface{} _ = json.Unmarshal(oldCfgRaw, &oldMap) + riskyOldMap := oldMap + riskyNewMap := body + if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { + if loaded, err := cfgpkg.LoadConfig(s.configPath); err == nil && loaded != nil { + if raw, err := json.Marshal(loaded.NormalizedView()); err == nil { + _ = json.Unmarshal(raw, &riskyOldMap) + } + } + } - riskyPaths := collectRiskyConfigPaths(oldMap, body) + riskyPaths := collectRiskyConfigPaths(riskyOldMap, riskyNewMap) changedRisky := make([]string, 0) for _, p := range riskyPaths { - if fmt.Sprintf("%v", getPathValue(oldMap, p)) != fmt.Sprintf("%v", getPathValue(body, p)) { + if fmt.Sprintf("%v", getPathValue(riskyOldMap, p)) != fmt.Sprintf("%v", getPathValue(riskyNewMap, p)) { changedRisky = append(changedRisky, p) } } if len(changedRisky) > 0 && !confirmRisky { - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSONStatus(w, http.StatusBadRequest, map[string]interface{}{ "ok": false, "error": "risky fields changed; confirmation required", "requires_confirm": true, @@ -881,39 +843,50 @@ func (s *Server) handleWebUIConfig(w http.ResponseWriter, r *http.Request) { return } - candidate, err := json.Marshal(body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } cfg := cfgpkg.DefaultConfig() - dec := json.NewDecoder(bytes.NewReader(candidate)) - dec.DisallowUnknownFields() - if err := dec.Decode(cfg); err != nil { - http.Error(w, "config schema validation failed: "+err.Error(), http.StatusBadRequest) - return + if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { + loaded, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + cfg = loaded + candidate, err := json.Marshal(body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var normalized cfgpkg.NormalizedConfig + dec := json.NewDecoder(bytes.NewReader(candidate)) + dec.DisallowUnknownFields() + if err := dec.Decode(&normalized); err != nil { + http.Error(w, "normalized config validation failed: "+err.Error(), http.StatusBadRequest) + return + } + cfg.ApplyNormalizedView(normalized) + } else { + candidate, err := json.Marshal(body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + dec := json.NewDecoder(bytes.NewReader(candidate)) + dec.DisallowUnknownFields() + if err := dec.Decode(cfg); err != nil { + http.Error(w, "config schema validation failed: "+err.Error(), http.StatusBadRequest) + return + } } if errs := cfgpkg.Validate(cfg); len(errs) > 0 { list := make([]string, 0, len(errs)) for _, e := range errs { list = append(list, e.Error()) } - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": false, "error": "config validation failed", "details": list}) + writeJSONStatus(w, http.StatusBadRequest, map[string]interface{}{"ok": false, "error": "config validation failed", "details": list}) return } - b, err := json.MarshalIndent(body, "", " ") - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - tmp := s.configPath + ".tmp" - if err := os.WriteFile(tmp, b, 0644); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := os.Rename(tmp, s.configPath); err != nil { + if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -928,7 +901,11 @@ func (s *Server) handleWebUIConfig(w http.ResponseWriter, r *http.Request) { return } } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "reloaded": true}) + if strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("mode")), "normalized") { + writeJSON(w, map[string]interface{}{"ok": true, "reloaded": true, "config": cfg.NormalizedView()}) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "reloaded": true}) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } @@ -975,6 +952,8 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string { "channels.telegram.allow_chats", "models.providers.openai.api_base", "models.providers.openai.api_key", + "runtime.providers.openai.api_base", + "runtime.providers.openai.api_key", "gateway.token", "gateway.port", } @@ -989,6 +968,11 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string { paths = append(paths, path) seen[path] = true } + normalizedPath := "runtime.providers." + name + "." + field + if !seen[normalizedPath] { + paths = append(paths, normalizedPath) + seen[normalizedPath] = true + } } } return paths @@ -1007,6 +991,15 @@ func collectProviderNames(maps ...map[string]interface{}) []string { seen[name] = true names = append(names, name) } + runtimeMap, _ := root["runtime"].(map[string]interface{}) + runtimeProviders, _ := runtimeMap["providers"].(map[string]interface{}) + for name := range runtimeProviders { + if strings.TrimSpace(name) == "" || seen[name] { + continue + } + seen[name] = true + names = append(names, name) + } } sort.Strings(names) return names @@ -1045,7 +1038,7 @@ func (s *Server) handleWebUIUpload(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "path": path, "name": h.Filename}) + writeJSON(w, map[string]interface{}{"ok": true, "path": path, "name": h.Filename}) } func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Request) { @@ -1100,7 +1093,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re s.oauthFlowMu.Lock() s.oauthFlows[flowID] = flow s.oauthFlowMu.Unlock() - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "flow_id": flowID, "mode": flow.Mode, @@ -1171,7 +1164,7 @@ func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "account": session.Email, "credential_file": session.CredentialFile, @@ -1245,7 +1238,7 @@ func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.R http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "account": session.Email, "credential_file": session.CredentialFile, @@ -1282,7 +1275,7 @@ func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "accounts": accounts}) + writeJSON(w, map[string]interface{}{"ok": true, "accounts": accounts}) case http.MethodPost: var body struct { Action string `json:"action"` @@ -1299,7 +1292,7 @@ func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "account": account}) + writeJSON(w, map[string]interface{}{"ok": true, "account": account}) case "delete": if err := loginMgr.DeleteAccount(body.CredentialFile); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -1316,13 +1309,13 @@ func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": true}) + writeJSON(w, map[string]interface{}{"ok": true, "deleted": true}) case "clear_cooldown": if err := loginMgr.ClearCooldown(body.CredentialFile); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) + writeJSON(w, map[string]interface{}{"ok": true, "cleared": true}) default: http.Error(w, "unsupported action", http.StatusBadRequest) } @@ -1368,7 +1361,7 @@ func (s *Server) handleWebUIProviderModels(w http.ResponseWriter, r *http.Reques http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "models": pc.Models, }) @@ -1408,7 +1401,7 @@ func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Reque if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec"))); secs > 0 { query.CooldownBefore = time.Now().Add(time.Duration(secs) * time.Second) } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "view": providers.GetProviderRuntimeView(cfg, query), }) @@ -1436,7 +1429,7 @@ func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Reque } _ = cfg providers.ClearProviderAPICooldown(providerName) - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) + writeJSON(w, map[string]interface{}{"ok": true, "cleared": true}) case "clear_history": cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) if err != nil { @@ -1445,7 +1438,7 @@ func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Reque } _ = cfg providers.ClearProviderRuntimeHistory(providerName) - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) + writeJSON(w, map[string]interface{}{"ok": true, "cleared": true}) case "refresh_now": cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) if err != nil { @@ -1459,7 +1452,7 @@ func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Reque } order, _ := providers.RerankProviderRuntime(cfg, providerName) summary := providers.GetProviderRuntimeSummary(cfg, providers.ProviderRuntimeQuery{Provider: providerName, HealthBelow: 50}) - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "provider": providerName, "refreshed": true, @@ -1478,49 +1471,12 @@ func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Reque http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "provider": providerName, "reranked": true, "candidate_order": order}) + writeJSON(w, map[string]interface{}{"ok": true, "provider": providerName, "reranked": true, "candidate_order": order}) default: http.Error(w, "unsupported action", http.StatusBadRequest) } } -func (s *Server) handleWebUIProviderRuntimeSummary(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - query := providers.ProviderRuntimeQuery{ - Provider: strings.TrimSpace(r.URL.Query().Get("provider")), - Reason: strings.TrimSpace(r.URL.Query().Get("reason")), - Target: strings.TrimSpace(r.URL.Query().Get("target")), - } - if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("window_sec"))); secs > 0 { - query.Window = time.Duration(secs) * time.Second - } - if healthBelow, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("health_below"))); healthBelow > 0 { - query.HealthBelow = healthBelow - } - if query.HealthBelow <= 0 { - query.HealthBelow = 50 - } - if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec"))); secs > 0 { - query.CooldownBefore = time.Now().Add(time.Duration(secs) * time.Second) - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ - "ok": true, - "summary": providers.GetProviderRuntimeSummary(cfg, query), - }) -} - func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { if strings.TrimSpace(s.configPath) == "" { return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("config path not set") @@ -1670,7 +1626,7 @@ func (s *Server) handleWebUIChat(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "reply": resp, "session": session}) + writeJSON(w, map[string]interface{}{"ok": true, "reply": resp, "session": session}) } func (s *Server) handleWebUIChatHistory(w http.ResponseWriter, r *http.Request) { @@ -1687,74 +1643,10 @@ func (s *Server) handleWebUIChatHistory(w http.ResponseWriter, r *http.Request) session = "main" } if s.onChatHistory == nil { - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "session": session, "messages": []interface{}{}}) + writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": []interface{}{}}) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "session": session, "messages": s.onChatHistory(session)}) -} - -func (s *Server) handleWebUIChatStream(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Deprecation", "true") - w.Header().Set("X-Clawgo-Replaced-By", "/api/chat/live") - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if s.onChat == nil { - http.Error(w, "chat handler not configured", http.StatusInternalServerError) - return - } - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "stream unsupported", http.StatusInternalServerError) - return - } - var body struct { - Session string `json:"session"` - Message string `json:"message"` - Media string `json:"media"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - session := body.Session - if session == "" { - session = r.URL.Query().Get("session") - } - if session == "" { - session = "main" - } - prompt := body.Message - if body.Media != "" { - if prompt != "" { - prompt += "\n" - } - prompt += "[file: " + body.Media + "]" - } - - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - resp, err := s.onChat(r.Context(), session, prompt) - if err != nil { - _, _ = w.Write([]byte("Error: " + err.Error())) - flusher.Flush() - return - } - chunk := 180 - for i := 0; i < len(resp); i += chunk { - end := i + chunk - if end > len(resp) { - end = len(resp) - } - _, _ = w.Write([]byte(resp[i:end])) - flusher.Flush() - } + writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": s.onChatHistory(session)}) } func (s *Server) handleWebUIChatLive(w http.ResponseWriter, r *http.Request) { @@ -1837,7 +1729,7 @@ func (s *Server) handleWebUIVersion(w http.ResponseWriter, r *http.Request) { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "gateway_version": firstNonEmptyString(s.gatewayVersion, gatewayBuildVersion()), "webui_version": firstNonEmptyString(s.webuiVersion, detectWebUIVersion(strings.TrimSpace(s.webUIDir))), @@ -1855,9 +1747,7 @@ func (s *Server) handleWebUIWhatsAppStatus(w http.ResponseWriter, r *http.Reques return } payload, code := s.webUIWhatsAppStatusPayload(r.Context()) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - _ = json.NewEncoder(w).Encode(payload) + writeJSONStatus(w, code, payload) } func (s *Server) handleWebUIWhatsAppLogout(w http.ResponseWriter, r *http.Request) { @@ -1906,7 +1796,7 @@ func (s *Server) handleWebUIWhatsAppQR(w http.ResponseWriter, r *http.Request) { status, _ := payload["status"].(map[string]interface{}) qrCode := "" if status != nil { - qrCode, _ = status["qr_code"].(string) + qrCode = stringFromMap(status, "qr_code") } if code != http.StatusOK || strings.TrimSpace(qrCode) == "" { http.Error(w, "qr unavailable", http.StatusNotFound) @@ -2136,52 +2026,36 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) { func (s *Server) buildWebUIRuntimeSnapshot(ctx context.Context) map[string]interface{} { var providerPayload map[string]interface{} + var normalizedConfig interface{} if strings.TrimSpace(s.configPath) != "" { if cfg, err := cfgpkg.LoadConfig(strings.TrimSpace(s.configPath)); err == nil { providerPayload = providers.GetProviderRuntimeSnapshot(cfg) + normalizedConfig = cfg.NormalizedView() } } if providerPayload == nil { providerPayload = map[string]interface{}{"items": []interface{}{}} } + runtimePayload := map[string]interface{}{} + if s.onSubagents != nil { + if res, err := s.onSubagents(ctx, "snapshot", map[string]interface{}{"limit": 200}); err == nil { + if m, ok := res.(map[string]interface{}); ok { + runtimePayload = m + } + } + } return map[string]interface{}{ "version": s.webUIVersionPayload(), + "config": normalizedConfig, + "runtime": runtimePayload, "nodes": s.webUINodesPayload(ctx), "sessions": s.webUISessionsPayload(), "task_queue": s.webUITaskQueuePayload(false), "ekg": s.webUIEKGSummaryPayload("24h"), - "subagents": s.webUISubagentsRuntimePayload(ctx), "providers": providerPayload, } } -func (s *Server) webUISubagentsRuntimePayload(ctx context.Context) map[string]interface{} { - if s.onSubagents == nil { - return map[string]interface{}{ - "items": []interface{}{}, - "registry": []interface{}{}, - "stream": []interface{}{}, - } - } - call := func(action string, args map[string]interface{}) interface{} { - res, err := s.onSubagents(ctx, action, args) - if err != nil { - return []interface{}{} - } - if m, ok := res.(map[string]interface{}); ok { - if items, ok := m["items"]; ok { - return items - } - } - return []interface{}{} - } - return map[string]interface{}{ - "items": call("list", map[string]interface{}{}), - "registry": call("registry", map[string]interface{}{}), - "stream": call("stream_all", map[string]interface{}{"limit": 300, "task_limit": 36}), - } -} - func (s *Server) webUIVersionPayload() map[string]interface{} { return map[string]interface{}{ "gateway_version": firstNonEmptyString(s.gatewayVersion, gatewayBuildVersion()), @@ -2303,7 +2177,7 @@ func (s *Server) webUINodeAlertsPayload(nodeList []nodes.NodeInfo, p2p map[strin if nodeID == "" { continue } - if ok, _ := row["ok"].(bool); ok { + if ok, _ := tools.MapBoolArg(row, "ok"); ok { continue } failuresByNode[nodeID]++ @@ -2374,11 +2248,10 @@ func int64Value(v interface{}) int64 { } func (s *Server) webUINodesDispatchPayload(limit int) []map[string]interface{} { - workspace := strings.TrimSpace(s.workspacePath) - if workspace == "" { + path := s.memoryFilePath("nodes-dispatch-audit.jsonl") + if path == "" { return []map[string]interface{}{} } - path := filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl") data, err := os.ReadFile(path) if err != nil { return []map[string]interface{}{} @@ -2459,11 +2332,10 @@ func (s *Server) webUINodeArtifactsPayloadFiltered(nodeFilter, actionFilter, kin } func (s *Server) readNodeDispatchAuditRows() ([]map[string]interface{}, string) { - workspace := strings.TrimSpace(s.workspacePath) - if workspace == "" { + path := s.memoryFilePath("nodes-dispatch-audit.jsonl") + if path == "" { return nil, "" } - path := filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl") data, err := os.ReadFile(path) if err != nil { return nil, path @@ -2578,6 +2450,75 @@ func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, s return nil, "", fmt.Errorf("artifact content unavailable") } +func resolveRelativeFilePath(root, raw string) (string, string, error) { + root = strings.TrimSpace(root) + if root == "" { + return "", "", fmt.Errorf("workspace not configured") + } + clean := filepath.Clean(strings.TrimSpace(raw)) + if clean == "." || clean == "" || strings.HasPrefix(clean, "..") || filepath.IsAbs(clean) { + return "", "", fmt.Errorf("invalid path") + } + full := filepath.Join(root, clean) + cleanRoot := filepath.Clean(root) + if full != cleanRoot { + prefix := cleanRoot + string(os.PathSeparator) + if !strings.HasPrefix(filepath.Clean(full), prefix) { + return "", "", fmt.Errorf("invalid path") + } + } + return clean, full, nil +} + +func relativeFilePathStatus(err error) int { + if err == nil { + return http.StatusOK + } + if err.Error() == "workspace not configured" { + return http.StatusInternalServerError + } + return http.StatusBadRequest +} + +func readRelativeTextFile(root, raw string) (string, string, bool, error) { + clean, full, err := resolveRelativeFilePath(root, raw) + if err != nil { + return "", "", false, err + } + b, err := os.ReadFile(full) + if err != nil { + if os.IsNotExist(err) { + return clean, "", false, nil + } + return clean, "", false, err + } + return clean, string(b), true, nil +} + +func writeRelativeTextFile(root, raw string, content string, ensureDir bool) (string, error) { + clean, full, err := resolveRelativeFilePath(root, raw) + if err != nil { + return "", err + } + if ensureDir { + if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { + return "", err + } + } + if err := os.WriteFile(full, []byte(content), 0644); err != nil { + return "", err + } + return clean, nil +} + +func (s *Server) memoryFilePath(name string) string { + workspace := strings.TrimSpace(s.workspacePath) + if workspace == "" { + return "" + } + return filepath.Join(workspace, "memory", strings.TrimSpace(name)) +} + func (s *Server) filteredNodeDispatches(nodeFilter, actionFilter string, limit int) []map[string]interface{} { items := s.webUINodesDispatchPayload(limit) if nodeFilter == "" && actionFilter == "" { @@ -2821,7 +2762,7 @@ func (s *Server) webUISessionsPayload() map[string]interface{} { } func (s *Server) webUITaskQueuePayload(includeHeartbeat bool) map[string]interface{} { - path := filepath.Join(strings.TrimSpace(s.workspacePath), "memory", "task-audit.jsonl") + path := s.memoryFilePath("task-audit.jsonl") b, err := os.ReadFile(path) lines := []string{} if err == nil { @@ -2875,7 +2816,7 @@ func (s *Server) webUITaskQueuePayload(includeHeartbeat bool) map[string]interfa running = append(running, row) } } - queuePath := filepath.Join(strings.TrimSpace(s.workspacePath), "memory", "task_queue.json") + queuePath := s.memoryFilePath("task_queue.json") if qb, qErr := os.ReadFile(queuePath); qErr == nil { var q map[string]interface{} if json.Unmarshal(qb, &q) == nil { @@ -2901,8 +2842,7 @@ func (s *Server) webUITaskQueuePayload(includeHeartbeat bool) map[string]interfa } func (s *Server) webUIEKGSummaryPayload(window string) map[string]interface{} { - workspace := strings.TrimSpace(s.workspacePath) - ekgPath := filepath.Join(workspace, "memory", "ekg-events.jsonl") + ekgPath := s.memoryFilePath("ekg-events.jsonl") window = strings.ToLower(strings.TrimSpace(window)) windowDur := 24 * time.Hour switch window { @@ -3052,8 +2992,7 @@ func (s *Server) handleWebUITools(w http.ResponseWriter, r *http.Request) { serverChecks = buildMCPServerChecks(cfg) } } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "tools": toolsList, "mcp_tools": mcpItems, "mcp_server_checks": serverChecks, @@ -3203,7 +3142,7 @@ func (s *Server) handleWebUIMCPInstall(w http.ResponseWriter, r *http.Request) { http.Error(w, strings.TrimSpace(msg), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "package": pkgName, "output": out, @@ -3221,7 +3160,7 @@ func (s *Server) handleWebUINodes(w http.ResponseWriter, r *http.Request) { case http.MethodGet: payload := s.webUINodesPayload(r.Context()) payload["ok"] = true - _ = json.NewEncoder(w).Encode(payload) + writeJSON(w, payload) case http.MethodPost: var body struct { Action string `json:"action"` @@ -3242,7 +3181,7 @@ func (s *Server) handleWebUINodes(w http.ResponseWriter, r *http.Request) { } id := body.ID ok := s.mgr.Remove(id) - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": ok, "id": id}) + writeJSON(w, map[string]interface{}{"ok": true, "deleted": ok, "id": id}) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } @@ -3257,16 +3196,8 @@ func (s *Server) handleWebUINodeDispatches(w http.ResponseWriter, r *http.Reques http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - limit := 50 - if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" { - if n, err := strconv.Atoi(raw); err == nil && n > 0 { - if n > 500 { - n = 500 - } - limit = n - } - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + limit := queryBoundedPositiveInt(r, "limit", 50, 500) + writeJSON(w, map[string]interface{}{ "ok": true, "items": s.webUINodesDispatchPayload(limit), }) @@ -3313,7 +3244,7 @@ func (s *Server) handleWebUINodeDispatchReplay(w http.ResponseWriter, r *http.Re http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "result": resp, }) @@ -3328,20 +3259,12 @@ func (s *Server) handleWebUINodeArtifacts(w http.ResponseWriter, r *http.Request http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - limit := 200 - if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" { - if n, err := strconv.Atoi(raw); err == nil && n > 0 { - if n > 1000 { - n = 1000 - } - limit = n - } - } + limit := queryBoundedPositiveInt(r, "limit", 200, 1000) retentionSummary := s.applyNodeArtifactRetention() nodeFilter := strings.TrimSpace(r.URL.Query().Get("node")) actionFilter := strings.TrimSpace(r.URL.Query().Get("action")) kindFilter := strings.TrimSpace(r.URL.Query().Get("kind")) - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "items": s.webUINodeArtifactsPayloadFiltered(nodeFilter, actionFilter, kindFilter, limit), "artifact_retention": retentionSummary, @@ -3358,15 +3281,7 @@ func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.R return } retentionSummary := s.applyNodeArtifactRetention() - limit := 200 - if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" { - if n, err := strconv.Atoi(raw); err == nil && n > 0 { - if n > 1000 { - n = 1000 - } - limit = n - } - } + limit := queryBoundedPositiveInt(r, "limit", 200, 1000) nodeFilter := strings.TrimSpace(r.URL.Query().Get("node")) actionFilter := strings.TrimSpace(r.URL.Query().Get("action")) kindFilter := strings.TrimSpace(r.URL.Query().Get("kind")) @@ -3531,7 +3446,7 @@ func (s *Server) handleWebUINodeArtifactDelete(w http.ResponseWriter, r *http.Re http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "id": strings.TrimSpace(body.ID), "deleted_file": deletedFile, @@ -3583,7 +3498,7 @@ func (s *Server) handleWebUINodeArtifactPrune(w http.ResponseWriter, r *http.Req deletedFiles++ } } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "pruned": pruned, "deleted_files": deletedFiles, @@ -3650,6 +3565,43 @@ func (s *Server) fetchRegistryItems(ctx context.Context) []map[string]interface{ } func (s *Server) fetchRemoteNodeRegistry(ctx context.Context, node nodes.NodeInfo) ([]map[string]interface{}, error) { + baseURL := nodeWebUIBaseURL(node) + if baseURL == "" { + return nil, fmt.Errorf("node %s endpoint missing", strings.TrimSpace(node.ID)) + } + reqURL := baseURL + "/api/config?mode=normalized" + if tok := strings.TrimSpace(node.Token); tok != "" { + reqURL += "&token=" + url.QueryEscape(tok) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + return s.fetchRemoteNodeRegistryLegacy(ctx, node) + } + var payload struct { + OK bool `json:"ok"` + Config cfgpkg.NormalizedConfig `json:"config"` + RawConfig map[string]interface{} `json:"raw_config"` + } + if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil { + return s.fetchRemoteNodeRegistryLegacy(ctx, node) + } + items := buildRegistryItemsFromNormalizedConfig(payload.Config) + if len(items) > 0 { + return items, nil + } + return s.fetchRemoteNodeRegistryLegacy(ctx, node) +} + +func (s *Server) fetchRemoteNodeRegistryLegacy(ctx context.Context, node nodes.NodeInfo) ([]map[string]interface{}, error) { baseURL := nodeWebUIBaseURL(node) if baseURL == "" { return nil, fmt.Errorf("node %s endpoint missing", strings.TrimSpace(node.ID)) @@ -3683,6 +3635,50 @@ func (s *Server) fetchRemoteNodeRegistry(ctx context.Context, node nodes.NodeInf return payload.Result.Items, nil } +func buildRegistryItemsFromNormalizedConfig(view cfgpkg.NormalizedConfig) []map[string]interface{} { + items := make([]map[string]interface{}, 0, len(view.Core.Subagents)) + for agentID, subcfg := range view.Core.Subagents { + if strings.TrimSpace(agentID) == "" { + continue + } + items = append(items, map[string]interface{}{ + "agent_id": agentID, + "enabled": subcfg.Enabled, + "type": "subagent", + "transport": fallbackString(strings.TrimSpace(subcfg.RuntimeClass), "local"), + "node_id": "", + "parent_agent_id": "", + "notify_main_policy": "final_only", + "display_name": "", + "role": strings.TrimSpace(subcfg.Role), + "description": "", + "system_prompt_file": strings.TrimSpace(subcfg.Prompt), + "prompt_file_found": false, + "memory_namespace": "", + "tool_allowlist": append([]string(nil), subcfg.ToolAllowlist...), + "tool_visibility": map[string]interface{}{}, + "effective_tools": []string{}, + "inherited_tools": []string{}, + "routing_keywords": routeKeywordsForRegistry(view.Runtime.Router.Rules, agentID), + "managed_by": "config.json", + }) + } + sort.Slice(items, func(i, j int) bool { + return stringFromMap(items[i], "agent_id") < stringFromMap(items[j], "agent_id") + }) + return items +} + +func routeKeywordsForRegistry(rules []cfgpkg.AgentRouteRule, agentID string) []string { + agentID = strings.TrimSpace(agentID) + for _, rule := range rules { + if strings.TrimSpace(rule.AgentID) == agentID { + return append([]string(nil), rule.Keywords...) + } + } + return nil +} + func nodeWebUIBaseURL(node nodes.NodeInfo) string { endpoint := strings.TrimSpace(node.Endpoint) if endpoint == "" || strings.EqualFold(endpoint, "gateway") { @@ -3769,21 +3765,29 @@ func buildAgentTreeRoot(nodeID string, items []map[string]interface{}) map[strin } func stringFromMap(item map[string]interface{}, key string) string { - if item == nil { - return "" - } - v, _ := item[key].(string) - return strings.TrimSpace(v) + return tools.MapStringArg(item, key) } func boolFromMap(item map[string]interface{}, key string) bool { if item == nil { return false } - v, _ := item[key].(bool) + v, _ := tools.MapBoolArg(item, key) return v } +func rawStringFromMap(item map[string]interface{}, key string) string { + return tools.MapRawStringArg(item, key) +} + +func stringListFromMap(item map[string]interface{}, key string) []string { + return tools.MapStringListArg(item, key) +} + +func intFromMap(item map[string]interface{}, key string, fallback int) int { + return tools.MapIntArg(item, key, fallback) +} + func fallbackString(value, fallback string) string { value = strings.TrimSpace(value) if value != "" { @@ -3815,9 +3819,9 @@ func (s *Server) handleWebUICron(w http.ResponseWriter, r *http.Request) { return } if action == "list" { - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "jobs": normalizeCronJobs(res)}) + writeJSON(w, map[string]interface{}{"ok": true, "jobs": normalizeCronJobs(res)}) } else { - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "job": normalizeCronJob(res)}) + writeJSON(w, map[string]interface{}{"ok": true, "job": normalizeCronJob(res)}) } case http.MethodPost: args := map[string]interface{}{} @@ -3828,7 +3832,7 @@ func (s *Server) handleWebUICron(w http.ResponseWriter, r *http.Request) { args["id"] = id } action := "create" - if a, ok := args["action"].(string); ok && strings.TrimSpace(a) != "" { + if a := tools.MapStringArg(args, "action"); a != "" { action = strings.ToLower(strings.TrimSpace(a)) } res, err := s.onCron(action, args) @@ -3836,7 +3840,7 @@ func (s *Server) handleWebUICron(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "result": normalizeCronJob(res)}) + writeJSON(w, map[string]interface{}{"ok": true, "result": normalizeCronJob(res)}) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } @@ -3899,22 +3903,20 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { files = append(files, filepath.ToSlash(rel)) return nil }) - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "id": id, "files": files}) + writeJSON(w, map[string]interface{}{"ok": true, "id": id, "files": files}) return } if f := strings.TrimSpace(r.URL.Query().Get("file")); f != "" { - clean := filepath.Clean(f) - if strings.HasPrefix(clean, "..") { - http.Error(w, "invalid file path", http.StatusBadRequest) - return - } - full := filepath.Join(skillPath, clean) - b, err := os.ReadFile(full) + clean, content, found, err := readRelativeTextFile(skillPath, f) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), relativeFilePathStatus(err)) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "id": id, "file": filepath.ToSlash(clean), "content": string(b)}) + if !found { + http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "id": id, "file": filepath.ToSlash(clean), "content": content}) return } } @@ -3996,7 +3998,7 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { items = append(items, it) } } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "skills": items, "source": "clawhub", @@ -4012,7 +4014,7 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "imported": imported}) + writeJSON(w, map[string]interface{}{"ok": true, "imported": imported}) return } @@ -4021,15 +4023,14 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid json", http.StatusBadRequest) return } - action, _ := body["action"].(string) - action = strings.ToLower(strings.TrimSpace(action)) + action := strings.ToLower(stringFromMap(body, "action")) if action == "install_clawhub" { output, err := ensureClawHubReady(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "output": output, "installed": true, @@ -4037,8 +4038,8 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { }) return } - id, _ := body["id"].(string) - name, _ := body["name"].(string) + id := stringFromMap(body, "id") + name := stringFromMap(body, "name") if strings.TrimSpace(name) == "" { name = id } @@ -4057,13 +4058,7 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { http.Error(w, "clawhub is not installed. please install clawhub first.", http.StatusPreconditionFailed) return } - ignoreSuspicious := false - switch v := body["ignore_suspicious"].(type) { - case bool: - ignoreSuspicious = v - case string: - ignoreSuspicious = strings.EqualFold(strings.TrimSpace(v), "true") || strings.TrimSpace(v) == "1" - } + ignoreSuspicious, _ := tools.MapBoolArg(body, "ignore_suspicious") args := []string{"install", name} if ignoreSuspicious { args = append(args, "--force") @@ -4081,7 +4076,7 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("install failed: %v\n%s", err, outText), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "installed": name, "output": string(out)}) + writeJSON(w, map[string]interface{}{"ok": true, "installed": name, "output": string(out)}) case "enable": if _, err := os.Stat(disabledPath); err == nil { if err := os.Rename(disabledPath, enabledPath); err != nil { @@ -4089,7 +4084,7 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { return } } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true}) + writeJSON(w, map[string]interface{}{"ok": true}) case "disable": if _, err := os.Stat(enabledPath); err == nil { if err := os.Rename(enabledPath, disabledPath); err != nil { @@ -4097,41 +4092,25 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { return } } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true}) + writeJSON(w, map[string]interface{}{"ok": true}) case "write_file": skillPath, err := resolveSkillPath(name) if err != nil { http.Error(w, err.Error(), http.StatusNotFound) return } - filePath, _ := body["file"].(string) - clean := filepath.Clean(strings.TrimSpace(filePath)) - if clean == "" || strings.HasPrefix(clean, "..") { - http.Error(w, "invalid file path", http.StatusBadRequest) + content := rawStringFromMap(body, "content") + filePath := stringFromMap(body, "file") + clean, err := writeRelativeTextFile(skillPath, filePath, content, true) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) return } - content, _ := body["content"].(string) - full := filepath.Join(skillPath, clean) - if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := os.WriteFile(full, []byte(content), 0644); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "name": name, "file": filepath.ToSlash(clean)}) + writeJSON(w, map[string]interface{}{"ok": true, "name": name, "file": filepath.ToSlash(clean)}) case "create", "update": - desc, _ := body["description"].(string) - sys, _ := body["system_prompt"].(string) - var toolsList []string - if arr, ok := body["tools"].([]interface{}); ok { - for _, v := range arr { - if sv, ok := v.(string); ok && strings.TrimSpace(sv) != "" { - toolsList = append(toolsList, strings.TrimSpace(sv)) - } - } - } + desc := rawStringFromMap(body, "description") + sys := rawStringFromMap(body, "system_prompt") + toolsList := stringListFromMap(body, "tools") if action == "create" { if _, err := os.Stat(enabledPath); err == nil { http.Error(w, "skill already exists", http.StatusBadRequest) @@ -4147,7 +4126,7 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true}) + writeJSON(w, map[string]interface{}{"ok": true}) default: http.Error(w, "unsupported action", http.StatusBadRequest) } @@ -4167,7 +4146,7 @@ func (s *Server) handleWebUISkills(w http.ResponseWriter, r *http.Request) { if err := os.RemoveAll(pathB); err == nil { deleted = true } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": deleted, "id": id}) + writeJSON(w, map[string]interface{}{"ok": true, "deleted": deleted, "id": id}) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) @@ -4336,15 +4315,15 @@ func normalizeCronJob(v interface{}) map[string]interface{} { out[k] = val } if sch, ok := m["schedule"].(map[string]interface{}); ok { - kind, _ := sch["kind"].(string) - if expr, ok := sch["expr"].(string); ok && expr != "" { + kind := stringFromMap(sch, "kind") + if expr := stringFromMap(sch, "expr"); expr != "" { out["expr"] = expr } else if strings.EqualFold(strings.TrimSpace(kind), "every") { - if every, ok := sch["everyMs"].(float64); ok && every > 0 { - out["expr"] = fmt.Sprintf("@every %s", (time.Duration(int64(every)) * time.Millisecond).String()) + if every := intFromMap(sch, "everyMs", 0); every > 0 { + out["expr"] = fmt.Sprintf("@every %s", (time.Duration(every) * time.Millisecond).String()) } } else if strings.EqualFold(strings.TrimSpace(kind), "at") { - if at, ok := sch["atMs"].(float64); ok && at > 0 { + if at := intFromMap(sch, "atMs", 0); at > 0 { out["expr"] = time.UnixMilli(int64(at)).Format(time.RFC3339) } } @@ -5133,7 +5112,7 @@ func (s *Server) handleWebUISessions(w http.ResponseWriter, r *http.Request) { if len(out) == 0 { out = append(out, item{Key: "main", Channel: "main"}) } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "sessions": out}) + writeJSON(w, map[string]interface{}{"ok": true, "sessions": out}) } func isUserFacingSessionKey(key string) bool { @@ -5159,206 +5138,6 @@ func isUserFacingSessionKey(key string) bool { } } -func (s *Server) handleWebUISubagentProfiles(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - workspace := strings.TrimSpace(s.workspacePath) - if workspace == "" { - http.Error(w, "workspace path not set", http.StatusInternalServerError) - return - } - store := tools.NewSubagentProfileStore(workspace) - - switch r.Method { - case http.MethodGet: - agentID := strings.TrimSpace(r.URL.Query().Get("agent_id")) - if agentID != "" { - profile, ok, err := store.Get(agentID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "found": ok, "profile": profile}) - return - } - profiles, err := store.List() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "profiles": profiles}) - case http.MethodDelete: - agentID := strings.TrimSpace(r.URL.Query().Get("agent_id")) - if agentID == "" { - http.Error(w, "agent_id required", http.StatusBadRequest) - return - } - if err := store.Delete(agentID); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": true, "agent_id": agentID}) - case http.MethodPost: - var body struct { - Action string `json:"action"` - AgentID string `json:"agent_id"` - Name string `json:"name"` - NotifyMainPolicy string `json:"notify_main_policy"` - Role string `json:"role"` - SystemPromptFile string `json:"system_prompt_file"` - MemoryNamespace string `json:"memory_namespace"` - Status string `json:"status"` - ToolAllowlist []string `json:"tool_allowlist"` - MaxRetries *int `json:"max_retries"` - RetryBackoffMS *int `json:"retry_backoff_ms"` - TimeoutSec *int `json:"timeout_sec"` - MaxTaskChars *int `json:"max_task_chars"` - MaxResultChars *int `json:"max_result_chars"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - action := strings.ToLower(strings.TrimSpace(body.Action)) - if action == "" { - action = "upsert" - } - agentID := strings.TrimSpace(body.AgentID) - if agentID == "" { - http.Error(w, "agent_id required", http.StatusBadRequest) - return - } - - switch action { - case "create": - if _, ok, err := store.Get(agentID); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } else if ok { - http.Error(w, "subagent profile already exists", http.StatusConflict) - return - } - profile, err := store.Upsert(tools.SubagentProfile{ - AgentID: agentID, - Name: body.Name, - NotifyMainPolicy: body.NotifyMainPolicy, - Role: body.Role, - SystemPromptFile: body.SystemPromptFile, - MemoryNamespace: body.MemoryNamespace, - Status: body.Status, - ToolAllowlist: body.ToolAllowlist, - MaxRetries: derefInt(body.MaxRetries), - RetryBackoff: derefInt(body.RetryBackoffMS), - TimeoutSec: derefInt(body.TimeoutSec), - MaxTaskChars: derefInt(body.MaxTaskChars), - MaxResultChars: derefInt(body.MaxResultChars), - }) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "profile": profile}) - case "update": - existing, ok, err := store.Get(agentID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if !ok || existing == nil { - http.Error(w, "subagent profile not found", http.StatusNotFound) - return - } - next := *existing - next.Name = body.Name - next.NotifyMainPolicy = body.NotifyMainPolicy - next.Role = body.Role - next.SystemPromptFile = body.SystemPromptFile - next.MemoryNamespace = body.MemoryNamespace - if body.Status != "" { - next.Status = body.Status - } - if body.ToolAllowlist != nil { - next.ToolAllowlist = body.ToolAllowlist - } - if body.MaxRetries != nil { - next.MaxRetries = *body.MaxRetries - } - if body.RetryBackoffMS != nil { - next.RetryBackoff = *body.RetryBackoffMS - } - if body.TimeoutSec != nil { - next.TimeoutSec = *body.TimeoutSec - } - if body.MaxTaskChars != nil { - next.MaxTaskChars = *body.MaxTaskChars - } - if body.MaxResultChars != nil { - next.MaxResultChars = *body.MaxResultChars - } - profile, err := store.Upsert(next) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "profile": profile}) - case "enable", "disable": - existing, ok, err := store.Get(agentID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if !ok || existing == nil { - http.Error(w, "subagent profile not found", http.StatusNotFound) - return - } - if action == "enable" { - existing.Status = "active" - } else { - existing.Status = "disabled" - } - profile, err := store.Upsert(*existing) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "profile": profile}) - case "delete": - if err := store.Delete(agentID); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": true, "agent_id": agentID}) - case "upsert": - profile, err := store.Upsert(tools.SubagentProfile{ - AgentID: agentID, - Name: body.Name, - NotifyMainPolicy: body.NotifyMainPolicy, - Role: body.Role, - SystemPromptFile: body.SystemPromptFile, - MemoryNamespace: body.MemoryNamespace, - Status: body.Status, - ToolAllowlist: body.ToolAllowlist, - MaxRetries: derefInt(body.MaxRetries), - RetryBackoff: derefInt(body.RetryBackoffMS), - TimeoutSec: derefInt(body.TimeoutSec), - MaxTaskChars: derefInt(body.MaxTaskChars), - MaxResultChars: derefInt(body.MaxResultChars), - }) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "profile": profile}) - default: - http.Error(w, "unsupported action", http.StatusBadRequest) - } - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - func (s *Server) handleWebUIToolAllowlistGroups(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -5368,7 +5147,7 @@ func (s *Server) handleWebUIToolAllowlistGroups(w http.ResponseWriter, r *http.R http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "groups": tools.ToolAllowlistGroups(), }) @@ -5407,7 +5186,7 @@ func (s *Server) handleWebUISubagentsRuntime(w http.ResponseWriter, r *http.Requ body = map[string]interface{}{} } if action == "" { - if raw, _ := body["action"].(string); raw != "" { + if raw := stringFromMap(body, "action"); raw != "" { action = strings.ToLower(strings.TrimSpace(raw)) } } @@ -5423,72 +5202,7 @@ func (s *Server) handleWebUISubagentsRuntime(w http.ResponseWriter, r *http.Requ http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "result": result}) -} - -func (s *Server) handleWebUISubagentsRuntimeLive(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if s.onSubagents == nil { - http.Error(w, "subagent runtime handler not configured", http.StatusServiceUnavailable) - return - } - conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - ctx := r.Context() - taskID := strings.TrimSpace(r.URL.Query().Get("task_id")) - previewTaskID := strings.TrimSpace(r.URL.Query().Get("preview_task_id")) - sub := s.subscribeSubagentLive(ctx, taskID, previewTaskID) - initial := map[string]interface{}{ - "ok": true, - "type": "subagents_live", - "payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID), - } - _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteJSON(initial); err != nil { - return - } - for { - select { - case <-ctx.Done(): - return - case payload := <-sub: - _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil { - return - } - } - } -} - -func (s *Server) buildSubagentsLivePayload(ctx context.Context, taskID, previewTaskID string) map[string]interface{} { - call := func(action string, args map[string]interface{}) map[string]interface{} { - res, err := s.onSubagents(ctx, action, args) - if err != nil { - return map[string]interface{}{} - } - if m, ok := res.(map[string]interface{}); ok { - return m - } - return map[string]interface{}{} - } - payload := map[string]interface{}{} - taskID = strings.TrimSpace(taskID) - previewTaskID = strings.TrimSpace(previewTaskID) - if taskID != "" { - payload["thread"] = call("thread", map[string]interface{}{"id": taskID, "limit": 50}) - payload["inbox"] = call("inbox", map[string]interface{}{"id": taskID, "limit": 50}) - } - if previewTaskID != "" { - payload["preview"] = call("stream", map[string]interface{}{"id": previewTaskID, "limit": 12}) - } - return payload + writeJSON(w, map[string]interface{}{"ok": true, "result": result}) } func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { @@ -5514,21 +5228,19 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { } files = append(files, e.Name()) } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "files": files}) + writeJSON(w, map[string]interface{}{"ok": true, "files": files}) return } - clean := filepath.Clean(path) - if strings.HasPrefix(clean, "..") { - http.Error(w, "invalid path", http.StatusBadRequest) - return - } - full := filepath.Join(memoryDir, clean) - b, err := os.ReadFile(full) + clean, content, found, err := readRelativeTextFile(memoryDir, path) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "path": clean, "content": string(b)}) + if !found { + http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "content": content}) case http.MethodPost: var body struct { Path string `json:"path"` @@ -5538,81 +5250,61 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid json", http.StatusBadRequest) return } - clean := filepath.Clean(body.Path) - if clean == "" || strings.HasPrefix(clean, "..") { - http.Error(w, "invalid path", http.StatusBadRequest) + clean, err := writeRelativeTextFile(memoryDir, body.Path, body.Content, false) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) return } - full := filepath.Join(memoryDir, clean) - if err := os.WriteFile(full, []byte(body.Content), 0644); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "path": clean}) + writeJSON(w, map[string]interface{}{"ok": true, "path": clean}) case http.MethodDelete: - path := filepath.Clean(r.URL.Query().Get("path")) - if path == "" || strings.HasPrefix(path, "..") { - http.Error(w, "invalid path", http.StatusBadRequest) + clean, full, err := resolveRelativeFilePath(memoryDir, r.URL.Query().Get("path")) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return } - full := filepath.Join(memoryDir, path) if err := os.Remove(full); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": true, "path": path}) + writeJSON(w, map[string]interface{}{"ok": true, "deleted": true, "path": clean}) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } } -func (s *Server) handleWebUITaskAudit(w http.ResponseWriter, r *http.Request) { +func (s *Server) handleWebUIWorkspaceFile(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) return } - if r.Method != http.MethodGet { + workspace := strings.TrimSpace(s.workspacePath) + switch r.Method { + case http.MethodGet: + path := strings.TrimSpace(r.URL.Query().Get("path")) + clean, content, found, err := readRelativeTextFile(workspace, path) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "found": found, "content": content}) + case http.MethodPost: + var body struct { + Path string `json:"path"` + Content string `json:"content"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + clean, err := writeRelativeTextFile(workspace, body.Path, body.Content, true) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "saved": true}) + default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return } - path := filepath.Join(strings.TrimSpace(s.workspacePath), "memory", "task-audit.jsonl") - includeHeartbeat := r.URL.Query().Get("include_heartbeat") == "1" - limit := 100 - if v := r.URL.Query().Get("limit"); v != "" { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - if n > 500 { - n = 500 - } - limit = n - } - } - b, err := os.ReadFile(path) - if err != nil { - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "items": []map[string]interface{}{}}) - return - } - lines := strings.Split(string(b), "\n") - if len(lines) > 0 && lines[len(lines)-1] == "" { - lines = lines[:len(lines)-1] - } - if len(lines) > limit { - lines = lines[len(lines)-limit:] - } - items := make([]map[string]interface{}, 0, len(lines)) - for _, ln := range lines { - if ln == "" { - continue - } - var row map[string]interface{} - if err := json.Unmarshal([]byte(ln), &row); err == nil { - source := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", row["source"]))) - if !includeHeartbeat && source == "heartbeat" { - continue - } - items = append(items, row) - } - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "items": items}) } func (s *Server) handleWebUITaskQueue(w http.ResponseWriter, r *http.Request) { @@ -5624,7 +5316,7 @@ func (s *Server) handleWebUITaskQueue(w http.ResponseWriter, r *http.Request) { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - path := filepath.Join(strings.TrimSpace(s.workspacePath), "memory", "task-audit.jsonl") + path := s.memoryFilePath("task-audit.jsonl") includeHeartbeat := r.URL.Query().Get("include_heartbeat") == "1" b, err := os.ReadFile(path) lines := []string{} @@ -5681,7 +5373,7 @@ func (s *Server) handleWebUITaskQueue(w http.ResponseWriter, r *http.Request) { } // Merge command watchdog queue from memory/task_queue.json for visibility. - queuePath := filepath.Join(strings.TrimSpace(s.workspacePath), "memory", "task_queue.json") + queuePath := s.memoryFilePath("task_queue.json") if qb, qErr := os.ReadFile(queuePath); qErr == nil { var q map[string]interface{} if json.Unmarshal(qb, &q) == nil { @@ -5780,7 +5472,7 @@ func (s *Server) handleWebUITaskQueue(w http.ResponseWriter, r *http.Request) { sort.Slice(items, func(i, j int) bool { return fmt.Sprintf("%v", items[i]["time"]) > fmt.Sprintf("%v", items[j]["time"]) }) stats := map[string]int{"total": len(items), "running": len(running)} - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "running": running, "items": items, "stats": stats}) + writeJSON(w, map[string]interface{}{"ok": true, "running": running, "items": items, "stats": stats}) } func (s *Server) loadEKGRowsCached(path string, maxLines int) []map[string]interface{} { @@ -5834,8 +5526,7 @@ func (s *Server) handleWebUIEKGStats(w http.ResponseWriter, r *http.Request) { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - workspace := strings.TrimSpace(s.workspacePath) - ekgPath := filepath.Join(workspace, "memory", "ekg-events.jsonl") + ekgPath := s.memoryFilePath("ekg-events.jsonl") window := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("window"))) windowDur := 24 * time.Hour switch window { @@ -5929,7 +5620,7 @@ func (s *Server) handleWebUIEKGStats(w http.ResponseWriter, r *http.Request) { } return out } - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSON(w, map[string]interface{}{ "ok": true, "window": selectedWindow, "provider_top": toTopScore(providerScore, 5), @@ -5943,72 +5634,6 @@ func (s *Server) handleWebUIEKGStats(w http.ResponseWriter, r *http.Request) { }) } -func (s *Server) handleWebUIExecApprovals(w http.ResponseWriter, r *http.Request) { - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if strings.TrimSpace(s.configPath) == "" { - http.Error(w, "config path not set", http.StatusInternalServerError) - return - } - b, err := os.ReadFile(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - var cfg map[string]interface{} - if err := json.Unmarshal(b, &cfg); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if r.Method == http.MethodGet { - toolsMap, _ := cfg["tools"].(map[string]interface{}) - shellMap, _ := toolsMap["shell"].(map[string]interface{}) - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "exec_approvals": shellMap}) - return - } - if r.Method == http.MethodPost { - var body map[string]interface{} - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - toolsMap, _ := cfg["tools"].(map[string]interface{}) - if toolsMap == nil { - toolsMap = map[string]interface{}{} - cfg["tools"] = toolsMap - } - shellMap, _ := toolsMap["shell"].(map[string]interface{}) - if shellMap == nil { - shellMap = map[string]interface{}{} - toolsMap["shell"] = shellMap - } - for k, v := range body { - shellMap[k] = v - } - out, _ := json.MarshalIndent(cfg, "", " ") - tmp := s.configPath + ".tmp" - if err := os.WriteFile(tmp, out, 0644); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := os.Rename(tmp, s.configPath); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if s.onConfigAfter != nil { - if err := s.onConfigAfter(); err != nil { - http.Error(w, "config saved but reload failed: "+err.Error(), http.StatusInternalServerError) - return - } - } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "reloaded": true}) - return - } - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) -} - func (s *Server) handleWebUILogsRecent(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -6023,12 +5648,7 @@ func (s *Server) handleWebUILogsRecent(w http.ResponseWriter, r *http.Request) { http.Error(w, "log path not configured", http.StatusInternalServerError) return } - limit := 10 - if v := strings.TrimSpace(r.URL.Query().Get("limit")); v != "" { - if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 { - limit = n - } - } + limit := queryBoundedPositiveInt(r, "limit", 10, 200) b, err := os.ReadFile(path) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -6048,7 +5668,7 @@ func (s *Server) handleWebUILogsRecent(w http.ResponseWriter, r *http.Request) { out = append(out, parsed) } } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "logs": out}) + writeJSON(w, map[string]interface{}{"ok": true, "logs": out}) } func parseLogLine(line string) (map[string]interface{}, bool) { @@ -6120,62 +5740,6 @@ func (s *Server) handleWebUILogsLive(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) handleWebUILogsStream(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Deprecation", "true") - w.Header().Set("X-Clawgo-Replaced-By", "/api/logs/live") - if !s.checkAuth(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - path := strings.TrimSpace(s.logFilePath) - if path == "" { - http.Error(w, "log path not configured", http.StatusInternalServerError) - return - } - f, err := os.Open(path) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer f.Close() - fi, _ := f.Stat() - if fi != nil { - _, _ = f.Seek(fi.Size(), io.SeekStart) - } - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "stream unsupported", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/x-ndjson") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - - reader := bufio.NewReader(f) - for { - select { - case <-r.Context().Done(): - return - default: - line, err := reader.ReadString('\n') - if len(line) > 0 { - if parsed, ok := parseLogLine(line); ok { - b, _ := json.Marshal(parsed) - _, _ = w.Write(append(b, '\n')) - flusher.Flush() - } - } - if err != nil { - time.Sleep(500 * time.Millisecond) - } - } - } -} - func (s *Server) checkAuth(r *http.Request) bool { if s.token == "" { return true diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index f8e6680..e5313cc 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -269,6 +269,7 @@ func TestHandleWebUIConfigRequiresConfirmForProviderAPIBaseChange(t *testing.T) srv := NewServer("127.0.0.1", 0, "", nil) srv.SetConfigPath(cfgPath) + srv.SetConfigAfterHook(func() error { return nil }) req := httptest.NewRequest(http.MethodPost, "/api/config", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -287,6 +288,74 @@ func TestHandleWebUIConfigRequiresConfirmForProviderAPIBaseChange(t *testing.T) } } +func TestHandleWebUIConfigAcceptsStringConfirmRisky(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + pc := cfg.Models.Providers["openai"] + pc.APIBase = "https://old.example/v1" + pc.APIKey = "test-key" + cfg.Models.Providers["openai"] = pc + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + bodyCfg := cfgpkg.DefaultConfig() + bodyCfg.Logging.Enabled = false + bodyPC := bodyCfg.Models.Providers["openai"] + bodyPC.APIBase = "https://new.example/v1" + bodyPC.APIKey = "test-key" + bodyCfg.Models.Providers["openai"] = bodyPC + bodyMap := map[string]interface{}{} + raw, err := json.Marshal(bodyCfg) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + if err := json.Unmarshal(raw, &bodyMap); err != nil { + t.Fatalf("unmarshal body map: %v", err) + } + bodyMap["confirm_risky"] = "true" + body, err := json.Marshal(bodyMap) + if err != nil { + t.Fatalf("marshal request body: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "", nil) + srv.SetConfigPath(cfgPath) + srv.SetConfigAfterHook(func() error { return nil }) + + req := httptest.NewRequest(http.MethodPost, "/api/config", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + srv.handleWebUIConfig(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestNormalizeCronJobParsesStringScheduleValues(t *testing.T) { + t.Parallel() + + job := normalizeCronJob(map[string]interface{}{ + "schedule": map[string]interface{}{ + "kind": "every", + "everyMs": "60000", + }, + "payload": map[string]interface{}{ + "message": "hello", + }, + }) + if got, _ := job["expr"].(string); got == "" || !strings.Contains(got, "@every") { + t.Fatalf("expected normalized @every expr, got %#v", job["expr"]) + } +} + func TestHandleWebUIConfigRequiresConfirmForCustomProviderSecretChange(t *testing.T) { t.Parallel() @@ -414,6 +483,127 @@ func TestHandleWebUIConfigReturnsReloadHookError(t *testing.T) { } } +func TestHandleWebUIConfigNormalizedGet(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + cfg.Agents.Subagents["coder"] = cfgpkg.SubagentConfig{ + Enabled: true, + Role: "coding", + SystemPromptFile: "agents/coder/AGENT.md", + } + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "", nil) + srv.SetConfigPath(cfgPath) + req := httptest.NewRequest(http.MethodGet, "/api/config?mode=normalized", nil) + rec := httptest.NewRecorder() + + srv.handleWebUIConfig(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + var payload map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["ok"] != true { + t.Fatalf("expected ok=true, got %#v", payload) + } + configMap, _ := payload["config"].(map[string]interface{}) + coreMap, _ := configMap["core"].(map[string]interface{}) + if strings.TrimSpace(fmt.Sprintf("%v", coreMap["main_agent_id"])) != "main" { + t.Fatalf("unexpected normalized config: %#v", payload) + } +} + +func TestHandleWebUIConfigNormalizedPost(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + body := map[string]interface{}{ + "confirm_risky": true, + "core": map[string]interface{}{ + "default_provider": "openai", + "default_model": "gpt-5.4", + "main_agent_id": "main", + "subagents": map[string]interface{}{ + "reviewer": map[string]interface{}{ + "enabled": true, + "role": "testing", + "prompt": "agents/reviewer/AGENT.md", + "provider": "openai", + "tool_allowlist": []interface{}{"shell"}, + "runtime_class": "provider_bound", + }, + }, + "tools": map[string]interface{}{"shell_enabled": true, "mcp_enabled": false}, + "gateway": map[string]interface{}{"host": "127.0.0.1", "port": float64(18790)}, + }, + "runtime": map[string]interface{}{ + "router": map[string]interface{}{ + "enabled": true, + "strategy": "rules_first", + "allow_direct_agent_chat": false, + "max_hops": float64(6), + "default_timeout_sec": float64(600), + "default_wait_reply": true, + "sticky_thread_owner": true, + "rules": []interface{}{ + map[string]interface{}{"agent_id": "reviewer", "keywords": []interface{}{"review"}}, + }, + }, + "providers": map[string]interface{}{ + "openai": map[string]interface{}{ + "auth": "bearer", + "api_base": "https://api.openai.com/v1", + "timeout_sec": float64(30), + }, + }, + }, + } + raw, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "", nil) + srv.SetConfigPath(cfgPath) + srv.SetConfigAfterHook(func() error { return nil }) + + req := httptest.NewRequest(http.MethodPost, "/api/config?mode=normalized", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.handleWebUIConfig(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + loaded, err := cfgpkg.LoadConfig(cfgPath) + if err != nil { + t.Fatalf("reload config: %v", err) + } + if !loaded.Agents.Router.Enabled { + t.Fatalf("expected router to be enabled") + } + if _, ok := loaded.Agents.Subagents["reviewer"]; !ok { + t.Fatalf("expected reviewer subagent, got %+v", loaded.Agents.Subagents) + } +} + func TestHandleNodeConnectRegistersAndHeartbeatsNode(t *testing.T) { t.Parallel() @@ -627,62 +817,6 @@ func TestHandleWebUISessionsHidesInternalSessionsByDefault(t *testing.T) { } } -func TestHandleWebUISubagentsRuntimeLive(t *testing.T) { - t.Parallel() - - srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) - srv.SetSubagentHandler(func(ctx context.Context, action string, args map[string]interface{}) (interface{}, error) { - switch action { - case "thread": - return map[string]interface{}{ - "thread": map[string]interface{}{"thread_id": "thread-1"}, - "messages": []map[string]interface{}{ - {"message_id": "msg-1", "content": "hello"}, - }, - }, nil - case "inbox": - return map[string]interface{}{ - "messages": []map[string]interface{}{ - {"message_id": "msg-2", "content": "reply"}, - }, - }, nil - case "stream": - return map[string]interface{}{ - "task": map[string]interface{}{"id": "subagent-1"}, - "items": []map[string]interface{}{ - {"kind": "event", "message": "progress"}, - }, - }, nil - default: - return map[string]interface{}{}, nil - } - }) - - mux := http.NewServeMux() - mux.HandleFunc("/api/subagents_runtime/live", srv.handleWebUISubagentsRuntimeLive) - httpSrv := httptest.NewServer(mux) - defer httpSrv.Close() - - wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/api/subagents_runtime/live?task_id=subagent-1&preview_task_id=subagent-1" - conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - if err != nil { - t.Fatalf("dial websocket: %v", err) - } - defer conn.Close() - - var msg map[string]interface{} - if err := conn.ReadJSON(&msg); err != nil { - t.Fatalf("read live snapshot: %v", err) - } - payload, _ := msg["payload"].(map[string]interface{}) - thread, _ := payload["thread"].(map[string]interface{}) - inbox, _ := payload["inbox"].(map[string]interface{}) - preview, _ := payload["preview"].(map[string]interface{}) - if thread == nil || inbox == nil || preview == nil { - t.Fatalf("expected thread/inbox/preview payload, got: %+v", msg) - } -} - func TestHandleWebUIChatLive(t *testing.T) { t.Parallel() diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 0134c73..1e91432 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -28,25 +28,13 @@ func NewMessageBus() *MessageBus { func (mb *MessageBus) PublishInbound(msg InboundMessage) { mb.mu.RLock() + defer mb.mu.RUnlock() if mb.closed { - mb.mu.RUnlock() return } - ch := mb.inbound - mb.mu.RUnlock() - - defer func() { - if recover() != nil { - logger.WarnCF("bus", logger.C0129, map[string]interface{}{ - logger.FieldChannel: msg.Channel, - logger.FieldChatID: msg.ChatID, - "session_key": msg.SessionKey, - }) - } - }() select { - case ch <- msg: + case mb.inbound <- msg: case <-time.After(queueWriteTimeout): logger.ErrorCF("bus", logger.C0130, map[string]interface{}{ logger.FieldChannel: msg.Channel, @@ -67,24 +55,13 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { mb.mu.RLock() + defer mb.mu.RUnlock() if mb.closed { - mb.mu.RUnlock() return } - ch := mb.outbound - mb.mu.RUnlock() - - defer func() { - if recover() != nil { - logger.WarnCF("bus", logger.C0131, map[string]interface{}{ - logger.FieldChannel: msg.Channel, - logger.FieldChatID: msg.ChatID, - }) - } - }() select { - case ch <- msg: + case mb.outbound <- msg: case <-time.After(queueWriteTimeout): logger.ErrorCF("bus", logger.C0132, map[string]interface{}{ logger.FieldChannel: msg.Channel, diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go new file mode 100644 index 0000000..afd1c08 --- /dev/null +++ b/pkg/bus/bus_test.go @@ -0,0 +1,61 @@ +package bus + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestMessageBusPublishAfterCloseDoesNotPanic(t *testing.T) { + t.Parallel() + + mb := NewMessageBus() + mb.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + mb.PublishInbound(InboundMessage{Channel: "test"}) + mb.PublishOutbound(OutboundMessage{Channel: "test"}) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("publish after close blocked") + } +} + +func TestMessageBusCloseWhilePublishingDoesNotPanic(t *testing.T) { + t.Parallel() + + mb := NewMessageBus() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + if _, ok := mb.ConsumeInbound(ctx); !ok { + return + } + } + }() + + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 50; j++ { + mb.PublishInbound(InboundMessage{Channel: "test"}) + mb.PublishOutbound(OutboundMessage{Channel: "test"}) + } + }() + } + + time.Sleep(20 * time.Millisecond) + mb.Close() + cancel() + wg.Wait() +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 468e881..9da678c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -440,7 +440,7 @@ func GetConfigDir() string { func DefaultConfig() *Config { configDir := GetConfigDir() - return &Config{ + cfg := &Config{ Agents: AgentsConfig{ Defaults: AgentDefaults{ Workspace: filepath.Join(configDir, "workspace"), @@ -657,6 +657,8 @@ func DefaultConfig() *Config { }, }, } + cfg.Normalize() + return cfg } func normalizeProviderNameAlias(name string) string { @@ -767,6 +769,7 @@ func LoadConfig(path string) (*Config, error) { if err := env.Parse(cfg); err != nil { return nil, err } + cfg.Normalize() return cfg, nil } diff --git a/pkg/config/normalized.go b/pkg/config/normalized.go new file mode 100644 index 0000000..61cb163 --- /dev/null +++ b/pkg/config/normalized.go @@ -0,0 +1,250 @@ +package config + +import "strings" + +type NormalizedConfig struct { + Core NormalizedCoreConfig `json:"core"` + Runtime NormalizedRuntimeConfig `json:"runtime"` +} + +type NormalizedCoreConfig struct { + DefaultProvider string `json:"default_provider,omitempty"` + DefaultModel string `json:"default_model,omitempty"` + MainAgentID string `json:"main_agent_id,omitempty"` + Subagents map[string]NormalizedSubagentConfig `json:"subagents,omitempty"` + Tools NormalizedCoreToolsConfig `json:"tools,omitempty"` + Gateway NormalizedCoreGatewayConfig `json:"gateway,omitempty"` +} + +type NormalizedSubagentConfig struct { + Enabled bool `json:"enabled"` + Role string `json:"role,omitempty"` + Prompt string `json:"prompt,omitempty"` + Provider string `json:"provider,omitempty"` + ToolAllowlist []string `json:"tool_allowlist,omitempty"` + RuntimeClass string `json:"runtime_class,omitempty"` +} + +type NormalizedCoreToolsConfig struct { + ShellEnabled bool `json:"shell_enabled"` + MCPEnabled bool `json:"mcp_enabled"` +} + +type NormalizedCoreGatewayConfig struct { + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` +} + +type NormalizedRuntimeConfig struct { + Router NormalizedRuntimeRouterConfig `json:"router,omitempty"` + Providers map[string]NormalizedRuntimeProviderConfig `json:"providers,omitempty"` +} + +type NormalizedRuntimeRouterConfig struct { + Enabled bool `json:"enabled"` + Strategy string `json:"strategy,omitempty"` + AllowDirectAgentChat bool `json:"allow_direct_agent_chat,omitempty"` + MaxHops int `json:"max_hops,omitempty"` + DefaultTimeoutSec int `json:"default_timeout_sec,omitempty"` + DefaultWaitReply bool `json:"default_wait_reply,omitempty"` + StickyThreadOwner bool `json:"sticky_thread_owner,omitempty"` + Rules []AgentRouteRule `json:"rules,omitempty"` +} + +type NormalizedRuntimeProviderConfig struct { + Auth string `json:"auth,omitempty"` + APIBase string `json:"api_base,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + OAuth ProviderOAuthConfig `json:"oauth,omitempty"` + RuntimePersist bool `json:"runtime_persist,omitempty"` + RuntimeHistoryFile string `json:"runtime_history_file,omitempty"` + RuntimeHistoryMax int `json:"runtime_history_max,omitempty"` + Responses ProviderResponsesConfig `json:"responses,omitempty"` +} + +func (c *Config) Normalize() { + if c == nil { + return + } + if strings.TrimSpace(c.Agents.Router.MainAgentID) == "" { + c.Agents.Router.MainAgentID = "main" + } + if c.Agents.Subagents == nil { + c.Agents.Subagents = map[string]SubagentConfig{} + } + if c.Agents.Router.Enabled { + mainID := strings.TrimSpace(c.Agents.Router.MainAgentID) + if mainID == "" { + mainID = "main" + c.Agents.Router.MainAgentID = mainID + } + main := c.Agents.Subagents[mainID] + if !main.Enabled { + main.Enabled = true + } + if strings.TrimSpace(main.Role) == "" { + main.Role = "orchestrator" + } + if strings.TrimSpace(main.Type) == "" { + main.Type = "router" + } + if strings.TrimSpace(main.SystemPromptFile) == "" { + main.SystemPromptFile = "agents/main/AGENT.md" + } + c.Agents.Subagents[mainID] = main + } + if provider, model := ParseProviderModelRef(c.Agents.Defaults.Model.Primary); provider != "" && model != "" { + c.Agents.Defaults.Model.Primary = provider + "/" + model + } +} + +func (c *Config) NormalizedView() NormalizedConfig { + view := NormalizedConfig{ + Core: NormalizedCoreConfig{ + MainAgentID: strings.TrimSpace(c.Agents.Router.MainAgentID), + Subagents: map[string]NormalizedSubagentConfig{}, + Tools: NormalizedCoreToolsConfig{ + ShellEnabled: c.Tools.Shell.Enabled, + MCPEnabled: c.Tools.MCP.Enabled, + }, + Gateway: NormalizedCoreGatewayConfig{ + Host: c.Gateway.Host, + Port: c.Gateway.Port, + }, + }, + Runtime: NormalizedRuntimeConfig{ + Router: NormalizedRuntimeRouterConfig{ + Enabled: c.Agents.Router.Enabled, + Strategy: c.Agents.Router.Strategy, + AllowDirectAgentChat: c.Agents.Router.AllowDirectAgentChat, + MaxHops: c.Agents.Router.MaxHops, + DefaultTimeoutSec: c.Agents.Router.DefaultTimeoutSec, + DefaultWaitReply: c.Agents.Router.DefaultWaitReply, + StickyThreadOwner: c.Agents.Router.StickyThreadOwner, + Rules: append([]AgentRouteRule(nil), c.Agents.Router.Rules...), + }, + Providers: map[string]NormalizedRuntimeProviderConfig{}, + }, + } + view.Core.DefaultProvider, view.Core.DefaultModel = ParseProviderModelRef(c.Agents.Defaults.Model.Primary) + if view.Core.DefaultProvider == "" { + view.Core.DefaultProvider = PrimaryProviderName(c) + view.Core.DefaultModel = strings.TrimSpace(c.Agents.Defaults.Model.Primary) + } + for id, subcfg := range c.Agents.Subagents { + view.Core.Subagents[id] = NormalizedSubagentConfig{ + Enabled: subcfg.Enabled, + Role: subcfg.Role, + Prompt: subcfg.SystemPromptFile, + Provider: subcfg.Runtime.Provider, + ToolAllowlist: append([]string(nil), subcfg.Tools.Allowlist...), + RuntimeClass: firstNonEmptyRuntimeClass(subcfg), + } + } + for name, pc := range c.Models.Providers { + view.Runtime.Providers[name] = NormalizedRuntimeProviderConfig{ + Auth: pc.Auth, + APIBase: pc.APIBase, + TimeoutSec: pc.TimeoutSec, + OAuth: pc.OAuth, + RuntimePersist: pc.RuntimePersist, + RuntimeHistoryFile: pc.RuntimeHistoryFile, + RuntimeHistoryMax: pc.RuntimeHistoryMax, + Responses: pc.Responses, + } + } + return view +} + +func firstNonEmptyRuntimeClass(subcfg SubagentConfig) string { + switch { + case strings.TrimSpace(subcfg.Runtime.Provider) != "": + return "provider_bound" + case strings.TrimSpace(subcfg.Transport) != "": + return strings.TrimSpace(subcfg.Transport) + default: + return "default" + } +} + +func (c *Config) ApplyNormalizedView(view NormalizedConfig) { + if c == nil { + return + } + defaultProvider := strings.TrimSpace(view.Core.DefaultProvider) + defaultModel := strings.TrimSpace(view.Core.DefaultModel) + if defaultProvider != "" && defaultModel != "" { + c.Agents.Defaults.Model.Primary = normalizeProviderNameAlias(defaultProvider) + "/" + defaultModel + } + if strings.TrimSpace(view.Core.MainAgentID) != "" { + c.Agents.Router.MainAgentID = strings.TrimSpace(view.Core.MainAgentID) + } + c.Tools.Shell.Enabled = view.Core.Tools.ShellEnabled + c.Tools.MCP.Enabled = view.Core.Tools.MCPEnabled + if strings.TrimSpace(view.Core.Gateway.Host) != "" { + c.Gateway.Host = strings.TrimSpace(view.Core.Gateway.Host) + } + if view.Core.Gateway.Port > 0 { + c.Gateway.Port = view.Core.Gateway.Port + } + + nextSubagents := map[string]SubagentConfig{} + for id, current := range c.Agents.Subagents { + nextSubagents[id] = current + } + for id, item := range view.Core.Subagents { + current := c.Agents.Subagents[id] + current.Enabled = item.Enabled + current.Role = strings.TrimSpace(item.Role) + current.SystemPromptFile = strings.TrimSpace(item.Prompt) + current.Tools.Allowlist = append([]string(nil), item.ToolAllowlist...) + current.Runtime.Provider = strings.TrimSpace(item.Provider) + switch strings.TrimSpace(item.RuntimeClass) { + case "", "default": + case "provider_bound": + if strings.TrimSpace(current.Transport) == "" { + current.Transport = "local" + } + default: + current.Transport = strings.TrimSpace(item.RuntimeClass) + } + nextSubagents[id] = current + } + c.Agents.Subagents = nextSubagents + + c.Agents.Router.Enabled = view.Runtime.Router.Enabled + if strings.TrimSpace(view.Runtime.Router.Strategy) != "" { + c.Agents.Router.Strategy = strings.TrimSpace(view.Runtime.Router.Strategy) + } + c.Agents.Router.AllowDirectAgentChat = view.Runtime.Router.AllowDirectAgentChat + if view.Runtime.Router.MaxHops > 0 { + c.Agents.Router.MaxHops = view.Runtime.Router.MaxHops + } + if view.Runtime.Router.DefaultTimeoutSec > 0 { + c.Agents.Router.DefaultTimeoutSec = view.Runtime.Router.DefaultTimeoutSec + } + c.Agents.Router.DefaultWaitReply = view.Runtime.Router.DefaultWaitReply + c.Agents.Router.StickyThreadOwner = view.Runtime.Router.StickyThreadOwner + c.Agents.Router.Rules = append([]AgentRouteRule(nil), view.Runtime.Router.Rules...) + + nextProviders := map[string]ProviderConfig{} + for name, current := range c.Models.Providers { + nextProviders[name] = current + } + for name, item := range view.Runtime.Providers { + current := c.Models.Providers[name] + current.Auth = strings.TrimSpace(item.Auth) + current.APIBase = strings.TrimSpace(item.APIBase) + if item.TimeoutSec > 0 { + current.TimeoutSec = item.TimeoutSec + } + current.OAuth = item.OAuth + current.RuntimePersist = item.RuntimePersist + current.RuntimeHistoryFile = item.RuntimeHistoryFile + current.RuntimeHistoryMax = item.RuntimeHistoryMax + current.Responses = item.Responses + nextProviders[name] = current + } + c.Models.Providers = nextProviders + c.Normalize() +} diff --git a/pkg/config/normalized_test.go b/pkg/config/normalized_test.go new file mode 100644 index 0000000..31e26a7 --- /dev/null +++ b/pkg/config/normalized_test.go @@ -0,0 +1,30 @@ +package config + +import "testing" + +func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) { + cfg := DefaultConfig() + cfg.Agents.Router.Enabled = true + cfg.Agents.Subagents["coder"] = SubagentConfig{ + Enabled: true, + Role: "coding", + SystemPromptFile: "agents/coder/AGENT.md", + Tools: SubagentToolsConfig{Allowlist: []string{"shell"}}, + Runtime: SubagentRuntimeConfig{Provider: "openai"}, + } + + view := cfg.NormalizedView() + if view.Core.DefaultProvider != "openai" || view.Core.DefaultModel != "gpt-5.4" { + t.Fatalf("unexpected default model projection: %+v", view.Core) + } + subcfg, ok := view.Core.Subagents["coder"] + if !ok { + t.Fatalf("expected normalized subagent") + } + if subcfg.Prompt != "agents/coder/AGENT.md" || subcfg.Provider != "openai" { + t.Fatalf("unexpected normalized subagent: %+v", subcfg) + } + if !view.Runtime.Router.Enabled || view.Runtime.Router.Strategy != "rules_first" { + t.Fatalf("unexpected runtime router: %+v", view.Runtime.Router) + } +} diff --git a/pkg/configops/configops.go b/pkg/configops/configops.go index 17c468d..e9fa062 100644 --- a/pkg/configops/configops.go +++ b/pkg/configops/configops.go @@ -40,13 +40,22 @@ func LoadConfigAsMap(path string) (map[string]interface{}, error) { func NormalizeConfigPath(path string) string { p := strings.TrimSpace(path) p = strings.Trim(p, ".") - parts := strings.Split(p, ".") - for i, part := range parts { - if part == "enable" { - parts[i] = "enabled" - } + if p == "" { + return "" } - return strings.Join(parts, ".") + parts := strings.Split(p, ".") + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if part == "enable" { + part = "enabled" + } + out = append(out, part) + } + return strings.Join(out, ".") } func ParseConfigValue(raw string) interface{} { @@ -70,10 +79,20 @@ func ParseConfigValue(raw string) interface{} { if len(v) >= 2 && ((v[0] == '"' && v[len(v)-1] == '"') || (v[0] == '\'' && v[len(v)-1] == '\'')) { return v[1 : len(v)-1] } + if strings.HasPrefix(v, "{") || strings.HasPrefix(v, "[") { + var parsed interface{} + if err := json.Unmarshal([]byte(v), &parsed); err == nil { + return parsed + } + } return v } func SetMapValueByPath(root map[string]interface{}, path string, value interface{}) error { + if root == nil { + return fmt.Errorf("root is nil") + } + path = NormalizeConfigPath(path) if path == "" { return fmt.Errorf("path is empty") } @@ -106,6 +125,10 @@ func SetMapValueByPath(root map[string]interface{}, path string, value interface } func GetMapValueByPath(root map[string]interface{}, path string) (interface{}, bool) { + if root == nil { + return nil, false + } + path = NormalizeConfigPath(path) if path == "" { return nil, false } diff --git a/pkg/configops/configops_test.go b/pkg/configops/configops_test.go new file mode 100644 index 0000000..ee02890 --- /dev/null +++ b/pkg/configops/configops_test.go @@ -0,0 +1,48 @@ +package configops + +import ( + "reflect" + "testing" +) + +func TestNormalizeConfigPath(t *testing.T) { + t.Parallel() + + if got := NormalizeConfigPath(".agents..enable."); got != "agents.enabled" { + t.Fatalf("unexpected normalized path: %q", got) + } +} + +func TestParseConfigValueJSON(t *testing.T) { + t.Parallel() + + got := ParseConfigValue(`{"enabled":true,"count":2}`) + row, ok := got.(map[string]interface{}) + if !ok { + t.Fatalf("expected map value, got %#v", got) + } + if enabled, _ := row["enabled"].(bool); !enabled { + t.Fatalf("expected enabled=true, got %#v", row["enabled"]) + } +} + +func TestSetMapValueByPathRejectsNilRoot(t *testing.T) { + t.Parallel() + + if err := SetMapValueByPath(nil, "agents.main.enabled", true); err == nil { + t.Fatal("expected error for nil root") + } +} + +func TestSetAndGetMapValueByPath(t *testing.T) { + t.Parallel() + + root := map[string]interface{}{} + if err := SetMapValueByPath(root, ".agents.main.enable.", true); err != nil { + t.Fatalf("set value failed: %v", err) + } + got, ok := GetMapValueByPath(root, "agents.main.enabled") + if !ok || !reflect.DeepEqual(got, true) { + t.Fatalf("unexpected get result: %#v, %v", got, ok) + } +} diff --git a/pkg/nodes/transport.go b/pkg/nodes/transport.go index 3cd6b60..e56ad80 100644 --- a/pkg/nodes/transport.go +++ b/pkg/nodes/transport.go @@ -240,29 +240,29 @@ func normalizeArtifacts(payload map[string]interface{}, action string) []map[str } artifact := map[string]interface{}{} - if mediaType, _ := payload["media_type"].(string); strings.TrimSpace(mediaType) != "" { - artifact["kind"] = strings.TrimSpace(mediaType) + if mediaType := payloadString(payload, "media_type"); mediaType != "" { + artifact["kind"] = mediaType } - if mimeType, _ := payload["mime_type"].(string); strings.TrimSpace(mimeType) != "" { - artifact["mime_type"] = strings.TrimSpace(mimeType) + if mimeType := payloadString(payload, "mime_type"); mimeType != "" { + artifact["mime_type"] = mimeType } - if storage, _ := payload["storage"].(string); strings.TrimSpace(storage) != "" { - artifact["storage"] = strings.TrimSpace(storage) + if storage := payloadString(payload, "storage"); storage != "" { + artifact["storage"] = storage } - if path, _ := payload["path"].(string); strings.TrimSpace(path) != "" { - artifact["path"] = filepath.Clean(strings.TrimSpace(path)) + if path := payloadString(payload, "path"); path != "" { + artifact["path"] = filepath.Clean(path) } - if url, _ := payload["url"].(string); strings.TrimSpace(url) != "" { - artifact["url"] = strings.TrimSpace(url) + if url := payloadString(payload, "url"); url != "" { + artifact["url"] = url } - if image, _ := payload["image"].(string); strings.TrimSpace(image) != "" { - artifact["content_base64"] = strings.TrimSpace(image) + if image := payloadString(payload, "image"); image != "" { + artifact["content_base64"] = image } - if text, _ := payload["content_text"].(string); strings.TrimSpace(text) != "" { + if text := payloadString(payload, "content_text"); text != "" { artifact["content_text"] = text } - if name, _ := payload["name"].(string); strings.TrimSpace(name) != "" { - artifact["name"] = strings.TrimSpace(name) + if name := payloadString(payload, "name"); name != "" { + artifact["name"] = name } if size := int64FromPayload(payload["size_bytes"]); size > 0 { artifact["size_bytes"] = size @@ -277,34 +277,52 @@ func normalizeArtifacts(payload map[string]interface{}, action string) []map[str } func normalizeArtifactList(raw interface{}) []map[string]interface{} { - items, ok := raw.([]interface{}) - if !ok { - return []map[string]interface{}{} - } - out := make([]map[string]interface{}, 0, len(items)) - for _, item := range items { - row, ok := item.(map[string]interface{}) - if !ok || len(row) == 0 { - continue - } - normalized := map[string]interface{}{} - for _, key := range []string{"id", "name", "kind", "mime_type", "storage", "path", "url", "content_text", "content_base64", "source_path"} { - if value, ok := row[key]; ok && strings.TrimSpace(fmt.Sprint(value)) != "" { - normalized[key] = value + switch items := raw.(type) { + case []map[string]interface{}: + out := make([]map[string]interface{}, 0, len(items)) + for _, item := range items { + if normalized, ok := normalizeArtifactRow(item); ok { + out = append(out, normalized) } } - if truncated, ok := row["truncated"].(bool); ok && truncated { - normalized["truncated"] = true + return out + case []interface{}: + out := make([]map[string]interface{}, 0, len(items)) + for _, item := range items { + row, ok := item.(map[string]interface{}) + if !ok { + continue + } + if normalized, ok := normalizeArtifactRow(row); ok { + out = append(out, normalized) + } } - if size := int64FromPayload(row["size_bytes"]); size > 0 { - normalized["size_bytes"] = size - } - if len(normalized) == 0 { - continue - } - out = append(out, normalized) + return out + default: + return []map[string]interface{}{} } - return out +} + +func normalizeArtifactRow(row map[string]interface{}) (map[string]interface{}, bool) { + if len(row) == 0 { + return nil, false + } + normalized := map[string]interface{}{} + for _, key := range []string{"id", "name", "kind", "mime_type", "storage", "path", "url", "content_text", "content_base64", "source_path"} { + if value, ok := row[key]; ok && strings.TrimSpace(fmt.Sprint(value)) != "" { + normalized[key] = value + } + } + if truncated, ok := row["truncated"].(bool); ok && truncated { + normalized["truncated"] = true + } + if size := int64FromPayload(row["size_bytes"]); size > 0 { + normalized["size_bytes"] = size + } + if len(normalized) == 0 { + return nil, false + } + return normalized, true } func int64FromPayload(v interface{}) int64 { @@ -318,7 +336,21 @@ func int64FromPayload(v interface{}) int64 { case json.Number: n, _ := value.Int64() return n + case string: + n, _ := json.Number(strings.TrimSpace(value)).Int64() + return n default: return 0 } } + +func payloadString(payload map[string]interface{}, key string) string { + if payload == nil { + return "" + } + v, ok := payload[key] + if !ok { + return "" + } + return strings.TrimSpace(fmt.Sprint(v)) +} diff --git a/pkg/nodes/transport_test.go b/pkg/nodes/transport_test.go index d662996..8566d02 100644 --- a/pkg/nodes/transport_test.go +++ b/pkg/nodes/transport_test.go @@ -92,6 +92,27 @@ func TestNormalizeDevicePayloadBuildsArtifacts(t *testing.T) { } } +func TestNormalizeDevicePayloadNormalizesExistingArtifactRows(t *testing.T) { + t.Parallel() + + payload := normalizeDevicePayload("screen_snapshot", map[string]interface{}{ + "artifacts": []map[string]interface{}{ + { + "path": "/tmp/screen.png", + "kind": "image", + "size_bytes": "42", + }, + }, + }) + artifacts, ok := payload["artifacts"].([]map[string]interface{}) + if !ok || len(artifacts) != 1 { + t.Fatalf("expected one normalized artifact, got %+v", payload["artifacts"]) + } + if got := artifacts[0]["size_bytes"]; got != int64(42) { + t.Fatalf("expected normalized size_bytes, got %#v", got) + } +} + func TestWebRTCTransportSendEndToEnd(t *testing.T) { t.Parallel() diff --git a/pkg/providers/execution.go b/pkg/providers/execution.go new file mode 100644 index 0000000..92ae768 --- /dev/null +++ b/pkg/providers/execution.go @@ -0,0 +1,117 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" +) + +func newProviderExecutionError(code, message, stage string, retryable bool, source string) *ProviderExecutionError { + return &ProviderExecutionError{ + Code: code, + Message: message, + Stage: stage, + Retryable: retryable, + Source: source, + } +} + +func (p *HTTPProvider) executeJSONAttempts(ctx context.Context, endpoint string, payload interface{}, mutate func(*http.Request, authAttempt), classify func(int, []byte) (oauthFailureReason, bool)) (ProviderExecutionResult, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return ProviderExecutionResult{Error: newProviderExecutionError("marshal_failed", err.Error(), "marshal", false, p.providerName)}, fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := p.authAttempts(ctx) + if err != nil { + return ProviderExecutionResult{Error: newProviderExecutionError("auth_unavailable", err.Error(), "auth", false, p.providerName)}, err + } + var last ProviderExecutionResult + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return ProviderExecutionResult{Error: newProviderExecutionError("request_build_failed", err.Error(), "request", false, p.providerName)}, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p, false) + if mutate != nil { + mutate(req, attempt) + } + body, status, ctype, err := p.doJSONAttempt(req, attempt) + if err != nil { + return ProviderExecutionResult{ + StatusCode: status, + ContentType: ctype, + AttemptKind: attempt.kind, + Error: newProviderExecutionError("request_failed", err.Error(), "request", false, p.providerName), + }, err + } + reason, retry := classify(status, body) + last = ProviderExecutionResult{ + Body: body, + StatusCode: status, + ContentType: ctype, + AttemptKind: attempt.kind, + Retryable: retry, + Failure: reason, + } + if !retry { + p.markAttemptSuccess(attempt) + return last, nil + } + applyAttemptFailure(p, attempt, reason, nil) + } + return last, nil +} + +func (p *HTTPProvider) executeStreamAttempts(ctx context.Context, endpoint string, payload interface{}, mutate func(*http.Request, authAttempt), onEvent func(string)) (ProviderExecutionResult, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return ProviderExecutionResult{Error: newProviderExecutionError("marshal_failed", err.Error(), "marshal", false, p.providerName)}, fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := p.authAttempts(ctx) + if err != nil { + return ProviderExecutionResult{Error: newProviderExecutionError("auth_unavailable", err.Error(), "auth", false, p.providerName)}, err + } + var last ProviderExecutionResult + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return ProviderExecutionResult{Error: newProviderExecutionError("request_build_failed", err.Error(), "request", false, p.providerName)}, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p, true) + if mutate != nil { + mutate(req, attempt) + } + body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent) + if err != nil { + return ProviderExecutionResult{ + StatusCode: status, + ContentType: ctype, + AttemptKind: attempt.kind, + Error: newProviderExecutionError("stream_failed", err.Error(), "request", false, p.providerName), + }, err + } + reason, _ := classifyOAuthFailure(status, body) + last = ProviderExecutionResult{ + Body: body, + StatusCode: status, + ContentType: ctype, + AttemptKind: attempt.kind, + Retryable: quotaHit, + Failure: reason, + } + if !quotaHit { + p.markAttemptSuccess(attempt) + return last, nil + } + applyAttemptFailure(p, attempt, reason, nil) + } + return last, nil +} diff --git a/pkg/providers/execution_test.go b/pkg/providers/execution_test.go new file mode 100644 index 0000000..c2fbef8 --- /dev/null +++ b/pkg/providers/execution_test.go @@ -0,0 +1,33 @@ +package providers + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestHTTPProviderExecuteJSONAttemptsReturnsEnvelope(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer token" { + t.Fatalf("authorization = %q", got) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + provider := NewHTTPProvider("test", "token", server.URL, "gpt-test", false, "bearer", 5*time.Second, nil) + result, err := provider.executeJSONAttempts(context.Background(), server.URL, map[string]any{"hello": "world"}, nil, classifyOAuthFailure) + if err != nil { + t.Fatalf("executeJSONAttempts error: %v", err) + } + if result.StatusCode != http.StatusOK || result.ContentType != "application/json" || result.AttemptKind != "api_key" { + t.Fatalf("unexpected envelope: %+v", result) + } + if string(result.Body) != `{"ok":true}` { + t.Fatalf("unexpected body: %s", string(result.Body)) + } +} diff --git a/pkg/providers/gemini_cli_provider_test.go b/pkg/providers/gemini_cli_provider_test.go index d7eb50a..3f8f808 100644 --- a/pkg/providers/gemini_cli_provider_test.go +++ b/pkg/providers/gemini_cli_provider_test.go @@ -39,7 +39,7 @@ func TestGeminiCLIProviderChatUsesCloudCodeEndpoint(t *testing.T) { if got := r.Header.Get("X-Goog-Api-Client"); got != geminiCLIApiClient { t.Fatalf("x-goog-api-client = %q", got) } - if got := r.Header.Get("User-Agent"); got != "GeminiCLI/gemini-2.5-pro" { + if got := r.Header.Get("User-Agent"); got != geminiCLIUserAgent("gemini-2.5-pro") { t.Fatalf("user-agent = %q", got) } var payload map[string]any diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index cc48d49..f83e848 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -2,7 +2,6 @@ package providers import ( "bufio" - "bytes" "context" "crypto/rand" "encoding/json" @@ -709,90 +708,19 @@ func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Messa } func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payload interface{}, onEvent func(string)) ([]byte, int, string, error) { - jsonData, err := json.Marshal(payload) - if err != nil { - return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) - } - attempts, err := p.authAttempts(ctx) + result, err := p.executeStreamAttempts(ctx, endpoint, payload, nil, onEvent) if err != nil { return nil, 0, "", err } - var lastBody []byte - var lastStatus int - var lastType string - for _, attempt := range attempts { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) - if err != nil { - return nil, 0, "", fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - applyAttemptAuth(req, attempt) - applyAttemptProviderHeaders(req, attempt, p, true) - - body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent) - if err != nil { - return nil, 0, "", err - } - if !quotaHit { - p.markAttemptSuccess(attempt) - return body, status, ctype, nil - } - lastBody, lastStatus, lastType = body, status, ctype - if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { - reason, _ := classifyOAuthFailure(status, body) - p.oauth.markExhausted(attempt.session, reason) - recordProviderOAuthError(p.providerName, attempt.session, reason) - } - if attempt.kind == "api_key" { - reason, _ := classifyOAuthFailure(status, body) - p.markAPIKeyFailure(reason) - } - } - return lastBody, lastStatus, lastType, nil + return result.Body, result.StatusCode, result.ContentType, nil } func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload interface{}) ([]byte, int, string, error) { - jsonData, err := json.Marshal(payload) - if err != nil { - return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) - } - attempts, err := p.authAttempts(ctx) + result, err := p.executeJSONAttempts(ctx, endpoint, payload, nil, classifyOAuthFailure) if err != nil { return nil, 0, "", err } - var lastBody []byte - var lastStatus int - var lastType string - for _, attempt := range attempts { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) - if err != nil { - return nil, 0, "", fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - applyAttemptAuth(req, attempt) - applyAttemptProviderHeaders(req, attempt, p, false) - - body, status, ctype, err := p.doJSONAttempt(req, attempt) - if err != nil { - return nil, 0, "", err - } - reason, retry := classifyOAuthFailure(status, body) - if !retry { - p.markAttemptSuccess(attempt) - return body, status, ctype, nil - } - lastBody, lastStatus, lastType = body, status, ctype - if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { - p.oauth.markExhausted(attempt.session, reason) - recordProviderOAuthError(p.providerName, attempt.session, reason) - } - if attempt.kind == "api_key" { - p.markAPIKeyFailure(reason) - } - } - return lastBody, lastStatus, lastType, nil + return result.Body, result.StatusCode, result.ContentType, nil } type authAttempt struct { diff --git a/pkg/providers/types.go b/pkg/providers/types.go index daccb4a..d07a7f7 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -92,3 +92,21 @@ type ToolFunctionDefinition struct { Parameters map[string]interface{} `json:"parameters"` Strict *bool `json:"strict,omitempty"` } + +type ProviderExecutionError struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Stage string `json:"stage,omitempty"` + Retryable bool `json:"retryable,omitempty"` + Source string `json:"source,omitempty"` +} + +type ProviderExecutionResult struct { + Body []byte `json:"-"` + StatusCode int `json:"status_code,omitempty"` + ContentType string `json:"content_type,omitempty"` + AttemptKind string `json:"attempt_kind,omitempty"` + Retryable bool `json:"retryable,omitempty"` + Failure oauthFailureReason `json:"failure_reason,omitempty"` + Error *ProviderExecutionError `json:"error,omitempty"` +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 9dd9d67..56a9cec 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -5,6 +5,7 @@ import ( "crypto/sha1" "encoding/hex" "encoding/json" + "fmt" "os" "path/filepath" "strconv" @@ -312,7 +313,7 @@ func (sm *SessionManager) rewriteSessionFileLocked(session *Session) error { e := toOpenClawMessageEvent(msg) b, err := json.Marshal(e) if err != nil { - continue + return fmt.Errorf("marshal session message: %w", err) } if _, err := f.Write(append(b, '\n')); err != nil { return err @@ -671,6 +672,7 @@ func (sm *SessionManager) loadSessions() error { continue } scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) session.mu.Lock() for scanner.Scan() { if msg, ok := fromJSONLLine(scanner.Bytes()); ok { @@ -678,7 +680,13 @@ func (sm *SessionManager) loadSessions() error { } } session.mu.Unlock() - f.Close() + closeErr := f.Close() + if err := scanner.Err(); err != nil { + return fmt.Errorf("scan session file %s: %w", file.Name(), err) + } + if closeErr != nil { + return fmt.Errorf("close session file %s: %w", file.Name(), closeErr) + } } return sm.writeOpenClawSessionsIndex() diff --git a/pkg/session/manager_test.go b/pkg/session/manager_test.go new file mode 100644 index 0000000..3530c51 --- /dev/null +++ b/pkg/session/manager_test.go @@ -0,0 +1,64 @@ +package session + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/YspCoder/clawgo/pkg/providers" +) + +func TestLoadSessionsReturnsScannerErrorForOversizedLine(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + line := `{"role":"user","content":"` + strings.Repeat("x", 2*1024*1024) + `"}` + if err := os.WriteFile(filepath.Join(storage, "huge.jsonl"), []byte(line+"\n"), 0644); err != nil { + t.Fatalf("write session file failed: %v", err) + } + + sm := &SessionManager{ + sessions: map[string]*Session{}, + storage: storage, + } + if err := sm.loadSessions(); err != nil { + t.Fatalf("expected oversized line to load with expanded scanner buffer, got %v", err) + } +} + +func TestFromJSONLLineParsesOpenClawToolResult(t *testing.T) { + t.Parallel() + + line := []byte(`{"type":"message","message":{"role":"toolResult","content":[{"type":"text","text":"done"}],"toolCallId":"call-1"}}`) + msg, ok := fromJSONLLine(line) + if !ok { + t.Fatal("expected line to parse") + } + if msg.Role != "tool" || msg.ToolCallID != "call-1" || msg.Content != "done" { + t.Fatalf("unexpected parsed message: %+v", msg) + } +} + +func TestRewriteSessionFileLockedPersistsMessages(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := &SessionManager{storage: storage} + session := &Session{ + Key: "abc", + Messages: []providers.Message{ + {Role: "user", Content: "hello"}, + }, + } + if err := sm.rewriteSessionFileLocked(session); err != nil { + t.Fatalf("rewrite session failed: %v", err) + } + data, err := os.ReadFile(filepath.Join(storage, "abc.jsonl")) + if err != nil { + t.Fatalf("read rewritten session failed: %v", err) + } + if !strings.Contains(string(data), `"role":"user"`) { + t.Fatalf("unexpected rewritten session contents: %s", string(data)) + } +} diff --git a/pkg/tools/arg_helpers.go b/pkg/tools/arg_helpers.go new file mode 100644 index 0000000..c2bf267 --- /dev/null +++ b/pkg/tools/arg_helpers.go @@ -0,0 +1,168 @@ +package tools + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" +) + +func MapStringArg(args map[string]interface{}, key string) string { + return strings.TrimSpace(MapRawStringArg(args, key)) +} + +func MapRawStringArg(args map[string]interface{}, key string) string { + if args == nil { + return "" + } + raw, ok := args[key] + if !ok || raw == nil { + return "" + } + switch v := raw.(type) { + case string: + return v + case []byte: + return string(v) + case json.Number: + return v.String() + case fmt.Stringer: + return v.String() + case int, int8, int16, int32, int64, float32, float64, bool, uint, uint8, uint16, uint32, uint64: + return fmt.Sprint(v) + default: + return "" + } +} + +func MapBoolArg(args map[string]interface{}, key string) (bool, bool) { + if args == nil { + return false, false + } + raw, ok := args[key] + if !ok || raw == nil { + return false, false + } + switch v := raw.(type) { + case bool: + return v, true + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "true", "1", "yes", "on": + return true, true + case "false", "0", "no", "off": + return false, true + } + case json.Number: + if n, err := v.Int64(); err == nil { + return n != 0, true + } + case float64: + return v != 0, true + case int: + return v != 0, true + case int64: + return v != 0, true + } + return false, false +} + +func MapIntArg(args map[string]interface{}, key string, fallback int) int { + if args == nil { + return fallback + } + raw, ok := args[key] + if !ok || raw == nil { + return fallback + } + switch v := raw.(type) { + case int: + if v > 0 { + return v + } + case int64: + if v > 0 { + return int(v) + } + case float64: + if v > 0 { + return int(v) + } + case json.Number: + if n, err := v.Int64(); err == nil && n > 0 { + return int(n) + } + case string: + if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil && n > 0 { + return n + } + } + return fallback +} + +func MapStringListArg(args map[string]interface{}, key string) []string { + if args == nil { + return nil + } + raw, ok := args[key] + if !ok || raw == nil { + return nil + } + switch v := raw.(type) { + case []string: + return normalizeArgStringList(v) + case []interface{}: + items := make([]string, 0, len(v)) + for _, item := range v { + if s := strings.TrimSpace(fmt.Sprint(item)); s != "" && s != "" { + items = append(items, s) + } + } + return normalizeArgStringList(items) + case string: + if strings.TrimSpace(v) == "" { + return nil + } + return normalizeArgStringList(strings.Split(v, ",")) + default: + return nil + } +} + +func MapObjectArg(args map[string]interface{}, key string) map[string]interface{} { + if args == nil { + return map[string]interface{}{} + } + raw, ok := args[key] + if !ok || raw == nil { + return map[string]interface{}{} + } + obj, _ := raw.(map[string]interface{}) + if obj == nil { + return map[string]interface{}{} + } + return obj +} + +func normalizeArgStringList(items []string) []string { + if len(items) == 0 { + return nil + } + out := make([]string, 0, len(items)) + seen := make(map[string]struct{}, len(items)) + for _, item := range items { + trimmed := strings.TrimSpace(item) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + out = append(out, trimmed) + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/pkg/tools/arg_helpers_object_test.go b/pkg/tools/arg_helpers_object_test.go new file mode 100644 index 0000000..e7c2878 --- /dev/null +++ b/pkg/tools/arg_helpers_object_test.go @@ -0,0 +1,23 @@ +package tools + +import "testing" + +func TestMapObjectArgReturnsEmptyMapForMissingValue(t *testing.T) { + t.Parallel() + + got := MapObjectArg(nil, "arguments") + if got == nil || len(got) != 0 { + t.Fatalf("expected empty map, got %#v", got) + } +} + +func TestMapObjectArgReturnsObjectValue(t *testing.T) { + t.Parallel() + + got := MapObjectArg(map[string]interface{}{ + "arguments": map[string]interface{}{"path": "README.md"}, + }, "arguments") + if got["path"] != "README.md" { + t.Fatalf("unexpected object arg: %#v", got) + } +} diff --git a/pkg/tools/arg_helpers_test.go b/pkg/tools/arg_helpers_test.go new file mode 100644 index 0000000..3afbd38 --- /dev/null +++ b/pkg/tools/arg_helpers_test.go @@ -0,0 +1,29 @@ +package tools + +import "testing" + +func TestMapBoolArgParsesStringValues(t *testing.T) { + t.Parallel() + + got, ok := MapBoolArg(map[string]interface{}{"enabled": "true"}, "enabled") + if !ok || !got { + t.Fatalf("expected string true to parse, got %v %v", got, ok) + } +} + +func TestMapIntArgParsesStringValues(t *testing.T) { + t.Parallel() + + if got := MapIntArg(map[string]interface{}{"limit": "25"}, "limit", 5); got != 25 { + t.Fatalf("expected parsed int, got %d", got) + } +} + +func TestMapStringListArgParsesCommaSeparatedValues(t *testing.T) { + t.Parallel() + + got := MapStringListArg(map[string]interface{}{"tools": "shell, sessions, shell"}, "tools") + if len(got) != 2 || got[0] != "shell" || got[1] != "sessions" { + t.Fatalf("unexpected string list: %#v", got) + } +} diff --git a/pkg/tools/browser.go b/pkg/tools/browser.go index 19cfe88..d3fc0bb 100644 --- a/pkg/tools/browser.go +++ b/pkg/tools/browser.go @@ -48,8 +48,8 @@ func (t *BrowserTool) Parameters() map[string]interface{} { } func (t *BrowserTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - action, _ := args["action"].(string) - url, _ := args["url"].(string) + action := MapStringArg(args, "action") + url := MapStringArg(args, "url") switch action { case "screenshot": diff --git a/pkg/tools/camera.go b/pkg/tools/camera.go index 686b789..9a533b2 100644 --- a/pkg/tools/camera.go +++ b/pkg/tools/camera.go @@ -40,10 +40,8 @@ func (t *CameraTool) Parameters() map[string]interface{} { } func (t *CameraTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - filename := "" - if v, ok := args["filename"].(string); ok && v != "" { - filename = v - } else { + filename := MapStringArg(args, "filename") + if filename == "" { filename = fmt.Sprintf("snap_%d.jpg", time.Now().Unix()) } diff --git a/pkg/tools/camera_test.go b/pkg/tools/camera_test.go new file mode 100644 index 0000000..5218283 --- /dev/null +++ b/pkg/tools/camera_test.go @@ -0,0 +1,22 @@ +package tools + +import ( + "context" + "strings" + "testing" +) + +func TestCameraToolParsesFilenameArg(t *testing.T) { + t.Parallel() + + tool := NewCameraTool(t.TempDir()) + _, err := tool.Execute(context.Background(), map[string]interface{}{ + "filename": "custom.jpg", + }) + if err == nil { + t.Fatal("expected camera access to fail in test environment") + } + if !strings.Contains(err.Error(), "/dev/video0") && !strings.Contains(err.Error(), "camera device") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/pkg/tools/compat_alias.go b/pkg/tools/compat_alias.go index c6079f2..8330910 100644 --- a/pkg/tools/compat_alias.go +++ b/pkg/tools/compat_alias.go @@ -26,14 +26,21 @@ func (t *AliasTool) Description() string { func (t *AliasTool) Parameters() map[string]interface{} { return t.base.Parameters() } func (t *AliasTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + if args == nil { + args = map[string]interface{}{} + } + normalized := make(map[string]interface{}, len(args)+len(t.argMap)) + for key, value := range args { + normalized[key] = value + } if len(t.argMap) > 0 { for from, to := range t.argMap { - if v, ok := args[from]; ok { - if _, exists := args[to]; !exists { - args[to] = v + if v, ok := normalized[from]; ok { + if _, exists := normalized[to]; !exists { + normalized[to] = v } } } } - return t.base.Execute(ctx, args) + return t.base.Execute(ctx, normalized) } diff --git a/pkg/tools/compat_alias_test.go b/pkg/tools/compat_alias_test.go new file mode 100644 index 0000000..a39d277 --- /dev/null +++ b/pkg/tools/compat_alias_test.go @@ -0,0 +1,37 @@ +package tools + +import ( + "context" + "testing" +) + +type captureAliasTool struct { + args map[string]interface{} +} + +func (t *captureAliasTool) Name() string { return "capture" } +func (t *captureAliasTool) Description() string { return "capture" } +func (t *captureAliasTool) Parameters() map[string]interface{} { + return map[string]interface{}{} +} +func (t *captureAliasTool) Execute(_ context.Context, args map[string]interface{}) (string, error) { + t.args = args + return "ok", nil +} + +func TestAliasToolExecuteDoesNotMutateCallerArgs(t *testing.T) { + t.Parallel() + + base := &captureAliasTool{} + tool := NewAliasTool("read", "", base, map[string]string{"file_path": "path"}) + original := map[string]interface{}{"file_path": "README.md"} + if _, err := tool.Execute(context.Background(), original); err != nil { + t.Fatalf("execute failed: %v", err) + } + if _, ok := original["path"]; ok { + t.Fatalf("caller args were mutated: %+v", original) + } + if got, _ := base.args["path"].(string); got != "README.md" { + t.Fatalf("expected translated arg, got %+v", base.args) + } +} diff --git a/pkg/tools/cron_tool.go b/pkg/tools/cron_tool.go index d865118..ef68b4d 100644 --- a/pkg/tools/cron_tool.go +++ b/pkg/tools/cron_tool.go @@ -37,13 +37,11 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st if t.cs == nil { return "Error: cron service not available", nil } - action, _ := args["action"].(string) - action = strings.ToLower(strings.TrimSpace(action)) + action := strings.ToLower(MapStringArg(args, "action")) if action == "" { action = "list" } - id, _ := args["id"].(string) - id = strings.TrimSpace(id) + id := MapStringArg(args, "id") switch action { case "list": diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 639af5c..4447830 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -61,8 +61,8 @@ func (t *ReadFileTool) Parameters() map[string]interface{} { } func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - path, ok := args["path"].(string) - if !ok { + path := MapStringArg(args, "path") + if path == "" { return "", fmt.Errorf("path is required") } @@ -83,12 +83,12 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) } offset := int64(0) - if o, ok := args["offset"].(float64); ok { + if o := MapIntArg(args, "offset", 0); o > 0 { offset = int64(o) } limit := int64(stat.Size()) - if l, ok := args["limit"].(float64); ok { + if l := MapIntArg(args, "limit", 0); l > 0 { limit = int64(l) } @@ -145,16 +145,23 @@ func (t *WriteFileTool) Parameters() map[string]interface{} { } func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - path, ok := args["path"].(string) - if !ok { + path := MapStringArg(args, "path") + if path == "" { return "", fmt.Errorf("path is required") } - content, ok := args["content"].(string) + if args == nil { + return "", fmt.Errorf("content is required") + } + rawContent, ok := args["content"] if !ok { return "", fmt.Errorf("content is required") } - appendMode, _ := args["append"].(bool) + content, ok := rawContent.(string) + if !ok { + return "", fmt.Errorf("content is required") + } + appendMode, _ := MapBoolArg(args, "append") resolvedPath, err := resolveToolPath(t.allowedDir, path) if err != nil { @@ -217,12 +224,12 @@ func (t *ListDirTool) Parameters() map[string]interface{} { } func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - path, ok := args["path"].(string) - if !ok { + path := MapStringArg(args, "path") + if path == "" { return "", fmt.Errorf("path is required") } - recursive, _ := args["recursive"].(bool) + recursive, _ := MapBoolArg(args, "recursive") resolvedPath, err := resolveToolPath(t.allowedDir, path) if err != nil { @@ -310,17 +317,25 @@ func (t *EditFileTool) Parameters() map[string]interface{} { } func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - path, ok := args["path"].(string) - if !ok { + path := MapStringArg(args, "path") + if path == "" { return "", fmt.Errorf("path is required") } - oldText, ok := args["old_text"].(string) + rawOldText, ok := args["old_text"] if !ok { return "", fmt.Errorf("old_text is required") } + oldText, ok := rawOldText.(string) + if !ok || oldText == "" { + return "", fmt.Errorf("old_text is required") + } - newText, ok := args["new_text"].(string) + rawNewText, ok := args["new_text"] + if !ok { + return "", fmt.Errorf("new_text is required") + } + newText, ok := rawNewText.(string) if !ok { return "", fmt.Errorf("new_text is required") } diff --git a/pkg/tools/highlevel_arg_parsing_test.go b/pkg/tools/highlevel_arg_parsing_test.go new file mode 100644 index 0000000..64efbe5 --- /dev/null +++ b/pkg/tools/highlevel_arg_parsing_test.go @@ -0,0 +1,98 @@ +package tools + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/nodes" + "github.com/YspCoder/clawgo/pkg/providers" +) + +func TestSessionsToolParsesStringArguments(t *testing.T) { + t.Parallel() + + tool := NewSessionsTool(func(limit int) []SessionInfo { + return []SessionInfo{ + {Key: "cron:1", Kind: "cron", UpdatedAt: time.Now()}, + {Key: "main:1", Kind: "main", UpdatedAt: time.Now()}, + } + }, func(key string, limit int) []providers.Message { return nil }) + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "list", + "limit": "1", + "active_minutes": "60", + "kinds": "cron", + }) + if err != nil { + t.Fatalf("sessions execute failed: %v", err) + } + if !strings.Contains(out, "cron:1") || strings.Contains(out, "main:1") { + t.Fatalf("unexpected filtered output: %s", out) + } +} + +func TestSubagentsToolParsesStringLimits(t *testing.T) { + manager := NewSubagentManager(nil, t.TempDir(), nil) + _, err := manager.Spawn(context.Background(), SubagentSpawnOptions{ + Task: "check", + Label: "demo", + AgentID: "coder", + }) + if err != nil { + t.Fatalf("spawn failed: %v", err) + } + tool := NewSubagentsTool(manager) + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "list", + "limit": "1", + }) + if err != nil { + t.Fatalf("subagents execute failed: %v", err) + } + if !strings.Contains(out, "Subagents:") { + t.Fatalf("unexpected output: %s", out) + } + time.Sleep(50 * time.Millisecond) +} + +func TestNodesToolParsesStringDurationAndArtifactPaths(t *testing.T) { + t.Parallel() + + manager := nodes.NewManager() + manager.Upsert(nodes.NodeInfo{ + ID: "local", + Online: true, + Capabilities: nodes.Capabilities{ + Camera: true, + }, + }) + manager.RegisterHandler("local", func(req nodes.Request) nodes.Response { + return nodes.Response{ + OK: true, + Code: "ok", + Node: "local", + Action: req.Action, + Payload: map[string]interface{}{ + "artifacts": []map[string]interface{}{ + {"kind": "video", "path": "/tmp/demo.mp4"}, + }, + }, + } + }) + tool := NewNodesTool(manager, &nodes.Router{Relay: &nodes.HTTPRelayTransport{Manager: manager}}, "") + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "camera_clip", + "node": "local", + "duration_ms": "1000", + "artifact_paths": "memory/demo.md", + }) + if err != nil { + t.Fatalf("nodes execute failed: %v", err) + } + if !strings.Contains(out, `"ok":true`) { + t.Fatalf("unexpected output: %s", out) + } +} diff --git a/pkg/tools/io_arg_parsing_test.go b/pkg/tools/io_arg_parsing_test.go new file mode 100644 index 0000000..9f46e81 --- /dev/null +++ b/pkg/tools/io_arg_parsing_test.go @@ -0,0 +1,132 @@ +package tools + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/config" +) + +func TestFilesystemToolsParseStringArgs(t *testing.T) { + t.Parallel() + + root := t.TempDir() + write := NewWriteFileTool(root) + read := NewReadFileTool(root) + list := NewListDirTool(root) + + if _, err := write.Execute(context.Background(), map[string]interface{}{ + "path": "demo.txt", + "content": "hello world", + "append": "false", + }); err != nil { + t.Fatalf("write failed: %v", err) + } + + out, err := read.Execute(context.Background(), map[string]interface{}{ + "path": "demo.txt", + "offset": "6", + "limit": "5", + }) + if err != nil { + t.Fatalf("read failed: %v", err) + } + if out != "world" { + t.Fatalf("unexpected read output: %q", out) + } + + listOut, err := list.Execute(context.Background(), map[string]interface{}{ + "path": ".", + "recursive": "true", + }) + if err != nil { + t.Fatalf("list failed: %v", err) + } + if !strings.Contains(listOut, "demo.txt") { + t.Fatalf("unexpected list output: %s", listOut) + } + + edit := NewEditFileTool(root) + if _, err := edit.Execute(context.Background(), map[string]interface{}{ + "path": "demo.txt", + "old_text": "world", + "new_text": "", + }); err != nil { + t.Fatalf("edit with empty new_text failed: %v", err) + } + emptyWrite, err := write.Execute(context.Background(), map[string]interface{}{ + "path": "empty.txt", + "content": "", + }) + if err != nil { + t.Fatalf("write empty content failed: %v", err) + } + if !strings.Contains(emptyWrite, "empty.txt") { + t.Fatalf("unexpected empty write output: %s", emptyWrite) + } +} + +func TestSpawnToolParsesStringNumbers(t *testing.T) { + manager := NewSubagentManager(nil, t.TempDir(), nil) + manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) { + return "ok", nil + }) + tool := NewSpawnTool(manager) + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "task": "implement check", + "agent_id": "coder", + "max_retries": "2", + "retry_backoff_ms": "100", + "timeout_sec": "5", + }) + if err != nil { + t.Fatalf("spawn failed: %v", err) + } + if !strings.Contains(out, "spawned") && !strings.Contains(strings.ToLower(out), "subagent") { + t.Fatalf("unexpected spawn output: %s", out) + } + time.Sleep(50 * time.Millisecond) +} + +func TestExecBrowserWebToolsParseStringArgs(t *testing.T) { + t.Parallel() + + execTool := NewExecTool(configShellForTest(), t.TempDir(), NewProcessManager(t.TempDir())) + execOut, err := execTool.Execute(context.Background(), map[string]interface{}{ + "command": "printf hi", + "background": "false", + }) + if err != nil { + t.Fatalf("exec failed: %v", err) + } + if !strings.Contains(execOut, "hi") { + t.Fatalf("unexpected exec output: %s", execOut) + } + + browserTool := NewBrowserTool() + if _, err := browserTool.Execute(context.Background(), map[string]interface{}{ + "action": "unknown", + "url": "https://example.com", + }); err == nil { + t.Fatal("expected browser tool to reject unknown action") + } + + search := NewWebSearchTool("", 5) + searchOut, err := search.Execute(context.Background(), map[string]interface{}{ + "query": "golang", + "count": "3", + }) + if err != nil { + t.Fatalf("web search failed: %v", err) + } + if !strings.Contains(searchOut, "BRAVE_API_KEY") { + t.Fatalf("unexpected web search output: %s", searchOut) + } +} + +func configShellForTest() config.ShellConfig { + return config.ShellConfig{} +} diff --git a/pkg/tools/mcp.go b/pkg/tools/mcp.go index 3680b65..c0649e8 100644 --- a/pkg/tools/mcp.go +++ b/pkg/tools/mcp.go @@ -570,7 +570,7 @@ func (c *mcpSSEClient) listAll(ctx context.Context, method, field string) (map[s } batch, _ := result[field].([]interface{}) items = append(items, batch...) - next, _ := result["nextCursor"].(string) + next := mcpStringArg(result, "nextCursor") if strings.TrimSpace(next) == "" { return map[string]interface{}{field: items}, nil } @@ -821,7 +821,7 @@ func (c *mcpHTTPClient) listAll(ctx context.Context, method, field string) (map[ } batch, _ := result[field].([]interface{}) items = append(items, batch...) - next, _ := result["nextCursor"].(string) + next := mcpStringArg(result, "nextCursor") if strings.TrimSpace(next) == "" { return map[string]interface{}{field: items}, nil } @@ -956,7 +956,7 @@ func (c *mcpClient) listAll(ctx context.Context, method, field string) (map[stri } batch, _ := result[field].([]interface{}) items = append(items, batch...) - next, _ := result["nextCursor"].(string) + next := mcpStringArg(result, "nextCursor") if strings.TrimSpace(next) == "" { return map[string]interface{}{field: items}, nil } @@ -1322,14 +1322,9 @@ func renderMCPToolCallResult(result map[string]interface{}) (string, error) { } func mcpStringArg(args map[string]interface{}, key string) string { - v, _ := args[key].(string) - return v + return MapRawStringArg(args, key) } func mcpObjectArg(args map[string]interface{}, key string) map[string]interface{} { - v, _ := args[key].(map[string]interface{}) - if v == nil { - return map[string]interface{}{} - } - return v + return MapObjectArg(args, key) } diff --git a/pkg/tools/memory.go b/pkg/tools/memory.go index 841c63d..0084ef9 100644 --- a/pkg/tools/memory.go +++ b/pkg/tools/memory.go @@ -77,16 +77,13 @@ type searchResult struct { } func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - query, ok := args["query"].(string) - if !ok || query == "" { + query := MapStringArg(args, "query") + if query == "" { return "", fmt.Errorf("query is required") } namespace := parseMemoryNamespaceArg(args) - maxResults := 5 - if m, ok := args["maxResults"].(float64); ok { - maxResults = int(m) - } + maxResults := MapIntArg(args, "maxResults", 5) if maxResults <= 0 { maxResults = 5 } diff --git a/pkg/tools/memory_cron_arg_parsing_test.go b/pkg/tools/memory_cron_arg_parsing_test.go new file mode 100644 index 0000000..a96d368 --- /dev/null +++ b/pkg/tools/memory_cron_arg_parsing_test.go @@ -0,0 +1,62 @@ +package tools + +import ( + "context" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/cron" +) + +func TestCronToolParsesStringArgs(t *testing.T) { + t.Parallel() + + cs := cron.NewCronService(filepath.Join(t.TempDir(), "jobs.json"), nil) + at := time.Now().Add(time.Minute).UnixMilli() + if _, err := cs.AddJob("demo", cron.CronSchedule{Kind: "at", AtMS: &at}, "hello", true, "telegram", "chat-1"); err != nil { + t.Fatalf("add job failed: %v", err) + } + tool := NewCronTool(cs) + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "list", + }) + if err != nil { + t.Fatalf("cron execute failed: %v", err) + } + if !strings.Contains(out, "demo") { + t.Fatalf("unexpected cron output: %s", out) + } +} + +func TestMemoryGetAndWriteParseStringArgs(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + write := NewMemoryWriteTool(workspace) + get := NewMemoryGetTool(workspace) + + if _, err := write.Execute(context.Background(), map[string]interface{}{ + "content": "remember this", + "kind": "longterm", + "importance": "high", + "source": "user", + "tags": "preference,decision", + "append": "true", + }); err != nil { + t.Fatalf("memory write failed: %v", err) + } + + out, err := get.Execute(context.Background(), map[string]interface{}{ + "path": "MEMORY.md", + "from": "1", + "lines": "5", + }) + if err != nil { + t.Fatalf("memory get failed: %v", err) + } + if !strings.Contains(out, "remember this") { + t.Fatalf("unexpected memory get output: %s", out) + } +} diff --git a/pkg/tools/memory_get.go b/pkg/tools/memory_get.go index a0df901..32db72b 100644 --- a/pkg/tools/memory_get.go +++ b/pkg/tools/memory_get.go @@ -54,8 +54,7 @@ func (t *MemoryGetTool) Parameters() map[string]interface{} { } func (t *MemoryGetTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - rawPath, _ := args["path"].(string) - rawPath = strings.TrimSpace(rawPath) + rawPath := MapStringArg(args, "path") if rawPath == "" { return "", fmt.Errorf("path is required") } @@ -64,14 +63,8 @@ func (t *MemoryGetTool) Execute(ctx context.Context, args map[string]interface{} } namespace := parseMemoryNamespaceArg(args) - from := 1 - if v, ok := args["from"].(float64); ok && int(v) > 0 { - from = int(v) - } - lines := 80 - if v, ok := args["lines"].(float64); ok && int(v) > 0 { - lines = int(v) - } + from := MapIntArg(args, "from", 1) + lines := MapIntArg(args, "lines", 80) if lines > 500 { lines = 500 } diff --git a/pkg/tools/memory_namespace.go b/pkg/tools/memory_namespace.go index a18f434..af0669f 100644 --- a/pkg/tools/memory_namespace.go +++ b/pkg/tools/memory_namespace.go @@ -22,11 +22,7 @@ func memoryNamespaceBaseDir(workspace, namespace string) string { } func parseMemoryNamespaceArg(args map[string]interface{}) string { - if args == nil { - return "main" - } - raw, _ := args["namespace"].(string) - return normalizeMemoryNamespace(raw) + return normalizeMemoryNamespace(MapStringArg(args, "namespace")) } func isPathUnder(parent, child string) bool { diff --git a/pkg/tools/memory_namespace_test.go b/pkg/tools/memory_namespace_test.go index 0821b8c..a3d55db 100644 --- a/pkg/tools/memory_namespace_test.go +++ b/pkg/tools/memory_namespace_test.go @@ -101,3 +101,14 @@ func TestMemorySearchToolNamespaceIsolation(t *testing.T) { t.Fatalf("namespace isolation violated, coder search leaked main data: %s", coderRes) } } + +func TestParseMemoryNamespaceArgUsesHelper(t *testing.T) { + t.Parallel() + + if got := parseMemoryNamespaceArg(map[string]interface{}{"namespace": "Coder Agent"}); got != "coder-agent" { + t.Fatalf("unexpected namespace parse result: %q", got) + } + if got := parseMemoryNamespaceArg(nil); got != "main" { + t.Fatalf("expected main namespace for nil args, got %q", got) + } +} diff --git a/pkg/tools/memory_write.go b/pkg/tools/memory_write.go index c3f6692..6caaf9e 100644 --- a/pkg/tools/memory_write.go +++ b/pkg/tools/memory_write.go @@ -69,24 +69,21 @@ func (t *MemoryWriteTool) Parameters() map[string]interface{} { } func (t *MemoryWriteTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - content, _ := args["content"].(string) + content := MapRawStringArg(args, "content") if content == "" { return "error: content is required", nil } namespace := parseMemoryNamespaceArg(args) baseDir := memoryNamespaceBaseDir(t.workspace, namespace) - kind, _ := args["kind"].(string) - kind = strings.ToLower(strings.TrimSpace(kind)) + kind := strings.ToLower(MapStringArg(args, "kind")) if kind == "" { kind = "daily" } - importance, _ := args["importance"].(string) - importance = normalizeImportance(importance) + importance := normalizeImportance(MapStringArg(args, "importance")) - source, _ := args["source"].(string) - source = strings.TrimSpace(source) + source := MapStringArg(args, "source") if source == "" { source = "user" } @@ -94,7 +91,7 @@ func (t *MemoryWriteTool) Execute(ctx context.Context, args map[string]interface tags := parseTags(args["tags"]) appendMode := true - if v, ok := args["append"].(bool); ok { + if v, ok := MapBoolArg(args, "append"); ok { appendMode = v } @@ -160,23 +157,13 @@ func normalizeImportance(v string) string { } func parseTags(raw interface{}) []string { - items, ok := raw.([]interface{}) - if !ok { + items := MapStringListArg(map[string]interface{}{"tags": raw}, "tags") + if len(items) == 0 { return nil } out := make([]string, 0, len(items)) - seen := map[string]struct{}{} - for _, it := range items { - s, _ := it.(string) - s = strings.ToLower(strings.TrimSpace(s)) - if s == "" { - continue - } - if _, exists := seen[s]; exists { - continue - } - seen[s] = struct{}{} - out = append(out, s) + for _, item := range items { + out = append(out, strings.ToLower(strings.TrimSpace(item))) } return out } diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 38d4a91..f293460 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -118,33 +118,32 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) { } func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - action, _ := args["action"].(string) - action = strings.ToLower(strings.TrimSpace(action)) + action := strings.ToLower(MapStringArg(args, "action")) if action == "" { action = "send" } - content, _ := args["content"].(string) - if msg, _ := args["message"].(string); msg != "" { + content := MapRawStringArg(args, "content") + if msg := MapRawStringArg(args, "message"); msg != "" { content = msg } - media, _ := args["media"].(string) + media := MapStringArg(args, "media") if media == "" { - if p, _ := args["path"].(string); p != "" { + if p := MapStringArg(args, "path"); p != "" { media = p } } if media == "" { - if p, _ := args["file_path"].(string); p != "" { + if p := MapStringArg(args, "file_path"); p != "" { media = p } } if media == "" { - if p, _ := args["filePath"].(string); p != "" { + if p := MapStringArg(args, "filePath"); p != "" { media = p } } - messageID, _ := args["message_id"].(string) - emoji, _ := args["emoji"].(string) + messageID := MapStringArg(args, "message_id") + emoji := MapStringArg(args, "emoji") switch action { case "send": @@ -167,9 +166,9 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) return "", fmt.Errorf("%w: %s", ErrUnsupportedAction, action) } - channel, _ := args["channel"].(string) - chatID, _ := args["chat_id"].(string) - if to, _ := args["to"].(string); to != "" { + channel := MapStringArg(args, "channel") + chatID := MapStringArg(args, "chat_id") + if to := MapStringArg(args, "to"); to != "" { chatID = to } @@ -202,8 +201,8 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) buttonRow := pooled[:0] for _, b := range rowArr { if bMap, ok := b.(map[string]interface{}); ok { - text, _ := bMap["text"].(string) - data, _ := bMap["data"].(string) + text := MapStringArg(bMap, "text") + data := MapStringArg(bMap, "data") if text != "" && data != "" { buttonRow = append(buttonRow, bus.Button{Text: text, Data: data}) } diff --git a/pkg/tools/message_process_test.go b/pkg/tools/message_process_test.go new file mode 100644 index 0000000..6a1a0f0 --- /dev/null +++ b/pkg/tools/message_process_test.go @@ -0,0 +1,68 @@ +package tools + +import ( + "context" + "strings" + "testing" + + "github.com/YspCoder/clawgo/pkg/bus" +) + +func TestMessageToolParsesStringAliases(t *testing.T) { + t.Parallel() + + tool := NewMessageTool() + tool.SetContext("telegram", "chat-1") + + called := false + tool.SetSendCallback(func(channel, chatID, action, content, media, messageID, emoji string, buttons [][]bus.Button) error { + called = true + if channel != "telegram" || chatID != "chat-2" { + t.Fatalf("unexpected target: %s %s", channel, chatID) + } + if action != "send" || content != "hello" || media != "/tmp/a.png" { + t.Fatalf("unexpected payload: %s %s %s", action, content, media) + } + if len(buttons) != 1 || len(buttons[0]) != 1 || buttons[0][0].Text != "Open" { + t.Fatalf("unexpected buttons: %#v", buttons) + } + return nil + }) + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "message": "hello", + "to": "chat-2", + "path": "/tmp/a.png", + "buttons": []interface{}{ + []interface{}{ + map[string]interface{}{"text": "Open", "data": "open"}, + }, + }, + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !called || !strings.Contains(out, "Message action=send") { + t.Fatalf("unexpected output: %s", out) + } +} + +func TestProcessToolParsesStringIntegers(t *testing.T) { + t.Parallel() + + pm := NewProcessManager(t.TempDir()) + tool := NewProcessTool(pm) + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "list", + "offset": "10", + "limit": "20", + "timeout_ms": "5", + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !strings.HasPrefix(strings.TrimSpace(out), "[") { + t.Fatalf("expected json list output, got %s", out) + } +} diff --git a/pkg/tools/nodes_tool.go b/pkg/tools/nodes_tool.go index f8f89e8..fc86433 100644 --- a/pkg/tools/nodes_tool.go +++ b/pkg/tools/nodes_tool.go @@ -45,13 +45,12 @@ func (t *NodesTool) Parameters() map[string]interface{} { func (t *NodesTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { _ = ctx - action, _ := args["action"].(string) - action = strings.TrimSpace(strings.ToLower(action)) + action := strings.TrimSpace(strings.ToLower(MapStringArg(args, "action"))) if action == "" { return "", fmt.Errorf("action is required") } - nodeID, _ := args["node"].(string) - mode, _ := args["mode"].(string) + nodeID := MapStringArg(args, "node") + mode := MapStringArg(args, "mode") if t.manager == nil { return "", fmt.Errorf("nodes manager not configured") } @@ -75,33 +74,32 @@ func (t *NodesTool) Execute(ctx context.Context, args map[string]interface{}) (s reqArgs[k] = v } } - if rawPaths, ok := args["artifact_paths"].([]interface{}); ok && len(rawPaths) > 0 { + if rawPaths := MapStringListArg(args, "artifact_paths"); len(rawPaths) > 0 { reqArgs["artifact_paths"] = rawPaths } if cmd, ok := args["command"].([]interface{}); ok && len(cmd) > 0 { reqArgs["command"] = cmd } - if facing, _ := args["facing"].(string); strings.TrimSpace(facing) != "" { + if facing := MapStringArg(args, "facing"); facing != "" { f := strings.ToLower(strings.TrimSpace(facing)) if f != "front" && f != "back" && f != "both" { return "", fmt.Errorf("invalid_args: facing must be front|back|both") } reqArgs["facing"] = f } - if d, ok := args["duration_ms"].(float64); ok { - di := int(d) + if di := MapIntArg(args, "duration_ms", 0); di > 0 { if di <= 0 || di > 300000 { return "", fmt.Errorf("invalid_args: duration_ms must be in 1..300000") } reqArgs["duration_ms"] = di } - task, _ := args["task"].(string) - model, _ := args["model"].(string) + task := MapStringArg(args, "task") + model := MapStringArg(args, "model") if action == "agent_task" && strings.TrimSpace(task) == "" { return "", fmt.Errorf("invalid_args: agent_task requires task") } if action == "canvas_action" { - if act, _ := reqArgs["action"].(string); strings.TrimSpace(act) == "" { + if act := MapStringArg(reqArgs, "action"); act == "" { return "", fmt.Errorf("invalid_args: canvas_action requires args.action") } } @@ -153,11 +151,11 @@ func (t *NodesTool) writeAudit(req nodes.Request, resp nodes.Response, mode stri if len(req.Args) > 0 { row["request_args"] = req.Args } - if used, _ := resp.Payload["used_transport"].(string); strings.TrimSpace(used) != "" { - row["used_transport"] = strings.TrimSpace(used) + if used := MapStringArg(resp.Payload, "used_transport"); used != "" { + row["used_transport"] = used } - if fallback, _ := resp.Payload["fallback_from"].(string); strings.TrimSpace(fallback) != "" { - row["fallback_from"] = strings.TrimSpace(fallback) + if fallback := MapStringArg(resp.Payload, "fallback_from"); fallback != "" { + row["fallback_from"] = fallback } if count, kinds := artifactAuditSummary(resp.Payload["artifacts"]); count > 0 { row["artifact_count"] = count @@ -196,8 +194,8 @@ func artifactAuditSummary(raw interface{}) (int, []string) { if !ok { continue } - if kind, _ := row["kind"].(string); strings.TrimSpace(kind) != "" { - kinds = append(kinds, strings.TrimSpace(kind)) + if kind := MapStringArg(row, "kind"); kind != "" { + kinds = append(kinds, kind) } } return len(items), kinds @@ -231,14 +229,14 @@ func artifactAuditPreviews(raw interface{}) []map[string]interface{} { if size, ok := row["size_bytes"]; ok { entry["size_bytes"] = size } - if text, _ := row["content_text"].(string); strings.TrimSpace(text) != "" { + if text := MapStringArg(row, "content_text"); text != "" { entry["content_text"] = trimAuditContent(text) } - if b64, _ := row["content_base64"].(string); strings.TrimSpace(b64) != "" { + if b64 := MapStringArg(row, "content_base64"); b64 != "" { entry["content_base64"] = trimAuditContent(b64) entry["content_base64_truncated"] = len(b64) > nodeAuditArtifactPreviewLimit } - if truncated, ok := row["truncated"].(bool); ok && truncated { + if truncated, ok := MapBoolArg(row, "truncated"); ok && truncated { entry["truncated"] = true } if len(entry) > 0 { diff --git a/pkg/tools/parallel.go b/pkg/tools/parallel.go index fcfa78d..35ea91c 100644 --- a/pkg/tools/parallel.go +++ b/pkg/tools/parallel.go @@ -74,8 +74,8 @@ func (t *ParallelTool) Parameters() map[string]interface{} { } func (t *ParallelTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - callsRaw, ok := args["calls"].([]interface{}) - if !ok { + callsRaw := interfaceSliceArg(args, "calls") + if len(callsRaw) == 0 { return "", fmt.Errorf("calls must be an array") } @@ -86,9 +86,9 @@ func (t *ParallelTool) Execute(ctx context.Context, args map[string]interface{}) continue } - toolName, _ := call["tool"].(string) - toolArgs, _ := call["arguments"].(map[string]interface{}) - id, _ := call["id"].(string) + toolName := MapStringArg(call, "tool") + toolArgs := MapObjectArg(call, "arguments") + id := MapStringArg(call, "id") if id == "" { id = fmt.Sprintf("call_%d_%s", i, toolName) } diff --git a/pkg/tools/parallel_arg_parsing_test.go b/pkg/tools/parallel_arg_parsing_test.go new file mode 100644 index 0000000..3e3f9f5 --- /dev/null +++ b/pkg/tools/parallel_arg_parsing_test.go @@ -0,0 +1,87 @@ +package tools + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type stubFetchTool struct{} + +func (s *stubFetchTool) Name() string { return "web_fetch" } +func (s *stubFetchTool) Description() string { return "stub" } +func (s *stubFetchTool) Parameters() map[string]interface{} { + return map[string]interface{}{} +} +func (s *stubFetchTool) Execute(_ context.Context, args map[string]interface{}) (string, error) { + return "fetched:" + MapStringArg(args, "url"), nil +} +func (s *stubFetchTool) ParallelSafe() bool { return true } + +func TestMemorySearchToolParsesStringMaxResults(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + write := NewMemoryWriteTool(workspace) + if _, err := write.Execute(context.Background(), map[string]interface{}{ + "content": "alpha beta gamma", + "kind": "longterm", + "importance": "high", + }); err != nil { + t.Fatalf("memory write failed: %v", err) + } + + search := NewMemorySearchTool(workspace) + out, err := search.Execute(context.Background(), map[string]interface{}{ + "query": "alpha", + "maxResults": "1", + }) + if err != nil { + t.Fatalf("memory search failed: %v", err) + } + if !strings.Contains(out, "alpha beta gamma") { + t.Fatalf("unexpected search output: %s", out) + } +} + +func TestParallelToolParsesStringSlices(t *testing.T) { + t.Parallel() + + reg := NewToolRegistry() + reg.Register(&stubFetchTool{}) + tool := NewParallelTool(reg, 2, map[string]struct{}{"web_fetch": {}}) + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "calls": []map[string]interface{}{ + {"tool": "web_fetch", "arguments": map[string]interface{}{"url": "https://example.com"}, "id": "first"}, + }, + }) + if err != nil { + t.Fatalf("parallel execute failed: %v", err) + } + if !strings.Contains(out, "first") || !strings.Contains(out, "https://example.com") { + t.Fatalf("unexpected parallel output: %s", out) + } +} + +func TestParallelFetchToolParsesStringURLsSlice(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("parallel fetch ok")) + })) + defer srv.Close() + + tool := NewParallelFetchTool(NewWebFetchTool(100), 2, map[string]struct{}{"web_fetch": {}}) + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "urls": []string{srv.URL}, + }) + if err != nil { + t.Fatalf("parallel_fetch execute failed: %v", err) + } + if !strings.Contains(out, "parallel fetch ok") { + t.Fatalf("unexpected parallel_fetch output: %s", out) + } +} diff --git a/pkg/tools/parallel_fetch.go b/pkg/tools/parallel_fetch.go index f96c0b2..6ede420 100644 --- a/pkg/tools/parallel_fetch.go +++ b/pkg/tools/parallel_fetch.go @@ -47,8 +47,8 @@ func (t *ParallelFetchTool) Parameters() map[string]interface{} { } func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - urlsRaw, ok := args["urls"].([]interface{}) - if !ok { + urlsRaw := interfaceSliceArg(args, "urls") + if len(urlsRaw) == 0 { return "", fmt.Errorf("urls must be an array") } @@ -66,8 +66,8 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa sem := make(chan struct{}, minParallelLimit(maxParallel, len(urlsRaw))) for i, u := range urlsRaw { - urlStr, ok := u.(string) - if !ok { + urlStr := strings.TrimSpace(fmt.Sprint(u)) + if urlStr == "" || urlStr == "" { results[i] = "Error: invalid url" continue } @@ -95,8 +95,8 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa func (t *ParallelFetchTool) executeSerial(ctx context.Context, urlsRaw []interface{}) string { results := make([]string, len(urlsRaw)) for i, u := range urlsRaw { - urlStr, ok := u.(string) - if !ok { + urlStr := strings.TrimSpace(fmt.Sprint(u)) + if urlStr == "" || urlStr == "" { results[i] = "Error: invalid url" continue } @@ -141,3 +141,31 @@ func minParallelLimit(maxParallel, total int) int { } return maxParallel } + +func interfaceSliceArg(args map[string]interface{}, key string) []interface{} { + if args == nil { + return nil + } + raw, ok := args[key] + if !ok || raw == nil { + return nil + } + switch v := raw.(type) { + case []interface{}: + return v + case []string: + out := make([]interface{}, 0, len(v)) + for _, item := range v { + out = append(out, item) + } + return out + case []map[string]interface{}: + out := make([]interface{}, 0, len(v)) + for _, item := range v { + out = append(out, item) + } + return out + default: + return nil + } +} diff --git a/pkg/tools/process_tool.go b/pkg/tools/process_tool.go index 375de86..59b62b8 100644 --- a/pkg/tools/process_tool.go +++ b/pkg/tools/process_tool.go @@ -24,18 +24,18 @@ func (t *ProcessTool) Parameters() map[string]interface{} { } func (t *ProcessTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - action, _ := args["action"].(string) - sid, _ := args["session_id"].(string) + action := MapStringArg(args, "action") + sid := MapStringArg(args, "session_id") if sid == "" { - sid, _ = args["sessionId"].(string) + sid = MapStringArg(args, "sessionId") } switch action { case "list": b, _ := json.Marshal(t.m.List()) return string(b), nil case "log": - off := toInt(args["offset"]) - lim := toInt(args["limit"]) + off := MapIntArg(args, "offset", 0) + lim := MapIntArg(args, "limit", 0) return t.m.Log(sid, off, lim) case "kill": if err := t.m.Kill(sid); err != nil { @@ -43,7 +43,7 @@ func (t *ProcessTool) Execute(ctx context.Context, args map[string]interface{}) } return "killed", nil case "poll": - timeout := toInt(args["timeout_ms"]) + timeout := MapIntArg(args, "timeout_ms", 0) if timeout < 0 { timeout = 0 } @@ -58,8 +58,8 @@ func (t *ProcessTool) Execute(ctx context.Context, args map[string]interface{}) case <-ctx.Done(): } } - off := toInt(args["offset"]) - lim := toInt(args["limit"]) + off := MapIntArg(args, "offset", 0) + lim := MapIntArg(args, "limit", 0) if lim <= 0 { lim = 1200 } @@ -80,14 +80,3 @@ func (t *ProcessTool) Execute(ctx context.Context, args map[string]interface{}) return "", nil } } - -func toInt(v interface{}) int { - switch x := v.(type) { - case float64: - return int(x) - case int: - return x - default: - return 0 - } -} diff --git a/pkg/tools/remind.go b/pkg/tools/remind.go index 145e468..a388b49 100644 --- a/pkg/tools/remind.go +++ b/pkg/tools/remind.go @@ -3,6 +3,7 @@ package tools import ( "context" "fmt" + "strings" "sync" "time" @@ -65,18 +66,18 @@ func (t *RemindTool) Execute(ctx context.Context, args map[string]interface{}) ( return "", fmt.Errorf("cron service not available") } - message, ok := args["message"].(string) - if !ok { + message := MapRawStringArg(args, "message") + if strings.TrimSpace(message) == "" { return "", fmt.Errorf("message is required") } - timeExpr, ok := args["time_expr"].(string) - if !ok { + timeExpr := MapStringArg(args, "time_expr") + if timeExpr == "" { return "", fmt.Errorf("time_expr is required") } - channel, _ := args["channel"].(string) - chatID, _ := args["chat_id"].(string) + channel := MapStringArg(args, "channel") + chatID := MapStringArg(args, "chat_id") if channel == "" || chatID == "" { t.mu.RLock() defaultChannel := t.defaultChannel diff --git a/pkg/tools/remind_test.go b/pkg/tools/remind_test.go index 5827885..741a6ae 100644 --- a/pkg/tools/remind_test.go +++ b/pkg/tools/remind_test.go @@ -37,3 +37,27 @@ func TestRemindTool_UsesToolContextForDeliveryTarget(t *testing.T) { t.Fatalf("expected to chat-123, got %q", jobs[0].Payload.To) } } + +func TestRemindTool_ParsesStringTargets(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "jobs.json") + cs := cron.NewCronService(storePath, nil) + tool := NewRemindTool(cs) + + _, err := tool.Execute(context.Background(), map[string]interface{}{ + "message": "call mom", + "time_expr": "10m", + "channel": "telegram", + "chat_id": "chat-456", + }) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + jobs := cs.ListJobs(true) + if len(jobs) != 1 { + t.Fatalf("expected 1 job, got %d", len(jobs)) + } + if jobs[0].Payload.Channel != "telegram" || jobs[0].Payload.To != "chat-456" { + t.Fatalf("unexpected delivery target: %+v", jobs[0].Payload) + } +} diff --git a/pkg/tools/repo_map.go b/pkg/tools/repo_map.go index 1772284..9eeba26 100644 --- a/pkg/tools/repo_map.go +++ b/pkg/tools/repo_map.go @@ -65,12 +65,9 @@ type repoMapEntry struct { } func (t *RepoMapTool) Execute(_ context.Context, args map[string]interface{}) (string, error) { - query, _ := args["query"].(string) - maxResults := 20 - if raw, ok := args["max_results"].(float64); ok && raw > 0 { - maxResults = int(raw) - } - forceRefresh, _ := args["refresh"].(bool) + query := MapStringArg(args, "query") + maxResults := MapIntArg(args, "max_results", 20) + forceRefresh, _ := MapBoolArg(args, "refresh") cache, err := t.loadOrBuildMap(forceRefresh) if err != nil { diff --git a/pkg/tools/runtime_snapshot_test.go b/pkg/tools/runtime_snapshot_test.go new file mode 100644 index 0000000..508a12c --- /dev/null +++ b/pkg/tools/runtime_snapshot_test.go @@ -0,0 +1,42 @@ +package tools + +import ( + "context" + "testing" + "time" +) + +func TestSubagentManagerRuntimeSnapshot(t *testing.T) { + workspace := t.TempDir() + manager := NewSubagentManager(nil, workspace, nil) + manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) { + return "snapshot-result", nil + }) + task, err := manager.SpawnTask(context.Background(), SubagentSpawnOptions{ + Task: "implement snapshot support", + AgentID: "coder", + OriginChannel: "cli", + OriginChatID: "direct", + }) + if err != nil { + t.Fatalf("spawn task failed: %v", err) + } + if _, _, err := manager.WaitTask(context.Background(), task.ID); err != nil { + t.Fatalf("wait task failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + snapshot := manager.RuntimeSnapshot(20) + if len(snapshot.Tasks) == 0 || len(snapshot.Runs) == 0 { + t.Fatalf("expected runtime snapshot to include task and run records: %+v", snapshot) + } + if len(snapshot.Threads) == 0 || len(snapshot.Artifacts) == 0 { + t.Fatalf("expected runtime snapshot to include thread and artifact records: %+v", snapshot) + } + msgArtifact := snapshot.Artifacts[0] + if msgArtifact.SourceType != "agent_message" { + t.Fatalf("expected agent message artifact source type, got %+v", msgArtifact) + } + if msgArtifact.FromAgent == "" || msgArtifact.ToAgent == "" || msgArtifact.Name == "" { + t.Fatalf("expected runtime snapshot artifact to preserve message metadata, got %+v", msgArtifact) + } +} diff --git a/pkg/tools/runtime_types.go b/pkg/tools/runtime_types.go new file mode 100644 index 0000000..8a5a060 --- /dev/null +++ b/pkg/tools/runtime_types.go @@ -0,0 +1,151 @@ +package tools + +import ( + "fmt" + "strings" +) + +const ( + RuntimeStatusPending = "pending" + RuntimeStatusRouting = "routing" + RuntimeStatusRunning = "running" + RuntimeStatusWaiting = "waiting" + RuntimeStatusCompleted = "completed" + RuntimeStatusFailed = "failed" + RuntimeStatusCancelled = "cancelled" + RuntimeStatusRecovered = "recovered" +) + +type RuntimeError struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Stage string `json:"stage,omitempty"` + Retryable bool `json:"retryable,omitempty"` + Source string `json:"source,omitempty"` +} + +type DispatchDecision struct { + TargetAgent string `json:"target_agent,omitempty"` + Reason string `json:"reason,omitempty"` + Confidence float64 `json:"confidence,omitempty"` + TaskText string `json:"task_text,omitempty"` + RouteSource string `json:"route_source,omitempty"` +} + +func (d DispatchDecision) Valid() bool { + return strings.TrimSpace(d.TargetAgent) != "" && strings.TrimSpace(d.TaskText) != "" +} + +type TaskRecord struct { + ID string `json:"id"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + OwnerAgentID string `json:"owner_agent_id,omitempty"` + Status string `json:"status"` + Input string `json:"input,omitempty"` + OriginChannel string `json:"origin_channel,omitempty"` + OriginChatID string `json:"origin_chat_id,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type RunRecord struct { + ID string `json:"id"` + TaskID string `json:"task_id,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + ParentRunID string `json:"parent_run_id,omitempty"` + Kind string `json:"kind,omitempty"` + Status string `json:"status"` + Input string `json:"input,omitempty"` + Output string `json:"output,omitempty"` + Error *RuntimeError `json:"error,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type EventRecord struct { + ID string `json:"id,omitempty"` + RunID string `json:"run_id,omitempty"` + TaskID string `json:"task_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + Type string `json:"type"` + Status string `json:"status,omitempty"` + Message string `json:"message,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + Error *RuntimeError `json:"error,omitempty"` + At int64 `json:"ts"` +} + +type ArtifactRecord struct { + ID string `json:"id,omitempty"` + RunID string `json:"run_id,omitempty"` + TaskID string `json:"task_id,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + Kind string `json:"kind,omitempty"` + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` + AgentID string `json:"agent_id,omitempty"` + FromAgent string `json:"from_agent,omitempty"` + ToAgent string `json:"to_agent,omitempty"` + ReplyTo string `json:"reply_to,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + Status string `json:"status,omitempty"` + RequiresReply bool `json:"requires_reply,omitempty"` + CreatedAt int64 `json:"created_at"` + Visible bool `json:"visible"` + SourceType string `json:"source_type,omitempty"` +} + +type ThreadRecord struct { + ID string `json:"id"` + OwnerAgentID string `json:"owner_agent_id,omitempty"` + Participants []string `json:"participants,omitempty"` + Status string `json:"status"` + Topic string `json:"topic,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type RuntimeSnapshot struct { + Tasks []TaskRecord `json:"tasks,omitempty"` + Runs []RunRecord `json:"runs,omitempty"` + Events []EventRecord `json:"events,omitempty"` + Threads []ThreadRecord `json:"threads,omitempty"` + Artifacts []ArtifactRecord `json:"artifacts,omitempty"` +} + +type ExecutionRun struct { + Run RunRecord `json:"run"` + Task TaskRecord `json:"task"` + Decision DispatchDecision `json:"decision,omitempty"` +} + +func IsTerminalRuntimeStatus(status string) bool { + switch strings.ToLower(strings.TrimSpace(status)) { + case RuntimeStatusCompleted, RuntimeStatusFailed, RuntimeStatusCancelled: + return true + default: + return false + } +} + +func NewRuntimeError(code, message, stage string, retryable bool, source string) *RuntimeError { + return &RuntimeError{ + Code: strings.TrimSpace(code), + Message: strings.TrimSpace(message), + Stage: strings.TrimSpace(stage), + Retryable: retryable, + Source: strings.TrimSpace(source), + } +} + +func EventRecordID(runID, eventType string, at int64) string { + runID = strings.TrimSpace(runID) + eventType = strings.TrimSpace(eventType) + if runID == "" && eventType == "" && at <= 0 { + return "" + } + return fmt.Sprintf("%s:%s:%d", runID, eventType, at) +} diff --git a/pkg/tools/sessions_tool.go b/pkg/tools/sessions_tool.go index 2594300..39ba5e4 100644 --- a/pkg/tools/sessions_tool.go +++ b/pkg/tools/sessions_tool.go @@ -56,51 +56,21 @@ func (t *SessionsTool) Parameters() map[string]interface{} { func (t *SessionsTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { _ = ctx - action, _ := args["action"].(string) - action = strings.ToLower(strings.TrimSpace(action)) - limit := 20 - if v, ok := args["limit"].(float64); ok && int(v) > 0 { - limit = int(v) - } - includeTools := false - if v, ok := args["include_tools"].(bool); ok { - includeTools = v - } - around := 0 - if v, ok := args["around"].(float64); ok && int(v) > 0 { - around = int(v) - } - before := 0 - if v, ok := args["before"].(float64); ok && int(v) > 0 { - before = int(v) - } - after := 0 - if v, ok := args["after"].(float64); ok && int(v) > 0 { - after = int(v) - } - activeMinutes := 0 - if v, ok := args["active_minutes"].(float64); ok && int(v) > 0 { - activeMinutes = int(v) - } - query, _ := args["query"].(string) - query = strings.ToLower(strings.TrimSpace(query)) - roleFilter, _ := args["role"].(string) - roleFilter = strings.ToLower(strings.TrimSpace(roleFilter)) - fromMeSet := false - fromMe := false - if v, ok := args["from_me"].(bool); ok { - fromMeSet = true - fromMe = v - } + action := strings.ToLower(MapStringArg(args, "action")) + limit := MapIntArg(args, "limit", 20) + includeTools, _ := MapBoolArg(args, "include_tools") + around := MapIntArg(args, "around", 0) + before := MapIntArg(args, "before", 0) + after := MapIntArg(args, "after", 0) + activeMinutes := MapIntArg(args, "active_minutes", 0) + query := strings.ToLower(MapStringArg(args, "query")) + roleFilter := strings.ToLower(MapStringArg(args, "role")) + fromMe, fromMeSet := MapBoolArg(args, "from_me") kindFilter := map[string]struct{}{} - if rawKinds, ok := args["kinds"].([]interface{}); ok { - for _, it := range rawKinds { - if s, ok := it.(string); ok { - s = strings.ToLower(strings.TrimSpace(s)) - if s != "" { - kindFilter[s] = struct{}{} - } - } + for _, s := range MapStringListArg(args, "kinds") { + s = strings.ToLower(strings.TrimSpace(s)) + if s != "" { + kindFilter[s] = struct{}{} } } @@ -160,7 +130,7 @@ func (t *SessionsTool) Execute(ctx context.Context, args map[string]interface{}) if t.historyFn == nil { return "sessions history unavailable", nil } - key, _ := args["key"].(string) + key := MapStringArg(args, "key") if key == "" { return "key is required for history", nil } diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 816de96..eb2dcd8 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -65,13 +65,13 @@ func (t *ExecTool) Parameters() map[string]interface{} { } func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - command, ok := args["command"].(string) - if !ok { + command := MapRawStringArg(args, "command") + if strings.TrimSpace(command) == "" { return "", fmt.Errorf("command is required") } cwd := t.workingDir - if wd, ok := args["working_dir"].(string); ok && wd != "" { + if wd := MapStringArg(args, "working_dir"); wd != "" { cwd = wd } @@ -87,7 +87,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st } globalCommandWatchdog.setQueuePath(resolveCommandQueuePath(queueBase)) - if bg, _ := args["background"].(bool); bg { + if bg, _ := MapBoolArg(args, "background"); bg { if t.procManager == nil { return "", fmt.Errorf("background process manager not configured") } diff --git a/pkg/tools/skill_exec.go b/pkg/tools/skill_exec.go index ff26ce7..cb4061f 100644 --- a/pkg/tools/skill_exec.go +++ b/pkg/tools/skill_exec.go @@ -59,11 +59,11 @@ func (t *SkillExecTool) Parameters() map[string]interface{} { } func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - skill, _ := args["skill"].(string) - script, _ := args["script"].(string) - reason, _ := args["reason"].(string) - callerAgent, _ := args["caller_agent"].(string) - callerScope, _ := args["caller_scope"].(string) + skill := MapStringArg(args, "skill") + script := MapStringArg(args, "script") + reason := MapStringArg(args, "reason") + callerAgent := MapStringArg(args, "caller_agent") + callerScope := MapStringArg(args, "caller_scope") reason = strings.TrimSpace(reason) if reason == "" { reason = "unspecified" @@ -115,14 +115,7 @@ func (t *SkillExecTool) Execute(ctx context.Context, args map[string]interface{} return "", err } - cmdArgs := []string{} - if rawArgs, ok := args["args"].([]interface{}); ok { - for _, item := range rawArgs { - if s, ok := item.(string); ok { - cmdArgs = append(cmdArgs, s) - } - } - } + cmdArgs := MapStringListArg(args, "args") commandLabel := relScript if len(cmdArgs) > 0 { diff --git a/pkg/tools/skill_exec_args_test.go b/pkg/tools/skill_exec_args_test.go new file mode 100644 index 0000000..b469fb9 --- /dev/null +++ b/pkg/tools/skill_exec_args_test.go @@ -0,0 +1,38 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestSkillExecParsesStringArgsList(t *testing.T) { + workspace := t.TempDir() + skillDir := filepath.Join(workspace, "skills", "demo") + scriptDir := filepath.Join(skillDir, "scripts") + if err := os.MkdirAll(scriptDir, 0755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Demo\n"), 0644); err != nil { + t.Fatalf("write skill md failed: %v", err) + } + scriptPath := filepath.Join(scriptDir, "run.sh") + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\nprintf \"%s %s\" \"$1\" \"$2\"\n"), 0755); err != nil { + t.Fatalf("write script failed: %v", err) + } + + tool := NewSkillExecTool(workspace) + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "skill": "demo", + "script": "scripts/run.sh", + "args": "hello,world", + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !strings.Contains(out, "hello world") { + t.Fatalf("unexpected output: %s", out) + } +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 5303167..d3fd356 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -90,19 +90,19 @@ func (t *SpawnTool) SetContext(channel, chatID string) { } func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - task, ok := args["task"].(string) - if !ok { + task := MapStringArg(args, "task") + if task == "" { return "", fmt.Errorf("task is required") } - label, _ := args["label"].(string) - role, _ := args["role"].(string) - agentID, _ := args["agent_id"].(string) - maxRetries := intArg(args, "max_retries") - retryBackoff := intArg(args, "retry_backoff_ms") - timeoutSec := intArg(args, "timeout_sec") - maxTaskChars := intArg(args, "max_task_chars") - maxResultChars := intArg(args, "max_result_chars") + label := MapStringArg(args, "label") + role := MapStringArg(args, "role") + agentID := MapStringArg(args, "agent_id") + maxRetries := MapIntArg(args, "max_retries", 0) + retryBackoff := MapIntArg(args, "retry_backoff_ms", 0) + timeoutSec := MapIntArg(args, "timeout_sec", 0) + maxTaskChars := MapIntArg(args, "max_task_chars", 0) + maxResultChars := MapIntArg(args, "max_result_chars", 0) if label == "" && role != "" { label = role } else if label == "" && agentID != "" { @@ -113,8 +113,8 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s return "Error: Subagent manager not configured", nil } - originChannel, _ := args["channel"].(string) - originChatID, _ := args["chat_id"].(string) + originChannel := MapStringArg(args, "channel") + originChatID := MapStringArg(args, "chat_id") if originChannel == "" || originChatID == "" { t.mu.RLock() defaultChannel := t.originChannel @@ -147,19 +147,3 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s return result, nil } - -func intArg(args map[string]interface{}, key string) int { - if args == nil { - return 0 - } - if v, ok := args[key].(float64); ok { - return int(v) - } - if v, ok := args[key].(int); ok { - return v - } - if v, ok := args[key].(int64); ok { - return int(v) - } - return 0 -} diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index c368fe5..00d0936 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -107,7 +107,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b if runStore != nil { for _, task := range runStore.List() { mgr.tasks[task.ID] = task - if task.Status == "running" { + if task.Status == RuntimeStatusRunning { mgr.recoverableTaskIDs = append(mgr.recoverableTaskIDs, task.ID) } } @@ -300,7 +300,7 @@ func (sm *SubagentManager) spawnTask(ctx context.Context, opts SubagentSpawnOpti ParentRunID: parentRunID, OriginChannel: originChannel, OriginChatID: originChatID, - Status: "running", + Status: RuntimeStatusRouting, Created: now, Updated: now, } @@ -332,7 +332,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { }() sm.mu.Lock() - task.Status = "running" + task.Status = RuntimeStatusRunning task.Created = time.Now().UnixMilli() task.Updated = task.Created sm.persistTaskLocked(task, "started", "") @@ -341,7 +341,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { result, runErr := sm.runWithRetry(ctx, task) sm.mu.Lock() if runErr != nil { - task.Status = "failed" + task.Status = RuntimeStatusFailed task.Result = fmt.Sprintf("Error: %v", runErr) task.Result = applySubagentResultQuota(task.Result, task.MaxResultChars) task.Updated = time.Now().UnixMilli() @@ -357,10 +357,10 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { Status: "delivered", CreatedAt: task.Updated, }) - sm.persistTaskLocked(task, "completed", task.Result) + sm.persistTaskLocked(task, "failed", task.Result) sm.notifyTaskWaitersLocked(task.ID) } else { - task.Status = "completed" + task.Status = RuntimeStatusCompleted task.Result = applySubagentResultQuota(result, task.MaxResultChars) task.Updated = time.Now().UnixMilli() task.WaitingReply = false @@ -789,8 +789,8 @@ func (sm *SubagentManager) KillTask(taskID string) bool { cancel() delete(sm.cancelFuncs, taskID) } - if t.Status == "running" { - t.Status = "killed" + if !IsTerminalRuntimeStatus(t.Status) { + t.Status = RuntimeStatusCancelled t.WaitingReply = false t.Updated = time.Now().UnixMilli() sm.persistTaskLocked(t, "killed", "") @@ -885,6 +885,53 @@ func (sm *SubagentManager) Events(taskID string, limit int) ([]SubagentRunEvent, return sm.runStore.Events(taskID, limit) } +func (sm *SubagentManager) RuntimeSnapshot(limit int) RuntimeSnapshot { + if sm == nil { + return RuntimeSnapshot{} + } + tasks := sm.ListTasks() + snapshot := RuntimeSnapshot{ + Tasks: make([]TaskRecord, 0, len(tasks)), + Runs: make([]RunRecord, 0, len(tasks)), + } + seenThreads := map[string]struct{}{} + for _, task := range tasks { + snapshot.Tasks = append(snapshot.Tasks, taskToTaskRecord(task)) + snapshot.Runs = append(snapshot.Runs, taskToRunRecord(task)) + if evts, err := sm.Events(task.ID, limit); err == nil { + for _, evt := range evts { + snapshot.Events = append(snapshot.Events, EventRecord{ + ID: EventRecordID(evt.RunID, evt.Type, evt.At), + RunID: evt.RunID, + TaskID: evt.RunID, + AgentID: evt.AgentID, + Type: evt.Type, + Status: evt.Status, + Message: evt.Message, + RetryCount: evt.RetryCount, + At: evt.At, + }) + } + } + threadID := strings.TrimSpace(task.ThreadID) + if threadID == "" { + continue + } + if _, ok := seenThreads[threadID]; !ok { + if thread, found := sm.Thread(threadID); found { + snapshot.Threads = append(snapshot.Threads, threadToThreadRecord(thread)) + } + seenThreads[threadID] = struct{}{} + } + if msgs, err := sm.ThreadMessages(threadID, limit); err == nil { + for _, msg := range msgs { + snapshot.Artifacts = append(snapshot.Artifacts, messageToArtifactRecord(msg)) + } + } + } + return snapshot +} + func (sm *SubagentManager) Thread(threadID string) (*AgentThread, bool) { if sm.mailboxStore == nil { return nil, false @@ -929,7 +976,7 @@ func (sm *SubagentManager) pruneArchivedLocked() { } cutoff := time.Now().Add(-time.Duration(sm.archiveAfterMinute) * time.Minute).UnixMilli() for id, t := range sm.tasks { - if t.Status == "running" { + if !IsTerminalRuntimeStatus(t.Status) { continue } if t.Updated > 0 && t.Updated < cutoff { @@ -1035,13 +1082,13 @@ func (sm *SubagentManager) WaitTask(ctx context.Context, taskID string) (*Subage task, ok := sm.tasks[taskID] if !ok && sm.runStore != nil { if persisted, found := sm.runStore.Get(taskID); found && persisted != nil { - if strings.TrimSpace(persisted.Status) != "running" { + if IsTerminalRuntimeStatus(persisted.Status) { sm.mu.Unlock() return persisted, true, nil } } } - if ok && task != nil && strings.TrimSpace(task.Status) != "running" { + if ok && task != nil && IsTerminalRuntimeStatus(task.Status) { cp := cloneSubagentTask(task) sm.mu.Unlock() return cp, true, nil @@ -1063,13 +1110,13 @@ func (sm *SubagentManager) WaitTask(ctx context.Context, taskID string) (*Subage sm.mu.Lock() sm.pruneArchivedLocked() task, ok := sm.tasks[taskID] - if ok && task != nil && strings.TrimSpace(task.Status) != "running" { + if ok && task != nil && IsTerminalRuntimeStatus(task.Status) { cp := cloneSubagentTask(task) sm.mu.Unlock() return cp, true, nil } if !ok && sm.runStore != nil { - if persisted, found := sm.runStore.Get(taskID); found && persisted != nil && strings.TrimSpace(persisted.Status) != "running" { + if persisted, found := sm.runStore.Get(taskID); found && persisted != nil && IsTerminalRuntimeStatus(persisted.Status) { sm.mu.Unlock() return persisted, true, nil } diff --git a/pkg/tools/subagent_config_manager.go b/pkg/tools/subagent_config_manager.go index 07ea39f..f4ac6e1 100644 --- a/pkg/tools/subagent_config_manager.go +++ b/pkg/tools/subagent_config_manager.go @@ -154,53 +154,15 @@ func DeleteConfigSubagent(configPath, agentID string) (map[string]interface{}, e } func stringArgFromMap(args map[string]interface{}, key string) string { - if args == nil { - return "" - } - v, _ := args[key].(string) - return strings.TrimSpace(v) + return MapStringArg(args, key) } func boolArgFromMap(args map[string]interface{}, key string) (bool, bool) { - if args == nil { - return false, false - } - raw, ok := args[key] - if !ok { - return false, false - } - switch v := raw.(type) { - case bool: - return v, true - default: - return false, false - } + return MapBoolArg(args, key) } func stringListArgFromMap(args map[string]interface{}, key string) []string { - if args == nil { - return nil - } - raw, ok := args[key] - if !ok { - return nil - } - switch v := raw.(type) { - case []string: - return normalizeKeywords(v) - case []interface{}: - items := make([]string, 0, len(v)) - for _, item := range v { - s, _ := item.(string) - s = strings.TrimSpace(s) - if s != "" { - items = append(items, s) - } - } - return normalizeKeywords(items) - default: - return nil - } + return normalizeKeywords(MapStringListArg(args, key)) } func upsertRouteRuleConfig(rules []config.AgentRouteRule, rule config.AgentRouteRule) []config.AgentRouteRule { diff --git a/pkg/tools/subagent_config_tool_test.go b/pkg/tools/subagent_config_tool_test.go index 59e44b7..642d5df 100644 --- a/pkg/tools/subagent_config_tool_test.go +++ b/pkg/tools/subagent_config_tool_test.go @@ -63,3 +63,50 @@ func TestSubagentConfigToolUpsert(t *testing.T) { t.Fatalf("expected router rules to persist") } } + +func TestSubagentConfigToolUpsertParsesStringAndCSVArgs(t *testing.T) { + workspace := t.TempDir() + configPath := filepath.Join(workspace, "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Router.Enabled = true + cfg.Agents.Subagents["main"] = config.SubagentConfig{ + Enabled: true, + Type: "router", + Role: "orchestrator", + SystemPromptFile: "agents/main/AGENT.md", + } + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("save config failed: %v", err) + } + runtimecfg.Set(cfg) + t.Cleanup(func() { runtimecfg.Set(config.DefaultConfig()) }) + + tool := NewSubagentConfigTool(configPath) + _, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "upsert", + "agent_id": "reviewer", + "enabled": "true", + "role": "testing", + "system_prompt_file": "agents/reviewer/AGENT.md", + "routing_keywords": "review, regression", + "tool_allowlist": "shell, sessions", + }) + if err != nil { + t.Fatalf("upsert failed: %v", err) + } + + reloaded, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("reload config failed: %v", err) + } + subcfg := reloaded.Agents.Subagents["reviewer"] + if !subcfg.Enabled { + t.Fatalf("expected reviewer to be enabled, got %+v", subcfg) + } + if len(subcfg.Tools.Allowlist) != 2 { + t.Fatalf("expected allowlist to parse from csv, got %+v", subcfg.Tools.Allowlist) + } + if len(reloaded.Agents.Router.Rules) != 1 || len(reloaded.Agents.Router.Rules[0].Keywords) != 2 { + t.Fatalf("expected routing keywords to parse from csv, got %+v", reloaded.Agents.Router.Rules) + } +} diff --git a/pkg/tools/subagent_mailbox.go b/pkg/tools/subagent_mailbox.go index b146839..b5c9974 100644 --- a/pkg/tools/subagent_mailbox.go +++ b/pkg/tools/subagent_mailbox.go @@ -337,6 +337,47 @@ func parseThreadSequence(threadID string) int { return n } +func threadToThreadRecord(thread *AgentThread) ThreadRecord { + if thread == nil { + return ThreadRecord{} + } + return ThreadRecord{ + ID: thread.ThreadID, + OwnerAgentID: thread.Owner, + Participants: append([]string(nil), thread.Participants...), + Status: thread.Status, + Topic: thread.Topic, + CreatedAt: thread.CreatedAt, + UpdatedAt: thread.UpdatedAt, + } +} + +func messageToArtifactRecord(msg AgentMessage) ArtifactRecord { + agentID := strings.TrimSpace(msg.FromAgent) + if agentID == "" { + agentID = strings.TrimSpace(msg.ToAgent) + } + return ArtifactRecord{ + ID: msg.MessageID, + RunID: msg.CorrelationID, + TaskID: msg.CorrelationID, + ThreadID: msg.ThreadID, + Kind: "message", + Name: msg.Type, + Content: msg.Content, + AgentID: agentID, + FromAgent: msg.FromAgent, + ToAgent: msg.ToAgent, + ReplyTo: msg.ReplyTo, + CorrelationID: msg.CorrelationID, + Status: msg.Status, + RequiresReply: msg.RequiresReply, + CreatedAt: msg.CreatedAt, + Visible: true, + SourceType: "agent_message", + } +} + func parseMessageSequence(messageID string) int { messageID = strings.TrimSpace(messageID) if !strings.HasPrefix(messageID, "msg-") { diff --git a/pkg/tools/subagent_profile.go b/pkg/tools/subagent_profile.go index 08f5f12..8dc04e4 100644 --- a/pkg/tools/subagent_profile.go +++ b/pkg/tools/subagent_profile.go @@ -269,20 +269,7 @@ func clampInt(v, min, max int) int { } func parseStringList(raw interface{}) []string { - items, ok := raw.([]interface{}) - if !ok { - return nil - } - out := make([]string, 0, len(items)) - for _, item := range items { - s, _ := item.(string) - s = strings.TrimSpace(s) - if s == "" { - continue - } - out = append(out, s) - } - return normalizeStringList(out) + return normalizeStringList(MapStringListArg(map[string]interface{}{"items": raw}, "items")) } func (s *SubagentProfileStore) mergedProfilesLocked() (map[string]SubagentProfile, error) { @@ -574,10 +561,8 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter if t.store == nil { return "subagent profile store not available", nil } - action, _ := args["action"].(string) - action = strings.ToLower(strings.TrimSpace(action)) - agentID, _ := args["agent_id"].(string) - agentID = normalizeSubagentIdentifier(agentID) + action := strings.ToLower(MapStringArg(args, "action")) + agentID := normalizeSubagentIdentifier(MapStringArg(args, "agent_id")) switch action { case "list": @@ -724,22 +709,9 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter } func stringArg(args map[string]interface{}, key string) string { - v, _ := args[key].(string) - return strings.TrimSpace(v) + return MapStringArg(args, key) } func profileIntArg(args map[string]interface{}, key string) int { - if args == nil { - return 0 - } - switch v := args[key].(type) { - case float64: - return int(v) - case int: - return v - case int64: - return int(v) - default: - return 0 - } + return MapIntArg(args, key, 0) } diff --git a/pkg/tools/subagent_profile_test.go b/pkg/tools/subagent_profile_test.go index 070ee69..18e16cd 100644 --- a/pkg/tools/subagent_profile_test.go +++ b/pkg/tools/subagent_profile_test.go @@ -44,6 +44,40 @@ func TestSubagentProfileStoreNormalization(t *testing.T) { } } +func TestSubagentProfileToolCreateParsesStringNumericArgs(t *testing.T) { + store := NewSubagentProfileStore(t.TempDir()) + tool := NewSubagentProfileTool(store) + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "create", + "agent_id": "reviewer", + "role": "testing", + "status": "active", + "system_prompt_file": "agents/reviewer/AGENT.md", + "tool_allowlist": "shell,sessions", + "max_retries": "2", + "retry_backoff_ms": "100", + "timeout_sec": "5", + }) + if err != nil { + t.Fatalf("create failed: %v", err) + } + if !strings.Contains(out, "Created subagent profile") { + t.Fatalf("unexpected output: %s", out) + } + + profile, ok, err := store.Get("reviewer") + if err != nil || !ok { + t.Fatalf("expected created profile, got ok=%v err=%v", ok, err) + } + if profile.MaxRetries != 2 || profile.TimeoutSec != 5 { + t.Fatalf("unexpected numeric fields: %+v", profile) + } + if len(profile.ToolAllowlist) != 2 { + t.Fatalf("unexpected allowlist: %+v", profile.ToolAllowlist) + } +} + func TestSubagentManagerSpawnRejectsDisabledProfile(t *testing.T) { workspace := t.TempDir() manager := NewSubagentManager(nil, workspace, nil) diff --git a/pkg/tools/subagent_router.go b/pkg/tools/subagent_router.go index c17f604..c00bd10 100644 --- a/pkg/tools/subagent_router.go +++ b/pkg/tools/subagent_router.go @@ -12,6 +12,7 @@ type RouterDispatchRequest struct { Label string Role string AgentID string + Decision *DispatchDecision NotifyMainPolicy string ThreadID string CorrelationID string @@ -32,6 +33,8 @@ type RouterReply struct { AgentID string Status string Result string + Run RunRecord + Error *RuntimeError } type SubagentRouter struct { @@ -46,6 +49,14 @@ func (r *SubagentRouter) DispatchTask(ctx context.Context, req RouterDispatchReq if r == nil || r.manager == nil { return nil, fmt.Errorf("subagent router is not configured") } + if req.Decision != nil { + if strings.TrimSpace(req.AgentID) == "" { + req.AgentID = strings.TrimSpace(req.Decision.TargetAgent) + } + if strings.TrimSpace(req.Task) == "" { + req.Task = strings.TrimSpace(req.Decision.TaskText) + } + } task, err := r.manager.SpawnTask(ctx, SubagentSpawnOptions{ Task: req.Task, Label: req.Label, @@ -92,6 +103,8 @@ func (r *SubagentRouter) WaitReply(ctx context.Context, taskID string, interval AgentID: task.AgentID, Status: task.Status, Result: strings.TrimSpace(task.Result), + Run: taskToRunRecord(task), + Error: taskRuntimeError(task), }, nil } diff --git a/pkg/tools/subagent_store.go b/pkg/tools/subagent_store.go index fd903d5..f83dab6 100644 --- a/pkg/tools/subagent_store.go +++ b/pkg/tools/subagent_store.go @@ -69,6 +69,23 @@ func (s *SubagentRunStore) load() error { if line == "" { continue } + var record RunRecord + if err := json.Unmarshal([]byte(line), &record); err == nil && strings.TrimSpace(record.ID) != "" { + task := &SubagentTask{ + ID: record.ID, + Task: record.Input, + AgentID: record.AgentID, + ThreadID: record.ThreadID, + CorrelationID: record.CorrelationID, + ParentRunID: record.ParentRunID, + Status: record.Status, + Result: record.Output, + Created: record.CreatedAt, + Updated: record.UpdatedAt, + } + s.runs[task.ID] = task + continue + } var task SubagentTask if err := json.Unmarshal([]byte(line), &task); err != nil { continue @@ -84,7 +101,7 @@ func (s *SubagentRunStore) AppendRun(task *SubagentTask) error { return nil } cp := cloneSubagentTask(task) - data, err := json.Marshal(cp) + data, err := json.Marshal(taskToRunRecord(cp)) if err != nil { return err } @@ -110,7 +127,18 @@ func (s *SubagentRunStore) AppendEvent(evt SubagentRunEvent) error { if s == nil { return nil } - data, err := json.Marshal(evt) + record := EventRecord{ + ID: EventRecordID(evt.RunID, evt.Type, evt.At), + RunID: evt.RunID, + TaskID: evt.RunID, + AgentID: evt.AgentID, + Type: evt.Type, + Status: evt.Status, + Message: evt.Message, + RetryCount: evt.RetryCount, + At: evt.At, + } + data, err := json.Marshal(record) if err != nil { return err } @@ -185,7 +213,19 @@ func (s *SubagentRunStore) Events(runID string, limit int) ([]SubagentRunEvent, } var evt SubagentRunEvent if err := json.Unmarshal([]byte(line), &evt); err != nil { - continue + var record EventRecord + if err := json.Unmarshal([]byte(line), &record); err != nil { + continue + } + evt = SubagentRunEvent{ + RunID: record.RunID, + AgentID: record.AgentID, + Type: record.Type, + Status: record.Status, + Message: record.Message, + RetryCount: record.RetryCount, + At: record.At, + } } if evt.RunID != runID { continue @@ -249,6 +289,55 @@ func cloneSubagentTask(task *SubagentTask) *SubagentTask { return &cp } +func taskToTaskRecord(task *SubagentTask) TaskRecord { + if task == nil { + return TaskRecord{} + } + return TaskRecord{ + ID: task.ID, + ThreadID: task.ThreadID, + CorrelationID: task.CorrelationID, + OwnerAgentID: task.AgentID, + Status: strings.TrimSpace(task.Status), + Input: task.Task, + OriginChannel: task.OriginChannel, + OriginChatID: task.OriginChatID, + CreatedAt: task.Created, + UpdatedAt: task.Updated, + } +} + +func taskRuntimeError(task *SubagentTask) *RuntimeError { + if task == nil || !strings.EqualFold(strings.TrimSpace(task.Status), RuntimeStatusFailed) { + return nil + } + msg := strings.TrimSpace(task.Result) + msg = strings.TrimPrefix(msg, "Error:") + msg = strings.TrimSpace(msg) + return NewRuntimeError("subagent_failed", msg, "subagent", false, "subagent") +} + +func taskToRunRecord(task *SubagentTask) RunRecord { + if task == nil { + return RunRecord{} + } + return RunRecord{ + ID: task.ID, + TaskID: task.ID, + ThreadID: task.ThreadID, + CorrelationID: task.CorrelationID, + AgentID: task.AgentID, + ParentRunID: task.ParentRunID, + Kind: "subagent", + Status: strings.TrimSpace(task.Status), + Input: task.Task, + Output: strings.TrimSpace(task.Result), + Error: taskRuntimeError(task), + CreatedAt: task.Created, + UpdatedAt: task.Updated, + } +} + func formatSubagentEventLog(evt SubagentRunEvent) string { base := fmt.Sprintf("- %d %s", evt.At, evt.Type) if strings.TrimSpace(evt.Status) != "" { diff --git a/pkg/tools/subagents_tool.go b/pkg/tools/subagents_tool.go index 8a02917..70fd29f 100644 --- a/pkg/tools/subagents_tool.go +++ b/pkg/tools/subagents_tool.go @@ -45,26 +45,14 @@ func (t *SubagentsTool) Execute(ctx context.Context, args map[string]interface{} if t.manager == nil { return "subagent manager not available", nil } - action, _ := args["action"].(string) - action = strings.ToLower(strings.TrimSpace(action)) - id, _ := args["id"].(string) - id = strings.TrimSpace(id) - message, _ := args["message"].(string) - message = strings.TrimSpace(message) - messageID, _ := args["message_id"].(string) - messageID = strings.TrimSpace(messageID) - threadID, _ := args["thread_id"].(string) - threadID = strings.TrimSpace(threadID) - agentID, _ := args["agent_id"].(string) - agentID = strings.TrimSpace(agentID) - limit := 20 - if v, ok := args["limit"].(float64); ok && int(v) > 0 { - limit = int(v) - } - recentMinutes := 0 - if v, ok := args["recent_minutes"].(float64); ok && int(v) > 0 { - recentMinutes = int(v) - } + action := strings.ToLower(MapStringArg(args, "action")) + id := MapStringArg(args, "id") + message := MapStringArg(args, "message") + messageID := MapStringArg(args, "message_id") + threadID := MapStringArg(args, "thread_id") + agentID := MapStringArg(args, "agent_id") + limit := MapIntArg(args, "limit", 20) + recentMinutes := MapIntArg(args, "recent_minutes", 0) switch action { case "list": diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 835dd5a..4b3dc2d 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -66,16 +66,14 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} return "Error: BRAVE_API_KEY not configured", nil } - query, ok := args["query"].(string) - if !ok { + query := MapStringArg(args, "query") + if query == "" { return "", fmt.Errorf("query is required") } count := t.maxResults - if c, ok := args["count"].(float64); ok { - if int(c) > 0 && int(c) <= 10 { - count = int(c) - } + if c := MapIntArg(args, "count", count); c > 0 && c <= 10 { + count = c } searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", @@ -183,8 +181,8 @@ func (t *WebFetchTool) Parameters() map[string]interface{} { } func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - urlStr, ok := args["url"].(string) - if !ok { + urlStr := MapStringArg(args, "url") + if urlStr == "" { return "", fmt.Errorf("url is required") } @@ -202,10 +200,8 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) } maxChars := t.maxChars - if mc, ok := args["maxChars"].(float64); ok { - if int(mc) > 100 { - maxChars = int(mc) - } + if mc := MapIntArg(args, "maxChars", maxChars); mc > 100 { + maxChars = mc } req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) diff --git a/workspace/embedkeep.txt b/workspace/embedkeep.txt new file mode 100644 index 0000000..b54a93c --- /dev/null +++ b/workspace/embedkeep.txt @@ -0,0 +1 @@ +embed workspace placeholder