diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index 3018b54..0e65191 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -175,14 +175,14 @@ func gatewayCmd() { } bindAgentLoopHandlers(agentLoop) var reloadMu sync.Mutex - var applyReload func() error - registryServer.SetConfigAfterHook(func() error { + var applyReload func(forceRuntimeReload bool) error + registryServer.SetConfigAfterHook(func(forceRuntimeReload bool) error { reloadMu.Lock() defer reloadMu.Unlock() if applyReload == nil { return fmt.Errorf("reload handler not ready") } - return applyReload() + return applyReload(forceRuntimeReload) }) whatsAppBridge, whatsAppEmbedded := setupEmbeddedWhatsAppBridge(ctx, cfg) if whatsAppBridge != nil { @@ -341,7 +341,7 @@ func gatewayCmd() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, gatewayNotifySignals()...) - applyReload = func() error { + applyReload = func(forceRuntimeReload bool) error { fmt.Println("\nReloading config...") newCfg, err := config.LoadConfig(getConfigPath()) if err != nil { @@ -357,7 +357,7 @@ func gatewayCmd() { fmt.Printf("Error starting heartbeat service: %v\n", err) } - if reflect.DeepEqual(cfg, newCfg) { + if !forceRuntimeReload && reflect.DeepEqual(cfg, newCfg) { fmt.Println("Config unchanged, skip reload") return nil } @@ -376,7 +376,7 @@ func gatewayCmd() { reflect.DeepEqual(cfg.Tools, newCfg.Tools) && reflect.DeepEqual(cfg.Channels, newCfg.Channels) - if runtimeSame { + if runtimeSame && !forceRuntimeReload { configureLogging(newCfg) sentinelService.Stop() sentinelService = sentinel.NewService( @@ -451,7 +451,7 @@ func gatewayCmd() { switch { case isGatewayReloadSignal(sig): reloadMu.Lock() - err := applyReload() + err := applyReload(false) reloadMu.Unlock() if err != nil { fmt.Printf("Reload failed: %v\n", err) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f31eefc..f5bb170 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -67,6 +67,9 @@ type AgentLoop struct { subagentDigestMu sync.Mutex subagentDigestDelay time.Duration subagentDigests map[string]*subagentDigestState + runMu sync.Mutex + runCancel context.CancelFunc + runWG sync.WaitGroup } type providerCandidate struct { @@ -403,19 +406,34 @@ func (al *AgentLoop) readSubagentPromptFile(relPath string) string { } func (al *AgentLoop) Run(ctx context.Context) error { + al.runMu.Lock() + if al.runCancel != nil { + al.runMu.Unlock() + return fmt.Errorf("agent loop already running") + } + runCtx, cancel := context.WithCancel(ctx) + al.runCancel = cancel al.running = true + al.runMu.Unlock() + defer func() { + al.runMu.Lock() + al.running = false + al.runCancel = nil + al.runMu.Unlock() + }() - shards := al.buildSessionShards(ctx) + shards := al.buildSessionShards(runCtx) defer func() { for _, ch := range shards { close(ch) } + al.runWG.Wait() }() for al.running { - msg, ok := al.bus.ConsumeInbound(ctx) + msg, ok := al.bus.ConsumeInbound(runCtx) if !ok { - if ctx.Err() != nil { + if runCtx.Err() != nil { return nil } continue @@ -423,7 +441,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { idx := sessionShardIndex(msg.SessionKey, len(shards)) select { case shards[idx] <- msg: - case <-ctx.Done(): + case <-runCtx.Done(): return nil } } @@ -432,7 +450,14 @@ func (al *AgentLoop) Run(ctx context.Context) error { } func (al *AgentLoop) Stop() { + al.runMu.Lock() + cancel := al.runCancel + al.runMu.Unlock() + if cancel != nil { + cancel() + } al.running = false + al.runWG.Wait() } func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundMessage { @@ -440,7 +465,9 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM shards := make([]chan bus.InboundMessage, count) for i := 0; i < count; i++ { shards[i] = make(chan bus.InboundMessage, 64) + al.runWG.Add(1) go func(ch <-chan bus.InboundMessage) { + defer al.runWG.Done() for msg := range ch { al.processInbound(ctx, msg) } diff --git a/pkg/api/server.go b/pkg/api/server.go index 1be2b71..6a38974 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -48,7 +48,7 @@ type Server struct { logFilePath string onChat func(ctx context.Context, sessionKey, content string) (string, error) onChatHistory func(sessionKey string) []map[string]interface{} - onConfigAfter func() error + onConfigAfter func(forceRuntimeReload bool) error onCron func(action string, args map[string]interface{}) (interface{}, error) onToolsCatalog func() interface{} whatsAppBridge *channels.WhatsAppBridgeService @@ -85,7 +85,7 @@ func (s *Server) SetChatHandler(fn func(ctx context.Context, sessionKey, content func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) { s.onChatHistory = fn } -func (s *Server) SetConfigAfterHook(fn func() error) { s.onConfigAfter = fn } +func (s *Server) SetConfigAfterHook(fn func(forceRuntimeReload bool) error) { s.onConfigAfter = fn } func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) { s.onCron = fn } @@ -414,7 +414,7 @@ func (s *Server) persistWebUIConfig(cfg *cfgpkg.Config) error { return err } if s.onConfigAfter != nil { - return s.onConfigAfter() + return s.onConfigAfter(false) } return requestSelfReloadSignal() } @@ -978,7 +978,9 @@ func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.P return err } if s.onConfigAfter != nil { - if err := s.onConfigAfter(); err != nil { + // Provider updates can take effect through external credential files + // even when config.json remains structurally identical. + if err := s.onConfigAfter(true); err != nil { return err } } else { @@ -3118,12 +3120,15 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { case http.MethodGet: path := strings.TrimSpace(r.URL.Query().Get("path")) if path == "" { + files := make([]string, 0, 16) + if _, err := os.Stat(filepath.Join(s.workspacePath, "MEMORY.md")); err == nil { + files = append(files, "MEMORY.md") + } entries, err := os.ReadDir(memoryDir) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - files := make([]string, 0, len(entries)) for _, e := range entries { if e.IsDir() { continue @@ -3133,7 +3138,11 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]interface{}{"ok": true, "files": files}) return } - clean, content, found, err := readRelativeTextFile(memoryDir, path) + baseDir := memoryDir + if strings.EqualFold(path, "MEMORY.md") { + baseDir = strings.TrimSpace(s.workspacePath) + } + clean, content, found, err := readRelativeTextFile(baseDir, path) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index 33548dc..372d85e 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -100,7 +100,10 @@ func TestHandleWebUIConfigPostSavesRawConfig(t *testing.T) { srv := NewServer("127.0.0.1", 0, "") srv.SetConfigPath(cfgPath) hookCalled := 0 - srv.SetConfigAfterHook(func() error { + srv.SetConfigAfterHook(func(forceRuntimeReload bool) error { + if forceRuntimeReload { + t.Fatalf("expected raw config save to use non-forced reload") + } hookCalled++ return nil }) @@ -150,7 +153,12 @@ func TestHandleWebUIConfigPostSavesNormalizedConfig(t *testing.T) { srv := NewServer("127.0.0.1", 0, "") srv.SetConfigPath(cfgPath) - srv.SetConfigAfterHook(func() error { return nil }) + srv.SetConfigAfterHook(func(forceRuntimeReload bool) error { + if forceRuntimeReload { + t.Fatalf("expected normalized config save to use non-forced reload") + } + return nil + }) req := httptest.NewRequest(http.MethodPost, "/api/config?mode=normalized", strings.NewReader(`{"core":{"gateway":{"host":"127.0.0.1","port":18790},"tools":{"shell_enabled":false,"mcp_enabled":true}},"runtime":{"router":{"enabled":true,"strategy":"rules_first","max_hops":2,"default_timeout_sec":90},"providers":{"openai":{"api_base":"https://api.openai.com/v1","auth":"bearer","timeout_sec":150}}}}`)) req.Header.Set("Content-Type", "application/json") @@ -249,6 +257,98 @@ func TestHandleWebUISessionsHidesInternalSessionsByDefault(t *testing.T) { } } +func TestSaveProviderConfigForcesRuntimeReload(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Logging.Enabled = false + cfg.Models.Providers["openai"] = cfgpkg.ProviderConfig{ + APIBase: "https://api.openai.com/v1", + Auth: "oauth", + Models: []string{"gpt-5"}, + TimeoutSec: 120, + OAuth: cfgpkg.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: filepath.Join(tmp, "auth.json"), + }, + } + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetConfigPath(cfgPath) + + forced := false + srv.SetConfigAfterHook(func(forceRuntimeReload bool) error { + forced = forceRuntimeReload + return nil + }) + + pc := cfg.Models.Providers["openai"] + if err := srv.saveProviderConfig(cfg, "openai", pc); err != nil { + t.Fatalf("save provider config: %v", err) + } + if !forced { + t.Fatalf("expected provider config save to force runtime reload") + } +} + +func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("# long-term\n"), 0o644); err != nil { + t.Fatalf("write workspace memory: %v", err) + } + if err := os.MkdirAll(filepath.Join(tmp, "memory"), 0o755); err != nil { + t.Fatalf("mkdir memory dir: %v", err) + } + if err := os.WriteFile(filepath.Join(tmp, "memory", "2026-03-19.md"), []byte("daily\n"), 0o644); err != nil { + t.Fatalf("write daily memory: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetWorkspacePath(tmp) + + listReq := httptest.NewRequest(http.MethodGet, "/api/memory", nil) + listRec := httptest.NewRecorder() + srv.handleWebUIMemory(listRec, listReq) + if listRec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", listRec.Code, listRec.Body.String()) + } + var listPayload struct { + OK bool `json:"ok"` + Files []string `json:"files"` + } + if err := json.Unmarshal(listRec.Body.Bytes(), &listPayload); err != nil { + t.Fatalf("decode list payload: %v", err) + } + if len(listPayload.Files) < 2 || listPayload.Files[0] != "MEMORY.md" { + t.Fatalf("expected MEMORY.md in memory file list, got %+v", listPayload.Files) + } + + readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=MEMORY.md", nil) + readRec := httptest.NewRecorder() + srv.handleWebUIMemory(readRec, readReq) + if readRec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", readRec.Code, readRec.Body.String()) + } + var readPayload struct { + OK bool `json:"ok"` + Path string `json:"path"` + Content string `json:"content"` + } + if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil { + t.Fatalf("decode read payload: %v", err) + } + if readPayload.Path != "MEMORY.md" || readPayload.Content != "# long-term\n" { + t.Fatalf("unexpected memory payload: %+v", readPayload) + } +} + func TestHandleWebUIChatLive(t *testing.T) { t.Parallel()