mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-15 01:37:31 +08:00
Fix provider hot reload and restore memory API
This commit is contained in:
@@ -175,14 +175,14 @@ func gatewayCmd() {
|
|||||||
}
|
}
|
||||||
bindAgentLoopHandlers(agentLoop)
|
bindAgentLoopHandlers(agentLoop)
|
||||||
var reloadMu sync.Mutex
|
var reloadMu sync.Mutex
|
||||||
var applyReload func() error
|
var applyReload func(forceRuntimeReload bool) error
|
||||||
registryServer.SetConfigAfterHook(func() error {
|
registryServer.SetConfigAfterHook(func(forceRuntimeReload bool) error {
|
||||||
reloadMu.Lock()
|
reloadMu.Lock()
|
||||||
defer reloadMu.Unlock()
|
defer reloadMu.Unlock()
|
||||||
if applyReload == nil {
|
if applyReload == nil {
|
||||||
return fmt.Errorf("reload handler not ready")
|
return fmt.Errorf("reload handler not ready")
|
||||||
}
|
}
|
||||||
return applyReload()
|
return applyReload(forceRuntimeReload)
|
||||||
})
|
})
|
||||||
whatsAppBridge, whatsAppEmbedded := setupEmbeddedWhatsAppBridge(ctx, cfg)
|
whatsAppBridge, whatsAppEmbedded := setupEmbeddedWhatsAppBridge(ctx, cfg)
|
||||||
if whatsAppBridge != nil {
|
if whatsAppBridge != nil {
|
||||||
@@ -341,7 +341,7 @@ func gatewayCmd() {
|
|||||||
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, gatewayNotifySignals()...)
|
signal.Notify(sigChan, gatewayNotifySignals()...)
|
||||||
applyReload = func() error {
|
applyReload = func(forceRuntimeReload bool) error {
|
||||||
fmt.Println("\nReloading config...")
|
fmt.Println("\nReloading config...")
|
||||||
newCfg, err := config.LoadConfig(getConfigPath())
|
newCfg, err := config.LoadConfig(getConfigPath())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -357,7 +357,7 @@ func gatewayCmd() {
|
|||||||
fmt.Printf("Error starting heartbeat service: %v\n", err)
|
fmt.Printf("Error starting heartbeat service: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reflect.DeepEqual(cfg, newCfg) {
|
if !forceRuntimeReload && reflect.DeepEqual(cfg, newCfg) {
|
||||||
fmt.Println("Config unchanged, skip reload")
|
fmt.Println("Config unchanged, skip reload")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -376,7 +376,7 @@ func gatewayCmd() {
|
|||||||
reflect.DeepEqual(cfg.Tools, newCfg.Tools) &&
|
reflect.DeepEqual(cfg.Tools, newCfg.Tools) &&
|
||||||
reflect.DeepEqual(cfg.Channels, newCfg.Channels)
|
reflect.DeepEqual(cfg.Channels, newCfg.Channels)
|
||||||
|
|
||||||
if runtimeSame {
|
if runtimeSame && !forceRuntimeReload {
|
||||||
configureLogging(newCfg)
|
configureLogging(newCfg)
|
||||||
sentinelService.Stop()
|
sentinelService.Stop()
|
||||||
sentinelService = sentinel.NewService(
|
sentinelService = sentinel.NewService(
|
||||||
@@ -451,7 +451,7 @@ func gatewayCmd() {
|
|||||||
switch {
|
switch {
|
||||||
case isGatewayReloadSignal(sig):
|
case isGatewayReloadSignal(sig):
|
||||||
reloadMu.Lock()
|
reloadMu.Lock()
|
||||||
err := applyReload()
|
err := applyReload(false)
|
||||||
reloadMu.Unlock()
|
reloadMu.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Reload failed: %v\n", err)
|
fmt.Printf("Reload failed: %v\n", err)
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ type AgentLoop struct {
|
|||||||
subagentDigestMu sync.Mutex
|
subagentDigestMu sync.Mutex
|
||||||
subagentDigestDelay time.Duration
|
subagentDigestDelay time.Duration
|
||||||
subagentDigests map[string]*subagentDigestState
|
subagentDigests map[string]*subagentDigestState
|
||||||
|
runMu sync.Mutex
|
||||||
|
runCancel context.CancelFunc
|
||||||
|
runWG sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type providerCandidate struct {
|
type providerCandidate struct {
|
||||||
@@ -403,19 +406,34 @@ func (al *AgentLoop) readSubagentPromptFile(relPath string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||||
|
al.runMu.Lock()
|
||||||
|
if al.runCancel != nil {
|
||||||
|
al.runMu.Unlock()
|
||||||
|
return fmt.Errorf("agent loop already running")
|
||||||
|
}
|
||||||
|
runCtx, cancel := context.WithCancel(ctx)
|
||||||
|
al.runCancel = cancel
|
||||||
al.running = true
|
al.running = true
|
||||||
|
al.runMu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
al.runMu.Lock()
|
||||||
|
al.running = false
|
||||||
|
al.runCancel = nil
|
||||||
|
al.runMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
shards := al.buildSessionShards(ctx)
|
shards := al.buildSessionShards(runCtx)
|
||||||
defer func() {
|
defer func() {
|
||||||
for _, ch := range shards {
|
for _, ch := range shards {
|
||||||
close(ch)
|
close(ch)
|
||||||
}
|
}
|
||||||
|
al.runWG.Wait()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for al.running {
|
for al.running {
|
||||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
msg, ok := al.bus.ConsumeInbound(runCtx)
|
||||||
if !ok {
|
if !ok {
|
||||||
if ctx.Err() != nil {
|
if runCtx.Err() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -423,7 +441,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
|||||||
idx := sessionShardIndex(msg.SessionKey, len(shards))
|
idx := sessionShardIndex(msg.SessionKey, len(shards))
|
||||||
select {
|
select {
|
||||||
case shards[idx] <- msg:
|
case shards[idx] <- msg:
|
||||||
case <-ctx.Done():
|
case <-runCtx.Done():
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -432,7 +450,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) Stop() {
|
func (al *AgentLoop) Stop() {
|
||||||
|
al.runMu.Lock()
|
||||||
|
cancel := al.runCancel
|
||||||
|
al.runMu.Unlock()
|
||||||
|
if cancel != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
al.running = false
|
al.running = false
|
||||||
|
al.runWG.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundMessage {
|
func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundMessage {
|
||||||
@@ -440,7 +465,9 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM
|
|||||||
shards := make([]chan bus.InboundMessage, count)
|
shards := make([]chan bus.InboundMessage, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
shards[i] = make(chan bus.InboundMessage, 64)
|
shards[i] = make(chan bus.InboundMessage, 64)
|
||||||
|
al.runWG.Add(1)
|
||||||
go func(ch <-chan bus.InboundMessage) {
|
go func(ch <-chan bus.InboundMessage) {
|
||||||
|
defer al.runWG.Done()
|
||||||
for msg := range ch {
|
for msg := range ch {
|
||||||
al.processInbound(ctx, msg)
|
al.processInbound(ctx, msg)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ type Server struct {
|
|||||||
logFilePath string
|
logFilePath string
|
||||||
onChat func(ctx context.Context, sessionKey, content string) (string, error)
|
onChat func(ctx context.Context, sessionKey, content string) (string, error)
|
||||||
onChatHistory func(sessionKey string) []map[string]interface{}
|
onChatHistory func(sessionKey string) []map[string]interface{}
|
||||||
onConfigAfter func() error
|
onConfigAfter func(forceRuntimeReload bool) error
|
||||||
onCron func(action string, args map[string]interface{}) (interface{}, error)
|
onCron func(action string, args map[string]interface{}) (interface{}, error)
|
||||||
onToolsCatalog func() interface{}
|
onToolsCatalog func() interface{}
|
||||||
whatsAppBridge *channels.WhatsAppBridgeService
|
whatsAppBridge *channels.WhatsAppBridgeService
|
||||||
@@ -85,7 +85,7 @@ func (s *Server) SetChatHandler(fn func(ctx context.Context, sessionKey, content
|
|||||||
func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) {
|
func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) {
|
||||||
s.onChatHistory = fn
|
s.onChatHistory = fn
|
||||||
}
|
}
|
||||||
func (s *Server) SetConfigAfterHook(fn func() error) { s.onConfigAfter = fn }
|
func (s *Server) SetConfigAfterHook(fn func(forceRuntimeReload bool) error) { s.onConfigAfter = fn }
|
||||||
func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) {
|
func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) {
|
||||||
s.onCron = fn
|
s.onCron = fn
|
||||||
}
|
}
|
||||||
@@ -414,7 +414,7 @@ func (s *Server) persistWebUIConfig(cfg *cfgpkg.Config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.onConfigAfter != nil {
|
if s.onConfigAfter != nil {
|
||||||
return s.onConfigAfter()
|
return s.onConfigAfter(false)
|
||||||
}
|
}
|
||||||
return requestSelfReloadSignal()
|
return requestSelfReloadSignal()
|
||||||
}
|
}
|
||||||
@@ -978,7 +978,9 @@ func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.P
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.onConfigAfter != nil {
|
if s.onConfigAfter != nil {
|
||||||
if err := s.onConfigAfter(); err != nil {
|
// Provider updates can take effect through external credential files
|
||||||
|
// even when config.json remains structurally identical.
|
||||||
|
if err := s.onConfigAfter(true); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -3118,12 +3120,15 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) {
|
|||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
path := strings.TrimSpace(r.URL.Query().Get("path"))
|
path := strings.TrimSpace(r.URL.Query().Get("path"))
|
||||||
if path == "" {
|
if path == "" {
|
||||||
|
files := make([]string, 0, 16)
|
||||||
|
if _, err := os.Stat(filepath.Join(s.workspacePath, "MEMORY.md")); err == nil {
|
||||||
|
files = append(files, "MEMORY.md")
|
||||||
|
}
|
||||||
entries, err := os.ReadDir(memoryDir)
|
entries, err := os.ReadDir(memoryDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
files := make([]string, 0, len(entries))
|
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if e.IsDir() {
|
if e.IsDir() {
|
||||||
continue
|
continue
|
||||||
@@ -3133,7 +3138,11 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeJSON(w, map[string]interface{}{"ok": true, "files": files})
|
writeJSON(w, map[string]interface{}{"ok": true, "files": files})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
clean, content, found, err := readRelativeTextFile(memoryDir, path)
|
baseDir := memoryDir
|
||||||
|
if strings.EqualFold(path, "MEMORY.md") {
|
||||||
|
baseDir = strings.TrimSpace(s.workspacePath)
|
||||||
|
}
|
||||||
|
clean, content, found, err := readRelativeTextFile(baseDir, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -100,7 +100,10 @@ func TestHandleWebUIConfigPostSavesRawConfig(t *testing.T) {
|
|||||||
srv := NewServer("127.0.0.1", 0, "")
|
srv := NewServer("127.0.0.1", 0, "")
|
||||||
srv.SetConfigPath(cfgPath)
|
srv.SetConfigPath(cfgPath)
|
||||||
hookCalled := 0
|
hookCalled := 0
|
||||||
srv.SetConfigAfterHook(func() error {
|
srv.SetConfigAfterHook(func(forceRuntimeReload bool) error {
|
||||||
|
if forceRuntimeReload {
|
||||||
|
t.Fatalf("expected raw config save to use non-forced reload")
|
||||||
|
}
|
||||||
hookCalled++
|
hookCalled++
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -150,7 +153,12 @@ func TestHandleWebUIConfigPostSavesNormalizedConfig(t *testing.T) {
|
|||||||
|
|
||||||
srv := NewServer("127.0.0.1", 0, "")
|
srv := NewServer("127.0.0.1", 0, "")
|
||||||
srv.SetConfigPath(cfgPath)
|
srv.SetConfigPath(cfgPath)
|
||||||
srv.SetConfigAfterHook(func() error { return nil })
|
srv.SetConfigAfterHook(func(forceRuntimeReload bool) error {
|
||||||
|
if forceRuntimeReload {
|
||||||
|
t.Fatalf("expected normalized config save to use non-forced reload")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/config?mode=normalized", strings.NewReader(`{"core":{"gateway":{"host":"127.0.0.1","port":18790},"tools":{"shell_enabled":false,"mcp_enabled":true}},"runtime":{"router":{"enabled":true,"strategy":"rules_first","max_hops":2,"default_timeout_sec":90},"providers":{"openai":{"api_base":"https://api.openai.com/v1","auth":"bearer","timeout_sec":150}}}}`))
|
req := httptest.NewRequest(http.MethodPost, "/api/config?mode=normalized", strings.NewReader(`{"core":{"gateway":{"host":"127.0.0.1","port":18790},"tools":{"shell_enabled":false,"mcp_enabled":true}},"runtime":{"router":{"enabled":true,"strategy":"rules_first","max_hops":2,"default_timeout_sec":90},"providers":{"openai":{"api_base":"https://api.openai.com/v1","auth":"bearer","timeout_sec":150}}}}`))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -249,6 +257,98 @@ func TestHandleWebUISessionsHidesInternalSessionsByDefault(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveProviderConfigForcesRuntimeReload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tmp := t.TempDir()
|
||||||
|
cfgPath := filepath.Join(tmp, "config.json")
|
||||||
|
cfg := cfgpkg.DefaultConfig()
|
||||||
|
cfg.Logging.Enabled = false
|
||||||
|
cfg.Models.Providers["openai"] = cfgpkg.ProviderConfig{
|
||||||
|
APIBase: "https://api.openai.com/v1",
|
||||||
|
Auth: "oauth",
|
||||||
|
Models: []string{"gpt-5"},
|
||||||
|
TimeoutSec: 120,
|
||||||
|
OAuth: cfgpkg.ProviderOAuthConfig{
|
||||||
|
Provider: "codex",
|
||||||
|
CredentialFile: filepath.Join(tmp, "auth.json"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil {
|
||||||
|
t.Fatalf("save config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := NewServer("127.0.0.1", 0, "")
|
||||||
|
srv.SetConfigPath(cfgPath)
|
||||||
|
|
||||||
|
forced := false
|
||||||
|
srv.SetConfigAfterHook(func(forceRuntimeReload bool) error {
|
||||||
|
forced = forceRuntimeReload
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
pc := cfg.Models.Providers["openai"]
|
||||||
|
if err := srv.saveProviderConfig(cfg, "openai", pc); err != nil {
|
||||||
|
t.Fatalf("save provider config: %v", err)
|
||||||
|
}
|
||||||
|
if !forced {
|
||||||
|
t.Fatalf("expected provider config save to force runtime reload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tmp := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("# long-term\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write workspace memory: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Join(tmp, "memory"), 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir memory dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(tmp, "memory", "2026-03-19.md"), []byte("daily\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write daily memory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := NewServer("127.0.0.1", 0, "")
|
||||||
|
srv.SetWorkspacePath(tmp)
|
||||||
|
|
||||||
|
listReq := httptest.NewRequest(http.MethodGet, "/api/memory", nil)
|
||||||
|
listRec := httptest.NewRecorder()
|
||||||
|
srv.handleWebUIMemory(listRec, listReq)
|
||||||
|
if listRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", listRec.Code, listRec.Body.String())
|
||||||
|
}
|
||||||
|
var listPayload struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
Files []string `json:"files"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(listRec.Body.Bytes(), &listPayload); err != nil {
|
||||||
|
t.Fatalf("decode list payload: %v", err)
|
||||||
|
}
|
||||||
|
if len(listPayload.Files) < 2 || listPayload.Files[0] != "MEMORY.md" {
|
||||||
|
t.Fatalf("expected MEMORY.md in memory file list, got %+v", listPayload.Files)
|
||||||
|
}
|
||||||
|
|
||||||
|
readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=MEMORY.md", nil)
|
||||||
|
readRec := httptest.NewRecorder()
|
||||||
|
srv.handleWebUIMemory(readRec, readReq)
|
||||||
|
if readRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", readRec.Code, readRec.Body.String())
|
||||||
|
}
|
||||||
|
var readPayload struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil {
|
||||||
|
t.Fatalf("decode read payload: %v", err)
|
||||||
|
}
|
||||||
|
if readPayload.Path != "MEMORY.md" || readPayload.Content != "# long-term\n" {
|
||||||
|
t.Fatalf("unexpected memory payload: %+v", readPayload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleWebUIChatLive(t *testing.T) {
|
func TestHandleWebUIChatLive(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user