diff --git a/README.md b/README.md index fa04fd4..f3b518f 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,8 @@ WebUI: http://:/?token= ``` +Gateway API docs: `docs/API.md` + ## 架构概览 默认协作模式: diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index 5afda3a..31b713e 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -2,28 +2,17 @@ package main import ( "context" - "crypto/sha256" "fmt" - "io" "net/http" "os" - "os/exec" "os/signal" "path/filepath" - "reflect" - "runtime" "strings" - "sync" "time" - "github.com/YspCoder/clawgo/pkg/agent" "github.com/YspCoder/clawgo/pkg/api" "github.com/YspCoder/clawgo/pkg/bus" - "github.com/YspCoder/clawgo/pkg/channels" - "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/cron" - "github.com/YspCoder/clawgo/pkg/heartbeat" - "github.com/YspCoder/clawgo/pkg/logger" "github.com/YspCoder/clawgo/pkg/providers" "github.com/YspCoder/clawgo/pkg/runtimecfg" "github.com/YspCoder/clawgo/pkg/sentinel" @@ -70,14 +59,6 @@ func gatewayCmd() { return dispatchCronJob(msgBus, job), nil }) configureCronServiceRuntime(cronService, cfg) - heartbeatService := buildHeartbeatService(cfg, msgBus) - sentinelService := sentinel.NewService( - getConfigPath(), - cfg.WorkspacePath(), - cfg.Sentinel.IntervalSec, - cfg.Sentinel.AutoHeal, - buildSentinelAlertHandler(cfg, msgBus), - ) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -87,7 +68,21 @@ func gatewayCmd() { fmt.Printf("Error initializing gateway runtime: %v\n", err) os.Exit(1) } - sentinelService.SetManager(channelManager) + + state := &gatewayRuntimeState{ + cfg: cfg, + agentLoop: agentLoop, + channelManager: channelManager, + heartbeatService: buildHeartbeatService(cfg, msgBus), + sentinelService: sentinel.NewService( + getConfigPath(), + cfg.WorkspacePath(), + cfg.Sentinel.IntervalSec, + cfg.Sentinel.AutoHeal, + buildSentinelAlertHandler(cfg, msgBus), + ), + } + state.sentinelService.SetManager(state.channelManager) pidFile := filepath.Join(filepath.Dir(getConfigPath()), "gateway.pid") if err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0644); err != nil { @@ -96,14 +91,14 @@ func gatewayCmd() { defer os.Remove(pidFile) } - enabledChannels := channelManager.GetEnabledChannels() + enabledChannels := state.channelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf("Channels enabled: %s\n", enabledChannels) } else { fmt.Println("Warning: no channels enabled") } - fmt.Printf("Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) + fmt.Printf("Gateway started on %s:%d\n", state.cfg.Gateway.Host, state.cfg.Gateway.Port) fmt.Println("Press Ctrl+C to stop. Send SIGHUP to hot-reload config.") if err := cronService.Start(); err != nil { @@ -111,21 +106,22 @@ func gatewayCmd() { } fmt.Println("Cron service started") - if err := heartbeatService.Start(); err != nil { + if err := state.heartbeatService.Start(); err != nil { fmt.Printf("Error starting heartbeat service: %v\n", err) } fmt.Println("Heartbeat service started") - if cfg.Sentinel.Enabled { - sentinelService.Start() + if state.cfg.Sentinel.Enabled { + state.sentinelService.Start() fmt.Println("Sentinel service started") } - registryServer := api.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, cfg.Gateway.Token) + registryServer := api.NewServer(state.cfg.Gateway.Host, state.cfg.Gateway.Port, state.cfg.Gateway.Token) registryServer.SetGatewayVersion(version) registryServer.SetConfigPath(getConfigPath()) - registryServer.SetToken(cfg.Gateway.Token) - registryServer.SetWorkspacePath(cfg.WorkspacePath()) - registryServer.SetLogFilePath(cfg.LogFilePath()) + registryServer.SetToken(state.cfg.Gateway.Token) + registryServer.SetWorkspacePath(state.cfg.WorkspacePath()) + registryServer.SetLogFilePath(state.cfg.LogFilePath()) + aistudioRelay := wsrelay.NewManager(wsrelay.Options{ Path: "/v1/ws", ProviderFactory: func(r *http.Request) (string, error) { @@ -144,293 +140,28 @@ func gatewayCmd() { defer func() { _ = aistudioRelay.Stop(context.Background()) }() providers.SetAIStudioRelayManager(aistudioRelay) registryServer.SetProtectedRoute(aistudioRelay.Path(), aistudioRelay.Handler()) - bindAgentLoopHandlers := func(loop *agent.AgentLoop) { - registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) { - if strings.TrimSpace(content) == "" { - return "", nil - } - return loop.ProcessDirect(cctx, content, sessionKey) - }) - registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} { - h := loop.GetSessionHistory(sessionKey) - out := make([]map[string]interface{}, 0, len(h)) - for _, m := range h { - entry := map[string]interface{}{"role": m.Role, "content": m.Content} - if strings.TrimSpace(m.ToolCallID) != "" { - entry["tool_call_id"] = m.ToolCallID - } - if len(m.ToolCalls) > 0 { - entry["tool_calls"] = m.ToolCalls - } - out = append(out, entry) - } - return out - }) - registryServer.SetToolsCatalogHandler(func() interface{} { - return loop.GetToolCatalog() - }) - } - bindAgentLoopHandlers(agentLoop) - var reloadMu sync.Mutex - triggerReload := func(source string, forceRuntimeReload bool) error { - reloadMu.Lock() - defer reloadMu.Unlock() - fmt.Printf("\nReloading config (source=%s)...\n", strings.TrimSpace(source)) - newCfg, err := config.LoadConfig(getConfigPath()) - if err != nil { - return fmt.Errorf("load config: %w", err) - } - if strings.EqualFold(strings.TrimSpace(os.Getenv(envRootGranted)), "1") || strings.EqualFold(strings.TrimSpace(os.Getenv(envRootGranted)), "true") { - applyMaximumPermissionPolicy(newCfg) - } - configureCronServiceRuntime(cronService, newCfg) - heartbeatService.Stop() - heartbeatService = buildHeartbeatService(newCfg, msgBus) - if err := heartbeatService.Start(); err != nil { - fmt.Printf("Error starting heartbeat service: %v\n", err) - } - if !forceRuntimeReload && reflect.DeepEqual(cfg, newCfg) { - fmt.Println("Config unchanged, skip reload") - return nil - } - - if cfg.Gateway.Host != newCfg.Gateway.Host || cfg.Gateway.Port != newCfg.Gateway.Port { - fmt.Printf("Warning: gateway host/port change detected (%s:%d -> %s:%d); restart required to rebind listener\n", - cfg.Gateway.Host, cfg.Gateway.Port, newCfg.Gateway.Host, newCfg.Gateway.Port) - } - - runtimeSame := reflect.DeepEqual(cfg.Agents, newCfg.Agents) && - reflect.DeepEqual(cfg.Models, newCfg.Models) && - reflect.DeepEqual(cfg.Tools, newCfg.Tools) && - reflect.DeepEqual(cfg.Channels, newCfg.Channels) - - if runtimeSame && !forceRuntimeReload { - configureLogging(newCfg) - sentinelService.Stop() - sentinelService = sentinel.NewService( - getConfigPath(), - newCfg.WorkspacePath(), - newCfg.Sentinel.IntervalSec, - newCfg.Sentinel.AutoHeal, - buildSentinelAlertHandler(newCfg, msgBus), - ) - if newCfg.Sentinel.Enabled { - sentinelService.SetManager(channelManager) - sentinelService.Start() - } - cfg = newCfg - runtimecfg.Set(cfg) - registryServer.SetToken(cfg.Gateway.Token) - registryServer.SetWorkspacePath(cfg.WorkspacePath()) - registryServer.SetLogFilePath(cfg.LogFilePath()) - fmt.Println("Config hot-reload applied (logging/metadata only)") - return nil - } - - newAgentLoop, newChannelManager, err := buildGatewayRuntime(ctx, newCfg, msgBus, cronService) - if err != nil { - return fmt.Errorf("init runtime: %w", err) - } - - channelManager.StopAll(ctx) - agentLoop.Stop() - channelManager = newChannelManager - agentLoop = newAgentLoop - cfg = newCfg - runtimecfg.Set(cfg) - bindAgentLoopHandlers(agentLoop) - configureLogging(newCfg) - registryServer.SetToken(cfg.Gateway.Token) - registryServer.SetWorkspacePath(cfg.WorkspacePath()) - registryServer.SetLogFilePath(cfg.LogFilePath()) - if rawWeixin, ok := channelManager.GetChannel("weixin"); ok { - if weixinChannel, ok := rawWeixin.(*channels.WeixinChannel); ok { - weixinChannel.SetConfigPath(getConfigPath()) - registryServer.SetWeixinChannel(weixinChannel) - } - } else { - registryServer.SetWeixinChannel(nil) - } - sentinelService.Stop() - sentinelService = sentinel.NewService( - getConfigPath(), - newCfg.WorkspacePath(), - newCfg.Sentinel.IntervalSec, - newCfg.Sentinel.AutoHeal, - buildSentinelAlertHandler(newCfg, msgBus), - ) - if newCfg.Sentinel.Enabled { - sentinelService.Start() - } - sentinelService.SetManager(channelManager) - - if err := channelManager.StartAll(ctx); err != nil { - return fmt.Errorf("start channels: %w", err) - } - go agentLoop.Run(ctx) - fmt.Println("Config hot-reload applied") - return nil - } + bindAgentLoopHandlers(registryServer, state.agentLoop) + triggerReload := newGatewayReloadTrigger(ctx, state, msgBus, cronService, registryServer) registryServer.SetConfigAfterHook(func(forceRuntimeReload bool) error { return triggerReload("api", forceRuntimeReload) }) - if rawWeixin, ok := channelManager.GetChannel("weixin"); ok { - if weixinChannel, ok := rawWeixin.(*channels.WeixinChannel); ok { - weixinChannel.SetConfigPath(getConfigPath()) - registryServer.SetWeixinChannel(weixinChannel) - } - } - registryServer.SetCronHandler(func(action string, args map[string]interface{}) (interface{}, error) { - getStr := func(k string) string { - v, _ := args[k].(string) - return strings.TrimSpace(v) - } - getBoolPtr := func(k string) *bool { - v, ok := args[k].(bool) - if !ok { - return nil - } - vv := v - return &vv - } - switch strings.ToLower(strings.TrimSpace(action)) { - case "", "list": - return cronService.ListJobs(true), nil - case "get": - id := getStr("id") - if id == "" { - return nil, fmt.Errorf("id required") - } - j := cronService.GetJob(id) - if j == nil { - return nil, fmt.Errorf("job not found: %s", id) - } - return j, nil - case "create": - name := getStr("name") - if name == "" { - name = "webui-cron" - } - msg := getStr("message") - if msg == "" { - return nil, fmt.Errorf("message required") - } - schedule := cron.CronSchedule{} - if expr := getStr("expr"); expr != "" { - schedule.Expr = expr - } else { - // Backward compatibility for older clients. - kind := strings.ToLower(getStr("kind")) - switch kind { - case "every": - everyMS, ok := args["everyMs"].(float64) - if !ok || int64(everyMS) <= 0 { - return nil, fmt.Errorf("expr required") - } - ev := int64(everyMS) - schedule.Kind = "every" - schedule.EveryMS = &ev - case "once", "at": - atMS, ok := args["atMs"].(float64) - var at int64 - if !ok || int64(atMS) <= 0 { - at = time.Now().Add(1 * time.Minute).UnixMilli() - } else { - at = int64(atMS) - } - schedule.Kind = "at" - schedule.AtMS = &at - default: - return nil, fmt.Errorf("expr required") - } - } - deliver := false - if v, ok := args["deliver"].(bool); ok { - deliver = v - } - return cronService.AddJob(name, schedule, msg, deliver, getStr("channel"), getStr("to")) - case "update": - id := getStr("id") - if id == "" { - return nil, fmt.Errorf("id required") - } - in := cron.UpdateJobInput{} - if v := getStr("name"); v != "" { - in.Name = &v - } - if v := getStr("message"); v != "" { - in.Message = &v - } - if p := getBoolPtr("enabled"); p != nil { - in.Enabled = p - } - if p := getBoolPtr("deliver"); p != nil { - in.Deliver = p - } - if v := getStr("channel"); v != "" { - in.Channel = &v - } - if v := getStr("to"); v != "" { - in.To = &v - } - if expr := getStr("expr"); expr != "" { - s := cron.CronSchedule{Expr: expr} - in.Schedule = &s - } else if kind := strings.ToLower(getStr("kind")); kind != "" { - // Backward compatibility for older clients. - s := cron.CronSchedule{Kind: kind} - switch kind { - case "every": - if everyMS, ok := args["everyMs"].(float64); ok && int64(everyMS) > 0 { - ev := int64(everyMS) - s.EveryMS = &ev - } else { - return nil, fmt.Errorf("expr required") - } - case "once", "at": - s.Kind = "at" - if atMS, ok := args["atMs"].(float64); ok && int64(atMS) > 0 { - at := int64(atMS) - s.AtMS = &at - } else { - at := time.Now().Add(1 * time.Minute).UnixMilli() - s.AtMS = &at - } - default: - return nil, fmt.Errorf("expr required") - } - in.Schedule = &s - } - return cronService.UpdateJob(id, in) - case "delete": - id := getStr("id") - return map[string]interface{}{"deleted": cronService.RemoveJob(id), "id": id}, nil - case "enable": - id := getStr("id") - j := cronService.EnableJob(id, true) - return map[string]interface{}{"ok": j != nil, "id": id}, nil - case "disable": - id := getStr("id") - j := cronService.EnableJob(id, false) - return map[string]interface{}{"ok": j != nil, "id": id}, nil - default: - return nil, fmt.Errorf("unsupported cron action: %s", action) - } - }) + (&gatewayReloader{state: state, registryServer: registryServer}).bindWeixinChannel() + bindCronHandler(registryServer, cronService) + if err := registryServer.Start(ctx); err != nil { fmt.Printf("Error starting gateway server: %v\n", err) } else { - fmt.Printf("Gateway server started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) + fmt.Printf("Gateway server started on %s:%d\n", state.cfg.Gateway.Host, state.cfg.Gateway.Port) } - if err := channelManager.StartAll(ctx); err != nil { + if err := state.channelManager.StartAll(ctx); err != nil { fmt.Printf("Error starting channels: %v\n", err) } - go agentLoop.Run(ctx) - go runGatewayStartupCompactionCheck(ctx, agentLoop) - go runGatewayBootstrapInit(ctx, cfg, agentLoop) + go state.agentLoop.Run(ctx) + go runGatewayStartupCompactionCheck(ctx, state.agentLoop) + go runGatewayBootstrapInit(ctx, state.cfg, state.agentLoop) stopConfigWatcher := startGatewayConfigWatcher(ctx, getConfigPath(), 500*time.Millisecond, 250*time.Millisecond, func() error { return triggerReload("watcher", false) @@ -452,731 +183,14 @@ func gatewayCmd() { default: fmt.Println("\nShutting down...") cancel() - heartbeatService.Stop() - sentinelService.Stop() + state.heartbeatService.Stop() + state.sentinelService.Stop() cronService.Stop() - agentLoop.Stop() - channelManager.StopAll(ctx) + state.agentLoop.Stop() + state.channelManager.StopAll(ctx) fmt.Println("Gateway stopped") return } } } } - -func runGatewayStartupCompactionCheck(parent context.Context, agentLoop *agent.AgentLoop) { - if agentLoop == nil { - return - } - - checkCtx, cancel := context.WithTimeout(parent, 10*time.Minute) - defer cancel() - - report := agentLoop.RunStartupSelfCheckAllSessions(checkCtx) - logger.InfoCF("gateway", logger.C0110, map[string]interface{}{ - "sessions_total": report.TotalSessions, - "sessions_compacted": report.CompactedSessions, - }) -} - -func runGatewayBootstrapInit(parent context.Context, cfg *config.Config, agentLoop *agent.AgentLoop) { - if agentLoop == nil || cfg == nil { - return - } - workspace := cfg.WorkspacePath() - bootstrapPath := filepath.Join(workspace, "BOOTSTRAP.md") - if _, err := os.Stat(bootstrapPath); err != nil { - return - } - memDir := filepath.Join(workspace, "memory") - _ = os.MkdirAll(memDir, 0755) - markerPath := filepath.Join(memDir, "bootstrap.init.done") - if _, err := os.Stat(markerPath); err == nil { - return - } - - initCtx, cancel := context.WithTimeout(parent, 90*time.Second) - defer cancel() - prompt := "System startup bootstrap: read BOOTSTRAP.md and perform one-time self-initialization checks now. If already initialized, return concise status only." - resp, err := agentLoop.ProcessDirect(initCtx, prompt, "system:bootstrap:init") - if err != nil { - logger.ErrorCF("gateway", logger.C0111, map[string]interface{}{logger.FieldError: err.Error()}) - return - } - line := fmt.Sprintf("%s\n%s\n", time.Now().UTC().Format(time.RFC3339), strings.TrimSpace(resp)) - if err := os.WriteFile(markerPath, []byte(line), 0644); err != nil { - logger.ErrorCF("gateway", logger.C0112, map[string]interface{}{logger.FieldError: err.Error()}) - return - } - // Bootstrap only runs once. After successful initialization marker is written, - // remove BOOTSTRAP.md to avoid repeated first-run guidance. - if err := os.Remove(bootstrapPath); err != nil && !os.IsNotExist(err) { - logger.WarnCF("gateway", logger.C0113, map[string]interface{}{logger.FieldError: err.Error()}) - } - logger.InfoC("gateway", logger.C0114) -} - -type configFileFingerprint struct { - Size int64 - ModUnixNano int64 - SHA256 [32]byte -} - -func readConfigFileFingerprint(path string) (configFileFingerprint, error) { - info, err := os.Stat(path) - if err != nil { - return configFileFingerprint{}, err - } - content, err := os.ReadFile(path) - if err != nil { - return configFileFingerprint{}, err - } - return configFileFingerprint{ - Size: info.Size(), - ModUnixNano: info.ModTime().UnixNano(), - SHA256: sha256.Sum256(content), - }, nil -} - -func (f configFileFingerprint) sameContent(other configFileFingerprint) bool { - return f.Size == other.Size && f.SHA256 == other.SHA256 -} - -func startGatewayConfigWatcher(ctx context.Context, configPath string, debounce, pollInterval time.Duration, onContentChanged func() error) func() { - if debounce <= 0 { - debounce = 500 * time.Millisecond - } - if pollInterval <= 0 { - pollInterval = 250 * time.Millisecond - } - done := make(chan struct{}) - go func() { - defer close(done) - ticker := time.NewTicker(pollInterval) - defer ticker.Stop() - - last, err := readConfigFileFingerprint(configPath) - haveLast := err == nil - pending := false - lastDetectedAt := time.Time{} - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - current, err := readConfigFileFingerprint(configPath) - if err != nil { - continue - } - if !haveLast { - last = current - haveLast = true - continue - } - if !current.sameContent(last) { - last = current - pending = true - lastDetectedAt = time.Now() - continue - } - if pending && !lastDetectedAt.IsZero() && time.Since(lastDetectedAt) >= debounce { - pending = false - if onContentChanged != nil { - if err := onContentChanged(); err != nil { - fmt.Printf("Config watcher reload failed: %v\n", err) - } - } - } - } - } - }() - return func() { - select { - case <-done: - case <-time.After(2 * time.Second): - } - } -} - -func applyMaximumPermissionPolicy(cfg *config.Config) { - cfg.Tools.Shell.Enabled = true - cfg.Tools.Shell.Sandbox.Enabled = false -} - -func gatewayInstallServiceCmd() error { - switch runtime.GOOS { - case "darwin": - return gatewayInstallLaunchdService() - case "windows": - return gatewayInstallWindowsTask() - } - scope, unitPath, err := detectGatewayServiceScopeAndPath() - if err != nil { - return err - } - - exePath, err := os.Executable() - if err != nil { - return fmt.Errorf("resolve executable path failed: %w", err) - } - exePath, _ = filepath.Abs(exePath) - configPath := getConfigPath() - workDir := filepath.Dir(exePath) - - unitContent := buildGatewayUnitContent(scope, exePath, configPath, workDir) - if err := os.MkdirAll(filepath.Dir(unitPath), 0755); err != nil { - return fmt.Errorf("create service directory failed: %w", err) - } - if err := os.WriteFile(unitPath, []byte(unitContent), 0644); err != nil { - return fmt.Errorf("write service unit failed: %w", err) - } - - if err := runSystemctl(scope, "daemon-reload"); err != nil { - return err - } - if err := runSystemctl(scope, "enable", gatewayServiceName); err != nil { - return err - } - - fmt.Printf("Gateway service registered: %s (%s)\n", gatewayServiceName, scope) - fmt.Printf(" Unit file: %s\n", unitPath) - fmt.Println(" Start service: clawgo gateway start") - fmt.Println(" Restart service: clawgo gateway restart") - fmt.Println(" Stop service: clawgo gateway stop") - return nil -} - -func gatewayServiceControlCmd(action string) error { - switch runtime.GOOS { - case "darwin": - return gatewayLaunchdControl(action) - case "windows": - return gatewayWindowsTaskControl(action) - } - scope, _, err := detectInstalledGatewayService() - if err != nil { - return err - } - return runSystemctl(scope, action, gatewayServiceName) -} - -func gatewayScopePreference() string { - v := strings.ToLower(strings.TrimSpace(os.Getenv("CLAWGO_GATEWAY_SCOPE"))) - if v == "user" || v == "system" { - return v - } - return "" -} - -func detectGatewayServiceScopeAndPath() (string, string, error) { - switch runtime.GOOS { - case "linux": - default: - return "", "", fmt.Errorf("unsupported service manager for %s", runtime.GOOS) - } - switch gatewayScopePreference() { - case "user": - return userGatewayUnitPath() - case "system": - return "system", "/etc/systemd/system/" + gatewayServiceName, nil - } - if os.Geteuid() == 0 { - return "system", "/etc/systemd/system/" + gatewayServiceName, nil - } - return userGatewayUnitPath() -} - -func userGatewayUnitPath() (string, string, error) { - home, err := os.UserHomeDir() - if err != nil { - return "", "", fmt.Errorf("resolve user home failed: %w", err) - } - return "user", filepath.Join(home, ".config", "systemd", "user", gatewayServiceName), nil -} - -func detectInstalledGatewayService() (string, string, error) { - switch runtime.GOOS { - case "darwin": - return detectInstalledLaunchdService() - case "windows": - return detectInstalledWindowsTask() - } - systemPath := "/etc/systemd/system/" + gatewayServiceName - userScope, userPath, err := userGatewayUnitPath() - if err != nil { - return "", "", err - } - - systemExists := false - if info, err := os.Stat(systemPath); err == nil && !info.IsDir() { - systemExists = true - } - - userExists := false - if info, err := os.Stat(userPath); err == nil && !info.IsDir() { - userExists = true - } - - preferredScope := gatewayScopePreference() - switch preferredScope { - case "system": - if systemExists { - return "system", systemPath, nil - } - return "", "", fmt.Errorf("gateway service unit not found in system scope: %s", systemPath) - case "user": - if userExists { - return userScope, userPath, nil - } - return "", "", fmt.Errorf("gateway service unit not found in user scope: %s", userPath) - } - - // Auto-pick scope by current privilege to avoid non-root users accidentally - // selecting system scope when both unit files exist. - if os.Geteuid() == 0 { - if systemExists { - return "system", systemPath, nil - } - if userExists { - return userScope, userPath, nil - } - } else { - if userExists { - return userScope, userPath, nil - } - if systemExists { - return "system", systemPath, nil - } - } - - return "", "", fmt.Errorf("gateway service not registered. Run: clawgo gateway") -} - -func buildGatewayUnitContent(scope, exePath, configPath, workDir string) string { - quotedExec := fmt.Sprintf("%q gateway run --config %q", exePath, configPath) - installTarget := "default.target" - if scope == "system" { - installTarget = "multi-user.target" - } - home, err := os.UserHomeDir() - if err != nil { - home = filepath.Dir(configPath) - } - - return fmt.Sprintf(`[Unit] -Description=ClawGo Gateway -After=network.target - -[Service] -Type=simple -WorkingDirectory=%s -ExecStart=%s -Restart=always -RestartSec=3 -Environment=CLAWGO_CONFIG=%s -Environment=HOME=%s - -[Install] -WantedBy=%s -`, workDir, quotedExec, configPath, home, installTarget) -} - -func runSystemctl(scope string, args ...string) error { - cmdArgs := make([]string, 0, len(args)+1) - if scope == "user" { - cmdArgs = append(cmdArgs, "--user") - } - cmdArgs = append(cmdArgs, args...) - - cmd := exec.Command("systemctl", cmdArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - if scope == "user" { - return fmt.Errorf("systemctl --user %s failed: %w", strings.Join(args, " "), err) - } - return fmt.Errorf("systemctl %s failed: %w", strings.Join(args, " "), err) - } - return nil -} - -func gatewayLaunchdLabel() string { return "ai.clawgo.gateway" } - -func gatewayWindowsTaskName() string { return "ClawGo Gateway" } - -func detectLaunchdScopeAndPath() (string, string, error) { - label := gatewayLaunchdLabel() + ".plist" - switch gatewayScopePreference() { - case "system": - return "system", filepath.Join("/Library/LaunchDaemons", label), nil - case "user": - home, err := os.UserHomeDir() - if err != nil { - return "", "", fmt.Errorf("resolve user home failed: %w", err) - } - return "user", filepath.Join(home, "Library", "LaunchAgents", label), nil - } - if os.Geteuid() == 0 { - return "system", filepath.Join("/Library/LaunchDaemons", label), nil - } - home, err := os.UserHomeDir() - if err != nil { - return "", "", fmt.Errorf("resolve user home failed: %w", err) - } - return "user", filepath.Join(home, "Library", "LaunchAgents", label), nil -} - -func detectInstalledLaunchdService() (string, string, error) { - userScope, userPath, err := detectLaunchdScopeAndPath() - if err != nil && gatewayScopePreference() == "user" { - return "", "", err - } - systemPath := filepath.Join("/Library/LaunchDaemons", gatewayLaunchdLabel()+".plist") - systemExists := fileExists(systemPath) - userExists := fileExists(userPath) - - switch gatewayScopePreference() { - case "system": - if systemExists { - return "system", systemPath, nil - } - return "", "", fmt.Errorf("launchd plist not found in system scope: %s", systemPath) - case "user": - if userExists { - return userScope, userPath, nil - } - return "", "", fmt.Errorf("launchd plist not found in user scope: %s", userPath) - } - - if os.Geteuid() == 0 { - if systemExists { - return "system", systemPath, nil - } - if userExists { - return userScope, userPath, nil - } - } else { - if userExists { - return userScope, userPath, nil - } - if systemExists { - return "system", systemPath, nil - } - } - return "", "", fmt.Errorf("gateway service not registered. Run: clawgo gateway") -} - -func gatewayInstallLaunchdService() error { - scope, plistPath, err := detectLaunchdScopeAndPath() - if err != nil { - return err - } - exePath, err := os.Executable() - if err != nil { - return fmt.Errorf("resolve executable path failed: %w", err) - } - exePath, _ = filepath.Abs(exePath) - configPath := getConfigPath() - workDir := filepath.Dir(exePath) - if err := os.MkdirAll(filepath.Dir(plistPath), 0755); err != nil { - return fmt.Errorf("create launchd directory failed: %w", err) - } - content := buildGatewayLaunchdPlist(exePath, configPath, workDir) - if err := os.WriteFile(plistPath, []byte(content), 0644); err != nil { - return fmt.Errorf("write launchd plist failed: %w", err) - } - _ = runLaunchctl(scope, "bootout", launchdDomainTarget(scope), plistPath) - if err := runLaunchctl(scope, "bootstrap", launchdDomainTarget(scope), plistPath); err != nil { - return err - } - if err := runLaunchctl(scope, "kickstart", "-k", launchdServiceTarget(scope)); err != nil { - return err - } - fmt.Printf("✓ Gateway service registered: %s (%s)\n", gatewayLaunchdLabel(), scope) - fmt.Printf(" Launchd plist: %s\n", plistPath) - fmt.Println(" Start service: clawgo gateway start") - fmt.Println(" Restart service: clawgo gateway restart") - fmt.Println(" Stop service: clawgo gateway stop") - return nil -} - -func gatewayLaunchdControl(action string) error { - scope, plistPath, err := detectInstalledLaunchdService() - if err != nil { - return err - } - switch action { - case "start": - _ = runLaunchctl(scope, "bootstrap", launchdDomainTarget(scope), plistPath) - return runLaunchctl(scope, "kickstart", "-k", launchdServiceTarget(scope)) - case "stop": - return runLaunchctl(scope, "bootout", launchdDomainTarget(scope), plistPath) - case "restart": - _ = runLaunchctl(scope, "bootout", launchdDomainTarget(scope), plistPath) - if err := runLaunchctl(scope, "bootstrap", launchdDomainTarget(scope), plistPath); err != nil { - return err - } - return runLaunchctl(scope, "kickstart", "-k", launchdServiceTarget(scope)) - case "status": - return runLaunchctl(scope, "print", launchdServiceTarget(scope)) - default: - return fmt.Errorf("unsupported action: %s", action) - } -} - -func buildGatewayLaunchdPlist(exePath, configPath, workDir string) string { - return fmt.Sprintf(` - - - - Label - %s - ProgramArguments - - %s - gateway - run - --config - %s - - WorkingDirectory - %s - RunAtLoad - - KeepAlive - - StandardOutPath - %s - StandardErrorPath - %s - - -`, gatewayLaunchdLabel(), exePath, configPath, workDir, filepath.Join(filepath.Dir(configPath), "gateway.launchd.out.log"), filepath.Join(filepath.Dir(configPath), "gateway.launchd.err.log")) -} - -func launchdDomainTarget(scope string) string { - if scope == "system" { - return "system" - } - return fmt.Sprintf("gui/%d", os.Getuid()) -} - -func launchdServiceTarget(scope string) string { - return launchdDomainTarget(scope) + "/" + gatewayLaunchdLabel() -} - -func runLaunchctl(scope string, args ...string) error { - cmd := exec.Command("launchctl", args...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("launchctl %s failed: %w", strings.Join(args, " "), err) - } - return nil -} - -func gatewayInstallWindowsTask() error { - exePath, err := os.Executable() - if err != nil { - return fmt.Errorf("resolve executable path failed: %w", err) - } - exePath, _ = filepath.Abs(exePath) - configPath := getConfigPath() - taskName := gatewayWindowsTaskName() - command := fmt.Sprintf(`"%s" gateway run --config "%s"`, exePath, configPath) - _ = runSCHTASKS("/Delete", "/TN", taskName, "/F") - if err := runSCHTASKS("/Create", "/TN", taskName, "/SC", "ONLOGON", "/TR", command, "/F"); err != nil { - return err - } - fmt.Printf("✓ Gateway service registered: %s (windows task)\n", taskName) - fmt.Println(" Start service: clawgo gateway start") - fmt.Println(" Restart service: clawgo gateway restart") - fmt.Println(" Stop service: clawgo gateway stop") - return nil -} - -func gatewayWindowsTaskControl(action string) error { - _, _, err := detectInstalledWindowsTask() - if err != nil { - return err - } - taskName := gatewayWindowsTaskName() - switch action { - case "start": - return runSCHTASKS("/Run", "/TN", taskName) - case "stop": - return stopGatewayProcessByPIDFile() - case "restart": - _ = stopGatewayProcessByPIDFile() - return runSCHTASKS("/Run", "/TN", taskName) - case "status": - return runSCHTASKS("/Query", "/TN", taskName, "/V", "/FO", "LIST") - default: - return fmt.Errorf("unsupported action: %s", action) - } -} - -func detectInstalledWindowsTask() (string, string, error) { - taskName := gatewayWindowsTaskName() - if err := runSCHTASKSQuiet("/Query", "/TN", taskName); err != nil { - return "", "", fmt.Errorf("gateway service not registered. Run: clawgo gateway") - } - return "user", taskName, nil -} - -func runSCHTASKS(args ...string) error { - cmd := exec.Command("schtasks", args...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("schtasks %s failed: %w", strings.Join(args, " "), err) - } - return nil -} - -func runSCHTASKSQuiet(args ...string) error { - cmd := exec.Command("schtasks", args...) - cmd.Stdout = io.Discard - cmd.Stderr = io.Discard - return cmd.Run() -} - -func stopGatewayProcessByPIDFile() error { - pidPath := filepath.Join(filepath.Dir(getConfigPath()), "gateway.pid") - data, err := os.ReadFile(pidPath) - if err != nil { - return fmt.Errorf("gateway pid file not found: %w", err) - } - pid := strings.TrimSpace(string(data)) - if pid == "" { - return fmt.Errorf("gateway pid file is empty") - } - cmd := exec.Command("taskkill", "/PID", pid, "/T", "/F") - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("taskkill /PID %s failed: %w", pid, err) - } - return nil -} - -func fileExists(path string) bool { - info, err := os.Stat(path) - return err == nil && !info.IsDir() -} - -func buildGatewayRuntime(ctx context.Context, cfg *config.Config, msgBus *bus.MessageBus, cronService *cron.CronService) (*agent.AgentLoop, *channels.Manager, error) { - provider, err := providers.CreateProvider(cfg) - if err != nil { - return nil, nil, fmt.Errorf("create provider: %w", err) - } - - agentLoop := agent.NewAgentLoop(cfg, msgBus, provider, cronService) - agentLoop.SetConfigPath(getConfigPath()) - - startupInfo := agentLoop.GetStartupInfo() - toolsInfo := startupInfo["tools"].(map[string]interface{}) - skillsInfo := startupInfo["skills"].(map[string]interface{}) - fmt.Println("\nAgent Status:") - fmt.Printf(" 鈥?Tools: %d loaded\n", toolsInfo["count"]) - fmt.Printf(" 鈥?Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"]) - - logger.InfoCF("agent", logger.C0098, - map[string]interface{}{ - "tools_count": toolsInfo["count"], - "skills_total": skillsInfo["total"], - "skills_available": skillsInfo["available"], - }) - - channelManager, err := channels.NewManager(cfg, msgBus) - if err != nil { - return nil, nil, fmt.Errorf("create channel manager: %w", err) - } - - return agentLoop, channelManager, nil -} - -func normalizeCronTargetChatID(channel, chatID string) string { - ch := strings.ToLower(strings.TrimSpace(channel)) - target := strings.TrimSpace(chatID) - if ch == "" || target == "" { - return target - } - prefix := ch + ":" - if strings.HasPrefix(strings.ToLower(target), prefix) { - return strings.TrimSpace(target[len(prefix):]) - } - return target -} - -func dispatchCronJob(msgBus *bus.MessageBus, job *cron.CronJob) string { - if job == nil { - return "" - } - message := strings.TrimSpace(job.Payload.Message) - if message == "" { - return "" - } - targetChannel := strings.TrimSpace(job.Payload.Channel) - targetChatID := normalizeCronTargetChatID(targetChannel, job.Payload.To) - - if targetChannel != "" && targetChatID != "" { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: targetChannel, - ChatID: targetChatID, - Content: message, - }) - if job.Payload.Deliver { - return "delivered" - } - return "delivered_targeted" - } - - msgBus.PublishInbound(bus.InboundMessage{ - Channel: "system", - SenderID: "cron", - ChatID: "internal:cron", - Content: message, - SessionKey: fmt.Sprintf("cron:%s", job.ID), - Metadata: map[string]string{ - "trigger": "cron", - "job_id": job.ID, - }, - }) - return "scheduled" -} - -func configureCronServiceRuntime(cs *cron.CronService, cfg *config.Config) { - if cs == nil || cfg == nil { - return - } - cs.SetRuntimeOptions(cron.RuntimeOptions{ - RunLoopMinSleep: time.Duration(cfg.Cron.MinSleepSec) * time.Second, - RunLoopMaxSleep: time.Duration(cfg.Cron.MaxSleepSec) * time.Second, - RetryBackoffBase: time.Duration(cfg.Cron.RetryBackoffBaseSec) * time.Second, - RetryBackoffMax: time.Duration(cfg.Cron.RetryBackoffMaxSec) * time.Second, - MaxConsecutiveFailureRetries: int64(cfg.Cron.MaxConsecutiveFailureRetries), - MaxWorkers: cfg.Cron.MaxWorkers, - }) -} - -func buildHeartbeatService(cfg *config.Config, msgBus *bus.MessageBus) *heartbeat.HeartbeatService { - hbInterval := cfg.Agents.Defaults.Heartbeat.EverySec - if hbInterval <= 0 { - hbInterval = 30 * 60 - } - return heartbeat.NewHeartbeatService(cfg.WorkspacePath(), func(prompt string) (string, error) { - msgBus.PublishInbound(bus.InboundMessage{ - Channel: "system", - SenderID: "heartbeat", - ChatID: "internal:heartbeat", - Content: prompt, - SessionKey: "heartbeat:default", - Metadata: map[string]string{ - "trigger": "heartbeat", - }, - }) - return "queued", nil - }, hbInterval, cfg.Agents.Defaults.Heartbeat.Enabled, cfg.Agents.Defaults.Heartbeat.PromptTemplate) -} diff --git a/cmd/gateway_bootstrap.go b/cmd/gateway_bootstrap.go new file mode 100644 index 0000000..be6c076 --- /dev/null +++ b/cmd/gateway_bootstrap.go @@ -0,0 +1,71 @@ +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/YspCoder/clawgo/pkg/agent" + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/logger" +) + +func runGatewayStartupCompactionCheck(parent context.Context, agentLoop *agent.AgentLoop) { + if agentLoop == nil { + return + } + + checkCtx, cancel := context.WithTimeout(parent, 10*time.Minute) + defer cancel() + + report := agentLoop.RunStartupSelfCheckAllSessions(checkCtx) + logger.InfoCF("gateway", logger.C0110, map[string]interface{}{ + "sessions_total": report.TotalSessions, + "sessions_compacted": report.CompactedSessions, + }) +} + +func runGatewayBootstrapInit(parent context.Context, cfg *config.Config, agentLoop *agent.AgentLoop) { + if agentLoop == nil || cfg == nil { + return + } + workspace := cfg.WorkspacePath() + bootstrapPath := filepath.Join(workspace, "BOOTSTRAP.md") + if _, err := os.Stat(bootstrapPath); err != nil { + return + } + memDir := filepath.Join(workspace, "memory") + _ = os.MkdirAll(memDir, 0755) + markerPath := filepath.Join(memDir, "bootstrap.init.done") + if _, err := os.Stat(markerPath); err == nil { + return + } + + initCtx, cancel := context.WithTimeout(parent, 90*time.Second) + defer cancel() + prompt := "System startup bootstrap: read BOOTSTRAP.md and perform one-time self-initialization checks now. If already initialized, return concise status only." + resp, err := agentLoop.ProcessDirect(initCtx, prompt, "system:bootstrap:init") + if err != nil { + logger.ErrorCF("gateway", logger.C0111, map[string]interface{}{logger.FieldError: err.Error()}) + return + } + line := fmt.Sprintf("%s\n%s\n", time.Now().UTC().Format(time.RFC3339), strings.TrimSpace(resp)) + if err := os.WriteFile(markerPath, []byte(line), 0644); err != nil { + logger.ErrorCF("gateway", logger.C0112, map[string]interface{}{logger.FieldError: err.Error()}) + return + } + // Bootstrap only runs once. After successful initialization marker is written, + // remove BOOTSTRAP.md to avoid repeated first-run guidance. + if err := os.Remove(bootstrapPath); err != nil && !os.IsNotExist(err) { + logger.WarnCF("gateway", logger.C0113, map[string]interface{}{logger.FieldError: err.Error()}) + } + logger.InfoC("gateway", logger.C0114) +} + +func applyMaximumPermissionPolicy(cfg *config.Config) { + cfg.Tools.Shell.Enabled = true + cfg.Tools.Shell.Sandbox.Enabled = false +} diff --git a/cmd/gateway_reload.go b/cmd/gateway_reload.go new file mode 100644 index 0000000..2e13355 --- /dev/null +++ b/cmd/gateway_reload.go @@ -0,0 +1,230 @@ +package main + +import ( + "context" + "crypto/sha256" + "fmt" + "os" + "reflect" + "strings" + "sync" + "time" + + "github.com/YspCoder/clawgo/pkg/api" + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/channels" + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/cron" + "github.com/YspCoder/clawgo/pkg/runtimecfg" + "github.com/YspCoder/clawgo/pkg/sentinel" +) + +type gatewayReloader struct { + mu sync.Mutex + ctx context.Context + state *gatewayRuntimeState + msgBus *bus.MessageBus + cronService *cron.CronService + registryServer *api.Server +} + +func newGatewayReloadTrigger(ctx context.Context, state *gatewayRuntimeState, msgBus *bus.MessageBus, cronService *cron.CronService, registryServer *api.Server) func(string, bool) error { + reloader := &gatewayReloader{ + ctx: ctx, + state: state, + msgBus: msgBus, + cronService: cronService, + registryServer: registryServer, + } + return reloader.trigger +} + +func (r *gatewayReloader) trigger(source string, forceRuntimeReload bool) error { + r.mu.Lock() + defer r.mu.Unlock() + fmt.Printf("\nReloading config (source=%s)...\n", strings.TrimSpace(source)) + newCfg, err := config.LoadConfig(getConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + if strings.EqualFold(strings.TrimSpace(os.Getenv(envRootGranted)), "1") || strings.EqualFold(strings.TrimSpace(os.Getenv(envRootGranted)), "true") { + applyMaximumPermissionPolicy(newCfg) + } + configureCronServiceRuntime(r.cronService, newCfg) + r.state.heartbeatService.Stop() + r.state.heartbeatService = buildHeartbeatService(newCfg, r.msgBus) + if err := r.state.heartbeatService.Start(); err != nil { + fmt.Printf("Error starting heartbeat service: %v\n", err) + } + + if !forceRuntimeReload && reflect.DeepEqual(r.state.cfg, newCfg) { + fmt.Println("Config unchanged, skip reload") + return nil + } + + if r.state.cfg.Gateway.Host != newCfg.Gateway.Host || r.state.cfg.Gateway.Port != newCfg.Gateway.Port { + fmt.Printf("Warning: gateway host/port change detected (%s:%d -> %s:%d); restart required to rebind listener\n", + r.state.cfg.Gateway.Host, r.state.cfg.Gateway.Port, newCfg.Gateway.Host, newCfg.Gateway.Port) + } + + runtimeSame := reflect.DeepEqual(r.state.cfg.Agents, newCfg.Agents) && + reflect.DeepEqual(r.state.cfg.Models, newCfg.Models) && + reflect.DeepEqual(r.state.cfg.Tools, newCfg.Tools) && + reflect.DeepEqual(r.state.cfg.Channels, newCfg.Channels) + + if runtimeSame && !forceRuntimeReload { + configureLogging(newCfg) + r.state.sentinelService.Stop() + r.state.sentinelService = sentinel.NewService( + getConfigPath(), + newCfg.WorkspacePath(), + newCfg.Sentinel.IntervalSec, + newCfg.Sentinel.AutoHeal, + buildSentinelAlertHandler(newCfg, r.msgBus), + ) + if newCfg.Sentinel.Enabled { + r.state.sentinelService.SetManager(r.state.channelManager) + r.state.sentinelService.Start() + } + r.state.cfg = newCfg + runtimecfg.Set(r.state.cfg) + r.bindRegistryMetadata() + fmt.Println("Config hot-reload applied (logging/metadata only)") + return nil + } + + newAgentLoop, newChannelManager, err := buildGatewayRuntime(r.ctx, newCfg, r.msgBus, r.cronService) + if err != nil { + return fmt.Errorf("init runtime: %w", err) + } + + r.state.channelManager.StopAll(r.ctx) + r.state.agentLoop.Stop() + r.state.channelManager = newChannelManager + r.state.agentLoop = newAgentLoop + r.state.cfg = newCfg + runtimecfg.Set(r.state.cfg) + bindAgentLoopHandlers(r.registryServer, r.state.agentLoop) + configureLogging(newCfg) + r.bindRegistryMetadata() + r.bindWeixinChannel() + r.state.sentinelService.Stop() + r.state.sentinelService = sentinel.NewService( + getConfigPath(), + newCfg.WorkspacePath(), + newCfg.Sentinel.IntervalSec, + newCfg.Sentinel.AutoHeal, + buildSentinelAlertHandler(newCfg, r.msgBus), + ) + if newCfg.Sentinel.Enabled { + r.state.sentinelService.Start() + } + r.state.sentinelService.SetManager(r.state.channelManager) + + if err := r.state.channelManager.StartAll(r.ctx); err != nil { + return fmt.Errorf("start channels: %w", err) + } + go r.state.agentLoop.Run(r.ctx) + fmt.Println("Config hot-reload applied") + return nil +} + +func (r *gatewayReloader) bindRegistryMetadata() { + r.registryServer.SetToken(r.state.cfg.Gateway.Token) + r.registryServer.SetWorkspacePath(r.state.cfg.WorkspacePath()) + r.registryServer.SetLogFilePath(r.state.cfg.LogFilePath()) +} + +func (r *gatewayReloader) bindWeixinChannel() { + if rawWeixin, ok := r.state.channelManager.GetChannel("weixin"); ok { + if weixinChannel, ok := rawWeixin.(*channels.WeixinChannel); ok { + weixinChannel.SetConfigPath(getConfigPath()) + r.registryServer.SetWeixinChannel(weixinChannel) + } + } else { + r.registryServer.SetWeixinChannel(nil) + } +} + +type configFileFingerprint struct { + Size int64 + ModUnixNano int64 + SHA256 [32]byte +} + +func readConfigFileFingerprint(path string) (configFileFingerprint, error) { + info, err := os.Stat(path) + if err != nil { + return configFileFingerprint{}, err + } + content, err := os.ReadFile(path) + if err != nil { + return configFileFingerprint{}, err + } + return configFileFingerprint{ + Size: info.Size(), + ModUnixNano: info.ModTime().UnixNano(), + SHA256: sha256.Sum256(content), + }, nil +} + +func (f configFileFingerprint) sameContent(other configFileFingerprint) bool { + return f.Size == other.Size && f.SHA256 == other.SHA256 +} + +func startGatewayConfigWatcher(ctx context.Context, configPath string, debounce, pollInterval time.Duration, onContentChanged func() error) func() { + if debounce <= 0 { + debounce = 500 * time.Millisecond + } + if pollInterval <= 0 { + pollInterval = 250 * time.Millisecond + } + done := make(chan struct{}) + go func() { + defer close(done) + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + last, err := readConfigFileFingerprint(configPath) + haveLast := err == nil + pending := false + lastDetectedAt := time.Time{} + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + current, err := readConfigFileFingerprint(configPath) + if err != nil { + continue + } + if !haveLast { + last = current + haveLast = true + continue + } + if !current.sameContent(last) { + last = current + pending = true + lastDetectedAt = time.Now() + continue + } + if pending && !lastDetectedAt.IsZero() && time.Since(lastDetectedAt) >= debounce { + pending = false + if onContentChanged != nil { + if err := onContentChanged(); err != nil { + fmt.Printf("Config watcher reload failed: %v\n", err) + } + } + } + } + } + }() + return func() { + select { + case <-done: + case <-time.After(2 * time.Second): + } + } +} diff --git a/cmd/gateway_runtime.go b/cmd/gateway_runtime.go new file mode 100644 index 0000000..d50baae --- /dev/null +++ b/cmd/gateway_runtime.go @@ -0,0 +1,309 @@ +package main + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/YspCoder/clawgo/pkg/agent" + "github.com/YspCoder/clawgo/pkg/api" + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/channels" + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/cron" + "github.com/YspCoder/clawgo/pkg/heartbeat" + "github.com/YspCoder/clawgo/pkg/logger" + "github.com/YspCoder/clawgo/pkg/providers" + "github.com/YspCoder/clawgo/pkg/sentinel" +) + +type gatewayRuntimeState struct { + cfg *config.Config + agentLoop *agent.AgentLoop + channelManager *channels.Manager + heartbeatService *heartbeat.HeartbeatService + sentinelService *sentinel.Service +} + +func buildGatewayRuntime(ctx context.Context, cfg *config.Config, msgBus *bus.MessageBus, cronService *cron.CronService) (*agent.AgentLoop, *channels.Manager, error) { + provider, err := providers.CreateProvider(cfg) + if err != nil { + return nil, nil, fmt.Errorf("create provider: %w", err) + } + + agentLoop := agent.NewAgentLoop(cfg, msgBus, provider, cronService) + agentLoop.SetConfigPath(getConfigPath()) + + startupInfo := agentLoop.GetStartupInfo() + toolsInfo := startupInfo["tools"].(map[string]interface{}) + skillsInfo := startupInfo["skills"].(map[string]interface{}) + fmt.Println("\nAgent Status:") + fmt.Printf(" - Tools: %d loaded\n", toolsInfo["count"]) + fmt.Printf(" - Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"]) + + logger.InfoCF("agent", logger.C0098, + map[string]interface{}{ + "tools_count": toolsInfo["count"], + "skills_total": skillsInfo["total"], + "skills_available": skillsInfo["available"], + }) + + channelManager, err := channels.NewManager(cfg, msgBus) + if err != nil { + return nil, nil, fmt.Errorf("create channel manager: %w", err) + } + + return agentLoop, channelManager, nil +} + +func bindAgentLoopHandlers(registryServer *api.Server, loop *agent.AgentLoop) { + registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) { + if strings.TrimSpace(content) == "" { + return "", nil + } + return loop.ProcessDirect(cctx, content, sessionKey) + }) + registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} { + h := loop.GetSessionHistory(sessionKey) + out := make([]map[string]interface{}, 0, len(h)) + for _, m := range h { + entry := map[string]interface{}{"role": m.Role, "content": m.Content} + if strings.TrimSpace(m.ToolCallID) != "" { + entry["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + entry["tool_calls"] = m.ToolCalls + } + out = append(out, entry) + } + return out + }) + registryServer.SetToolsCatalogHandler(func() interface{} { + return loop.GetToolCatalog() + }) +} + +func bindCronHandler(registryServer *api.Server, cronService *cron.CronService) { + registryServer.SetCronHandler(func(action string, args map[string]interface{}) (interface{}, error) { + getStr := func(k string) string { + v, _ := args[k].(string) + return strings.TrimSpace(v) + } + getBoolPtr := func(k string) *bool { + v, ok := args[k].(bool) + if !ok { + return nil + } + vv := v + return &vv + } + switch strings.ToLower(strings.TrimSpace(action)) { + case "", "list": + return cronService.ListJobs(true), nil + case "get": + id := getStr("id") + if id == "" { + return nil, fmt.Errorf("id required") + } + j := cronService.GetJob(id) + if j == nil { + return nil, fmt.Errorf("job not found: %s", id) + } + return j, nil + case "create": + name := getStr("name") + if name == "" { + name = "webui-cron" + } + msg := getStr("message") + if msg == "" { + return nil, fmt.Errorf("message required") + } + schedule := cron.CronSchedule{} + if expr := getStr("expr"); expr != "" { + schedule.Expr = expr + } else { + // Backward compatibility for older clients. + kind := strings.ToLower(getStr("kind")) + switch kind { + case "every": + everyMS, ok := args["everyMs"].(float64) + if !ok || int64(everyMS) <= 0 { + return nil, fmt.Errorf("expr required") + } + ev := int64(everyMS) + schedule.Kind = "every" + schedule.EveryMS = &ev + case "once", "at": + atMS, ok := args["atMs"].(float64) + var at int64 + if !ok || int64(atMS) <= 0 { + at = time.Now().Add(1 * time.Minute).UnixMilli() + } else { + at = int64(atMS) + } + schedule.Kind = "at" + schedule.AtMS = &at + default: + return nil, fmt.Errorf("expr required") + } + } + deliver := false + if v, ok := args["deliver"].(bool); ok { + deliver = v + } + return cronService.AddJob(name, schedule, msg, deliver, getStr("channel"), getStr("to")) + case "update": + id := getStr("id") + if id == "" { + return nil, fmt.Errorf("id required") + } + in := cron.UpdateJobInput{} + if v := getStr("name"); v != "" { + in.Name = &v + } + if v := getStr("message"); v != "" { + in.Message = &v + } + if p := getBoolPtr("enabled"); p != nil { + in.Enabled = p + } + if p := getBoolPtr("deliver"); p != nil { + in.Deliver = p + } + if v := getStr("channel"); v != "" { + in.Channel = &v + } + if v := getStr("to"); v != "" { + in.To = &v + } + if expr := getStr("expr"); expr != "" { + s := cron.CronSchedule{Expr: expr} + in.Schedule = &s + } else if kind := strings.ToLower(getStr("kind")); kind != "" { + // Backward compatibility for older clients. + s := cron.CronSchedule{Kind: kind} + switch kind { + case "every": + if everyMS, ok := args["everyMs"].(float64); ok && int64(everyMS) > 0 { + ev := int64(everyMS) + s.EveryMS = &ev + } else { + return nil, fmt.Errorf("expr required") + } + case "once", "at": + s.Kind = "at" + if atMS, ok := args["atMs"].(float64); ok && int64(atMS) > 0 { + at := int64(atMS) + s.AtMS = &at + } else { + at := time.Now().Add(1 * time.Minute).UnixMilli() + s.AtMS = &at + } + default: + return nil, fmt.Errorf("expr required") + } + in.Schedule = &s + } + return cronService.UpdateJob(id, in) + case "delete": + id := getStr("id") + return map[string]interface{}{"deleted": cronService.RemoveJob(id), "id": id}, nil + case "enable": + id := getStr("id") + j := cronService.EnableJob(id, true) + return map[string]interface{}{"ok": j != nil, "id": id}, nil + case "disable": + id := getStr("id") + j := cronService.EnableJob(id, false) + return map[string]interface{}{"ok": j != nil, "id": id}, nil + default: + return nil, fmt.Errorf("unsupported cron action: %s", action) + } + }) +} + +func normalizeCronTargetChatID(channel, chatID string) string { + ch := strings.ToLower(strings.TrimSpace(channel)) + target := strings.TrimSpace(chatID) + if ch == "" || target == "" { + return target + } + prefix := ch + ":" + if strings.HasPrefix(strings.ToLower(target), prefix) { + return strings.TrimSpace(target[len(prefix):]) + } + return target +} + +func dispatchCronJob(msgBus *bus.MessageBus, job *cron.CronJob) string { + if job == nil { + return "" + } + message := strings.TrimSpace(job.Payload.Message) + if message == "" { + return "" + } + targetChannel := strings.TrimSpace(job.Payload.Channel) + targetChatID := normalizeCronTargetChatID(targetChannel, job.Payload.To) + + if targetChannel != "" && targetChatID != "" { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: targetChannel, + ChatID: targetChatID, + Content: message, + }) + if job.Payload.Deliver { + return "delivered" + } + return "delivered_targeted" + } + + msgBus.PublishInbound(bus.InboundMessage{ + Channel: "system", + SenderID: "cron", + ChatID: "internal:cron", + Content: message, + SessionKey: fmt.Sprintf("cron:%s", job.ID), + Metadata: map[string]string{ + "trigger": "cron", + "job_id": job.ID, + }, + }) + return "scheduled" +} + +func configureCronServiceRuntime(cs *cron.CronService, cfg *config.Config) { + if cs == nil || cfg == nil { + return + } + cs.SetRuntimeOptions(cron.RuntimeOptions{ + RunLoopMinSleep: time.Duration(cfg.Cron.MinSleepSec) * time.Second, + RunLoopMaxSleep: time.Duration(cfg.Cron.MaxSleepSec) * time.Second, + RetryBackoffBase: time.Duration(cfg.Cron.RetryBackoffBaseSec) * time.Second, + RetryBackoffMax: time.Duration(cfg.Cron.RetryBackoffMaxSec) * time.Second, + MaxConsecutiveFailureRetries: int64(cfg.Cron.MaxConsecutiveFailureRetries), + MaxWorkers: cfg.Cron.MaxWorkers, + }) +} + +func buildHeartbeatService(cfg *config.Config, msgBus *bus.MessageBus) *heartbeat.HeartbeatService { + hbInterval := cfg.Agents.Defaults.Heartbeat.EverySec + if hbInterval <= 0 { + hbInterval = 30 * 60 + } + return heartbeat.NewHeartbeatService(cfg.WorkspacePath(), func(prompt string) (string, error) { + msgBus.PublishInbound(bus.InboundMessage{ + Channel: "system", + SenderID: "heartbeat", + ChatID: "internal:heartbeat", + Content: prompt, + SessionKey: "heartbeat:default", + Metadata: map[string]string{ + "trigger": "heartbeat", + }, + }) + return "queued", nil + }, hbInterval, cfg.Agents.Defaults.Heartbeat.Enabled, cfg.Agents.Defaults.Heartbeat.PromptTemplate) +} diff --git a/cmd/gateway_services.go b/cmd/gateway_services.go new file mode 100644 index 0000000..c6460fc --- /dev/null +++ b/cmd/gateway_services.go @@ -0,0 +1,473 @@ +package main + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" +) + +func gatewayInstallServiceCmd() error { + switch runtime.GOOS { + case "darwin": + return gatewayInstallLaunchdService() + case "windows": + return gatewayInstallWindowsTask() + } + scope, unitPath, err := detectGatewayServiceScopeAndPath() + if err != nil { + return err + } + + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("resolve executable path failed: %w", err) + } + exePath, _ = filepath.Abs(exePath) + configPath := getConfigPath() + workDir := filepath.Dir(exePath) + + unitContent := buildGatewayUnitContent(scope, exePath, configPath, workDir) + if err := os.MkdirAll(filepath.Dir(unitPath), 0755); err != nil { + return fmt.Errorf("create service directory failed: %w", err) + } + if err := os.WriteFile(unitPath, []byte(unitContent), 0644); err != nil { + return fmt.Errorf("write service unit failed: %w", err) + } + + if err := runSystemctl(scope, "daemon-reload"); err != nil { + return err + } + if err := runSystemctl(scope, "enable", gatewayServiceName); err != nil { + return err + } + + fmt.Printf("Gateway service registered: %s (%s)\n", gatewayServiceName, scope) + fmt.Printf(" Unit file: %s\n", unitPath) + fmt.Println(" Start service: clawgo gateway start") + fmt.Println(" Restart service: clawgo gateway restart") + fmt.Println(" Stop service: clawgo gateway stop") + return nil +} + +func gatewayServiceControlCmd(action string) error { + switch runtime.GOOS { + case "darwin": + return gatewayLaunchdControl(action) + case "windows": + return gatewayWindowsTaskControl(action) + } + scope, _, err := detectInstalledGatewayService() + if err != nil { + return err + } + return runSystemctl(scope, action, gatewayServiceName) +} + +func gatewayScopePreference() string { + v := strings.ToLower(strings.TrimSpace(os.Getenv("CLAWGO_GATEWAY_SCOPE"))) + if v == "user" || v == "system" { + return v + } + return "" +} + +func detectGatewayServiceScopeAndPath() (string, string, error) { + switch runtime.GOOS { + case "linux": + default: + return "", "", fmt.Errorf("unsupported service manager for %s", runtime.GOOS) + } + switch gatewayScopePreference() { + case "user": + return userGatewayUnitPath() + case "system": + return "system", "/etc/systemd/system/" + gatewayServiceName, nil + } + if os.Geteuid() == 0 { + return "system", "/etc/systemd/system/" + gatewayServiceName, nil + } + return userGatewayUnitPath() +} + +func userGatewayUnitPath() (string, string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", "", fmt.Errorf("resolve user home failed: %w", err) + } + return "user", filepath.Join(home, ".config", "systemd", "user", gatewayServiceName), nil +} + +func detectInstalledGatewayService() (string, string, error) { + switch runtime.GOOS { + case "darwin": + return detectInstalledLaunchdService() + case "windows": + return detectInstalledWindowsTask() + } + systemPath := "/etc/systemd/system/" + gatewayServiceName + userScope, userPath, err := userGatewayUnitPath() + if err != nil { + return "", "", err + } + + systemExists := false + if info, err := os.Stat(systemPath); err == nil && !info.IsDir() { + systemExists = true + } + + userExists := false + if info, err := os.Stat(userPath); err == nil && !info.IsDir() { + userExists = true + } + + preferredScope := gatewayScopePreference() + switch preferredScope { + case "system": + if systemExists { + return "system", systemPath, nil + } + return "", "", fmt.Errorf("gateway service unit not found in system scope: %s", systemPath) + case "user": + if userExists { + return userScope, userPath, nil + } + return "", "", fmt.Errorf("gateway service unit not found in user scope: %s", userPath) + } + + // Auto-pick scope by current privilege to avoid non-root users accidentally + // selecting system scope when both unit files exist. + if os.Geteuid() == 0 { + if systemExists { + return "system", systemPath, nil + } + if userExists { + return userScope, userPath, nil + } + } else { + if userExists { + return userScope, userPath, nil + } + if systemExists { + return "system", systemPath, nil + } + } + + return "", "", fmt.Errorf("gateway service not registered. Run: clawgo gateway") +} + +func buildGatewayUnitContent(scope, exePath, configPath, workDir string) string { + quotedExec := fmt.Sprintf("%q gateway run --config %q", exePath, configPath) + installTarget := "default.target" + if scope == "system" { + installTarget = "multi-user.target" + } + home, err := os.UserHomeDir() + if err != nil { + home = filepath.Dir(configPath) + } + + return fmt.Sprintf(`[Unit] +Description=ClawGo Gateway +After=network.target + +[Service] +Type=simple +WorkingDirectory=%s +ExecStart=%s +Restart=always +RestartSec=3 +Environment=CLAWGO_CONFIG=%s +Environment=HOME=%s + +[Install] +WantedBy=%s +`, workDir, quotedExec, configPath, home, installTarget) +} + +func runSystemctl(scope string, args ...string) error { + cmdArgs := make([]string, 0, len(args)+1) + if scope == "user" { + cmdArgs = append(cmdArgs, "--user") + } + cmdArgs = append(cmdArgs, args...) + + cmd := exec.Command("systemctl", cmdArgs...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + if scope == "user" { + return fmt.Errorf("systemctl --user %s failed: %w", strings.Join(args, " "), err) + } + return fmt.Errorf("systemctl %s failed: %w", strings.Join(args, " "), err) + } + return nil +} + +func gatewayLaunchdLabel() string { return "ai.clawgo.gateway" } + +func gatewayWindowsTaskName() string { return "ClawGo Gateway" } + +func detectLaunchdScopeAndPath() (string, string, error) { + label := gatewayLaunchdLabel() + ".plist" + switch gatewayScopePreference() { + case "system": + return "system", filepath.Join("/Library/LaunchDaemons", label), nil + case "user": + home, err := os.UserHomeDir() + if err != nil { + return "", "", fmt.Errorf("resolve user home failed: %w", err) + } + return "user", filepath.Join(home, "Library", "LaunchAgents", label), nil + } + if os.Geteuid() == 0 { + return "system", filepath.Join("/Library/LaunchDaemons", label), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", "", fmt.Errorf("resolve user home failed: %w", err) + } + return "user", filepath.Join(home, "Library", "LaunchAgents", label), nil +} + +func detectInstalledLaunchdService() (string, string, error) { + userScope, userPath, err := detectLaunchdScopeAndPath() + if err != nil && gatewayScopePreference() == "user" { + return "", "", err + } + systemPath := filepath.Join("/Library/LaunchDaemons", gatewayLaunchdLabel()+".plist") + systemExists := fileExists(systemPath) + userExists := fileExists(userPath) + + switch gatewayScopePreference() { + case "system": + if systemExists { + return "system", systemPath, nil + } + return "", "", fmt.Errorf("launchd plist not found in system scope: %s", systemPath) + case "user": + if userExists { + return userScope, userPath, nil + } + return "", "", fmt.Errorf("launchd plist not found in user scope: %s", userPath) + } + + if os.Geteuid() == 0 { + if systemExists { + return "system", systemPath, nil + } + if userExists { + return userScope, userPath, nil + } + } else { + if userExists { + return userScope, userPath, nil + } + if systemExists { + return "system", systemPath, nil + } + } + return "", "", fmt.Errorf("gateway service not registered. Run: clawgo gateway") +} + +func gatewayInstallLaunchdService() error { + scope, plistPath, err := detectLaunchdScopeAndPath() + if err != nil { + return err + } + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("resolve executable path failed: %w", err) + } + exePath, _ = filepath.Abs(exePath) + configPath := getConfigPath() + workDir := filepath.Dir(exePath) + if err := os.MkdirAll(filepath.Dir(plistPath), 0755); err != nil { + return fmt.Errorf("create launchd directory failed: %w", err) + } + content := buildGatewayLaunchdPlist(exePath, configPath, workDir) + if err := os.WriteFile(plistPath, []byte(content), 0644); err != nil { + return fmt.Errorf("write launchd plist failed: %w", err) + } + _ = runLaunchctl(scope, "bootout", launchdDomainTarget(scope), plistPath) + if err := runLaunchctl(scope, "bootstrap", launchdDomainTarget(scope), plistPath); err != nil { + return err + } + if err := runLaunchctl(scope, "kickstart", "-k", launchdServiceTarget(scope)); err != nil { + return err + } + fmt.Printf("✓ Gateway service registered: %s (%s)\n", gatewayLaunchdLabel(), scope) + fmt.Printf(" Launchd plist: %s\n", plistPath) + fmt.Println(" Start service: clawgo gateway start") + fmt.Println(" Restart service: clawgo gateway restart") + fmt.Println(" Stop service: clawgo gateway stop") + return nil +} + +func gatewayLaunchdControl(action string) error { + scope, plistPath, err := detectInstalledLaunchdService() + if err != nil { + return err + } + switch action { + case "start": + _ = runLaunchctl(scope, "bootstrap", launchdDomainTarget(scope), plistPath) + return runLaunchctl(scope, "kickstart", "-k", launchdServiceTarget(scope)) + case "stop": + return runLaunchctl(scope, "bootout", launchdDomainTarget(scope), plistPath) + case "restart": + _ = runLaunchctl(scope, "bootout", launchdDomainTarget(scope), plistPath) + if err := runLaunchctl(scope, "bootstrap", launchdDomainTarget(scope), plistPath); err != nil { + return err + } + return runLaunchctl(scope, "kickstart", "-k", launchdServiceTarget(scope)) + case "status": + return runLaunchctl(scope, "print", launchdServiceTarget(scope)) + default: + return fmt.Errorf("unsupported action: %s", action) + } +} + +func buildGatewayLaunchdPlist(exePath, configPath, workDir string) string { + return fmt.Sprintf(` + + + + Label + %s + ProgramArguments + + %s + gateway + run + --config + %s + + WorkingDirectory + %s + RunAtLoad + + KeepAlive + + StandardOutPath + %s + StandardErrorPath + %s + + +`, gatewayLaunchdLabel(), exePath, configPath, workDir, filepath.Join(filepath.Dir(configPath), "gateway.launchd.out.log"), filepath.Join(filepath.Dir(configPath), "gateway.launchd.err.log")) +} + +func launchdDomainTarget(scope string) string { + if scope == "system" { + return "system" + } + return fmt.Sprintf("gui/%d", os.Getuid()) +} + +func launchdServiceTarget(scope string) string { + return launchdDomainTarget(scope) + "/" + gatewayLaunchdLabel() +} + +func runLaunchctl(scope string, args ...string) error { + cmd := exec.Command("launchctl", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("launchctl %s failed: %w", strings.Join(args, " "), err) + } + return nil +} + +func gatewayInstallWindowsTask() error { + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("resolve executable path failed: %w", err) + } + exePath, _ = filepath.Abs(exePath) + configPath := getConfigPath() + taskName := gatewayWindowsTaskName() + command := fmt.Sprintf(`"%s" gateway run --config "%s"`, exePath, configPath) + _ = runSCHTASKS("/Delete", "/TN", taskName, "/F") + if err := runSCHTASKS("/Create", "/TN", taskName, "/SC", "ONLOGON", "/TR", command, "/F"); err != nil { + return err + } + fmt.Printf("✓ Gateway service registered: %s (windows task)\n", taskName) + fmt.Println(" Start service: clawgo gateway start") + fmt.Println(" Restart service: clawgo gateway restart") + fmt.Println(" Stop service: clawgo gateway stop") + return nil +} + +func gatewayWindowsTaskControl(action string) error { + _, _, err := detectInstalledWindowsTask() + if err != nil { + return err + } + taskName := gatewayWindowsTaskName() + switch action { + case "start": + return runSCHTASKS("/Run", "/TN", taskName) + case "stop": + return stopGatewayProcessByPIDFile() + case "restart": + _ = stopGatewayProcessByPIDFile() + return runSCHTASKS("/Run", "/TN", taskName) + case "status": + return runSCHTASKS("/Query", "/TN", taskName, "/V", "/FO", "LIST") + default: + return fmt.Errorf("unsupported action: %s", action) + } +} + +func detectInstalledWindowsTask() (string, string, error) { + taskName := gatewayWindowsTaskName() + if err := runSCHTASKSQuiet("/Query", "/TN", taskName); err != nil { + return "", "", fmt.Errorf("gateway service not registered. Run: clawgo gateway") + } + return "user", taskName, nil +} + +func runSCHTASKS(args ...string) error { + cmd := exec.Command("schtasks", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("schtasks %s failed: %w", strings.Join(args, " "), err) + } + return nil +} + +func runSCHTASKSQuiet(args ...string) error { + cmd := exec.Command("schtasks", args...) + cmd.Stdout = io.Discard + cmd.Stderr = io.Discard + return cmd.Run() +} + +func stopGatewayProcessByPIDFile() error { + pidPath := filepath.Join(filepath.Dir(getConfigPath()), "gateway.pid") + data, err := os.ReadFile(pidPath) + if err != nil { + return fmt.Errorf("gateway pid file not found: %w", err) + } + pid := strings.TrimSpace(string(data)) + if pid == "" { + return fmt.Errorf("gateway pid file is empty") + } + cmd := exec.Command("taskkill", "/PID", pid, "/T", "/F") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("taskkill /PID %s failed: %w", pid, err) + } + return nil +} + +func fileExists(path string) bool { + info, err := os.Stat(path) + return err == nil && !info.IsDir() +} diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 0000000..eb05838 --- /dev/null +++ b/docs/API.md @@ -0,0 +1,1331 @@ +# Gateway API + +This document describes the HTTP and WebSocket API exposed by `clawgo gateway run`. +It is based on the routes registered in `pkg/api/server.go` and the gateway wiring in +`cmd/cmd_gateway.go`. + +## Authentication + +All `/api/*` routes and protected extra routes require the gateway token when +`gateway.token` is non-empty. `/health` does not require authentication. + +Accepted token transports: + +- `Authorization: Bearer ` +- Query parameter: `?token=` +- Cookie: `clawgo_webui_token=` +- Browser asset fallback: a request whose `Referer` URL contains the same `token` + +If `gateway.token` is empty, API authentication is disabled. + +Unauthenticated requests return HTTP `401` with plain text: + +```text +unauthorized +``` + +## Base URL + +The gateway listens on `gateway.host` and `gateway.port` from `config.json`. +The sample config uses: + +```text +http://localhost:18790 +``` + +Examples below use: + +```bash +BASE=http://localhost:18790 +TOKEN= +``` + +JSON endpoints use `Content-Type: application/json` unless noted otherwise. + +## Live Connections + +The current live endpoints are WebSocket endpoints, not SSE endpoints: + +- `GET /api/chat/live` +- `GET /api/events/live` +- `GET /api/logs/live` + +Connect with the same token rules as HTTP requests, for example: + +```text +ws://localhost:18790/api/events/live?token= +``` + +## Health Check + +### `GET /health` + +Purpose: liveness probe for the gateway HTTP process. + +Authentication: none. + +Response: + +```text +ok +``` + +## Config + +### `GET /api/config` + +Purpose: read the merged gateway config. Defaults are merged with the configured +`config.json` content. + +Authentication: required when `gateway.token` is set. + +Query parameters: + +| Name | Description | +| --- | --- | +| `mode=normalized` | Return normalized config view plus `raw_config`. | +| `mode=hot` | Return merged config plus hot reload field metadata. | +| `include_hot_reload_fields=1` | Include hot reload field names and details. | + +Example: + +```bash +curl -H "Authorization: Bearer $TOKEN" "$BASE/api/config?mode=normalized" +``` + +Response: + +```json +{ + "ok": true, + "config": {}, + "raw_config": {} +} +``` + +### `POST /api/config` + +Purpose: save config and trigger gateway reload behavior. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `mode=raw` or omitted | Body is the raw config shape. | +| `mode=normalized` | Body is the normalized config view. | + +Example body: + +```json +{ + "gateway": { + "host": "0.0.0.0", + "port": 18790, + "token": "" + } +} +``` + +Success response: + +```json +{ + "ok": true +} +``` + +Validation errors return HTTP `400`: + +```json +{ + "ok": false, + "error": "invalid config: ...", + "errors": ["..."] +} +``` + +## Chat + +### `POST /api/chat` + +Purpose: send a direct message to the Agent runtime and receive one complete reply. + +Authentication: required. + +Request body: + +| Field | Required | Description | +| --- | --- | --- | +| `session` | No | Session key. Defaults to query `session`, then `main`. | +| `message` | No | Prompt text. | +| `media` | No | Uploaded file path. Appended to the prompt as `[file: ]`. | + +Example: + +```bash +curl -X POST "$BASE/api/chat" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"session":"main","message":"Hello"}' +``` + +Response: + +```json +{ + "ok": true, + "reply": "Hello! How can I help?", + "session": "main" +} +``` + +## Chat History + +### `GET /api/chat/history` + +Purpose: read stored messages for a session through the gateway history callback. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `session` | Session key. Defaults to `main`. | + +Response: + +```json +{ + "ok": true, + "session": "main", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] +} +``` + +## Live Chat + +### `GET /api/chat/live` + +Purpose: WebSocket chat request that streams the completed reply in small JSON +chunks. + +Authentication: required. + +Protocol: + +1. Client opens the WebSocket. +2. Client sends one JSON message with `session`, `message`, and optional `media`. +3. Server replies with zero or more `chat_chunk` messages and one `chat_done` + message, or a `chat_error` message. + +Client message: + +```json +{ + "session": "main", + "message": "Summarize today", + "media": "" +} +``` + +Server chunk: + +```json +{ + "ok": true, + "type": "chat_chunk", + "session": "main", + "delta": "Partial reply text" +} +``` + +Done message: + +```json +{ + "ok": true, + "type": "chat_done", + "session": "main" +} +``` + +Error message: + +```json +{ + "ok": false, + "type": "chat_error", + "error": "invalid json", + "session": "main" +} +``` + +## Events + +### `GET /api/events/live` + +Purpose: WebSocket event stream for gateway-side events such as config changes. + +Authentication: required. + +Initial message: + +```json +{ + "type": "ready" +} +``` + +Known event payload example: + +```json +{ + "type": "config_changed", + "source": "webui" +} +``` + +The connection stays open until the client disconnects. + +## Version + +### `GET /api/version` + +Purpose: read gateway build version and compiled channel keys. + +Authentication: required. + +Response: + +```json +{ + "ok": true, + "gateway_version": "devel", + "compiled_channels": ["weixin", "telegram", "feishu"] +} +``` + +## Provider OAuth + +Provider endpoints use the configured provider when `provider` is empty. Provider +updates save `config.json` and trigger a runtime reload hook when available. + +### `GET|POST /api/provider/oauth/start` + +Purpose: start a manual OAuth login flow. + +Authentication: required. + +GET query parameters or POST JSON fields: + +| Field | Description | +| --- | --- | +| `provider` | Provider name. Defaults to primary provider. | +| `account_label` | Optional label for the imported account. | +| `network_proxy` | Optional proxy override. | +| `provider_config` | POST only. Inline provider config override. | + +Response: + +```json +{ + "ok": true, + "flow_id": "1710000000000000000", + "mode": "manual", + "auth_url": "https://example.com/oauth", + "user_code": "", + "instructions": "Open the URL and paste the callback.", + "account_label": "work", + "network_proxy": "" +} +``` + +### `POST /api/provider/oauth/complete` + +Purpose: complete a previously started manual OAuth flow and persist credentials. + +Authentication: required. + +Request body: + +```json +{ + "provider": "codex", + "flow_id": "1710000000000000000", + "callback_url": "http://localhost:1455/callback?code=...", + "account_label": "work", + "network_proxy": "" +} +``` + +Response: + +```json +{ + "ok": true, + "account": "user@example.com", + "credential_file": "/home/user/.clawgo/auth/codex-work.json", + "network_proxy": "", + "models": ["gpt-5.4"] +} +``` + +### `POST /api/provider/oauth/import` + +Purpose: import an OAuth credential JSON file with multipart form data. + +Authentication: required. + +Multipart fields: + +| Field | Required | Description | +| --- | --- | --- | +| `file` | Yes | Auth JSON file. | +| `provider` | No | Provider name. | +| `account_label` | No | Account label. | +| `network_proxy` | No | Proxy override. | +| `provider_config` | No | JSON encoded inline provider config. | + +Response: + +```json +{ + "ok": true, + "account": "user@example.com", + "credential_file": "/home/user/.clawgo/auth/codex-work.json", + "network_proxy": "", + "models": ["gpt-5.4"] +} +``` + +### `GET /api/provider/oauth/accounts` + +Purpose: list OAuth accounts for a provider. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `provider` | Provider name. Defaults to primary provider. | + +Response: + +```json +{ + "ok": true, + "accounts": [] +} +``` + +### `POST /api/provider/oauth/accounts` + +Purpose: manage a provider OAuth account. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `provider` | Provider name. Defaults to primary provider. | + +Request body: + +```json +{ + "action": "refresh", + "credential_file": "/home/user/.clawgo/auth/codex-work.json" +} +``` + +Supported actions: + +- `refresh` +- `delete` +- `clear_cooldown` + +Response examples: + +```json +{ + "ok": true, + "account": {} +} +``` + +```json +{ + "ok": true, + "deleted": true +} +``` + +```json +{ + "ok": true, + "cleared": true +} +``` + +## Provider Models / Runtime + +### `POST /api/provider/models` + +Purpose: replace the configured model list for a provider. + +Authentication: required. + +Request body: + +```json +{ + "provider": "openai", + "model": "gpt-5.4", + "models": ["gpt-5.4", "gpt-5.4-mini"] +} +``` + +At least one value from `model` or `models` is required. + +Response: + +```json +{ + "ok": true, + "models": ["gpt-5.4", "gpt-5.4-mini"] +} +``` + +### `GET /api/provider/runtime` + +Purpose: inspect provider runtime state and history. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `provider` | Provider name. Defaults to primary provider where applicable. | +| `kind` | Runtime event kind filter. | +| `reason` | Runtime reason filter. | +| `target` | Runtime target filter. | +| `sort` | Sort mode passed to provider runtime view. | +| `changes_only=true` | Include only change events. | +| `window_sec` | Time window in seconds. | +| `limit` | Max result count. | +| `cursor` | Cursor offset. | +| `health_below` | Filter by health threshold. | +| `cooldown_until_before_sec` | Filter cooldowns before now plus this many seconds. | + +Response: + +```json +{ + "ok": true, + "view": {} +} +``` + +### `POST /api/provider/runtime` + +Purpose: operate on provider runtime metadata. + +Authentication: required. + +Request body: + +```json +{ + "provider": "codex", + "action": "refresh_now", + "only_expiring": true +} +``` + +Supported actions: + +- `clear_api_cooldown` +- `clear_history` +- `refresh_now` +- `rerank` + +Response examples: + +```json +{ + "ok": true, + "cleared": true +} +``` + +```json +{ + "ok": true, + "provider": "codex", + "refreshed": true, + "result": {}, + "candidate_order": [], + "summary": {} +} +``` + +```json +{ + "ok": true, + "provider": "codex", + "reranked": true, + "candidate_order": [] +} +``` + +## Weixin Login + +### `GET /api/weixin/status` + +Purpose: read Weixin channel status, pending login records, and accounts. + +Authentication: required. + +Response: + +```json +{ + "ok": true, + "enabled": false, + "base_url": "https://ilinkai.weixin.qq.com", + "pending_logins": [], + "pending_login": { + "login_id": "", + "qr_available": false + }, + "accounts": [] +} +``` + +If the channel is unavailable, the endpoint returns HTTP `200` with `ok: false` +and an `error` field. + +### `POST /api/weixin/login/start` + +Purpose: start a Weixin QR login flow. + +Authentication: required. + +Request body: none. + +Response: same shape as `GET /api/weixin/status`. + +### `POST /api/weixin/login/cancel` + +Purpose: cancel a pending Weixin login by ID. + +Authentication: required. + +Request body: + +```json +{ + "login_id": "login-id" +} +``` + +Response: same shape as `GET /api/weixin/status`. + +### `GET /api/weixin/qr.svg` + +Purpose: render the current or selected pending Weixin login QR code as SVG. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `login_id` | Optional pending login ID. Defaults to the first pending login. | + +Response headers: + +```text +Content-Type: image/svg+xml +``` + +### `POST /api/weixin/accounts/remove` + +Purpose: remove a Weixin account by bot ID. + +Authentication: required. + +Request body: + +```json +{ + "bot_id": "bot-id" +} +``` + +Response: same shape as `GET /api/weixin/status`. + +### `POST /api/weixin/accounts/default` + +Purpose: set the default Weixin account by bot ID. + +Authentication: required. + +Request body: + +```json +{ + "bot_id": "bot-id" +} +``` + +Response: same shape as `GET /api/weixin/status`. + +## Upload + +### `POST /api/upload` + +Purpose: upload a file for later chat usage. + +Authentication: required. + +Content type: `multipart/form-data`. + +Multipart fields: + +| Field | Required | Description | +| --- | --- | --- | +| `file` | Yes | File to upload. | + +The server stores files under the OS temp directory in `clawgo_webui_uploads`. + +Response: + +```json +{ + "ok": true, + "path": "/tmp/clawgo_webui_uploads/1710000000000000000_input.txt", + "name": "input.txt" +} +``` + +## Cron + +### `GET /api/cron` + +Purpose: list cron jobs or read one cron job. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `id` | Optional job ID. When present, returns one job. | + +List response: + +```json +{ + "ok": true, + "jobs": [] +} +``` + +Single job response: + +```json +{ + "ok": true, + "job": {} +} +``` + +### `POST /api/cron` + +Purpose: create or mutate cron jobs through the gateway cron handler. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `id` | Optional job ID copied into the request args. | + +Request body: + +```json +{ + "action": "create", + "name": "daily-check", + "message": "Run daily check", + "expr": "0 9 * * *", + "deliver": false, + "channel": "", + "to": "" +} +``` + +Supported actions from the gateway runtime: + +- `create` +- `update` +- `delete` +- `enable` +- `disable` +- `get` +- `list` + +Legacy scheduling fields `kind`, `everyMs`, and `atMs` are still accepted by the +runtime for older clients. + +Response: + +```json +{ + "ok": true, + "result": {} +} +``` + +## Skills + +### `GET /api/skills` + +Purpose: list skills, inspect a skill's files, or read a skill file. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `check_updates=1` | Check ClawHub for remote versions when `clawhub` is installed. | +| `id` | Skill ID for detail operations. | +| `files=1` | With `id`, list files in that skill. | +| `file` | With `id`, read one relative text file. | + +List response: + +```json +{ + "ok": true, + "skills": [], + "source": "clawhub", + "clawhub_installed": false, + "clawhub_path": "" +} +``` + +Files response: + +```json +{ + "ok": true, + "id": "example", + "files": ["SKILL.md"] +} +``` + +File response: + +```json +{ + "ok": true, + "id": "example", + "file": "SKILL.md", + "content": "# example" +} +``` + +### `POST /api/skills` + +Purpose: import, install, enable, disable, create, update, or write skill files. + +Authentication: required. + +Multipart upload: + +- `Content-Type: multipart/form-data` +- `file`: `.zip`, `.tar`, `.tar.gz`, or `.tgz` archive containing one or more + `SKILL.md` files. + +Multipart response: + +```json +{ + "ok": true, + "imported": ["my-skill"] +} +``` + +JSON body fields: + +| Field | Description | +| --- | --- | +| `action` | `install`, `enable`, `disable`, `write_file`, `create`, or `update`. | +| `id` | Existing skill ID. | +| `name` | Skill name. Defaults to `id`. | +| `description` | Used by `create` and `update`. | +| `tools` | String list used by `create` and `update`. | +| `system_prompt` | Used by `create` and `update`. | +| `file` | Relative file path for `write_file`. | +| `content` | File content for `write_file`. | +| `ignore_suspicious` | Adds `--force` to ClawHub install. | + +Example: + +```json +{ + "action": "disable", + "name": "example" +} +``` + +Response: + +```json +{ + "ok": true +} +``` + +### `DELETE /api/skills` + +Purpose: delete an enabled or disabled skill directory. + +Authentication: required. + +Query parameters: + +| Name | Required | Description | +| --- | --- | --- | +| `id` | Yes | Skill ID. | + +Response: + +```json +{ + "ok": true, + "deleted": true, + "id": "example" +} +``` + +## Sessions + +### `GET /api/sessions` + +Purpose: list user-facing session keys found in the main agent session store. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `include_internal=1` | Include internal, subagent, heartbeat, cron, and hook sessions. | + +Response: + +```json +{ + "ok": true, + "sessions": [ + { + "key": "main", + "channel": "main" + } + ] +} +``` + +## Memory + +### `GET /api/memory` + +Purpose: list memory files or read one memory file. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `path` | Optional relative path. `MEMORY.md` reads from workspace root. | + +List response: + +```json +{ + "ok": true, + "files": ["MEMORY.md"] +} +``` + +File response: + +```json +{ + "ok": true, + "path": "MEMORY.md", + "content": "..." +} +``` + +### `POST /api/memory` + +Purpose: write a memory file under `workspace/memory`. + +Authentication: required. + +Request body: + +```json +{ + "path": "notes.md", + "content": "Remember this." +} +``` + +Response: + +```json +{ + "ok": true, + "path": "notes.md" +} +``` + +### `DELETE /api/memory` + +Purpose: delete a memory file under `workspace/memory`. + +Authentication: required. + +Query parameters: + +| Name | Required | Description | +| --- | --- | --- | +| `path` | Yes | Relative memory file path. | + +Response: + +```json +{ + "ok": true, + "deleted": true, + "path": "notes.md" +} +``` + +## Workspace File + +### `GET /api/workspace_file` + +Purpose: read a relative text file under the configured workspace. + +Authentication: required. + +Query parameters: + +| Name | Required | Description | +| --- | --- | --- | +| `path` | Yes | Relative workspace path. Absolute and `..` paths are rejected. | + +Response: + +```json +{ + "ok": true, + "path": "AGENTS.md", + "found": true, + "content": "..." +} +``` + +### `POST /api/workspace_file` + +Purpose: write a relative text file under the configured workspace. Parent +directories are created when needed. + +Authentication: required. + +Request body: + +```json +{ + "path": "notes/today.md", + "content": "..." +} +``` + +Response: + +```json +{ + "ok": true, + "path": "notes/today.md", + "saved": true +} +``` + +## Tools + +### `GET /api/tool_allowlist_groups` + +Purpose: read built-in tool allowlist group definitions. + +Authentication: required. + +Response: + +```json +{ + "ok": true, + "groups": [] +} +``` + +### `GET /api/tools` + +Purpose: read the runtime tool catalog plus MCP-specific tool and server checks. + +Authentication: required. + +Response: + +```json +{ + "tools": [], + "mcp_tools": [], + "mcp_server_checks": [ + { + "name": "context7", + "enabled": false, + "transport": "stdio", + "status": "disabled", + "message": "server is disabled", + "command": "npx", + "resolved": "", + "package": "@upstash/context7-mcp", + "installer": "", + "installable": false, + "missing_command": false + } + ] +} +``` + +Note: this endpoint currently does not include an `ok` field. + +## MCP Install + +### `POST /api/mcp/install` + +Purpose: install an MCP package through a supported package manager and resolve +the installed binary path. + +Authentication: required. + +Request body: + +```json +{ + "package": "example-mcp", + "installer": "uv" +} +``` + +Supported installers: + +- `uv` (default): runs `uv tool install ` +- `bun`: runs `bun add -g ` + +Response: + +```json +{ + "ok": true, + "package": "example-mcp", + "output": "installed example-mcp via uv", + "bin_name": "example-mcp", + "bin_path": "/home/user/.local/bin/example-mcp" +} +``` + +## Logs + +### `GET /api/logs/recent` + +Purpose: read recent log entries from the configured gateway log file. + +Authentication: required. + +Query parameters: + +| Name | Description | +| --- | --- | +| `limit` | Number of lines to scan. Defaults to `10`, maximum `200`. | + +Response: + +```json +{ + "ok": true, + "logs": [ + { + "time": "2026-05-10T00:00:00Z", + "level": "INFO", + "msg": "gateway started" + } + ] +} +``` + +JSON log lines are returned as parsed objects. Plain text lines are wrapped as +`time`, `level`, and `msg`. + +### `GET /api/logs/live` + +Purpose: WebSocket tail of new log entries from the configured gateway log file. + +Authentication: required. + +Server message: + +```json +{ + "ok": true, + "type": "log_entry", + "entry": { + "time": "2026-05-10T00:00:00Z", + "level": "INFO", + "msg": "gateway started" + } +} +``` + +If the log file cannot be opened after the WebSocket upgrade, the server sends: + +```json +{ + "ok": false, + "error": "open ...: no such file or directory" +} +``` + +## Protected Extra Routes + +### `GET /v1/ws` + +Purpose: AIStudio relay WebSocket route registered by the gateway with +`SetProtectedRoute`. + +Authentication: required. + +Provider selection: + +- Query parameter: `provider` +- Header: `X-Clawgo-Provider` +- Default: `aistudio` + +This route is owned by `pkg/wsrelay` and is documented here only because the +gateway registers it as a protected API-facing route. + +## Error Responses + +Handlers mostly use Go's `http.Error`, so many errors are plain text with an +HTTP status code instead of JSON. + +Common responses: + +| Status | Body | Meaning | +| --- | --- | --- | +| `400` | `invalid json` or validation text | Invalid request body, missing required field, invalid path, or unsupported action. | +| `401` | `unauthorized` | Missing or invalid gateway token. | +| `404` | `... not found` | Requested skill, login, QR, or resource was not found. | +| `405` | `method not allowed` | HTTP method is not supported by that route. | +| `412` | `clawhub is not installed...` | Skill install requires ClawHub. | +| `429` | `clawhub rate limit exceeded...` | ClawHub rejected skill lookup or install due to rate limiting. | +| `500` | error text | Gateway config, filesystem, log, cron, or runtime handler failure. | +| `502` | error text | Upstream channel/login operation failed. | +| `503` | `weixin channel unavailable` | Weixin route needs an active Weixin channel. | + +JSON validation responses may look like: + +```json +{ + "ok": false, + "error": "invalid config: ...", + "errors": ["..."] +} +``` + +WebSocket endpoints report post-upgrade errors as JSON messages when possible. + +## Common Flows + +### Send a chat message + +```bash +curl -X POST "$BASE/api/chat" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"session":"main","message":"Ping"}' +``` + +### Upload a file and reference it in chat + +```bash +UPLOAD_PATH=$( + curl -s -X POST "$BASE/api/upload" \ + -H "Authorization: Bearer $TOKEN" \ + -F "file=@./notes.txt" | + jq -r .path +) + +curl -X POST "$BASE/api/chat" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d "{\"session\":\"main\",\"message\":\"Read this file\",\"media\":\"$UPLOAD_PATH\"}" +``` + +### Start and complete OAuth + +```bash +curl -X POST "$BASE/api/provider/oauth/start" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"provider":"codex","account_label":"work"}' +``` + +Open the returned `auth_url`, then call: + +```bash +curl -X POST "$BASE/api/provider/oauth/complete" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"provider":"codex","flow_id":"","callback_url":""}' +``` + +### Watch config and log events + +Use WebSocket clients against: + +```text +ws://localhost:18790/api/events/live?token= +ws://localhost:18790/api/logs/live?token= +``` diff --git a/docs/PROJECT_OVERVIEW.md b/docs/PROJECT_OVERVIEW.md new file mode 100644 index 0000000..0759925 --- /dev/null +++ b/docs/PROJECT_OVERVIEW.md @@ -0,0 +1,395 @@ +# ClawGo 项目说明 + +本文档用于快速熟悉当前代码库:它是什么、如何运行、核心模块如何协作,以及开发时应该从哪里下手。 + +## 1. 项目定位 + +ClawGo 是一个 Go 原生的个人 AI Agent Runtime。它不只是命令行聊天壳,而是把 Agent 对话、工具调用、消息通道、子 Agent、记忆、定时任务、网关 Web API、健康巡检和服务化部署组织到同一个运行时里。 + +从代码结构看,项目当前主要服务两类使用方式: + +- `clawgo agent`:命令行直连 Agent,可交互,也可通过 `-m` 传入单次消息。 +- `clawgo gateway run`:启动常驻网关,接入 WebUI/API、Telegram、微信、飞书、定时任务、心跳和 Sentinel 巡检。 + +## 2. 目录结构 + +```text +. +├── cmd/ # CLI 入口和各子命令 +├── pkg/ +│ ├── agent/ # Agent 主循环、规划、上下文、子 Agent 编排 +│ ├── api/ # Gateway HTTP/WebUI API +│ ├── browser/ # Chromium 截图/内容读取辅助 +│ ├── bus/ # 入站/出站消息总线 +│ ├── channels/ # Telegram / Weixin / Feishu 通道 +│ ├── config/ # 配置结构、默认值、校验、标准化视图 +│ ├── configops/ # 配置操作辅助 +│ ├── cron/ # 定时任务调度和持久化 +│ ├── events/ # 类型化事件总线 +│ ├── heartbeat/ # 心跳服务 +│ ├── lifecycle/ # 后台循环生命周期工具 +│ ├── logger/ # 结构化日志与轮转 +│ ├── providers/ # LLM Provider 适配 +│ ├── runtimecfg/ # 运行时全局配置快照 +│ ├── scheduling/ # 资源调度键 +│ ├── sentinel/ # 健康巡检与告警 +│ ├── session/ # 会话历史持久化 +│ ├── tools/ # Agent 工具注册与实现 +│ └── wsrelay/ # WebSocket relay +├── workspace/ # 内置 workspace 模板、skills、记忆/身份文件 +├── scripts/ # 构建脚本 +├── config.example.json # 完整配置示例 +├── Dockerfile # 容器化运行 +├── docker-compose.yml # 本地 compose 示例 +├── Makefile # 构建、测试、安装目标 +├── README.md +└── README_EN.md +``` + +## 3. 启动入口 + +主入口在 `cmd/main.go`。程序会先解析全局参数: + +- `--config ` 或 `CLAWGO_CONFIG`:指定配置文件。 +- `--debug` / `-d`:启用 debug 日志。 + +随后根据第一个命令分发: + +```text +onboard 初始化配置和 workspace +agent 直接与 Agent 交互 +gateway 注册、管理或前台运行网关服务 +status 查看状态 +provider 管理模型服务商和 OAuth +config 读写配置 +cron 管理定时任务 +channel 测试和管理消息通道 +skills 管理 skills +tui 终端聊天 UI,取决于构建开关 +uninstall 卸载组件 +version 输出版本 +``` + +## 4. 核心运行链路 + +### 4.1 CLI 直连模式 + +`cmd/cmd_agent.go` 的主要流程: + +1. 加载配置。 +2. 根据配置创建 LLM Provider。 +3. 创建 `bus.MessageBus`。 +4. 创建 `cron.CronService`,用于 `remind` / `cron` 工具。 +5. 创建 `agent.AgentLoop`。 +6. 单次消息走 `ProcessDirect`,交互模式逐行读入后同样走 `ProcessDirect`。 + +简化链路: + +```text +用户输入 + -> cmd/agent + -> agent.AgentLoop.ProcessDirect + -> 构造上下文 + 历史 + 工具定义 + -> providers.LLMProvider.Chat + -> 执行 tool calls + -> 保存会话 + -> 返回回复 +``` + +### 4.2 Gateway 常驻模式 + +`cmd/cmd_gateway.go` 的 `gatewayCmd()` 是常驻服务的组合器: + +1. 加载配置并写入 `runtimecfg`。 +2. 创建消息总线。 +3. 创建 Cron、Heartbeat、Sentinel。 +4. 创建 AgentLoop 和 Channel Manager。 +5. 启动 HTTP API Server。 +6. 启动各消息通道。 +7. 启动 AgentLoop 后台消费消息。 +8. 监听配置文件变化和系统信号,支持热重载。 + +简化链路: + +```text +外部通道 / Web API / Cron / Heartbeat + -> bus.InboundMessage + -> agent.AgentLoop.Run + -> session shard 并发处理 + -> LLM + tools + -> bus.OutboundMessage + -> channels.Manager.dispatchOutbound + -> 外部通道 +``` + +## 5. 核心模块说明 + +### 5.1 `pkg/config` + +配置中心,定义完整的 `Config` 结构: + +- `agents`:workspace、模型默认值、上下文压缩、执行参数、router、subagents。 +- `channels`:微信、Telegram、飞书和消息去重窗口。 +- `models.providers`:OpenAI 兼容接口、Codex、Claude、Gemini、Qwen、Kimi、Vertex 等 Provider 配置。 +- `gateway`:HTTP 监听地址、端口和 token。 +- `cron`:定时任务轮询、退避和并发参数。 +- `tools`:web、shell、filesystem、MCP 工具配置。 +- `logging`:日志文件和轮转。 +- `sentinel`:巡检与告警。 +- `memory`:分层记忆开关。 + +`LoadConfig()` 会先加载默认配置,再严格解析 JSON,并通过环境变量覆盖部分字段。`DefaultConfig()` 是理解默认行为的最佳入口。 + +### 5.2 `pkg/agent` + +Agent Runtime 的核心。`NewAgentLoop()` 会完成: + +- 创建 workspace。 +- 初始化 `session.SessionManager`。 +- 注册本地工具。 +- 注册 MCP 远端工具。 +- 注册消息工具、spawn 工具、subagent profile 工具。 +- 注册记忆、浏览器、摄像头、系统信息等工具。 +- 初始化 provider fallback chain。 +- 注入 subagent 的递归运行逻辑。 + +`Run()` 用 session key 分片,保证不同会话可以并发处理,同一会话尽量串行。`ProcessDirect()` 用于 CLI/Web API 的直接请求。 + +### 5.3 `pkg/tools` + +工具系统以 `Tool` 接口为核心: + +```go +type Tool interface { + Name() string + Description() string + Parameters() map[string]interface{} + Execute(ctx context.Context, args map[string]interface{}) (string, error) +} +``` + +`ToolRegistry` 负责注册、查询、执行和导出 schema。当前 AgentLoop 默认注册的能力包括: + +- 文件读写、目录列表、编辑文件。 +- shell 执行和进程管理。 +- web search、web fetch、parallel fetch。 +- `parallel` 并发工具调用。 +- `message` 外发消息。 +- `spawn` / `subagent_profile` / `sessions`。 +- `memory_search` / `memory_get` / `memory_write`。 +- `skill_exec`。 +- `browser`、`camera`、`system_info`。 +- MCP 动态发现出的远端工具。 + +子 Agent 可以通过 profile 的 allowlist 限制工具范围。`parallel` 工具会校验内部调用是否也在 allowlist 中。 + +### 5.4 `pkg/providers` + +LLM Provider 抽象在 `pkg/providers/types.go`: + +- `LLMProvider`:同步 Chat。 +- `StreamingLLMProvider`:可选流式输出。 +- `ResponsesCompactor`:可选 Responses Compact。 +- `TokenCounter`:可选 token 统计。 + +`CreateProviderByName()` 会根据 provider 名称或 OAuth provider 路由到不同实现: + +- `HTTPProvider`:通用 OpenAI Responses / OpenAI-compatible 适配。 +- `CodexProvider` +- `ClaudeProvider` +- `GeminiProvider` +- `GeminiCLIProvider` +- `AistudioProvider` +- `AntigravityProvider` +- `QwenProvider` +- `KimiProvider` +- `IFlowProvider` +- `VertexProvider` + +Provider 层还包含 OAuth 多账号轮换、失败冷却、runtime history 持久化、OpenAI-compatible chat 转换、Responses API 工具调用格式转换等逻辑。 + +### 5.5 `pkg/bus` 和 `pkg/channels` + +`pkg/bus` 是运行时内部消息通道: + +- `InboundMessage`:外部消息、系统触发、cron、heartbeat 进入 Agent。 +- `OutboundMessage`:Agent 或工具发往外部通道。 + +`pkg/channels.Manager` 负责: + +- 根据配置创建 Telegram、Weixin、Feishu 通道。 +- 启停所有通道。 +- 监听 outbound bus 并分发到对应通道。 +- 对 outbound 做短窗口去重和限速。 +- 提供健康检查。 + +### 5.6 `pkg/api` + +Gateway HTTP API 和 WebUI 后端。主要接口包括: + +```text +GET /health +* /api/config +* /api/chat +* /api/chat/history +* /api/chat/live +* /api/events/live +* /api/version +* /api/provider/oauth/start +* /api/provider/oauth/complete +* /api/provider/oauth/import +* /api/provider/oauth/accounts +* /api/provider/models +* /api/provider/runtime +* /api/weixin/status +* /api/weixin/login/start +* /api/weixin/login/cancel +GET /api/weixin/qr.svg +* /api/upload +* /api/cron +* /api/skills +* /api/sessions +* /api/memory +* /api/workspace_file +* /api/tool_allowlist_groups +* /api/tools +* /api/mcp/install +* /api/logs/live +* /api/logs/recent +``` + +API Server 由 gateway 注入 chat handler、history handler、cron handler、tools catalog handler 等回调,因此 `pkg/api` 不直接拥有 AgentLoop。 + +### 5.7 `pkg/cron` + +定时任务服务。任务保存在配置目录下的 `cron/jobs.json`。支持: + +- cron 表达式。 +- 一次性或间隔式兼容字段。 +- 失败退避。 +- 最大并发 worker。 +- 任务启停、更新、删除、查询。 + +Gateway 中的 cron job 可以: + +- 投递为内部 inbound message,让 Agent 处理。 +- 如果指定了 channel/to,则直接发 outbound message 到外部通道。 + +### 5.8 `pkg/session` 与记忆 + +会话历史由 `SessionManager` 持久化为 JSONL,同时维护 `sessions.json` 索引。默认主会话路径来自: + +```text +/agents/main/sessions +``` + +会话 key 会被映射为稳定 session id,并根据 key 前缀识别类型: + +- `cron:*` -> cron +- `subagent:*` 或包含 `:subagent:` -> subagent +- `hook:*` -> hook +- 包含 `:` -> main +- 其他 -> other + +记忆工具读写 workspace 下的记忆文件,并支持 main / subagent namespace 隔离。 + +### 5.9 `pkg/sentinel` 和 `pkg/heartbeat` + +- Heartbeat 定时生成系统 inbound message,用于让 Agent 周期性自检或继续长期任务。 +- Sentinel 定期检查 channels、config、memory、logs,可通过配置的通道或 webhook 告警。 + +## 6. 配置和运行数据位置 + +默认配置目录: + +```text +~/.clawgo +``` + +debug 模式下: + +```text +.clawgo +``` + +常见文件: + +```text +~/.clawgo/config.json +~/.clawgo/workspace/ +~/.clawgo/logs/clawgo.log +~/.clawgo/cron/jobs.json +~/.clawgo/agents/main/sessions/ +~/.clawgo/runtime/providers/*.json +``` + +配置文件也可以通过 `--config` 或 `CLAWGO_CONFIG` 指定。 + +## 7. 本地开发 + +常用命令: + +```bash +go test ./... +go build -o clawgo ./cmd +make build +make test +make dev +``` + +启动方式: + +```bash +clawgo onboard +clawgo agent +clawgo agent -m "Hello" +clawgo gateway run +``` + +Docker: + +```bash +docker compose up --build +``` + +容器会把 `/home/clawgo/.clawgo` 挂载到本地 `./.clawgo`,首次启动时如果没有配置文件会执行 `clawgo onboard`,随后进入 `gateway run`。 + +## 8. 新功能开发入口建议 + +- 新 CLI 命令:从 `cmd/main.go` 分发,再新增 `cmd/cmd_xxx.go`。 +- 新配置字段:先改 `pkg/config/config.go`,再补 `DefaultConfig()`、`Normalize()` / `Validate()` 和示例配置。 +- 新工具:在 `pkg/tools` 实现 `Tool`,再到 `agent.NewAgentLoop()` 注册。 +- 新消息通道:实现 `channels.Channel`,在 `channels.Manager.initChannels()` 接入。 +- 新 Provider:实现 `providers.LLMProvider`,在 `CreateProviderByName()` 增加路由。 +- 新 Web API:在 `pkg/api/server.go` 注册路由,尽量通过 handler 回调依赖运行时能力。 +- 新后台循环:优先复用 `pkg/lifecycle.LoopRunner` 管理启动/停止。 + +## 9. 测试分布 + +项目已有较多单元测试,重点覆盖: + +- API Server。 +- Agent planning、router、memory、subagent。 +- bus 并发关闭与投递。 +- channel 去重和平台行为。 +- config normalized / validate。 +- cron 调度。 +- provider 请求构造、OAuth、兼容协议。 +- tools 参数解析、MCP、parallel、camera、remind、skill exec。 +- session manager。 + +新增行为建议就近补测试,尤其是: + +- 配置解析和兼容迁移。 +- Provider 请求格式。 +- 工具参数解析。 +- 消息通道去重/限流。 +- Agent loop 的 tool call 配对和上下文压缩。 + +## 10. 当前观察到的注意事项 + +- 当前工作区已有未提交修改:`pkg/api/server.go`、`pkg/bus/bus.go`、`pkg/channels/weixin.go`、`pkg/channels/weixin_test.go`、`pkg/tools/mcp.go`,以及未跟踪的 `cmd/artifacts/`。本文档没有修改这些文件。 +- README 和部分源码里的中文在当前 PowerShell 输出中显示为乱码,可能是终端编码或历史文件编码问题;编辑中文文档时建议统一使用 UTF-8。 +- `config.example.json` 中展示的配置较完整,是理解运行能力和默认部署形态的重要参考。 +- Gateway 热重载会区分元数据变更和 runtime 相关变更;host/port 变化需要重启才能重新绑定监听端口。 +- WebUI/API 可以读写部分配置,但运行时核心仍以本地 `config.json` 为主。 diff --git a/docs/optimization/01-docs-encoding-cleanup.md b/docs/optimization/01-docs-encoding-cleanup.md new file mode 100644 index 0000000..a12c313 --- /dev/null +++ b/docs/optimization/01-docs-encoding-cleanup.md @@ -0,0 +1,64 @@ +# 01 文档与编码清理 + +## 背景 + +当前 README、配置示例和部分源码输出中的中文在 PowerShell 中显示为乱码。乱码会影响新用户理解,也容易扩散到日志、WebUI、示例配置和后续文档。 + +## 目标 + +- 确认仓库文本文件统一使用 UTF-8。 +- 修复 README、配置示例、CLI 输出、Agent 通知文案中的乱码中文。 +- 保持文档内容准确,不顺手改变代码行为。 + +## 建议改动范围 + +主要文件: + +- `README.md` +- `README_EN.md` +- `config.example.json` +- `cmd/*.go` 中用户可见输出 +- `pkg/agent/*.go` 中用户可见中文通知 +- `workspace/*.md` +- `workspace/skills/**/SKILL.md` + +尽量避免修改: + +- Provider 协议逻辑 +- Gateway 生命周期逻辑 +- 工具执行逻辑 + +## 实施步骤 + +1. 扫描乱码: + + ```bash + rg "�|锛|鈹|闈|鎴|涓|乧|乀|銆" README.md README_EN.md config.example.json cmd pkg workspace + ``` + +2. 判断每处乱码原始含义,优先从上下文、英文 README、测试断言和代码行为恢复。 +3. 修复文档和用户可见字符串。 +4. 如修改测试断言中的字符串,同步更新相关测试。 +5. 确认所有 Markdown 文件可读,代码能编译。 + +## 验收标准 + +- `README.md` 中文可正常阅读。 +- `config.example.json` 中文注释/示例文字不再乱码。 +- CLI/Gateway 关键启动输出不再乱码。 +- `rg` 扫描不再出现明显历史乱码片段。 +- `go test ./...` 通过。 + +## 测试建议 + +```bash +go test ./... +go run ./cmd version +go run ./cmd --debug version +``` + +如果终端仍显示异常,确认终端编码和文件编码,而不是继续盲改文本。 + +## 并行注意 + +该任务可能触碰很多文件,但应只做文本修复。避免和结构重构任务互相覆盖同一段逻辑。 diff --git a/docs/optimization/02-gateway-structure-split.md b/docs/optimization/02-gateway-structure-split.md new file mode 100644 index 0000000..99050eb --- /dev/null +++ b/docs/optimization/02-gateway-structure-split.md @@ -0,0 +1,114 @@ +# 02 Gateway 结构拆分 + +## 背景 + +`cmd/cmd_gateway.go` 当前同时承担 gateway 命令解析、服务注册、runtime 组装、热重载、API handler 注入、channel 生命周期、cron/heartbeat/sentinel 编排、bootstrap 初始化等职责。文件过大,后续修改容易产生回归。 + +## 目标 + +在不改变行为的前提下拆分 Gateway 代码结构,让每类职责有清晰文件边界。 + +## 建议改动范围 + +主要文件: + +- `cmd/cmd_gateway.go` +- 新增 `cmd/gateway_runtime.go` +- 新增 `cmd/gateway_reload.go` +- 新增 `cmd/gateway_services.go` +- 新增 `cmd/gateway_bootstrap.go` + +可选文件: + +- `cmd/reload_windows.go` +- `cmd/reload_unix.go` +- `cmd/signals_windows.go` +- `cmd/signals_unix.go` + +尽量避免修改: + +- `pkg/api/server.go` +- `pkg/agent/*` +- `pkg/providers/*` +- `pkg/channels/*` + +## 建议拆分 + +### `cmd/cmd_gateway.go` + +保留命令入口和高层流程: + +- 参数分发。 +- `gateway run/start/stop/restart/status` 分支。 +- 前台 run 的主流程骨架。 + +### `cmd/gateway_runtime.go` + +放运行时组装: + +- `buildGatewayRuntime` +- `bindAgentLoopHandlers` +- `configureCronServiceRuntime` +- `buildHeartbeatService` +- `dispatchCronJob` +- `normalizeCronTargetChatID` + +### `cmd/gateway_reload.go` + +放热重载: + +- config fingerprint。 +- config watcher。 +- reload trigger。 +- runtimeSame 判断。 +- reload 后重新绑定 API handler、channel、sentinel。 + +### `cmd/gateway_services.go` + +放系统服务注册/控制: + +- Linux systemd。 +- macOS launchd。 +- Windows scheduled task。 +- PID file stop。 + +### `cmd/gateway_bootstrap.go` + +放启动辅助任务: + +- startup compaction check。 +- bootstrap init。 +- maximum permission policy。 + +## 验收标准 + +- 行为不变,命令仍可用: + + ```bash + go run ./cmd gateway run --config ./config.json + go run ./cmd gateway status --config ./config.json + ``` + +- `go test ./...` 通过。 +- `cmd/cmd_gateway.go` 明显变薄,职责聚焦于命令入口。 +- 新文件命名和函数归属清晰。 +- 没有引入循环依赖或跨 package 重构。 + +## 测试建议 + +```bash +go test ./cmd ./pkg/api ./pkg/channels ./pkg/agent +go test ./... +``` + +如果有本地配置,可额外手动验证: + +```bash +go run ./cmd gateway run --config ./config.json +``` + +启动后修改配置文件,观察热重载日志是否仍正常。 + +## 并行注意 + +本任务应避免改 API 路由和 Provider 行为。若 API 文档任务并行进行,它只读 `pkg/api/server.go`,双方冲突较小。 diff --git a/docs/optimization/03-provider-layer-modularization.md b/docs/optimization/03-provider-layer-modularization.md new file mode 100644 index 0000000..4b7d23a --- /dev/null +++ b/docs/optimization/03-provider-layer-modularization.md @@ -0,0 +1,111 @@ +# 03 Provider 层模块化 + +## 背景 + +`pkg/providers/http_provider.go` 当前聚合了通用 HTTP provider、Responses API、OpenAI-compatible Chat、Codex/Qwen/Kimi 兼容逻辑、OAuth runtime 状态、请求构造、响应解析和部分持久化状态逻辑。文件复杂度偏高,继续新增 Provider 会越来越难。 + +## 目标 + +在保持现有行为和测试通过的前提下,把 Provider 层按职责拆开,降低单文件复杂度。 + +## 建议改动范围 + +主要文件: + +- `pkg/providers/http_provider.go` +- `pkg/providers/oauth.go` +- `pkg/providers/*_provider.go` +- 新增若干 provider 内部文件 + +建议新增文件: + +- `pkg/providers/provider_registry.go` +- `pkg/providers/responses_adapter.go` +- `pkg/providers/openai_compat_adapter.go` +- `pkg/providers/provider_runtime.go` +- `pkg/providers/provider_request_options.go` + +尽量避免修改: + +- `pkg/agent/*` +- `pkg/tools/*` +- `cmd/*` +- `pkg/api/*` + +## 建议拆分 + +### Provider 注册与路由 + +移动到 `provider_registry.go`: + +- `CreateProvider` +- `CreateProviderByName` +- `normalizeProviderRouteName` +- `getProviderConfigByName` +- provider alias / route 相关逻辑 + +### Responses API 适配 + +移动到 `responses_adapter.go`: + +- Responses request body 构造。 +- Responses input item 转换。 +- Responses tools 转换。 +- Responses response 解析。 +- Responses compact summary。 + +### OpenAI-compatible Chat 适配 + +移动到 `openai_compat_adapter.go`: + +- chat completions request body 构造。 +- multimodal message 转换。 +- function call 转换。 +- Qwen/Kimi 兼容 chat 格式相关公共逻辑。 + +### Provider runtime 状态 + +移动到 `provider_runtime.go`: + +- runtime state 结构。 +- history 持久化。 +- health score。 +- recent hits/errors/changes。 +- candidate order。 + +### 请求选项与通用解析 + +移动到 `provider_request_options.go`: + +- `rawOption` +- `stringOption` +- `mapOption` +- `stringSliceOption` +- `int64FromOption` +- `float64FromOption` +- 其他 request option helper。 + +## 验收标准 + +- `go test ./pkg/providers` 通过。 +- `go test ./...` 通过。 +- Provider 对外接口不变。 +- 现有配置文件无需迁移。 +- `http_provider.go` 明显缩小,并聚焦 `HTTPProvider` 本体。 +- 没有改变已有 provider 的请求格式,除非测试明确覆盖并更新说明。 + +## 测试建议 + +重点跑: + +```bash +go test ./pkg/providers -count=1 +go test ./pkg/agent ./pkg/api ./cmd -count=1 +go test ./... +``` + +Provider 层已有大量请求构造测试。拆分时应优先保持测试不动,让测试证明行为未变。 + +## 并行注意 + +该任务风险较高,建议不要和“新增 Provider”或“AgentLoop 工具调用逻辑改动”同时修改同一分支。若必须并行,先完成纯移动,再做行为改动。 diff --git a/docs/optimization/04-tool-bootstrap-refactor.md b/docs/optimization/04-tool-bootstrap-refactor.md new file mode 100644 index 0000000..85579b1 --- /dev/null +++ b/docs/optimization/04-tool-bootstrap-refactor.md @@ -0,0 +1,101 @@ +# 04 工具注册架构优化 + +## 背景 + +`agent.NewAgentLoop()` 直接注册大量工具,包括文件、shell、web、MCP、message、spawn、sessions、memory、parallel、browser、camera、system_info 等。工具注册逻辑和 AgentLoop 初始化耦合较深,也让子 Agent 可见性、默认工具集、可选工具集的边界不够清晰。 + +## 目标 + +抽出工具启动/注册逻辑,让 AgentLoop 只负责调用工具构建器,而不是直接知道每个工具的创建细节。 + +## 建议改动范围 + +主要文件: + +- `pkg/agent/loop.go` +- `pkg/tools/registry.go` +- 新增 `pkg/tools/bootstrap.go` +- 新增 `pkg/tools/bootstrap_options.go` + +可选文件: + +- `pkg/tools/tool_allowlist_groups.go` +- `pkg/tools/subagent*.go` +- `pkg/tools/mcp.go` + +尽量避免修改: + +- `pkg/providers/*` +- `pkg/api/*` +- `cmd/*` + +## 建议设计 + +新增一个工具构建入口,例如: + +```go +type BootstrapOptions struct { + Config *config.Config + Workspace string + MessageBus *bus.MessageBus + CronService *cron.CronService + Provider providers.LLMProvider + ProcessManager *ProcessManager +} + +type BootstrapResult struct { + Registry *ToolRegistry + ProcessManager *ProcessManager + SubagentManager *SubagentManager + SubagentRouter *SubagentRouter +} + +func BootstrapDefaultTools(ctx context.Context, opts BootstrapOptions) (*BootstrapResult, error) +``` + +AgentLoop 中保留: + +- session manager 初始化。 +- context builder 初始化。 +- provider fallback chain。 +- subagent run func 注入。 + +工具构建器负责: + +- 创建并注册基础工具。 +- 根据 config 注册 cron/remind。 +- 根据 config 注册 MCP 和远端工具。 +- 创建 subagent manager/router。 +- 注册 tool catalog 需要的 metadata。 + +## 子 Agent 可见性 + +保持现有行为: + +- 主 Agent 默认看见全部注册工具。 +- 子 Agent 由 profile allowlist 限制。 +- `skill_exec` 仍为隐式允许工具。 +- `parallel` 内部 call 仍需要逐项校验 allowlist。 + +## 验收标准 + +- `agent.NewAgentLoop()` 更短,工具注册细节迁移出去。 +- 工具列表和原来保持一致。 +- MCP discovery 行为保持一致。 +- subagent spawn/profile/sessions 工具仍正常注册。 +- `go test ./pkg/agent ./pkg/tools` 通过。 +- `go test ./...` 通过。 + +## 测试建议 + +```bash +go test ./pkg/tools -count=1 +go test ./pkg/agent -count=1 +go test ./... +``` + +建议新增一个测试断言默认工具集名称,避免重构时漏注册工具。 + +## 并行注意 + +该任务会修改 `pkg/agent/loop.go`,应避免和 AgentLoop 行为改动并行落同一分支。它不应修改 Provider 或 Gateway。 diff --git a/docs/optimization/05-gateway-api-docs.md b/docs/optimization/05-gateway-api-docs.md new file mode 100644 index 0000000..6676259 --- /dev/null +++ b/docs/optimization/05-gateway-api-docs.md @@ -0,0 +1,120 @@ +# 05 Gateway API 文档 + +## 背景 + +Gateway 暴露了较多 WebUI/API endpoint,但目前主要靠 `pkg/api/server.go` 阅读理解。对前端、集成方和测试编写者来说,需要一份稳定的 API 文档。 + +## 目标 + +新增 Gateway API 文档,说明认证方式、核心接口、请求/响应样例、错误约定和常见调用场景。 + +## 建议改动范围 + +主要文件: + +- 新增 `docs/API.md` +- 可选更新 `README.md` 或 `docs/PROJECT_OVERVIEW.md` 增加链接 + +参考文件: + +- `pkg/api/server.go` +- `cmd/cmd_gateway.go` +- `config.example.json` + +尽量避免修改: + +- API handler 行为 +- Gateway runtime 行为 +- 前端资源或 WebUI 逻辑 + +## 文档建议结构 + +```text +# Gateway API + +## 认证 +## 基础 URL +## 健康检查 +## Chat +## Chat History +## Live Chat / SSE +## Events +## Config +## Provider OAuth +## Provider Models / Runtime +## Weixin Login +## Upload +## Cron +## Skills +## Sessions +## Memory +## Workspace File +## Tools +## MCP Install +## Logs +## 错误响应 +``` + +## 需要覆盖的接口 + +从当前代码至少覆盖: + +```text +GET /health +* /api/config +* /api/chat +* /api/chat/history +* /api/chat/live +* /api/events/live +* /api/version +* /api/provider/oauth/start +* /api/provider/oauth/complete +* /api/provider/oauth/import +* /api/provider/oauth/accounts +* /api/provider/models +* /api/provider/runtime +* /api/weixin/status +* /api/weixin/login/start +* /api/weixin/login/cancel +GET /api/weixin/qr.svg +* /api/weixin/accounts/remove +* /api/weixin/accounts/default +* /api/upload +* /api/cron +* /api/skills +* /api/sessions +* /api/memory +* /api/workspace_file +* /api/tool_allowlist_groups +* /api/tools +* /api/mcp/install +* /api/logs/live +* /api/logs/recent +``` + +## 验收标准 + +- 新增 `docs/API.md`。 +- 每个核心接口至少有用途、方法、认证、主要参数、响应示例。 +- 文档说明 gateway token 的传递方式。 +- 文档中明确哪些接口用于 SSE/live。 +- 文档不声称尚不存在的行为。 +- 纯文档任务无需修改源码。 + +## 测试建议 + +纯文档任务可不跑全量测试。建议至少用代码核对路由: + +```bash +rg "HandleFunc|SetProtectedRoute" pkg/api/server.go +``` + +如果顺手补了 API 测试,则运行: + +```bash +go test ./pkg/api +``` + +## 并行注意 + +本任务只写文档,适合和 Gateway 拆分并行执行。若 Gateway 拆分改动路由注册位置,最终合并时再核对一次 API 文档。 diff --git a/docs/optimization/06-runtime-service-tests.md b/docs/optimization/06-runtime-service-tests.md new file mode 100644 index 0000000..144af33 --- /dev/null +++ b/docs/optimization/06-runtime-service-tests.md @@ -0,0 +1,93 @@ +# 06 Cron / Heartbeat / Sentinel 测试 + +## 背景 + +当前测试覆盖面较广,但长期运行服务相关模块还有明显空白,尤其是: + +- `pkg/cron` +- `pkg/heartbeat` +- `pkg/sentinel` + +这些模块负责定时任务、心跳和健康巡检,是常驻 Gateway 的可靠性基础。 + +## 目标 + +补充运行时服务测试,覆盖调度、持久化、失败退避、生命周期启停和巡检结果。 + +## 建议改动范围 + +主要文件: + +- 新增 `pkg/cron/service_test.go` +- 新增 `pkg/heartbeat/service_test.go` +- 新增 `pkg/sentinel/service_test.go` + +可选文件: + +- `pkg/lifecycle/loop_runner_test.go` + +尽量避免修改: + +- `cmd/cmd_gateway.go` +- `pkg/agent/*` +- `pkg/providers/*` + +如果为了可测试性需要小改生产代码,应保持非常小的改动,例如注入 clock、缩短 sleep、暴露只读状态。 + +## Cron 测试建议 + +覆盖: + +- `AddJob` 后写入 store。 +- `ListJobs(includeDisabled)` 行为。 +- `EnableJob` 启停。 +- `UpdateJob` 修改 schedule/message。 +- cron expr 计算 next run。 +- 失败后 backoff 和 consecutive failure。 +- `MaxWorkers` 限制并发。 +- 一次性任务 `DeleteAfterRun`。 +- store 损坏或空文件时的行为。 + +## Heartbeat 测试建议 + +覆盖: + +- disabled 时 `Start()` 不触发 heartbeat。 +- enabled 时按 interval 调用回调。 +- `Stop()` 后不再触发。 +- `buildPrompt()` 使用自定义 prompt template。 +- 空 Markdown 判断。 +- ack token 提取。 + +## Sentinel 测试建议 + +覆盖: + +- config 文件不存在/损坏时返回问题。 +- memory 目录缺失时行为。 +- logs 目录/文件检查。 +- channel health check 汇总。 +- alert callback 被调用。 +- disabled 时不启动或不告警。 + +## 验收标准 + +- 新增三个模块的测试文件。 +- 测试稳定,不依赖真实网络、真实外部平台或长时间 sleep。 +- `go test ./pkg/cron ./pkg/heartbeat ./pkg/sentinel` 通过。 +- `go test ./...` 通过。 + +## 测试建议 + +```bash +go test ./pkg/cron -count=1 +go test ./pkg/heartbeat -count=1 +go test ./pkg/sentinel -count=1 +go test ./... +``` + +如果某个服务当前不易测试,优先做最小可测试性改造,而不是写依赖时间运气的测试。 + +## 并行注意 + +该任务大多新增测试文件,和结构重构冲突较小。若 Gateway 拆分同时进行,避免在本任务中修改 `cmd/`。 diff --git a/docs/optimization/README.md b/docs/optimization/README.md new file mode 100644 index 0000000..a06a4d0 --- /dev/null +++ b/docs/optimization/README.md @@ -0,0 +1,47 @@ +# ClawGo 并行优化任务拆分 + +本文档把当前值得优化的方向拆成可并行执行的工作包。每个工作包都有相对独立的文件范围、交付物和验收标准,便于多人或多个 agent 同时推进。 + +## 工作包总览 + +| 编号 | 任务 | 主要目标 | 建议并行性 | +| --- | --- | --- | --- | +| 01 | 文档与编码清理 | 统一中文文档/示例/输出文本编码,提升可读性 | 可独立执行 | +| 02 | Gateway 结构拆分 | 拆薄 `cmd/cmd_gateway.go`,降低维护成本 | 可独立执行 | +| 03 | Provider 层模块化 | 拆分 Provider 协议、OAuth、runtime 状态逻辑 | 可独立执行,但需谨慎 | +| 04 | 工具注册架构优化 | 抽出工具启动器,清晰区分默认工具、可选工具、子 Agent 可见性 | 可独立执行 | +| 05 | Gateway API 文档 | 补充 API 认证、接口、请求响应样例 | 可独立执行 | +| 06 | Cron / Heartbeat / Sentinel 测试 | 补长期运行服务测试覆盖 | 可独立执行 | + +## 推荐并行顺序 + +可以第一批同时启动: + +- `01-docs-encoding-cleanup.md` +- `02-gateway-structure-split.md` +- `05-gateway-api-docs.md` +- `06-runtime-service-tests.md` + +第二批建议在第一批基本稳定后启动: + +- `03-provider-layer-modularization.md` +- `04-tool-bootstrap-refactor.md` + +原因是 Provider 和工具注册都靠近 AgentLoop 核心,改动面更大,最好避开 Gateway 大拆分和文档编码清理的高频改动期。 + +## 并行开发约定 + +- 每个任务尽量只改自己文档中列出的主要文件。 +- 如果必须跨任务修改同一文件,先在 PR/提交说明中明确说明原因。 +- 不要顺手重构不在任务范围内的模块。 +- 每个任务完成后至少运行 `go test ./...`,纯文档任务除外。 +- 涉及行为变更时,优先补就近单元测试。 + +## 任务文件 + +- [01 文档与编码清理](./01-docs-encoding-cleanup.md) +- [02 Gateway 结构拆分](./02-gateway-structure-split.md) +- [03 Provider 层模块化](./03-provider-layer-modularization.md) +- [04 工具注册架构优化](./04-tool-bootstrap-refactor.md) +- [05 Gateway API 文档](./05-gateway-api-docs.md) +- [06 Cron / Heartbeat / Sentinel 测试](./06-runtime-service-tests.md) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f8113af..53301b1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -112,109 +112,29 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers sessionsManager := session.NewSessionManager(filepath.Join(filepath.Dir(cfg.WorkspacePath()), "agents", "main", "sessions")) - toolsRegistry := tools.NewToolRegistry() - processManager := tools.NewProcessManager(workspace) - readTool := tools.NewReadFileTool(workspace) - writeTool := tools.NewWriteFileTool(workspace) - listTool := tools.NewListDirTool(workspace) - toolsRegistry.Register(readTool) - toolsRegistry.Register(writeTool) - toolsRegistry.Register(listTool) - toolsRegistry.Register(tools.NewExecTool(cfg.Tools.Shell, workspace, processManager)) - toolsRegistry.Register(tools.NewProcessTool(processManager)) - toolsRegistry.Register(tools.NewSkillExecTool(workspace)) - - if cs != nil { - toolsRegistry.Register(tools.NewRemindTool(cs)) - toolsRegistry.Register(tools.NewCronTool(cs)) - } - - maxParallelCalls := cfg.Agents.Defaults.Execution.ToolMaxParallelCalls - if maxParallelCalls <= 0 { - maxParallelCalls = 4 - } - parallelSafe := make(map[string]struct{}) - for _, name := range cfg.Agents.Defaults.Execution.ToolParallelSafeNames { - trimmed := strings.TrimSpace(name) - if trimmed != "" { - parallelSafe[trimmed] = struct{}{} - } - } - - braveAPIKey := cfg.Tools.Web.Search.APIKey - toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) - webFetchTool := tools.NewWebFetchTool(50000) - toolsRegistry.Register(webFetchTool) - toolsRegistry.Register(tools.NewParallelFetchTool(webFetchTool, maxParallelCalls, parallelSafe)) - if cfg.Tools.MCP.Enabled { - mcpTool := tools.NewMCPTool(workspace, cfg.Tools.MCP) - toolsRegistry.Register(mcpTool) - discoveryCtx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.Tools.MCP.RequestTimeoutSec)*time.Second) - for _, remoteTool := range mcpTool.DiscoverTools(discoveryCtx) { - toolsRegistry.Register(remoteTool) - } - cancel() - } - - // Register message tool - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, action, content, media, messageID, emoji string, buttons [][]bus.Button) error { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, - Media: media, - Buttons: buttons, - Action: action, - MessageID: messageID, - Emoji: emoji, - }) - return nil - }) - toolsRegistry.Register(messageTool) - - // Register spawn tool - subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) - subagentRouter := tools.NewSubagentRouter(subagentManager) - spawnTool := tools.NewSpawnTool(subagentManager) - toolsRegistry.Register(spawnTool) - if store := subagentManager.ProfileStore(); store != nil { - toolsRegistry.Register(tools.NewSubagentProfileTool(store)) - } - toolsRegistry.Register(tools.NewSessionsTool( - func(limit int) []tools.SessionInfo { - sessions := alSessionListForTool(sessionsManager, limit) - return sessions + bootstrap, err := tools.BootstrapDefaultTools(context.Background(), tools.BootstrapOptions{ + Config: cfg, + Workspace: workspace, + MessageBus: msgBus, + CronService: cs, + Provider: provider, + SessionList: func(limit int) []tools.SessionInfo { + return alSessionListForTool(sessionsManager, limit) }, - func(key string, limit int) []providers.Message { + SessionHistory: func(key string, limit int) []providers.Message { h := sessionsManager.GetHistory(key) if limit > 0 && len(h) > limit { return h[len(h)-limit:] } return h }, - )) - - // Register edit file tool - editFileTool := tools.NewEditFileTool(workspace) - toolsRegistry.Register(editFileTool) - - // Register memory tools - memorySearchTool := tools.NewMemorySearchTool(workspace) - toolsRegistry.Register(memorySearchTool) - toolsRegistry.Register(tools.NewMemoryGetTool(workspace)) - toolsRegistry.Register(tools.NewMemoryWriteTool(workspace)) - - // Register parallel execution tool (leveraging Go's concurrency) - toolsRegistry.Register(tools.NewParallelTool(toolsRegistry, maxParallelCalls, parallelSafe)) - - // Register browser tool (integrated Chromium support) - toolsRegistry.Register(tools.NewBrowserTool()) - - // Register camera tool - toolsRegistry.Register(tools.NewCameraTool(workspace)) - // Register system info tool - toolsRegistry.Register(tools.NewSystemInfoTool()) + }) + if err != nil { + panic(err) + } + toolsRegistry := bootstrap.Registry + subagentManager := bootstrap.SubagentManager + subagentRouter := bootstrap.SubagentRouter loop := &AgentLoop{ bus: msgBus, diff --git a/pkg/agent/router_dispatch_test.go b/pkg/agent/router_dispatch_test.go index 853c6c4..9942939 100644 --- a/pkg/agent/router_dispatch_test.go +++ b/pkg/agent/router_dispatch_test.go @@ -27,7 +27,7 @@ func TestResolveDispatchDecisionRulesFirst(t *testing.T) { cfg.Agents.Router.Strategy = "rules_first" cfg.Agents.Subagents["coder"] = config.SubagentConfig{Enabled: true, Role: "coding", SystemPromptFile: "agents/coder/AGENT.md"} cfg.Agents.Subagents["tester"] = config.SubagentConfig{Enabled: true, Role: "testing", SystemPromptFile: "agents/tester/AGENT.md"} - cfg.Agents.Router.Rules = []config.AgentRouteRule{{AgentID: "coder", Keywords: []string{"鐧诲綍", "bug"}}} + cfg.Agents.Router.Rules = []config.AgentRouteRule{{AgentID: "coder", Keywords: []string{"登录", "bug"}}} decision := resolveDispatchDecision(cfg, "please fix the login bug and update the code") if decision.TargetAgent != "coder" || decision.TaskText == "" { diff --git a/pkg/api/server.go b/pkg/api/server.go index 46aaa1c..b540802 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -153,17 +153,24 @@ func (s *Server) broadcastEvent(payload map[string]interface{}) { subs = append(subs, conn) } s.eventSubsMu.Unlock() + var failed []*websocket.Conn for _, conn := range subs { if conn == nil { continue } if err := conn.WriteJSON(payload); err != nil { - s.eventSubsMu.Lock() - delete(s.eventSubs, conn) - s.eventSubsMu.Unlock() + failed = append(failed, conn) _ = conn.Close() } } + if len(failed) == 0 { + return + } + s.eventSubsMu.Lock() + for _, conn := range failed { + delete(s.eventSubs, conn) + } + s.eventSubsMu.Unlock() } func writeJSON(w http.ResponseWriter, payload interface{}) { diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 1e91432..9eda0ac 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -35,7 +35,15 @@ func (mb *MessageBus) PublishInbound(msg InboundMessage) { select { case mb.inbound <- msg: - case <-time.After(queueWriteTimeout): + return + default: + } + + timer := time.NewTimer(queueWriteTimeout) + defer stopAndDrainTimer(timer) + select { + case mb.inbound <- msg: + case <-timer.C: logger.ErrorCF("bus", logger.C0130, map[string]interface{}{ logger.FieldChannel: msg.Channel, logger.FieldChatID: msg.ChatID, @@ -62,7 +70,15 @@ func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { select { case mb.outbound <- msg: - case <-time.After(queueWriteTimeout): + return + default: + } + + timer := time.NewTimer(queueWriteTimeout) + defer stopAndDrainTimer(timer) + select { + case mb.outbound <- msg: + case <-timer.C: logger.ErrorCF("bus", logger.C0132, map[string]interface{}{ logger.FieldChannel: msg.Channel, logger.FieldChatID: msg.ChatID, @@ -70,6 +86,19 @@ func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { } } +func stopAndDrainTimer(timer *time.Timer) { + if timer == nil { + return + } + if timer.Stop() { + return + } + select { + case <-timer.C: + default: + } +} + func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { select { case msg, ok := <-mb.outbound: diff --git a/pkg/channels/feishu.go b/pkg/channels/feishu.go index e1dfdff..f7bca0c 100644 --- a/pkg/channels/feishu.go +++ b/pkg/channels/feishu.go @@ -135,7 +135,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error logger.WarnCF("feishu", logger.C0045, map[string]interface{}{logger.FieldError: lerr.Error(), logger.FieldChatID: msg.ChatID}) continue } - links = append(links, fmt.Sprintf("琛ㄦ牸%d: %s", i+1, link)) + links = append(links, fmt.Sprintf("表格%d: %s", i+1, link)) } if len(links) > 0 { if strings.TrimSpace(workMsg.Content) != "" { @@ -883,9 +883,9 @@ func normalizeFeishuText(s string) string { // Headers: "## title" -> "title" s = regexp.MustCompile(`(?m)^#{1,6}\s+`).ReplaceAllString(s, "") // Bullet styles - s = regexp.MustCompile(`(?m)^[-*]\s+`).ReplaceAllString(s, "鈥?") + s = regexp.MustCompile(`(?m)^[-*]\s+`).ReplaceAllString(s, "- ") // Ordered list to bullet for readability - s = regexp.MustCompile(`(?m)^\d+\.\s+`).ReplaceAllString(s, "鈥?") + s = regexp.MustCompile(`(?m)^\d+\.\s+`).ReplaceAllString(s, "- ") // Bold/italic/strike markers s = regexp.MustCompile(`\*\*(.*?)\*\*`).ReplaceAllString(s, `$1`) s = regexp.MustCompile(`__(.*?)__`).ReplaceAllString(s, `$1`) diff --git a/pkg/channels/weixin.go b/pkg/channels/weixin.go index c64aa90..2a363e1 100644 --- a/pkg/channels/weixin.go +++ b/pkg/channels/weixin.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "sort" + "strconv" "strings" "sync" "time" @@ -412,6 +413,10 @@ func (c *WeixinChannel) RemoveAccount(botID string) error { } c.mu.Lock() + if c.accounts[botID] == nil { + c.mu.Unlock() + return fmt.Errorf("bot_id not found: %s", botID) + } if cancel := c.pollers[botID]; cancel != nil { cancel() delete(c.pollers, botID) @@ -430,8 +435,14 @@ func (c *WeixinChannel) RemoveAccount(botID string) error { delete(c.chatContexts, chatID) } } + c.typingMu.Lock() + delete(c.typingCache, botID) + c.typingMu.Unlock() if strings.TrimSpace(c.config.DefaultBotID) == botID { c.config.DefaultBotID = "" + if len(c.accountOrder) > 0 { + c.config.DefaultBotID = c.accountOrder[0] + } } c.schedulePersistLocked() c.mu.Unlock() @@ -694,17 +705,23 @@ func (c *WeixinChannel) handleInboundMessage(botID string, msg weixinInboundMess } c.mu.Unlock() - var textParts []string - var itemTypes []string + var contentBuilder strings.Builder + var itemTypesBuilder strings.Builder for _, item := range msg.ItemList { - itemTypes = append(itemTypes, fmt.Sprintf("%d", item.Type)) + if itemTypesBuilder.Len() > 0 { + itemTypesBuilder.WriteByte(',') + } + itemTypesBuilder.WriteString(strconv.Itoa(item.Type)) if item.Type == 1 { if text := strings.TrimSpace(item.TextItem.Text); text != "" { - textParts = append(textParts, text) + if contentBuilder.Len() > 0 { + contentBuilder.WriteByte('\n') + } + contentBuilder.WriteString(text) } } } - content := strings.Join(textParts, "\n") + content := contentBuilder.String() if content == "" { return } @@ -717,7 +734,7 @@ func (c *WeixinChannel) handleInboundMessage(botID string, msg weixinInboundMess metadata := map[string]string{ "bot_id": botID, "context_token": contextToken, - "item_types": strings.Join(itemTypes, ","), + "item_types": itemTypesBuilder.String(), "raw_chat_id": rawChatID, } c.HandleMessage(rawChatID, chatID, content, nil, metadata) @@ -1118,24 +1135,25 @@ func (c *WeixinChannel) doJSON(ctx context.Context, path string, payload interfa } func (c *WeixinChannel) doJSONWithTimeout(ctx context.Context, path string, payload interface{}, out interface{}, token string, timeout time.Duration) error { + reqCtx := ctx + cancel := func() {} + if timeout > 0 { + reqCtx, cancel = context.WithTimeout(ctx, timeout) + } + defer cancel() + body, err := json.Marshal(payload) if err != nil { return fmt.Errorf("marshal request: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.config.BaseURL+path, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.BaseURL+path, bytes.NewReader(body)) if err != nil { return fmt.Errorf("build request: %w", err) } c.applyHeaders(req, true, token, true) - client := c.httpClient - if timeout > 0 { - clone := *c.httpClient - clone.Timeout = timeout - client = &clone - } - resp, err := client.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return err } diff --git a/pkg/channels/weixin_test.go b/pkg/channels/weixin_test.go index c57baff..63e555f 100644 --- a/pkg/channels/weixin_test.go +++ b/pkg/channels/weixin_test.go @@ -72,6 +72,50 @@ func TestWeixinHandleInboundMessageUsesCompositeSessionChatID(t *testing.T) { } } +func TestWeixinHandleInboundMessageBuildsMetadataAndContent(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + Accounts: []config.WeixinAccountConfig{ + {BotID: "bot-a", BotToken: "token-a"}, + }, + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + + ch.handleInboundMessage("bot-a", weixinInboundMessage{ + FromUserID: "wx-user-1", + ContextToken: "ctx-1", + ItemList: []weixinMessageItem{ + {Type: 2}, + {Type: 1, TextItem: struct { + Text string `json:"text"` + }{Text: "hello"}}, + {Type: 1, TextItem: struct { + Text string `json:"text"` + }{Text: " world "}}, + {Type: 3}, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + msg, ok := mb.ConsumeInbound(ctx) + if !ok { + t.Fatalf("expected inbound message") + } + if msg.Content != "hello\nworld" { + t.Fatalf("unexpected content: %q", msg.Content) + } + if got := msg.Metadata["item_types"]; got != "2,1,1,3" { + t.Fatalf("unexpected item_types: %q", got) + } + if got := msg.Metadata["context_token"]; got != "ctx-1" { + t.Fatalf("unexpected context_token: %q", got) + } +} + func TestWeixinResolveAccountForCompositeChatID(t *testing.T) { mb := bus.NewMessageBus() ch, err := NewWeixinChannel(config.WeixinConfig{ @@ -135,6 +179,56 @@ func TestWeixinSetDefaultAccount(t *testing.T) { } } +func TestWeixinRemoveAccountReturnsErrorWhenMissing(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + Accounts: []config.WeixinAccountConfig{ + {BotID: "bot-a", BotToken: "token-a"}, + }, + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + + err = ch.RemoveAccount("bot-missing") + if err == nil { + t.Fatalf("expected remove missing account to fail") + } + if !strings.Contains(err.Error(), "bot_id not found") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestWeixinRemoveAccountReassignsDefault(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + DefaultBotID: "bot-b", + Accounts: []config.WeixinAccountConfig{ + {BotID: "bot-a", BotToken: "token-a"}, + {BotID: "bot-b", BotToken: "token-b"}, + }, + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + + if err := ch.RemoveAccount("bot-b"); err != nil { + t.Fatalf("remove account: %v", err) + } + accounts := ch.ListAccounts() + if len(accounts) != 1 { + t.Fatalf("expected 1 account after removal, got %d", len(accounts)) + } + if accounts[0].BotID != "bot-a" { + t.Fatalf("expected remaining account bot-a, got %s", accounts[0].BotID) + } + if !accounts[0].Default { + t.Fatalf("expected remaining account to become default") + } +} + func TestWeixinCancelPendingLogin(t *testing.T) { mb := bus.NewMessageBus() ch, err := NewWeixinChannel(config.WeixinConfig{BaseURL: "https://ilinkai.weixin.qq.com"}, mb) @@ -422,3 +516,34 @@ func TestWeixinValidateAPIStatusErrorShape(t *testing.T) { t.Fatalf("marshal error text") } } + +func TestWeixinDoJSONWithTimeoutSetsRequestDeadline(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + Accounts: []config.WeixinAccountConfig{ + {BotID: "bot-a", BotToken: "token-a"}, + }, + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + + ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) { + if _, ok := req.Context().Deadline(); !ok { + t.Fatalf("expected request context to have deadline") + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"ret":0,"errcode":0}`)), + Header: make(http.Header), + }, nil + })} + + var out weixinAPIResponse + if err := ch.doJSONWithTimeout(context.Background(), "/ilink/bot/sendmessage", map[string]interface{}{ + "msg": map[string]interface{}{"item_list": []map[string]interface{}{}}, + }, &out, "token-a", 50*time.Millisecond); err != nil { + t.Fatalf("doJSONWithTimeout: %v", err) + } +} diff --git a/pkg/cron/service_test.go b/pkg/cron/service_test.go new file mode 100644 index 0000000..3f1e926 --- /dev/null +++ b/pkg/cron/service_test.go @@ -0,0 +1,239 @@ +package cron + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" +) + +func newTestCronService(t *testing.T, onJob JobHandler) *CronService { + t.Helper() + return NewCronService(filepath.Join(t.TempDir(), "cron.json"), onJob) +} + +func TestAddListEnableAndUpdateJobPersistsStore(t *testing.T) { + cs := newTestCronService(t, nil) + at := time.Now().Add(time.Hour).UnixMilli() + + job, err := cs.AddJob("daily note", CronSchedule{Kind: "at", AtMS: &at}, "hello", true, "telegram", "chat-1") + if err != nil { + t.Fatalf("AddJob returned error: %v", err) + } + + data, err := os.ReadFile(cs.storePath) + if err != nil { + t.Fatalf("store was not written: %v", err) + } + var stored CronStore + if err := json.Unmarshal(data, &stored); err != nil { + t.Fatalf("store JSON did not decode: %v", err) + } + if len(stored.Jobs) != 1 || stored.Jobs[0].ID != job.ID { + t.Fatalf("stored jobs = %+v, want one persisted job %q", stored.Jobs, job.ID) + } + + if got := cs.ListJobs(false); len(got) != 1 { + t.Fatalf("ListJobs(false) returned %d jobs, want 1", len(got)) + } + if disabled := cs.EnableJob(job.ID, false); disabled == nil || disabled.Enabled { + t.Fatalf("EnableJob(false) = %+v, want disabled job", disabled) + } + if got := cs.ListJobs(false); len(got) != 0 { + t.Fatalf("ListJobs(false) returned disabled jobs: %+v", got) + } + if got := cs.ListJobs(true); len(got) != 1 { + t.Fatalf("ListJobs(true) returned %d jobs, want 1", len(got)) + } + + nextAt := time.Now().Add(2 * time.Hour).UnixMilli() + name := "updated" + msg := "new message" + deliver := false + channel := "feishu" + to := "chat-2" + deleteAfterRun := true + updated, err := cs.UpdateJob(job.ID, UpdateJobInput{ + Name: &name, + Enabled: boolPtr(true), + Schedule: &CronSchedule{Kind: "at", AtMS: &nextAt}, + Message: &msg, + Deliver: &deliver, + Channel: &channel, + To: &to, + DeleteAfterRun: &deleteAfterRun, + }) + if err != nil { + t.Fatalf("UpdateJob returned error: %v", err) + } + if updated.Name != name || updated.Payload.Message != msg || updated.Payload.Deliver != deliver || + updated.Payload.Channel != channel || updated.Payload.To != to || !updated.DeleteAfterRun || !updated.Enabled { + t.Fatalf("updated job = %+v, want all fields updated", updated) + } + if updated.State.NextRunAtMS == nil || *updated.State.NextRunAtMS != nextAt { + t.Fatalf("updated next run = %v, want %d", updated.State.NextRunAtMS, nextAt) + } +} + +func TestComputeNextRunForCronExpression(t *testing.T) { + cs := newTestCronService(t, nil) + base := time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC).UnixMilli() + now := time.Date(2026, 1, 1, 12, 2, 0, 0, time.UTC).UnixMilli() + + next := cs.computeNextRunAfter(&CronSchedule{Kind: "cron", Expr: "*/5 * * * *", TZ: "UTC"}, base, now) + if next == nil { + t.Fatal("computeNextRunAfter returned nil") + } + want := time.Date(2026, 1, 1, 12, 5, 0, 0, time.UTC).UnixMilli() + if *next != want { + t.Fatalf("next run = %s, want %s", time.UnixMilli(*next), time.UnixMilli(want)) + } +} + +func TestFailureBackoffTracksConsecutiveFailures(t *testing.T) { + cs := newTestCronService(t, func(job *CronJob) (string, error) { + return "", errors.New("boom") + }) + cs.SetRuntimeOptions(RuntimeOptions{ + RetryBackoffBase: 10 * time.Millisecond, + RetryBackoffMax: 40 * time.Millisecond, + MaxConsecutiveFailureRetries: 2, + MaxWorkers: 1, + }) + + now := time.Now().UnixMilli() + next := now - 100 + cs.store.Jobs = []CronJob{{ + ID: "job-1", + Enabled: true, + Schedule: CronSchedule{ + Kind: "cron", + Expr: "@every 1h", + }, + State: CronJobState{NextRunAtMS: &next}, + }} + + if !cs.executeJobByID("job-1") { + t.Fatal("first executeJobByID returned false") + } + job := cs.GetJob("job-1") + if job.State.LastStatus != "error" || job.State.TotalRuns != 1 || job.State.TotalFailures != 1 || job.State.ConsecutiveFailures != 1 { + t.Fatalf("first failure state = %+v", job.State) + } + if job.State.NextRunAtMS == nil || *job.State.NextRunAtMS <= time.Now().UnixMilli() || *job.State.NextRunAtMS > time.Now().Add(time.Second).UnixMilli() { + t.Fatalf("first retry next run = %v, want near-future backoff", job.State.NextRunAtMS) + } + + cs.store.Jobs[0].State.NextRunAtMS = int64Ptr(time.Now().UnixMilli() - 100) + if !cs.executeJobByID("job-1") { + t.Fatal("second executeJobByID returned false") + } + job = cs.GetJob("job-1") + if job.State.ConsecutiveFailures != 2 || job.State.TotalFailures != 2 { + t.Fatalf("second failure state = %+v", job.State) + } +} + +func TestMaxWorkersLimitsConcurrentExecutions(t *testing.T) { + var active int64 + var maxActive int64 + started := make(chan struct{}, 3) + release := make(chan struct{}) + cs := newTestCronService(t, func(job *CronJob) (string, error) { + cur := atomic.AddInt64(&active, 1) + for { + old := atomic.LoadInt64(&maxActive) + if cur <= old || atomic.CompareAndSwapInt64(&maxActive, old, cur) { + break + } + } + started <- struct{}{} + <-release + atomic.AddInt64(&active, -1) + return "", nil + }) + cs.SetRuntimeOptions(RuntimeOptions{MaxWorkers: 2, RunLoopMinSleep: time.Millisecond, RunLoopMaxSleep: time.Millisecond}) + + due := time.Now().UnixMilli() - 100 + for i := 0; i < 3; i++ { + cs.store.Jobs = append(cs.store.Jobs, CronJob{ + ID: string(rune('a' + i)), + Enabled: true, + Schedule: CronSchedule{ + Kind: "cron", + Expr: "@every 1h", + }, + State: CronJobState{NextRunAtMS: &due}, + }) + } + + var wg sync.WaitGroup + wg.Add(1) + cs.runner.Start(func(stop <-chan struct{}) { + defer wg.Done() + cs.checkJobs() + }) + + <-started + <-started + select { + case <-started: + t.Fatal("third job started before a worker slot was released") + case <-time.After(25 * time.Millisecond): + } + close(release) + wg.Wait() + cs.Stop() + + if got := atomic.LoadInt64(&maxActive); got > 2 { + t.Fatalf("max concurrent executions = %d, want <= 2", got) + } +} + +func TestDeleteAfterRunRemovesOneTimeJob(t *testing.T) { + cs := newTestCronService(t, nil) + due := time.Now().UnixMilli() - 100 + cs.store.Jobs = []CronJob{{ + ID: "once", + Enabled: true, + Schedule: CronSchedule{Kind: "at", AtMS: &due}, + State: CronJobState{NextRunAtMS: &due}, + DeleteAfterRun: true, + }} + + if !cs.executeJobByID("once") { + t.Fatal("executeJobByID returned false") + } + if got := cs.GetJob("once"); got != nil { + t.Fatalf("GetJob returned %+v, want deleted job", got) + } +} + +func TestLoadStoreHandlesMissingAndCorruptFiles(t *testing.T) { + cs := newTestCronService(t, nil) + if err := cs.Load(); err != nil { + t.Fatalf("Load missing store returned error: %v", err) + } + if len(cs.ListJobs(true)) != 0 { + t.Fatalf("missing store loaded jobs: %+v", cs.ListJobs(true)) + } + + if err := os.WriteFile(cs.storePath, []byte("{not-json"), 0644); err != nil { + t.Fatalf("write corrupt store: %v", err) + } + if err := cs.Load(); err == nil { + t.Fatal("Load corrupt store returned nil error") + } +} + +func boolPtr(v bool) *bool { + return &v +} + +func int64Ptr(v int64) *int64 { + return &v +} diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go new file mode 100644 index 0000000..73ccc15 --- /dev/null +++ b/pkg/heartbeat/service_test.go @@ -0,0 +1,125 @@ +package heartbeat + +import ( + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestStartDisabledDoesNotTriggerHeartbeat(t *testing.T) { + var calls int64 + hs := NewHeartbeatService(t.TempDir(), func(prompt string) (string, error) { + atomic.AddInt64(&calls, 1) + return "", nil + }, 1, false, "") + hs.interval = time.Millisecond + + if err := hs.Start(); err != nil { + t.Fatalf("Start returned error: %v", err) + } + time.Sleep(15 * time.Millisecond) + hs.Stop() + + if got := atomic.LoadInt64(&calls); got != 0 { + t.Fatalf("heartbeat calls = %d, want 0", got) + } +} + +func TestEnabledStartCallsCallbackAndStopPreventsMoreCalls(t *testing.T) { + var calls int64 + hs := NewHeartbeatService(t.TempDir(), func(prompt string) (string, error) { + atomic.AddInt64(&calls, 1) + return "", nil + }, 1, true, "") + hs.interval = 5 * time.Millisecond + + if err := hs.Start(); err != nil { + t.Fatalf("Start returned error: %v", err) + } + waitForHeartbeatCalls(t, &calls, 1) + hs.Stop() + afterStop := atomic.LoadInt64(&calls) + time.Sleep(20 * time.Millisecond) + + if got := atomic.LoadInt64(&calls); got != afterStop { + t.Fatalf("heartbeat calls after Stop = %d, want %d", got, afterStop) + } +} + +func TestBuildPromptUsesTemplateAndSkipsEmptyMarkdown(t *testing.T) { + workspace := t.TempDir() + writeFile(t, filepath.Join(workspace, "AGENTS.md"), "# Policy\nheartbeat_ack_token: ACK_OK\n") + writeFile(t, filepath.Join(workspace, "HEARTBEAT.md"), "# Heartbeat\n\n## Notes\n") + + hs := NewHeartbeatService(workspace, nil, 1, true, "Custom template") + prompt := hs.buildPrompt() + + for _, want := range []string{"Custom template", "Current time:", "## AGENTS.md", "heartbeat_ack_token: ACK_OK", "## HEARTBEAT.md"} { + if !strings.Contains(prompt, want) { + t.Fatalf("prompt missing %q:\n%s", want, prompt) + } + } + if strings.Contains(prompt, "## Notes") { + t.Fatalf("prompt included effectively empty HEARTBEAT.md content:\n%s", prompt) + } +} + +func TestBuildPromptDefaultMentionsAckToken(t *testing.T) { + workspace := t.TempDir() + writeFile(t, filepath.Join(workspace, "AGENTS.md"), "- heartbeat_ack_token: `ACK_DONE`\n") + writeFile(t, filepath.Join(workspace, "HEARTBEAT.md"), "Take a tiny action.\n") + + hs := NewHeartbeatService(workspace, nil, 1, true, "") + prompt := hs.buildPrompt() + + if !strings.Contains(prompt, "return ACK_DONE") { + t.Fatalf("default prompt did not include ack token:\n%s", prompt) + } + if !strings.Contains(prompt, "Take a tiny action.") { + t.Fatalf("prompt did not include heartbeat notes:\n%s", prompt) + } +} + +func TestIsEffectivelyEmptyMarkdown(t *testing.T) { + if !isEffectivelyEmptyMarkdown("# Title\n\n## Empty\n") { + t.Fatal("heading-only markdown should be treated as empty") + } + if isEffectivelyEmptyMarkdown("# Title\n\nDo the thing.\n") { + t.Fatal("markdown with body text should not be treated as empty") + } +} + +func TestHeartbeatAckTokenFromText(t *testing.T) { + got := heartbeatAckTokenFromText("# Runtime\n- heartbeat_ack_token: \"ACK-123\"\n") + if got != "ACK-123" { + t.Fatalf("ack token = %q, want ACK-123", got) + } + if got := heartbeatAckTokenFromText("heartbeat: none"); got != "" { + t.Fatalf("ack token = %q, want empty", got) + } +} + +func waitForHeartbeatCalls(t *testing.T, calls *int64, want int64) { + t.Helper() + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if atomic.LoadInt64(calls) >= want { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("heartbeat calls = %d, want at least %d", atomic.LoadInt64(calls), want) +} + +func writeFile(t *testing.T, path string, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + t.Fatalf("MkdirAll(%s): %v", filepath.Dir(path), err) + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("WriteFile(%s): %v", path, err) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index e8620b5..9e0f233 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -6,17 +6,11 @@ import ( "crypto/rand" "encoding/json" "fmt" - "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/logger" "io" "net/http" - "net/url" - "os" - "path/filepath" - "regexp" "runtime" "strings" - "sync" "time" ) @@ -31,135 +25,6 @@ const ( kimiCompatUserAgent = "KimiCLI/1.10.6" ) -type providerAPIRuntimeState struct { - TokenMasked string `json:"token_masked,omitempty"` - CooldownUntil string `json:"cooldown_until,omitempty"` - FailureCount int `json:"failure_count,omitempty"` - LastFailure string `json:"last_failure,omitempty"` - HealthScore int `json:"health_score,omitempty"` -} - -type providerRuntimeEvent struct { - When string `json:"when,omitempty"` - Kind string `json:"kind,omitempty"` - Target string `json:"target,omitempty"` - Reason string `json:"reason,omitempty"` - Detail string `json:"detail,omitempty"` -} - -func recordProviderRuntimeChange(providerName, kind, target, reason, detail string) { - name := strings.TrimSpace(providerName) - if name == "" || strings.TrimSpace(reason) == "" { - return - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: strings.TrimSpace(kind), - Target: strings.TrimSpace(target), - Reason: strings.TrimSpace(reason), - Detail: strings.TrimSpace(detail), - }, runtimeEventLimit(state)) - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state -} - -type providerRuntimeCandidate struct { - Kind string `json:"kind,omitempty"` - Target string `json:"target,omitempty"` - Available bool `json:"available"` - Status string `json:"status,omitempty"` - CooldownUntil string `json:"cooldown_until,omitempty"` - HealthScore int `json:"health_score,omitempty"` - FailureCount int `json:"failure_count,omitempty"` -} - -type providerRuntimePersistConfig struct { - Enabled bool - File string - MaxEvents int - Loaded bool - LoadAttempt bool -} - -type ProviderRuntimeQuery struct { - Provider string - Window time.Duration - EventKind string - Reason string - Target string - Limit int - Cursor int - HealthBelow int - CooldownBefore time.Time - Sort string - ChangesOnly bool -} - -type ProviderRefreshAccountResult struct { - Target string `json:"target,omitempty"` - Status string `json:"status,omitempty"` - Detail string `json:"detail,omitempty"` - Expire string `json:"expire,omitempty"` -} - -type ProviderRefreshResult struct { - Provider string `json:"provider,omitempty"` - Checked int `json:"checked,omitempty"` - Refreshed int `json:"refreshed,omitempty"` - Skipped int `json:"skipped,omitempty"` - Failed int `json:"failed,omitempty"` - Accounts []ProviderRefreshAccountResult `json:"accounts,omitempty"` -} - -type ProviderRuntimeSummaryItem struct { - Name string `json:"name,omitempty"` - Auth string `json:"auth,omitempty"` - Status string `json:"status,omitempty"` - APIState providerAPIRuntimeState `json:"api_state,omitempty"` - OAuthAccounts []OAuthAccountInfo `json:"oauth_accounts,omitempty"` - CandidateOrder []providerRuntimeCandidate `json:"candidate_order,omitempty"` - LastSuccess *providerRuntimeEvent `json:"last_success,omitempty"` - LastSuccessAt string `json:"last_success_at,omitempty"` - LastError *providerRuntimeEvent `json:"last_error,omitempty"` - LastErrorAt string `json:"last_error_at,omitempty"` - LastErrorReason string `json:"last_error_reason,omitempty"` - TopCandidateChangedAt string `json:"top_candidate_changed_at,omitempty"` - StaleForSec int64 `json:"stale_for_sec,omitempty"` - InCooldown bool `json:"in_cooldown"` - LowHealth bool `json:"low_health"` - HasRecentErrors bool `json:"has_recent_errors"` - TopCandidate *providerRuntimeCandidate `json:"top_candidate,omitempty"` -} - -type ProviderRuntimeSummary struct { - TotalProviders int `json:"total_providers"` - Healthy int `json:"healthy"` - Degraded int `json:"degraded"` - Critical int `json:"critical"` - InCooldown int `json:"in_cooldown"` - LowHealth int `json:"low_health"` - RecentErrors int `json:"recent_errors"` - Providers []ProviderRuntimeSummaryItem `json:"providers,omitempty"` -} - -type providerRuntimeState struct { - API providerAPIRuntimeState `json:"api_state,omitempty"` - RecentHits []providerRuntimeEvent `json:"recent_hits,omitempty"` - RecentErrors []providerRuntimeEvent `json:"recent_errors,omitempty"` - RecentChanges []providerRuntimeEvent `json:"recent_changes,omitempty"` - LastSuccess *providerRuntimeEvent `json:"last_success,omitempty"` - CandidateOrder []providerRuntimeCandidate `json:"candidate_order,omitempty"` - Persist providerRuntimePersistConfig `json:"-"` -} - -var providerRuntimeRegistry = struct { - mu sync.Mutex - api map[string]providerRuntimeState -}{api: map[string]providerRuntimeState{}} - type HTTPProvider struct { providerName string apiKey string @@ -190,28 +55,6 @@ func NewHTTPProvider(providerName, apiKey, apiBase, defaultModel string, support } } -func ConfigureProviderRuntime(providerName string, pc config.ProviderConfig) { - name := strings.TrimSpace(providerName) - if name == "" { - return - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - state.Persist = providerRuntimePersistConfig{ - Enabled: pc.RuntimePersist, - File: runtimeHistoryFile(name, pc), - MaxEvents: runtimeHistoryMax(pc), - Loaded: state.Persist.Loaded, - LoadAttempt: state.Persist.LoadAttempt, - } - if state.Persist.Enabled && !state.Persist.LoadAttempt { - state.Persist.LoadAttempt = true - loadPersistedProviderRuntimeLocked(name, &state) - } - providerRuntimeRegistry.api[name] = state -} - func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { if p.apiBase == "" { return nil, fmt.Errorf("API base not configured") @@ -265,445 +108,6 @@ func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools return parseResponsesAPIResponse(body) } -func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) { - input := make([]map[string]interface{}, 0, len(messages)) - pendingCalls := map[string]struct{}{} - for _, msg := range messages { - input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) - } - requestBody := map[string]interface{}{ - "model": model, - "input": input, - } - responseTools := buildResponsesTools(tools, options) - if len(responseTools) > 0 { - requestBody["tools"] = responseTools - requestBody["tool_choice"] = "auto" - if tc, ok := rawOption(options, "tool_choice"); ok { - requestBody["tool_choice"] = tc - } - if tc, ok := rawOption(options, "responses_tool_choice"); ok { - requestBody["tool_choice"] = tc - } - } - if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { - requestBody["max_output_tokens"] = maxTokens - } - if temperature, ok := float64FromOption(options, "temperature"); ok { - requestBody["temperature"] = temperature - } - if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 { - requestBody["include"] = include - } - if metadata, ok := mapOption(options, "responses_metadata"); ok && len(metadata) > 0 { - requestBody["metadata"] = metadata - } - if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" { - requestBody["previous_response_id"] = prevID - } - if p.useOpenAICompatChatUpstream() { - chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options) - return p.postJSON(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody) - } - if p.useCodexCompat() { - requestBody = p.codexCompatRequestBody(requestBody) - return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, nil) - } - return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody) -} - -func toResponsesInputItemsWithState(msg Message, pendingCalls map[string]struct{}) []map[string]interface{} { - role := strings.ToLower(strings.TrimSpace(msg.Role)) - switch role { - case "system", "developer", "user": - if content := responsesMessageContent(msg); len(content) > 0 { - return []map[string]interface{}{{ - "type": "message", - "role": role, - "content": content, - }} - } - return []map[string]interface{}{responsesMessageItem(role, msg.Content, "input_text")} - case "assistant": - items := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls)) - if msg.Content != "" || len(msg.ToolCalls) == 0 { - items = append(items, responsesMessageItem(role, msg.Content, "output_text")) - } - for _, tc := range msg.ToolCalls { - callID := tc.ID - if callID == "" { - continue - } - name := tc.Name - argsRaw := "" - if tc.Function != nil { - if tc.Function.Name != "" { - name = tc.Function.Name - } - argsRaw = tc.Function.Arguments - } - if name == "" { - continue - } - if argsRaw == "" { - argsJSON, err := json.Marshal(tc.Arguments) - if err != nil { - argsRaw = "{}" - } else { - argsRaw = string(argsJSON) - } - } - if pendingCalls != nil { - pendingCalls[callID] = struct{}{} - } - items = append(items, map[string]interface{}{ - "type": "function_call", - "call_id": callID, - "name": name, - "arguments": argsRaw, - }) - } - if len(items) == 0 { - return []map[string]interface{}{responsesMessageItem(role, msg.Content, "output_text")} - } - return items - case "tool": - callID := msg.ToolCallID - if callID == "" { - return nil - } - if pendingCalls != nil { - if _, ok := pendingCalls[callID]; !ok { - // Strict pairing: drop orphan/duplicate tool outputs instead of degrading role. - return nil - } - delete(pendingCalls, callID) - } - return []map[string]interface{}{map[string]interface{}{ - "type": "function_call_output", - "call_id": callID, - "output": msg.Content, - }} - default: - return []map[string]interface{}{responsesMessageItem("user", msg.Content, "input_text")} - } -} - -func responsesMessageContent(msg Message) []map[string]interface{} { - content := make([]map[string]interface{}, 0, len(msg.ContentParts)) - for _, part := range msg.ContentParts { - switch strings.ToLower(strings.TrimSpace(part.Type)) { - case "input_text", "text": - if part.Text == "" { - continue - } - content = append(content, map[string]interface{}{ - "type": "input_text", - "text": part.Text, - }) - case "input_image", "image": - entry := map[string]interface{}{ - "type": "input_image", - } - if part.ImageURL != "" { - entry["image_url"] = part.ImageURL - } - if part.FileID != "" { - entry["file_id"] = part.FileID - } - if detail := strings.TrimSpace(part.Detail); detail != "" { - entry["detail"] = detail - } - if _, ok := entry["image_url"]; !ok { - if _, ok := entry["file_id"]; !ok { - continue - } - } - content = append(content, entry) - case "input_file", "file": - entry := map[string]interface{}{ - "type": "input_file", - } - if part.FileData != "" { - entry["file_data"] = part.FileData - } - if part.FileID != "" { - entry["file_id"] = part.FileID - } - if part.FileURL != "" { - entry["file_url"] = part.FileURL - } - if part.Filename != "" { - entry["filename"] = part.Filename - } - if _, ok := entry["file_data"]; !ok { - if _, ok := entry["file_id"]; !ok { - if _, ok := entry["file_url"]; !ok { - continue - } - } - } - content = append(content, entry) - } - } - return content -} - -func buildResponsesTools(tools []ToolDefinition, options map[string]interface{}) []map[string]interface{} { - responseTools := make([]map[string]interface{}, 0, len(tools)+2) - for _, t := range tools { - typ := strings.ToLower(strings.TrimSpace(t.Type)) - if typ == "" { - typ = "function" - } - if typ == "function" { - name := strings.TrimSpace(t.Function.Name) - if name == "" { - name = strings.TrimSpace(t.Name) - } - if name == "" { - continue - } - entry := map[string]interface{}{ - "type": "function", - "name": name, - "parameters": map[string]interface{}{}, - } - if t.Function.Parameters != nil { - entry["parameters"] = t.Function.Parameters - } else if t.Parameters != nil { - entry["parameters"] = t.Parameters - } - desc := strings.TrimSpace(t.Function.Description) - if desc == "" { - desc = strings.TrimSpace(t.Description) - } - if desc != "" { - entry["description"] = desc - } - if t.Function.Strict != nil { - entry["strict"] = *t.Function.Strict - } else if t.Strict != nil { - entry["strict"] = *t.Strict - } - responseTools = append(responseTools, entry) - continue - } - - // Built-in tool types (web_search, file_search, code_interpreter, etc.). - entry := map[string]interface{}{ - "type": typ, - } - if name := strings.TrimSpace(t.Name); name != "" { - entry["name"] = name - } - if desc := strings.TrimSpace(t.Description); desc != "" { - entry["description"] = desc - } - if t.Strict != nil { - entry["strict"] = *t.Strict - } - for k, v := range t.Parameters { - entry[k] = v - } - responseTools = append(responseTools, entry) - } - - if extraTools, ok := mapSliceOption(options, "responses_tools"); ok { - responseTools = append(responseTools, extraTools...) - } - return responseTools -} - -func rawOption(options map[string]interface{}, key string) (interface{}, bool) { - if options == nil { - return nil, false - } - v, ok := options[key] - if !ok || v == nil { - return nil, false - } - return v, true -} - -func stringOption(options map[string]interface{}, key string) (string, bool) { - v, ok := rawOption(options, key) - if !ok { - return "", false - } - s, ok := v.(string) - if !ok { - return "", false - } - return strings.TrimSpace(s), true -} - -func mapOption(options map[string]interface{}, key string) (map[string]interface{}, bool) { - v, ok := rawOption(options, key) - if !ok { - return nil, false - } - m, ok := v.(map[string]interface{}) - return m, ok -} - -func stringSliceOption(options map[string]interface{}, key string) ([]string, bool) { - v, ok := rawOption(options, key) - if !ok { - return nil, false - } - switch t := v.(type) { - case []string: - out := make([]string, 0, len(t)) - for _, item := range t { - if s := strings.TrimSpace(item); s != "" { - out = append(out, s) - } - } - return out, true - case []interface{}: - out := make([]string, 0, len(t)) - for _, item := range t { - s := strings.TrimSpace(fmt.Sprintf("%v", item)) - if s != "" { - out = append(out, s) - } - } - return out, true - } - return nil, false -} - -func mapSliceOption(options map[string]interface{}, key string) ([]map[string]interface{}, bool) { - v, ok := rawOption(options, key) - if !ok { - return nil, false - } - switch t := v.(type) { - case []map[string]interface{}: - return t, true - case []interface{}: - out := make([]map[string]interface{}, 0, len(t)) - for _, item := range t { - m, ok := item.(map[string]interface{}) - if ok { - out = append(out, m) - } - } - return out, true - } - return nil, false -} - -func responsesMessageItem(role, text, contentType string) map[string]interface{} { - ct := contentType - if ct == "" { - ct = "input_text" - } - return map[string]interface{}{ - "type": "message", - "role": role, - "content": []map[string]interface{}{ - { - "type": ct, - "text": text, - }, - }, - } -} - -func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { - input := make([]map[string]interface{}, 0, len(messages)) - pendingCalls := map[string]struct{}{} - for _, msg := range messages { - input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) - } - requestBody := map[string]interface{}{ - "model": model, - "input": input, - "stream": true, - } - responseTools := buildResponsesTools(tools, options) - if len(responseTools) > 0 { - requestBody["tools"] = responseTools - requestBody["tool_choice"] = "auto" - if tc, ok := rawOption(options, "tool_choice"); ok { - requestBody["tool_choice"] = tc - } - if tc, ok := rawOption(options, "responses_tool_choice"); ok { - requestBody["tool_choice"] = tc - } - } - if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { - requestBody["max_output_tokens"] = maxTokens - } - if temperature, ok := float64FromOption(options, "temperature"); ok { - requestBody["temperature"] = temperature - } - if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 { - requestBody["include"] = include - } - if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 { - requestBody["stream_options"] = streamOpts - } - if p.useOpenAICompatChatUpstream() { - chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options) - chatBody["stream"] = true - streamOptions := map[string]interface{}{"include_usage": true} - chatBody["stream_options"] = streamOptions - return p.postJSONStream(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody, func(event string) { - var obj map[string]interface{} - if err := json.Unmarshal([]byte(event), &obj); err != nil { - return - } - choices, _ := obj["choices"].([]interface{}) - for _, choice := range choices { - item, _ := choice.(map[string]interface{}) - delta, _ := item["delta"].(map[string]interface{}) - if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" { - onDelta(txt) - } - } - }) - } - if p.useCodexCompat() { - requestBody = p.codexCompatRequestBody(requestBody) - return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, func(event string) { - var obj map[string]interface{} - if err := json.Unmarshal([]byte(event), &obj); err != nil { - return - } - if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" { - onDelta(d) - return - } - if delta, ok := obj["delta"].(map[string]interface{}); ok { - if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" { - onDelta(txt) - } - } - }) - } - return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) { - var obj map[string]interface{} - if err := json.Unmarshal([]byte(event), &obj); err != nil { - return - } - typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])) - if typ == "response.output_text.delta" { - if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" { - onDelta(d) - } - return - } - if delta, ok := obj["delta"].(map[string]interface{}); ok { - if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" { - onDelta(txt) - } - } - }) -} - func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payload interface{}, onEvent func(string)) ([]byte, int, string, error) { result, err := p.executeStreamAttempts(ctx, endpoint, payload, nil, onEvent) if err != nil { @@ -720,94 +124,6 @@ func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload in return result.Body, result.StatusCode, result.ContentType, nil } -type authAttempt struct { - session *oauthSession - token string - kind string -} - -func (p *HTTPProvider) authAttempts(ctx context.Context) ([]authAttempt, error) { - mode := strings.ToLower(strings.TrimSpace(p.authMode)) - if mode == "oauth" || mode == "hybrid" { - out := make([]authAttempt, 0, 1) - apiAttempt, apiReady := p.apiKeyAttempt() - if p.oauth == nil { - if mode == "hybrid" && apiReady { - return []authAttempt{apiAttempt}, nil - } - return nil, fmt.Errorf("oauth is enabled but provider session manager is not configured") - } - attempts, err := p.oauth.prepareAttemptsLocked(ctx) - if err != nil { - return nil, err - } - oauthAttempts := make([]authAttempt, 0, len(attempts)) - for _, attempt := range attempts { - oauthAttempts = append(oauthAttempts, authAttempt{session: attempt.Session, token: attempt.Token, kind: "oauth"}) - } - if mode == "hybrid" && apiReady { - out = append(out, apiAttempt) - } - if len(attempts) == 0 { - if len(out) > 0 { - p.updateCandidateOrder(out) - return out, nil - } - return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first") - } - out = append(out, oauthAttempts...) - p.updateCandidateOrder(out) - return out, nil - } - apiAttempt, apiReady := p.apiKeyAttempt() - if !apiReady { - return nil, fmt.Errorf("api key temporarily unavailable") - } - out := []authAttempt{apiAttempt} - p.updateCandidateOrder(out) - return out, nil -} - -func (p *HTTPProvider) updateCandidateOrder(attempts []authAttempt) { - name := strings.TrimSpace(p.providerName) - if name == "" { - return - } - candidates := make([]providerRuntimeCandidate, 0, len(attempts)) - for _, attempt := range attempts { - candidate := providerRuntimeCandidate{ - Kind: attempt.kind, - Available: true, - Status: "ready", - } - if attempt.kind == "api_key" { - candidate.Target = maskToken(p.apiKey) - candidate.HealthScore = providerAPIHealth(name) - } else if attempt.session != nil { - candidate.Target = firstNonEmpty(attempt.session.Email, attempt.session.AccountID, attempt.session.FilePath) - candidate.HealthScore = sessionHealthScore(attempt.session) - candidate.FailureCount = attempt.session.FailureCount - candidate.CooldownUntil = attempt.session.CooldownUntil - } - candidates = append(candidates, candidate) - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - if !providerCandidatesEqual(state.CandidateOrder, candidates) { - state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: "scheduler", - Target: name, - Reason: "candidate_order_changed", - Detail: candidateOrderChangeDetail(state.CandidateOrder, candidates), - }, runtimeEventLimit(state)) - } - state.CandidateOrder = candidates - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state -} - func applyAttemptAuth(req *http.Request, attempt authAttempt) { if req == nil { return @@ -1074,1431 +390,6 @@ func classifyOAuthFailure(status int, body []byte) (oauthFailureReason, bool) { return "", false } -func (p *HTTPProvider) markAttemptSuccess(attempt authAttempt) { - if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { - p.oauth.markSuccess(attempt.session) - } - if attempt.kind == "api_key" { - p.markAPIKeySuccess() - } - p.recordProviderHit(attempt, "") -} - -func (p *HTTPProvider) markAPIKeyFailure(reason oauthFailureReason) { - name := strings.TrimSpace(p.providerName) - if name == "" { - return - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - if state.API.HealthScore <= 0 { - state.API.HealthScore = 100 - } - state.API.FailureCount++ - state.API.LastFailure = string(reason) - state.API.HealthScore = maxInt(1, state.API.HealthScore-healthPenaltyForReason(reason)) - cooldown := 15 * time.Minute - switch reason { - case oauthFailureQuota: - cooldown = 60 * time.Minute - case oauthFailureForbidden: - cooldown = 30 * time.Minute - } - state.API.CooldownUntil = time.Now().Add(cooldown).Format(time.RFC3339) - state.API.TokenMasked = maskToken(p.apiKey) - state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: "api_key", - Target: maskToken(p.apiKey), - Reason: string(reason), - }, runtimeEventLimit(state)) - state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: "api_key", - Target: maskToken(p.apiKey), - Reason: "api_key_cooldown_" + string(reason), - Detail: "api key entered cooldown after request failure", - }, runtimeEventLimit(state)) - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state -} - -func (p *HTTPProvider) markAPIKeySuccess() { - name := strings.TrimSpace(p.providerName) - if name == "" { - return - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - if state.API.HealthScore <= 0 { - state.API.HealthScore = 100 - } else { - state.API.HealthScore = minInt(100, state.API.HealthScore+3) - } - wasCooling := strings.TrimSpace(state.API.CooldownUntil) != "" - state.API.CooldownUntil = "" - state.API.TokenMasked = maskToken(p.apiKey) - if wasCooling { - state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: "api_key", - Target: maskToken(p.apiKey), - Reason: "api_key_recovered", - Detail: "api key cooldown cleared after successful request", - }, runtimeEventLimit(state)) - } - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state -} - -func (p *HTTPProvider) apiKeyAttempt() (authAttempt, bool) { - token := strings.TrimSpace(p.apiKey) - if token == "" { - return authAttempt{}, false - } - name := strings.TrimSpace(p.providerName) - if name == "" { - return authAttempt{token: token, kind: "api_key"}, true - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - if state.API.TokenMasked == "" { - state.API.TokenMasked = maskToken(token) - } - if state.API.HealthScore <= 0 { - state.API.HealthScore = 100 - } - if state.API.CooldownUntil != "" { - if until, err := time.Parse(time.RFC3339, state.API.CooldownUntil); err == nil { - if time.Now().Before(until) { - providerRuntimeRegistry.api[name] = state - return authAttempt{}, false - } - } - state.API.CooldownUntil = "" - } - providerRuntimeRegistry.api[name] = state - return authAttempt{token: token, kind: "api_key"}, true -} - -func providerAPIHealth(name string) int { - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - if state.API.HealthScore <= 0 { - return 100 - } - return state.API.HealthScore -} - -func maskToken(value string) string { - value = strings.TrimSpace(value) - if value == "" { - return "" - } - if len(value) <= 8 { - return value[:2] + "***" - } - return value[:4] + "***" + value[len(value)-4:] -} - -func appendRuntimeEvent(events []providerRuntimeEvent, event providerRuntimeEvent, limit int) []providerRuntimeEvent { - out := append([]providerRuntimeEvent{event}, events...) - if limit <= 0 { - limit = 8 - } - if len(out) > limit { - out = out[:limit] - } - return out -} - -func (p *HTTPProvider) recordProviderHit(attempt authAttempt, reason string) { - name := strings.TrimSpace(p.providerName) - if name == "" { - return - } - target := "" - if attempt.kind == "api_key" { - target = maskToken(p.apiKey) - } else if attempt.session != nil { - target = firstNonEmpty(attempt.session.Email, attempt.session.AccountID, attempt.session.FilePath) - } - event := providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: attempt.kind, - Target: target, - Reason: reason, - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - state.RecentHits = appendRuntimeEvent(state.RecentHits, event, runtimeEventLimit(state)) - state.LastSuccess = &event - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state -} - -func recordProviderOAuthError(providerName string, session *oauthSession, reason oauthFailureReason) { - name := strings.TrimSpace(providerName) - if name == "" || session == nil { - return - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: "oauth", - Target: firstNonEmpty(session.Email, session.AccountID, session.FilePath), - Reason: string(reason), - }, runtimeEventLimit(state)) - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state -} - -func ClearProviderAPICooldown(providerName string) { - name := strings.TrimSpace(providerName) - if name == "" { - return - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - target := state.API.TokenMasked - state.API.CooldownUntil = "" - state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ - When: time.Now().Format(time.RFC3339), - Kind: "api_key", - Target: target, - Reason: "manual_clear_api_cooldown", - Detail: "api key cooldown cleared from runtime panel", - }, runtimeEventLimit(state)) - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state -} - -func ClearProviderRuntimeHistory(providerName string) { - name := strings.TrimSpace(providerName) - if name == "" { - return - } - providerRuntimeRegistry.mu.Lock() - defer providerRuntimeRegistry.mu.Unlock() - state := providerRuntimeRegistry.api[name] - state.RecentHits = nil - state.RecentErrors = nil - state.RecentChanges = nil - state.LastSuccess = nil - if state.Persist.Enabled && strings.TrimSpace(state.Persist.File) != "" { - _ = os.Remove(state.Persist.File) - } - providerRuntimeRegistry.api[name] = state -} - -func runtimeEventLimit(state providerRuntimeState) int { - if state.Persist.MaxEvents > 0 { - return state.Persist.MaxEvents - } - return 8 -} - -func runtimeHistoryMax(pc config.ProviderConfig) int { - if pc.RuntimeHistoryMax > 0 { - return pc.RuntimeHistoryMax - } - return 24 -} - -func runtimeHistoryFile(name string, pc config.ProviderConfig) string { - if file := strings.TrimSpace(pc.RuntimeHistoryFile); file != "" { - return file - } - return filepath.Join(config.GetConfigDir(), "runtime", "providers", strings.TrimSpace(name)+".json") -} - -func loadPersistedProviderRuntimeLocked(name string, state *providerRuntimeState) { - if state == nil || !state.Persist.Enabled || strings.TrimSpace(state.Persist.File) == "" { - return - } - raw, err := os.ReadFile(state.Persist.File) - if err != nil { - if os.IsNotExist(err) { - state.Persist.Loaded = true - } - return - } - var persisted providerRuntimeState - if err := json.Unmarshal(raw, &persisted); err != nil { - state.Persist.Loaded = true - return - } - if state.API == (providerAPIRuntimeState{}) { - state.API = persisted.API - } - if len(state.RecentHits) == 0 { - state.RecentHits = persisted.RecentHits - } - if len(state.RecentErrors) == 0 { - state.RecentErrors = persisted.RecentErrors - } - if len(state.RecentChanges) == 0 { - state.RecentChanges = persisted.RecentChanges - } - if state.LastSuccess == nil && persisted.LastSuccess != nil { - last := *persisted.LastSuccess - state.LastSuccess = &last - } - if len(state.CandidateOrder) == 0 { - state.CandidateOrder = persisted.CandidateOrder - } - state.Persist.Loaded = true -} - -func persistProviderRuntimeLocked(name string, state providerRuntimeState) { - if !state.Persist.Enabled || strings.TrimSpace(state.Persist.File) == "" { - return - } - if err := os.MkdirAll(filepath.Dir(state.Persist.File), 0o700); err != nil { - return - } - payload := providerRuntimeState{ - API: state.API, - RecentHits: trimRuntimeEvents(state.RecentHits, runtimeEventLimit(state)), - RecentErrors: trimRuntimeEvents(state.RecentErrors, runtimeEventLimit(state)), - RecentChanges: trimRuntimeEvents(state.RecentChanges, runtimeEventLimit(state)), - LastSuccess: state.LastSuccess, - CandidateOrder: state.CandidateOrder, - } - raw, err := json.MarshalIndent(payload, "", " ") - if err != nil { - return - } - _ = os.WriteFile(state.Persist.File, raw, 0o600) -} - -func trimRuntimeEvents(events []providerRuntimeEvent, limit int) []providerRuntimeEvent { - if limit <= 0 || len(events) <= limit { - return events - } - return events[:limit] -} - -func eventTimeUnix(event providerRuntimeEvent) int64 { - when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) - if err != nil { - return 0 - } - return when.Unix() -} - -func filterRuntimeEvents(events []providerRuntimeEvent, query ProviderRuntimeQuery) []providerRuntimeEvent { - if len(events) == 0 { - return nil - } - kind := strings.TrimSpace(query.EventKind) - reason := strings.TrimSpace(query.Reason) - target := strings.ToLower(strings.TrimSpace(query.Target)) - var cutoff time.Time - if query.Window > 0 { - cutoff = time.Now().Add(-query.Window) - } - filtered := make([]providerRuntimeEvent, 0, len(events)) - for _, event := range events { - if !cutoff.IsZero() { - when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) - if err != nil || when.Before(cutoff) { - continue - } - } - if kind != "" && !strings.EqualFold(strings.TrimSpace(event.Kind), kind) { - continue - } - if reason != "" && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Reason)), strings.ToLower(reason)) { - continue - } - if target != "" && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Target)), target) && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Detail)), target) { - continue - } - filtered = append(filtered, event) - } - return filtered -} - -func mergeRuntimeEvents(item map[string]interface{}, query ProviderRuntimeQuery) ([]providerRuntimeEvent, int) { - hits, _ := item["recent_hits"].([]providerRuntimeEvent) - errors, _ := item["recent_errors"].([]providerRuntimeEvent) - changes, _ := item["recent_changes"].([]providerRuntimeEvent) - merged := make([]providerRuntimeEvent, 0, len(hits)+len(errors)+len(changes)) - if !query.ChangesOnly { - merged = append(merged, filterRuntimeEvents(hits, query)...) - merged = append(merged, filterRuntimeEvents(errors, query)...) - } - merged = append(merged, filterRuntimeEvents(changes, query)...) - desc := !strings.EqualFold(strings.TrimSpace(query.Sort), "asc") - for i := 0; i < len(merged); i++ { - for j := i + 1; j < len(merged); j++ { - left := eventTimeUnix(merged[i]) - right := eventTimeUnix(merged[j]) - swap := right > left - if !desc { - swap = right < left - } - if swap { - merged[i], merged[j] = merged[j], merged[i] - } - } - } - start := query.Cursor - if start < 0 { - start = 0 - } - if start > len(merged) { - start = len(merged) - } - limit := query.Limit - if limit <= 0 { - limit = 20 - } - end := start + limit - if end > len(merged) { - end = len(merged) - } - nextCursor := 0 - if end < len(merged) { - nextCursor = end - } - return merged[start:end], nextCursor -} - -func matchesProviderCandidateFilters(item map[string]interface{}, query ProviderRuntimeQuery) bool { - if query.HealthBelow <= 0 && query.CooldownBefore.IsZero() { - return true - } - apiState, _ := item["api_state"].(providerAPIRuntimeState) - candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) - if query.HealthBelow > 0 { - if runtimeHealthValue(apiState.HealthScore) < query.HealthBelow { - return true - } - for _, candidate := range candidates { - if runtimeHealthValue(candidate.HealthScore) < query.HealthBelow { - return true - } - } - } - if !query.CooldownBefore.IsZero() { - values := []string{apiState.CooldownUntil} - for _, candidate := range candidates { - values = append(values, candidate.CooldownUntil) - } - for _, value := range values { - if strings.TrimSpace(value) == "" { - continue - } - until, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) - if err == nil && until.Before(query.CooldownBefore) { - return true - } - } - } - return false -} - -func providerInCooldown(item map[string]interface{}) bool { - apiState, _ := item["api_state"].(providerAPIRuntimeState) - if cooldownActive(apiState.CooldownUntil) { - return true - } - candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) - for _, candidate := range candidates { - if cooldownActive(candidate.CooldownUntil) { - return true - } - } - return false -} - -func cooldownActive(value string) bool { - if strings.TrimSpace(value) == "" { - return false - } - until, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) - return err == nil && time.Now().Before(until) -} - -func buildProviderCandidateOrder(_ string, pc config.ProviderConfig, accounts []OAuthAccountInfo, api providerAPIRuntimeState) []providerRuntimeCandidate { - authMode := strings.ToLower(strings.TrimSpace(pc.Auth)) - apiCandidate := providerRuntimeCandidate{ - Kind: "api_key", - Target: maskToken(pc.APIKey), - Available: strings.TrimSpace(pc.APIKey) != "", - Status: "ready", - CooldownUntil: strings.TrimSpace(api.CooldownUntil), - HealthScore: runtimeHealthValue(api.HealthScore), - FailureCount: api.FailureCount, - } - if strings.TrimSpace(apiCandidate.CooldownUntil) != "" { - if until, err := time.Parse(time.RFC3339, apiCandidate.CooldownUntil); err == nil && time.Now().Before(until) { - apiCandidate.Available = false - apiCandidate.Status = "cooldown" - } - } - oauthAvailable := make([]providerRuntimeCandidate, 0, len(accounts)) - oauthUnavailable := make([]providerRuntimeCandidate, 0, len(accounts)) - for _, account := range accounts { - candidate := providerRuntimeCandidate{ - Kind: "oauth", - Target: firstNonEmpty(account.Email, account.AccountID, account.CredentialFile), - Available: true, - Status: "ready", - CooldownUntil: strings.TrimSpace(account.CooldownUntil), - HealthScore: runtimeHealthValue(account.HealthScore), - FailureCount: account.FailureCount, - } - if strings.TrimSpace(candidate.CooldownUntil) != "" { - if until, err := time.Parse(time.RFC3339, candidate.CooldownUntil); err == nil && time.Now().Before(until) { - candidate.Available = false - candidate.Status = "cooldown" - } - } - if candidate.Available { - oauthAvailable = append(oauthAvailable, candidate) - } else { - oauthUnavailable = append(oauthUnavailable, candidate) - } - } - sortRuntimeCandidates(oauthAvailable) - sortRuntimeCandidates(oauthUnavailable) - out := make([]providerRuntimeCandidate, 0, 1+len(accounts)) - switch authMode { - case "oauth": - out = append(out, oauthAvailable...) - case "hybrid": - if apiCandidate.Target != "" && apiCandidate.Available { - out = append(out, apiCandidate) - } - out = append(out, oauthAvailable...) - case "none": - default: - if apiCandidate.Target != "" { - out = append(out, apiCandidate) - } - } - if authMode == "hybrid" { - if apiCandidate.Target != "" && !apiCandidate.Available { - out = append(out, apiCandidate) - } - out = append(out, oauthUnavailable...) - } else if authMode == "oauth" { - out = append(out, oauthUnavailable...) - } - return out -} - -func runtimeHealthValue(value int) int { - if value <= 0 { - return 100 - } - return value -} - -func sortRuntimeCandidates(items []providerRuntimeCandidate) { - for i := 0; i < len(items); i++ { - for j := i + 1; j < len(items); j++ { - if items[j].HealthScore > items[i].HealthScore || (items[j].HealthScore == items[i].HealthScore && items[j].Target < items[i].Target) { - items[i], items[j] = items[j], items[i] - } - } - } -} - -func providerCandidatesEqual(left, right []providerRuntimeCandidate) bool { - if len(left) != len(right) { - return false - } - for i := range left { - if left[i].Kind != right[i].Kind || - left[i].Target != right[i].Target || - left[i].Available != right[i].Available || - left[i].Status != right[i].Status || - left[i].CooldownUntil != right[i].CooldownUntil || - left[i].HealthScore != right[i].HealthScore || - left[i].FailureCount != right[i].FailureCount { - return false - } - } - return true -} - -func summarizeCandidate(candidate providerRuntimeCandidate) string { - target := strings.TrimSpace(candidate.Target) - if target == "" { - target = "-" - } - return strings.TrimSpace(candidate.Kind) + ":" + target -} - -func candidateOrderChangeDetail(before, after []providerRuntimeCandidate) string { - if len(before) == 0 && len(after) == 0 { - return "" - } - beforeTop := "-" - afterTop := "-" - if len(before) > 0 { - beforeTop = summarizeCandidate(before[0]) - } - if len(after) > 0 { - afterTop = summarizeCandidate(after[0]) - } - beforeOrder := make([]string, 0, len(before)) - for _, item := range before { - beforeOrder = append(beforeOrder, summarizeCandidate(item)) - } - afterOrder := make([]string, 0, len(after)) - for _, item := range after { - afterOrder = append(afterOrder, summarizeCandidate(item)) - } - return fmt.Sprintf("top %s -> %s | order [%s] -> [%s]", beforeTop, afterTop, strings.Join(beforeOrder, " > "), strings.Join(afterOrder, " > ")) -} - -func GetProviderRuntimeSnapshot(cfg *config.Config) map[string]interface{} { - if cfg == nil { - return map[string]interface{}{"items": []interface{}{}} - } - items := make([]map[string]interface{}, 0) - configs := getAllProviderConfigs(cfg) - for name, pc := range configs { - ConfigureProviderRuntime(name, pc) - providerRuntimeRegistry.mu.Lock() - state := providerRuntimeRegistry.api[name] - providerRuntimeRegistry.mu.Unlock() - item := map[string]interface{}{ - "name": name, - "auth": strings.TrimSpace(pc.Auth), - "api_base": strings.TrimSpace(pc.APIBase), - "api_state": state.API, - "recent_hits": state.RecentHits, - "recent_errors": state.RecentErrors, - "recent_changes": state.RecentChanges, - "last_success": state.LastSuccess, - } - candidateOrder := state.CandidateOrder - if strings.EqualFold(name, "aistudio") { - if accounts := listAIStudioRelayAccounts(); len(accounts) > 0 { - item["oauth_accounts"] = accounts - } - } else if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { - if mgr, err := NewOAuthLoginManager(pc, time.Duration(maxInt(pc.TimeoutSec, 90))*time.Second); err == nil { - if accounts, err := mgr.ListAccounts(); err == nil { - item["oauth_accounts"] = accounts - candidateOrder = buildProviderCandidateOrder(name, pc, accounts, state.API) - } - } - } else if len(candidateOrder) == 0 && strings.TrimSpace(pc.APIKey) != "" { - candidateOrder = buildProviderCandidateOrder(name, pc, nil, state.API) - } - if len(candidateOrder) > 0 { - providerRuntimeRegistry.mu.Lock() - state = providerRuntimeRegistry.api[name] - state.CandidateOrder = candidateOrder - persistProviderRuntimeLocked(name, state) - providerRuntimeRegistry.api[name] = state - providerRuntimeRegistry.mu.Unlock() - } - item["candidate_order"] = candidateOrder - items = append(items, item) - } - return map[string]interface{}{"items": items} -} - -func GetProviderRuntimeView(cfg *config.Config, query ProviderRuntimeQuery) map[string]interface{} { - if cfg == nil { - return map[string]interface{}{"items": []interface{}{}} - } - snapshot := GetProviderRuntimeSnapshot(cfg) - rawItems, _ := snapshot["items"].([]map[string]interface{}) - if len(rawItems) == 0 { - return map[string]interface{}{"items": []interface{}{}} - } - filterName := strings.TrimSpace(query.Provider) - items := make([]map[string]interface{}, 0, len(rawItems)) - for _, item := range rawItems { - name := strings.TrimSpace(fmt.Sprintf("%v", item["name"])) - if filterName != "" && name != filterName { - continue - } - next := map[string]interface{}{} - for key, value := range item { - next[key] = value - } - hits, _ := item["recent_hits"].([]providerRuntimeEvent) - errors, _ := item["recent_errors"].([]providerRuntimeEvent) - changes, _ := item["recent_changes"].([]providerRuntimeEvent) - next["recent_hits"] = filterRuntimeEvents(hits, query) - next["recent_errors"] = filterRuntimeEvents(errors, query) - next["recent_changes"] = filterRuntimeEvents(changes, query) - if query.ChangesOnly { - next["recent_hits"] = []providerRuntimeEvent{} - next["recent_errors"] = []providerRuntimeEvent{} - } - events, nextCursor := mergeRuntimeEvents(next, query) - next["events"] = events - next["next_cursor"] = nextCursor - if !matchesProviderCandidateFilters(next, query) { - continue - } - items = append(items, next) - } - return map[string]interface{}{"items": items} -} - -func GetProviderRuntimeSummary(cfg *config.Config, query ProviderRuntimeQuery) ProviderRuntimeSummary { - snapshot := GetProviderRuntimeSnapshot(cfg) - rawItems, _ := snapshot["items"].([]map[string]interface{}) - summary := ProviderRuntimeSummary{Providers: make([]ProviderRuntimeSummaryItem, 0, len(rawItems))} - for _, item := range rawItems { - name := strings.TrimSpace(fmt.Sprintf("%v", item["name"])) - if strings.TrimSpace(query.Provider) != "" && name != strings.TrimSpace(query.Provider) { - continue - } - auth := strings.TrimSpace(fmt.Sprintf("%v", item["auth"])) - apiState, _ := item["api_state"].(providerAPIRuntimeState) - accounts, _ := item["oauth_accounts"].([]OAuthAccountInfo) - candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) - errors, _ := item["recent_errors"].([]providerRuntimeEvent) - changes, _ := item["recent_changes"].([]providerRuntimeEvent) - errors = filterRuntimeEvents(errors, query) - changes = filterRuntimeEvents(changes, query) - lastSuccess, _ := item["last_success"].(*providerRuntimeEvent) - inCooldown := providerInCooldown(item) - lowHealth := matchesProviderCandidateFilters(item, ProviderRuntimeQuery{HealthBelow: maxInt(query.HealthBelow, 1)}) - hasRecentErrors := len(errors) > 0 - lastError := latestProviderRuntimeEvent(errors) - topChangedAt := latestRuntimeChangeAt(changes, "candidate_order_changed") - status := providerRuntimeSummaryStatus(inCooldown, lowHealth, hasRecentErrors) - providerItem := ProviderRuntimeSummaryItem{ - Name: name, - Auth: auth, - Status: status, - APIState: apiState, - OAuthAccounts: accounts, - CandidateOrder: candidates, - LastSuccess: lastSuccess, - LastError: lastError, - TopCandidateChangedAt: topChangedAt, - InCooldown: inCooldown, - LowHealth: lowHealth, - HasRecentErrors: hasRecentErrors, - } - if lastSuccess != nil { - providerItem.LastSuccessAt = strings.TrimSpace(lastSuccess.When) - if when := parseRuntimeEventTime(*lastSuccess); !when.IsZero() { - providerItem.StaleForSec = int64(time.Since(when).Seconds()) - } - } else { - providerItem.StaleForSec = -1 - } - if lastError != nil { - providerItem.LastErrorAt = strings.TrimSpace(lastError.When) - providerItem.LastErrorReason = strings.TrimSpace(lastError.Reason) - } - if len(candidates) > 0 { - top := candidates[0] - providerItem.TopCandidate = &top - } - summary.TotalProviders++ - switch status { - case "critical": - summary.Critical++ - case "degraded": - summary.Degraded++ - default: - summary.Healthy++ - } - if inCooldown { - summary.InCooldown++ - } - if lowHealth { - summary.LowHealth++ - } - if hasRecentErrors { - summary.RecentErrors++ - } - if inCooldown || lowHealth || hasRecentErrors || strings.TrimSpace(query.Provider) != "" { - summary.Providers = append(summary.Providers, providerItem) - } - } - return summary -} - -func latestProviderRuntimeEvent(events []providerRuntimeEvent) *providerRuntimeEvent { - if len(events) == 0 { - return nil - } - best := events[0] - bestTime := eventTimeUnix(best) - for i := 1; i < len(events); i++ { - currentTime := eventTimeUnix(events[i]) - if currentTime > bestTime { - best = events[i] - bestTime = currentTime - } - } - copyEvent := best - return ©Event -} - -func latestRuntimeChangeAt(events []providerRuntimeEvent, reason string) string { - targetReason := strings.TrimSpace(reason) - if targetReason == "" || len(events) == 0 { - return "" - } - var latest *providerRuntimeEvent - var latestUnix int64 - for i := range events { - if !strings.EqualFold(strings.TrimSpace(events[i].Reason), targetReason) { - continue - } - currentUnix := eventTimeUnix(events[i]) - if latest == nil || currentUnix > latestUnix { - eventCopy := events[i] - latest = &eventCopy - latestUnix = currentUnix - } - } - if latest == nil { - return "" - } - return strings.TrimSpace(latest.When) -} - -func parseRuntimeEventTime(event providerRuntimeEvent) time.Time { - when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) - if err != nil { - return time.Time{} - } - return when -} - -func providerRuntimeSummaryStatus(inCooldown, lowHealth, hasRecentErrors bool) string { - if inCooldown || lowHealth { - return "critical" - } - if hasRecentErrors { - return "degraded" - } - return "healthy" -} - -func RefreshProviderRuntimeNow(cfg *config.Config, providerName string, onlyExpiring bool) (*ProviderRefreshResult, error) { - pc, err := getProviderConfigByName(cfg, providerName) - if err != nil { - return nil, err - } - if !strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") && !strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { - return nil, fmt.Errorf("provider %q does not use oauth", providerName) - } - manager, err := newOAuthManager(pc, time.Duration(maxInt(pc.TimeoutSec, 90))*time.Second) - if err != nil { - return nil, err - } - defer manager.bgCancel() - manager.providerName = strings.TrimSpace(providerName) - lead := 365 * 24 * time.Hour - if onlyExpiring { - lead = manager.cfg.RefreshLead - if lead <= 0 { - lead = 30 * time.Minute - } - } - return manager.refreshExpiringSessions(context.Background(), lead) -} - -func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerRuntimeCandidate, error) { - provider, err := CreateProviderByName(cfg, providerName) - if err != nil { - return nil, err - } - httpProvider, ok := unwrapHTTPProvider(provider) - if !ok { - return nil, fmt.Errorf("provider %q does not support runtime rerank", providerName) - } - _, err = httpProvider.authAttempts(context.Background()) - if err != nil && !strings.Contains(strings.ToLower(err.Error()), "oauth session not found") { - return nil, err - } - providerRuntimeRegistry.mu.Lock() - order := append([]providerRuntimeCandidate(nil), providerRuntimeRegistry.api[strings.TrimSpace(providerName)].CandidateOrder...) - providerRuntimeRegistry.mu.Unlock() - return order, nil -} - -func unwrapHTTPProvider(provider LLMProvider) (*HTTPProvider, bool) { - switch typed := provider.(type) { - case *HTTPProvider: - return typed, true - case *CodexProvider: - if typed == nil { - return nil, false - } - return typed.base, typed.base != nil - case *AntigravityProvider: - if typed == nil { - return nil, false - } - return typed.base, typed.base != nil - case *ClaudeProvider: - if typed == nil { - return nil, false - } - return typed.base, typed.base != nil - case *QwenProvider: - if typed == nil { - return nil, false - } - return typed.base, typed.base != nil - case *KimiProvider: - if typed == nil { - return nil, false - } - return typed.base, typed.base != nil - default: - return nil, false - } -} - -func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) { - var resp struct { - Status string `json:"status"` - Output []struct { - ID string `json:"id"` - Type string `json:"type"` - CallID string `json:"call_id"` - Name string `json:"name"` - ArgsRaw string `json:"arguments"` - Role string `json:"role"` - Content []struct { - Type string `json:"type"` - Text string `json:"text"` - } `json:"content"` - } `json:"output"` - OutputText string `json:"output_text"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - } - if err := json.Unmarshal(body, &resp); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - toolCalls := make([]ToolCall, 0) - outputText := strings.TrimSpace(resp.OutputText) - for _, item := range resp.Output { - switch strings.TrimSpace(item.Type) { - case "function_call": - name := strings.TrimSpace(item.Name) - if name == "" { - continue - } - args := map[string]interface{}{} - if strings.TrimSpace(item.ArgsRaw) != "" { - if err := json.Unmarshal([]byte(item.ArgsRaw), &args); err != nil { - args["raw"] = item.ArgsRaw - } - } - id := strings.TrimSpace(item.CallID) - if id == "" { - id = strings.TrimSpace(item.ID) - } - if id == "" { - id = fmt.Sprintf("call_%d", len(toolCalls)+1) - } - toolCalls = append(toolCalls, ToolCall{ID: id, Name: name, Arguments: args}) - case "message": - if outputText == "" { - texts := make([]string, 0, len(item.Content)) - for _, c := range item.Content { - if strings.TrimSpace(c.Type) == "output_text" && strings.TrimSpace(c.Text) != "" { - texts = append(texts, c.Text) - } - } - if len(texts) > 0 { - outputText = strings.Join(texts, "\n") - } - } - } - } - - if len(toolCalls) == 0 { - compatCalls, cleanedContent := parseCompatFunctionCalls(outputText) - if len(compatCalls) > 0 { - toolCalls = compatCalls - outputText = cleanedContent - } - } - - finishReason := strings.TrimSpace(resp.Status) - if finishReason == "" || finishReason == "completed" { - finishReason = "stop" - } - - var usage *UsageInfo - if resp.Usage.TotalTokens > 0 || resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 { - usage = &UsageInfo{PromptTokens: resp.Usage.InputTokens, CompletionTokens: resp.Usage.OutputTokens, TotalTokens: resp.Usage.TotalTokens} - } - return &LLMResponse{Content: strings.TrimSpace(outputText), ToolCalls: toolCalls, FinishReason: finishReason, Usage: usage}, nil -} - -func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) { - var payload struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - } - if err := json.Unmarshal(body, &payload); err != nil { - return nil, err - } - if len(payload.Choices) == 0 { - return &LLMResponse{}, nil - } - choice := payload.Choices[0] - resp := &LLMResponse{ - Content: choice.Message.Content, - FinishReason: choice.FinishReason, - } - if payload.Usage.TotalTokens > 0 || payload.Usage.PromptTokens > 0 || payload.Usage.CompletionTokens > 0 { - resp.Usage = &UsageInfo{ - PromptTokens: payload.Usage.PromptTokens, - CompletionTokens: payload.Usage.CompletionTokens, - TotalTokens: payload.Usage.TotalTokens, - } - } - if len(choice.Message.ToolCalls) > 0 { - resp.ToolCalls = make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - resp.ToolCalls = append(resp.ToolCalls, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: &FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - }, - Name: tc.Function.Name, - }) - } - } - return resp, nil -} - -func previewResponseBody(body []byte) string { - preview := strings.TrimSpace(string(body)) - preview = strings.ReplaceAll(preview, "\n", " ") - preview = strings.ReplaceAll(preview, "\r", " ") - if preview == "" { - return "" - } - const maxLen = 600 - if len(preview) > maxLen { - return preview[:maxLen] + "..." - } - return preview -} - -func int64FromOption(options map[string]interface{}, key string) (int64, bool) { - if options == nil { - return 0, false - } - v, ok := options[key] - if !ok { - return 0, false - } - switch t := v.(type) { - case int: - return int64(t), true - case int64: - return t, true - case float64: - return int64(t), true - default: - return 0, false - } -} - -func float64FromOption(options map[string]interface{}, key string) (float64, bool) { - if options == nil { - return 0, false - } - v, ok := options[key] - if !ok { - return 0, false - } - switch t := v.(type) { - case float32: - return float64(t), true - case float64: - return t, true - case int: - return float64(t), true - default: - return 0, false - } -} - -func normalizeAPIBase(raw string) string { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return "" - } - u, err := url.Parse(trimmed) - if err != nil { - return strings.TrimRight(trimmed, "/") - } - u.Path = strings.TrimRight(u.Path, "/") - return strings.TrimRight(u.String(), "/") -} - -func endpointFor(base, relative string) string { - b := strings.TrimRight(strings.TrimSpace(base), "/") - if b == "" { - return relative - } - if strings.HasSuffix(b, relative) { - return b - } - if relative == "/responses/compact" && strings.HasSuffix(b, "/responses") { - return b + "/compact" - } - if relative == "/responses" && strings.HasSuffix(b, "/responses/compact") { - return strings.TrimSuffix(b, "/compact") - } - return b + relative -} - -func (p *HTTPProvider) useCodexCompat() bool { - if p == nil || p.oauth == nil { - return false - } - if !strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultCodexOAuthProvider) { - return false - } - base := strings.ToLower(strings.TrimSpace(p.apiBase)) - if base == "" { - return true - } - return strings.Contains(base, "api.openai.com") || strings.Contains(base, "chatgpt.com/backend-api/codex") -} - -func (p *HTTPProvider) codexCompatBase() string { - if p == nil { - return codexCompatBaseURL - } - base := strings.ToLower(strings.TrimSpace(p.apiBase)) - if strings.Contains(base, "chatgpt.com/backend-api/codex") { - return normalizeAPIBase(p.apiBase) - } - if base != "" && !strings.Contains(base, "api.openai.com") { - return normalizeAPIBase(p.apiBase) - } - return codexCompatBaseURL -} - -func (p *HTTPProvider) codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} { - return codexCompatRequestBody(requestBody) -} - -func (p *HTTPProvider) oauthProvider() string { - if p == nil || p.oauth == nil { - return "" - } - return strings.ToLower(strings.TrimSpace(p.oauth.cfg.Provider)) -} - -func (p *HTTPProvider) useOpenAICompatChatUpstream() bool { - switch p.oauthProvider() { - case defaultQwenOAuthProvider, defaultKimiOAuthProvider: - return true - default: - return false - } -} - -func (p *HTTPProvider) compatBase() string { - switch p.oauthProvider() { - case defaultQwenOAuthProvider: - if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") { - return normalizeAPIBase(p.apiBase) - } - return qwenCompatBaseURL - case defaultKimiOAuthProvider: - if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") { - return normalizeAPIBase(p.apiBase) - } - return kimiCompatBaseURL - default: - return normalizeAPIBase(p.apiBase) - } -} - -func (p *HTTPProvider) compatModel(model string) string { - trimmed := strings.TrimSpace(qwenBaseModel(model)) - if p.oauthProvider() == defaultKimiOAuthProvider && strings.HasPrefix(strings.ToLower(trimmed), "kimi-") { - return trimmed[5:] - } - return trimmed -} - -func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]interface{} { - requestBody := map[string]interface{}{ - "model": p.compatModel(model), - "messages": openAICompatMessages(messages), - } - if suffix := qwenModelSuffix(model); suffix != "" { - applyOpenAICompatThinkingSuffix(requestBody, suffix) - } - if len(tools) > 0 { - requestBody["tools"] = openAICompatTools(tools) - requestBody["tool_choice"] = "auto" - if tc, ok := rawOption(options, "tool_choice"); ok { - requestBody["tool_choice"] = tc - } - } - if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { - requestBody["max_tokens"] = maxTokens - } - if temperature, ok := float64FromOption(options, "temperature"); ok { - requestBody["temperature"] = temperature - } - return requestBody -} - -func openAICompatMessages(messages []Message) []map[string]interface{} { - out := make([]map[string]interface{}, 0, len(messages)) - for _, msg := range messages { - role := strings.ToLower(strings.TrimSpace(msg.Role)) - content := openAICompatMessageContent(msg) - switch role { - case "system": - out = append(out, map[string]interface{}{"role": "system", "content": content}) - case "developer": - out = append(out, map[string]interface{}{"role": "user", "content": content}) - case "assistant": - item := map[string]interface{}{"role": "assistant", "content": content} - if len(msg.ToolCalls) > 0 { - toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls)) - for _, tc := range msg.ToolCalls { - args := "" - if tc.Function != nil { - args = tc.Function.Arguments - } - if args == "" { - raw, _ := json.Marshal(tc.Arguments) - args = string(raw) - } - name := tc.Name - if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" { - name = tc.Function.Name - } - toolCalls = append(toolCalls, map[string]interface{}{ - "id": tc.ID, - "type": "function", - "function": map[string]interface{}{ - "name": name, - "arguments": args, - }, - }) - } - item["tool_calls"] = toolCalls - } - out = append(out, item) - case "tool": - out = append(out, map[string]interface{}{ - "role": "tool", - "tool_call_id": msg.ToolCallID, - "content": content, - }) - default: - out = append(out, map[string]interface{}{"role": "user", "content": content}) - } - } - return out -} - -func openAICompatMessageContent(msg Message) interface{} { - if len(msg.ContentParts) == 0 { - return msg.Content - } - parts := make([]map[string]interface{}, 0, len(msg.ContentParts)) - for _, part := range msg.ContentParts { - switch strings.ToLower(strings.TrimSpace(part.Type)) { - case "text", "input_text": - if strings.TrimSpace(part.Text) == "" { - continue - } - parts = append(parts, map[string]interface{}{ - "type": "text", - "text": part.Text, - }) - case "input_image", "image_url": - imageURL := strings.TrimSpace(part.ImageURL) - if imageURL == "" { - continue - } - payload := map[string]interface{}{ - "type": "image_url", - "image_url": map[string]interface{}{ - "url": imageURL, - }, - } - if detail := strings.TrimSpace(part.Detail); detail != "" { - payload["image_url"].(map[string]interface{})["detail"] = detail - } - parts = append(parts, payload) - default: - if strings.TrimSpace(part.Text) == "" { - continue - } - parts = append(parts, map[string]interface{}{ - "type": "text", - "text": part.Text, - }) - } - } - if len(parts) == 0 { - return msg.Content - } - if len(parts) == 1 && parts[0]["type"] == "text" && len(msg.ToolCalls) == 0 { - if text, _ := parts[0]["text"].(string); text != "" { - return text - } - } - return parts -} - -func openAICompatTools(tools []ToolDefinition) []map[string]interface{} { - out := make([]map[string]interface{}, 0, len(tools)) - for _, tool := range tools { - out = append(out, map[string]interface{}{ - "type": "function", - "function": map[string]interface{}{ - "name": tool.Function.Name, - "description": tool.Function.Description, - "parameters": tool.Function.Parameters, - }, - }) - } - return out -} - -func codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} { - if requestBody == nil { - requestBody = map[string]interface{}{} - } - requestBody["stream"] = true - requestBody["store"] = false - requestBody["parallel_tool_calls"] = true - if _, ok := requestBody["include"]; !ok { - requestBody["include"] = []string{"reasoning.encrypted_content"} - } - delete(requestBody, "max_output_tokens") - delete(requestBody, "max_completion_tokens") - delete(requestBody, "temperature") - delete(requestBody, "top_p") - delete(requestBody, "truncation") - delete(requestBody, "user") - if input, ok := requestBody["input"].([]map[string]interface{}); ok { - for _, item := range input { - if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") { - item["role"] = "developer" - } - } - requestBody["input"] = input - } - return requestBody -} - -func parseCompatFunctionCalls(content string) ([]ToolCall, string) { - if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { - return nil, content - } - blockRe := regexp.MustCompile(`(?is)\s*(.*?)\s*`) - blocks := blockRe.FindAllStringSubmatch(content, -1) - if len(blocks) == 0 { - return nil, content - } - toolCalls := make([]ToolCall, 0, len(blocks)) - for i, block := range blocks { - raw := block[1] - invoke := extractTag(raw, "invoke") - if invoke != "" { - raw = invoke - } - name := extractTag(raw, "toolname") - if strings.TrimSpace(name) == "" { - name = extractTag(raw, "tool_name") - } - name = strings.TrimSpace(name) - if name == "" { - continue - } - args := map[string]interface{}{} - paramsRaw := strings.TrimSpace(extractTag(raw, "parameters")) - if paramsRaw != "" { - if strings.HasPrefix(paramsRaw, "{") && strings.HasSuffix(paramsRaw, "}") { - _ = json.Unmarshal([]byte(paramsRaw), &args) - } - if len(args) == 0 { - paramTagRe := regexp.MustCompile(`(?is)<([a-zA-Z0-9_:-]+)>\s*(.*?)\s*`) - matches := paramTagRe.FindAllStringSubmatch(paramsRaw, -1) - for _, m := range matches { - if len(m) < 4 || !strings.EqualFold(strings.TrimSpace(m[1]), strings.TrimSpace(m[3])) { - continue - } - k := strings.TrimSpace(m[1]) - v := strings.TrimSpace(m[2]) - if k == "" || v == "" { - continue - } - args[k] = v - } - } - } - toolCalls = append(toolCalls, ToolCall{ID: fmt.Sprintf("compat_call_%d", i+1), Name: name, Arguments: args}) - } - cleaned := strings.TrimSpace(blockRe.ReplaceAllString(content, "")) - return toolCalls, cleaned -} - -func extractTag(src string, tag string) string { - re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag))) - m := re.FindStringSubmatch(src) - if len(m) < 2 { - return "" - } - return strings.TrimSpace(m[1]) -} - func (p *HTTPProvider) GetDefaultModel() string { return p.defaultModel } @@ -2506,228 +397,3 @@ func (p *HTTPProvider) GetDefaultModel() string { func (p *HTTPProvider) SupportsResponsesCompact() bool { return p != nil && p.supportsResponsesCompact } - -func (p *HTTPProvider) BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error) { - if !p.SupportsResponsesCompact() { - return "", fmt.Errorf("responses compact is not enabled for this provider") - } - input := make([]map[string]interface{}, 0, len(messages)+1) - if strings.TrimSpace(existingSummary) != "" { - input = append(input, responsesMessageItem("system", "Existing summary:\n"+strings.TrimSpace(existingSummary), "input_text")) - } - pendingCalls := map[string]struct{}{} - for _, msg := range messages { - input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) - } - if len(input) == 0 { - return strings.TrimSpace(existingSummary), nil - } - - compactReq := map[string]interface{}{"model": model, "input": input} - compactBody, statusCode, contentType, err := p.postJSON(ctx, endpointFor(p.apiBase, "/responses/compact"), compactReq) - if err != nil { - return "", fmt.Errorf("responses compact request failed: %w", err) - } - if statusCode != http.StatusOK { - return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(compactBody)) - } - if !json.Valid(compactBody) { - return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(compactBody)) - } - - var compactResp struct { - Output interface{} `json:"output"` - CompactedInput interface{} `json:"compacted_input"` - Compacted interface{} `json:"compacted"` - } - if err := json.Unmarshal(compactBody, &compactResp); err != nil { - return "", fmt.Errorf("responses compact request failed: invalid JSON: %w", err) - } - compactPayload := compactResp.Output - if compactPayload == nil { - compactPayload = compactResp.CompactedInput - } - if compactPayload == nil { - compactPayload = compactResp.Compacted - } - payloadBytes, err := json.Marshal(compactPayload) - if err != nil { - return "", fmt.Errorf("failed to serialize compact output: %w", err) - } - compactedPayload := strings.TrimSpace(string(payloadBytes)) - if compactedPayload == "" || compactedPayload == "null" { - return "", fmt.Errorf("empty compact output") - } - if len(compactedPayload) > 12000 { - compactedPayload = compactedPayload[:12000] + "..." - } - - summaryPrompt := fmt.Sprintf( - "Compacted conversation JSON:\n%s\n\nReturn a concise markdown summary with sections: Key Facts, Decisions, Open Items, Next Steps.", - compactedPayload, - ) - summaryReq := map[string]interface{}{ - "model": model, - "input": summaryPrompt, - } - if maxSummaryChars > 0 { - estMaxTokens := maxSummaryChars / 3 - if estMaxTokens < 128 { - estMaxTokens = 128 - } - summaryReq["max_output_tokens"] = estMaxTokens - } - summaryBody, summaryStatus, summaryType, err := p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), summaryReq) - if err != nil { - return "", fmt.Errorf("responses summary request failed: %w", err) - } - if summaryStatus != http.StatusOK { - return "", fmt.Errorf("responses summary request failed (status %d, content-type %q): %s", summaryStatus, summaryType, previewResponseBody(summaryBody)) - } - if !json.Valid(summaryBody) { - return "", fmt.Errorf("responses summary request failed (status %d, content-type %q): non-JSON response: %s", summaryStatus, summaryType, previewResponseBody(summaryBody)) - } - summaryResp, err := parseResponsesAPIResponse(summaryBody) - if err != nil { - return "", fmt.Errorf("responses summary request failed: %w", err) - } - summary := strings.TrimSpace(summaryResp.Content) - if summary == "" { - return "", fmt.Errorf("empty summary after responses compact") - } - if maxSummaryChars > 0 && len(summary) > maxSummaryChars { - summary = summary[:maxSummaryChars] - } - return summary, nil -} - -func normalizeProviderRouteName(name string) string { - switch strings.ToLower(strings.TrimSpace(name)) { - case "geminicli", "gemini_cli": - return "gemini-cli" - case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio": - return "aistudio" - case "google", "gemini-api-key", "gemini_api_key": - return "gemini" - case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key": - return "claude" - case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat": - return "openai-compatibility" - case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility": - return "vertex" - case "codex-api-key", "codex_api_key": - return "codex" - case "i-flow", "i_flow": - return "iflow" - default: - return strings.TrimSpace(name) - } -} - -func CreateProvider(cfg *config.Config) (LLMProvider, error) { - name := config.PrimaryProviderName(cfg) - provider, err := CreateProviderByName(cfg, name) - if err != nil { - return nil, err - } - _, model := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary) - if hp, ok := provider.(*HTTPProvider); ok && strings.TrimSpace(model) != "" { - hp.defaultModel = strings.TrimSpace(model) - } - return provider, nil -} - -func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) { - routeName := normalizeProviderRouteName(name) - pc, err := getProviderConfigByName(cfg, routeName) - if err != nil { - return nil, err - } - ConfigureProviderRuntime(routeName, pc) - oauthProvider := normalizeOAuthProvider(pc.OAuth.Provider) - if pc.APIBase == "" && - oauthProvider != defaultAntigravityOAuthProvider && - oauthProvider != defaultGeminiOAuthProvider && - oauthProvider != "aistudio" && - oauthProvider != defaultCodexOAuthProvider && - oauthProvider != defaultClaudeOAuthProvider && - oauthProvider != defaultQwenOAuthProvider && - oauthProvider != defaultKimiOAuthProvider && - oauthProvider != defaultIFlowOAuthProvider && - !strings.EqualFold(routeName, "gemini-cli") && - !strings.EqualFold(routeName, "aistudio") && - !strings.EqualFold(routeName, "vertex") && - !strings.EqualFold(routeName, defaultAntigravityOAuthProvider) && - !strings.EqualFold(routeName, defaultGeminiOAuthProvider) && - !strings.EqualFold(routeName, defaultCodexOAuthProvider) && - !strings.EqualFold(routeName, defaultClaudeOAuthProvider) && - !strings.EqualFold(routeName, defaultQwenOAuthProvider) && - !strings.EqualFold(routeName, defaultKimiOAuthProvider) && - !strings.EqualFold(routeName, defaultIFlowOAuthProvider) { - return nil, fmt.Errorf("no API base configured for provider %q", name) - } - if pc.TimeoutSec <= 0 { - return nil, fmt.Errorf("invalid timeout_sec for provider %q: %d", name, pc.TimeoutSec) - } - defaultModel := "" - if len(pc.Models) > 0 { - defaultModel = pc.Models[0] - } - var oauth *oauthManager - if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { - oauth, err = newOAuthManager(pc, time.Duration(pc.TimeoutSec)*time.Second) - if err != nil { - return nil, err - } - } - if oauthProvider == defaultAntigravityOAuthProvider || strings.EqualFold(routeName, defaultAntigravityOAuthProvider) { - return NewAntigravityProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if oauthProvider == "aistudio" || strings.EqualFold(routeName, "aistudio") { - return NewAistudioProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if strings.EqualFold(routeName, "gemini-cli") { - return NewGeminiCLIProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if oauthProvider == defaultGeminiOAuthProvider || strings.EqualFold(routeName, defaultGeminiOAuthProvider) || strings.EqualFold(routeName, "aistudio") { - return NewGeminiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if strings.EqualFold(routeName, "vertex") { - return NewVertexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if oauthProvider == defaultCodexOAuthProvider || strings.EqualFold(routeName, defaultCodexOAuthProvider) { - return NewCodexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if oauthProvider == defaultClaudeOAuthProvider || strings.EqualFold(routeName, defaultClaudeOAuthProvider) { - return NewClaudeProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if oauthProvider == defaultQwenOAuthProvider || strings.EqualFold(routeName, defaultQwenOAuthProvider) { - return NewQwenProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if oauthProvider == defaultKimiOAuthProvider || strings.EqualFold(routeName, defaultKimiOAuthProvider) { - return NewKimiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(routeName, defaultIFlowOAuthProvider) { - return NewIFlowProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil - } - return NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil -} - -func ProviderSupportsResponsesCompact(cfg *config.Config, name string) bool { - pc, err := getProviderConfigByName(cfg, name) - if err != nil { - return false - } - return pc.SupportsResponsesCompact -} - -func getAllProviderConfigs(cfg *config.Config) map[string]config.ProviderConfig { - return config.AllProviderConfigs(cfg) -} - -func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) { - if pc, ok := config.ProviderConfigByName(cfg, name); ok { - return pc, nil - } - return config.ProviderConfig{}, fmt.Errorf("provider %q not found", strings.TrimSpace(name)) -} diff --git a/pkg/providers/openai_compat_adapter.go b/pkg/providers/openai_compat_adapter.go new file mode 100644 index 0000000..f0376a2 --- /dev/null +++ b/pkg/providers/openai_compat_adapter.go @@ -0,0 +1,368 @@ +package providers + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) { + var payload struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + if len(payload.Choices) == 0 { + return &LLMResponse{}, nil + } + choice := payload.Choices[0] + resp := &LLMResponse{ + Content: choice.Message.Content, + FinishReason: choice.FinishReason, + } + if payload.Usage.TotalTokens > 0 || payload.Usage.PromptTokens > 0 || payload.Usage.CompletionTokens > 0 { + resp.Usage = &UsageInfo{ + PromptTokens: payload.Usage.PromptTokens, + CompletionTokens: payload.Usage.CompletionTokens, + TotalTokens: payload.Usage.TotalTokens, + } + } + if len(choice.Message.ToolCalls) > 0 { + resp.ToolCalls = make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + resp.ToolCalls = append(resp.ToolCalls, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + Name: tc.Function.Name, + }) + } + } + return resp, nil +} + +func (p *HTTPProvider) useCodexCompat() bool { + if p == nil || p.oauth == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultCodexOAuthProvider) { + return false + } + base := strings.ToLower(strings.TrimSpace(p.apiBase)) + if base == "" { + return true + } + return strings.Contains(base, "api.openai.com") || strings.Contains(base, "chatgpt.com/backend-api/codex") +} + +func (p *HTTPProvider) codexCompatBase() string { + if p == nil { + return codexCompatBaseURL + } + base := strings.ToLower(strings.TrimSpace(p.apiBase)) + if strings.Contains(base, "chatgpt.com/backend-api/codex") { + return normalizeAPIBase(p.apiBase) + } + if base != "" && !strings.Contains(base, "api.openai.com") { + return normalizeAPIBase(p.apiBase) + } + return codexCompatBaseURL +} + +func (p *HTTPProvider) codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} { + return codexCompatRequestBody(requestBody) +} + +func (p *HTTPProvider) oauthProvider() string { + if p == nil || p.oauth == nil { + return "" + } + return strings.ToLower(strings.TrimSpace(p.oauth.cfg.Provider)) +} + +func (p *HTTPProvider) useOpenAICompatChatUpstream() bool { + switch p.oauthProvider() { + case defaultQwenOAuthProvider, defaultKimiOAuthProvider: + return true + default: + return false + } +} + +func (p *HTTPProvider) compatBase() string { + switch p.oauthProvider() { + case defaultQwenOAuthProvider: + if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") { + return normalizeAPIBase(p.apiBase) + } + return qwenCompatBaseURL + case defaultKimiOAuthProvider: + if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") { + return normalizeAPIBase(p.apiBase) + } + return kimiCompatBaseURL + default: + return normalizeAPIBase(p.apiBase) + } +} + +func (p *HTTPProvider) compatModel(model string) string { + trimmed := strings.TrimSpace(qwenBaseModel(model)) + if p.oauthProvider() == defaultKimiOAuthProvider && strings.HasPrefix(strings.ToLower(trimmed), "kimi-") { + return trimmed[5:] + } + return trimmed +} + +func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]interface{} { + requestBody := map[string]interface{}{ + "model": p.compatModel(model), + "messages": openAICompatMessages(messages), + } + if suffix := qwenModelSuffix(model); suffix != "" { + applyOpenAICompatThinkingSuffix(requestBody, suffix) + } + if len(tools) > 0 { + requestBody["tools"] = openAICompatTools(tools) + requestBody["tool_choice"] = "auto" + if tc, ok := rawOption(options, "tool_choice"); ok { + requestBody["tool_choice"] = tc + } + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + requestBody["max_tokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + requestBody["temperature"] = temperature + } + return requestBody +} + +func openAICompatMessages(messages []Message) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(messages)) + for _, msg := range messages { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + content := openAICompatMessageContent(msg) + switch role { + case "system": + out = append(out, map[string]interface{}{"role": "system", "content": content}) + case "developer": + out = append(out, map[string]interface{}{"role": "user", "content": content}) + case "assistant": + item := map[string]interface{}{"role": "assistant", "content": content} + if len(msg.ToolCalls) > 0 { + toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + args := "" + if tc.Function != nil { + args = tc.Function.Arguments + } + if args == "" { + raw, _ := json.Marshal(tc.Arguments) + args = string(raw) + } + name := tc.Name + if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" { + name = tc.Function.Name + } + toolCalls = append(toolCalls, map[string]interface{}{ + "id": tc.ID, + "type": "function", + "function": map[string]interface{}{ + "name": name, + "arguments": args, + }, + }) + } + item["tool_calls"] = toolCalls + } + out = append(out, item) + case "tool": + out = append(out, map[string]interface{}{ + "role": "tool", + "tool_call_id": msg.ToolCallID, + "content": content, + }) + default: + out = append(out, map[string]interface{}{"role": "user", "content": content}) + } + } + return out +} + +func openAICompatMessageContent(msg Message) interface{} { + if len(msg.ContentParts) == 0 { + return msg.Content + } + parts := make([]map[string]interface{}, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "text", "input_text": + if strings.TrimSpace(part.Text) == "" { + continue + } + parts = append(parts, map[string]interface{}{ + "type": "text", + "text": part.Text, + }) + case "input_image", "image_url": + imageURL := strings.TrimSpace(part.ImageURL) + if imageURL == "" { + continue + } + payload := map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": imageURL, + }, + } + if detail := strings.TrimSpace(part.Detail); detail != "" { + payload["image_url"].(map[string]interface{})["detail"] = detail + } + parts = append(parts, payload) + default: + if strings.TrimSpace(part.Text) == "" { + continue + } + parts = append(parts, map[string]interface{}{ + "type": "text", + "text": part.Text, + }) + } + } + if len(parts) == 0 { + return msg.Content + } + if len(parts) == 1 && parts[0]["type"] == "text" && len(msg.ToolCalls) == 0 { + if text, _ := parts[0]["text"].(string); text != "" { + return text + } + } + return parts +} + +func openAICompatTools(tools []ToolDefinition) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(tools)) + for _, tool := range tools { + out = append(out, map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": tool.Function.Name, + "description": tool.Function.Description, + "parameters": tool.Function.Parameters, + }, + }) + } + return out +} + +func codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} { + if requestBody == nil { + requestBody = map[string]interface{}{} + } + requestBody["stream"] = true + requestBody["store"] = false + requestBody["parallel_tool_calls"] = true + if _, ok := requestBody["include"]; !ok { + requestBody["include"] = []string{"reasoning.encrypted_content"} + } + delete(requestBody, "max_output_tokens") + delete(requestBody, "max_completion_tokens") + delete(requestBody, "temperature") + delete(requestBody, "top_p") + delete(requestBody, "truncation") + delete(requestBody, "user") + if input, ok := requestBody["input"].([]map[string]interface{}); ok { + for _, item := range input { + if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") { + item["role"] = "developer" + } + } + requestBody["input"] = input + } + return requestBody +} + +func parseCompatFunctionCalls(content string) ([]ToolCall, string) { + if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { + return nil, content + } + blockRe := regexp.MustCompile(`(?is)\s*(.*?)\s*`) + blocks := blockRe.FindAllStringSubmatch(content, -1) + if len(blocks) == 0 { + return nil, content + } + toolCalls := make([]ToolCall, 0, len(blocks)) + for i, block := range blocks { + raw := block[1] + invoke := extractTag(raw, "invoke") + if invoke != "" { + raw = invoke + } + name := extractTag(raw, "toolname") + if strings.TrimSpace(name) == "" { + name = extractTag(raw, "tool_name") + } + name = strings.TrimSpace(name) + if name == "" { + continue + } + args := map[string]interface{}{} + paramsRaw := strings.TrimSpace(extractTag(raw, "parameters")) + if paramsRaw != "" { + if strings.HasPrefix(paramsRaw, "{") && strings.HasSuffix(paramsRaw, "}") { + _ = json.Unmarshal([]byte(paramsRaw), &args) + } + if len(args) == 0 { + paramTagRe := regexp.MustCompile(`(?is)<([a-zA-Z0-9_:-]+)>\s*(.*?)\s*`) + matches := paramTagRe.FindAllStringSubmatch(paramsRaw, -1) + for _, m := range matches { + if len(m) < 4 || !strings.EqualFold(strings.TrimSpace(m[1]), strings.TrimSpace(m[3])) { + continue + } + k := strings.TrimSpace(m[1]) + v := strings.TrimSpace(m[2]) + if k == "" || v == "" { + continue + } + args[k] = v + } + } + } + toolCalls = append(toolCalls, ToolCall{ID: fmt.Sprintf("compat_call_%d", i+1), Name: name, Arguments: args}) + } + cleaned := strings.TrimSpace(blockRe.ReplaceAllString(content, "")) + return toolCalls, cleaned +} + +func extractTag(src string, tag string) string { + re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag))) + m := re.FindStringSubmatch(src) + if len(m) < 2 { + return "" + } + return strings.TrimSpace(m[1]) +} diff --git a/pkg/providers/provider_registry.go b/pkg/providers/provider_registry.go new file mode 100644 index 0000000..8a572ee --- /dev/null +++ b/pkg/providers/provider_registry.go @@ -0,0 +1,139 @@ +package providers + +import ( + "fmt" + "github.com/YspCoder/clawgo/pkg/config" + "strings" + "time" +) + +func normalizeProviderRouteName(name string) string { + switch strings.ToLower(strings.TrimSpace(name)) { + case "geminicli", "gemini_cli": + return "gemini-cli" + case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio": + return "aistudio" + case "google", "gemini-api-key", "gemini_api_key": + return "gemini" + case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key": + return "claude" + case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat": + return "openai-compatibility" + case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility": + return "vertex" + case "codex-api-key", "codex_api_key": + return "codex" + case "i-flow", "i_flow": + return "iflow" + default: + return strings.TrimSpace(name) + } +} + +func CreateProvider(cfg *config.Config) (LLMProvider, error) { + name := config.PrimaryProviderName(cfg) + provider, err := CreateProviderByName(cfg, name) + if err != nil { + return nil, err + } + _, model := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary) + if hp, ok := provider.(*HTTPProvider); ok && strings.TrimSpace(model) != "" { + hp.defaultModel = strings.TrimSpace(model) + } + return provider, nil +} + +func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) { + routeName := normalizeProviderRouteName(name) + pc, err := getProviderConfigByName(cfg, routeName) + if err != nil { + return nil, err + } + ConfigureProviderRuntime(routeName, pc) + oauthProvider := normalizeOAuthProvider(pc.OAuth.Provider) + if pc.APIBase == "" && + oauthProvider != defaultAntigravityOAuthProvider && + oauthProvider != defaultGeminiOAuthProvider && + oauthProvider != "aistudio" && + oauthProvider != defaultCodexOAuthProvider && + oauthProvider != defaultClaudeOAuthProvider && + oauthProvider != defaultQwenOAuthProvider && + oauthProvider != defaultKimiOAuthProvider && + oauthProvider != defaultIFlowOAuthProvider && + !strings.EqualFold(routeName, "gemini-cli") && + !strings.EqualFold(routeName, "aistudio") && + !strings.EqualFold(routeName, "vertex") && + !strings.EqualFold(routeName, defaultAntigravityOAuthProvider) && + !strings.EqualFold(routeName, defaultGeminiOAuthProvider) && + !strings.EqualFold(routeName, defaultCodexOAuthProvider) && + !strings.EqualFold(routeName, defaultClaudeOAuthProvider) && + !strings.EqualFold(routeName, defaultQwenOAuthProvider) && + !strings.EqualFold(routeName, defaultKimiOAuthProvider) && + !strings.EqualFold(routeName, defaultIFlowOAuthProvider) { + return nil, fmt.Errorf("no API base configured for provider %q", name) + } + if pc.TimeoutSec <= 0 { + return nil, fmt.Errorf("invalid timeout_sec for provider %q: %d", name, pc.TimeoutSec) + } + defaultModel := "" + if len(pc.Models) > 0 { + defaultModel = pc.Models[0] + } + var oauth *oauthManager + if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + oauth, err = newOAuthManager(pc, time.Duration(pc.TimeoutSec)*time.Second) + if err != nil { + return nil, err + } + } + if oauthProvider == defaultAntigravityOAuthProvider || strings.EqualFold(routeName, defaultAntigravityOAuthProvider) { + return NewAntigravityProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == "aistudio" || strings.EqualFold(routeName, "aistudio") { + return NewAistudioProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if strings.EqualFold(routeName, "gemini-cli") { + return NewGeminiCLIProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultGeminiOAuthProvider || strings.EqualFold(routeName, defaultGeminiOAuthProvider) || strings.EqualFold(routeName, "aistudio") { + return NewGeminiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if strings.EqualFold(routeName, "vertex") { + return NewVertexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultCodexOAuthProvider || strings.EqualFold(routeName, defaultCodexOAuthProvider) { + return NewCodexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultClaudeOAuthProvider || strings.EqualFold(routeName, defaultClaudeOAuthProvider) { + return NewClaudeProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultQwenOAuthProvider || strings.EqualFold(routeName, defaultQwenOAuthProvider) { + return NewQwenProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultKimiOAuthProvider || strings.EqualFold(routeName, defaultKimiOAuthProvider) { + return NewKimiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(routeName, defaultIFlowOAuthProvider) { + return NewIFlowProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + return NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil +} + +func ProviderSupportsResponsesCompact(cfg *config.Config, name string) bool { + pc, err := getProviderConfigByName(cfg, name) + if err != nil { + return false + } + return pc.SupportsResponsesCompact +} + +func getAllProviderConfigs(cfg *config.Config) map[string]config.ProviderConfig { + return config.AllProviderConfigs(cfg) +} + +func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) { + if pc, ok := config.ProviderConfigByName(cfg, name); ok { + return pc, nil + } + return config.ProviderConfig{}, fmt.Errorf("provider %q not found", strings.TrimSpace(name)) +} diff --git a/pkg/providers/provider_request_options.go b/pkg/providers/provider_request_options.go new file mode 100644 index 0000000..ed442a5 --- /dev/null +++ b/pkg/providers/provider_request_options.go @@ -0,0 +1,171 @@ +package providers + +import ( + "fmt" + "net/url" + "strings" +) + +func rawOption(options map[string]interface{}, key string) (interface{}, bool) { + if options == nil { + return nil, false + } + v, ok := options[key] + if !ok || v == nil { + return nil, false + } + return v, true +} + +func stringOption(options map[string]interface{}, key string) (string, bool) { + v, ok := rawOption(options, key) + if !ok { + return "", false + } + s, ok := v.(string) + if !ok { + return "", false + } + return strings.TrimSpace(s), true +} + +func mapOption(options map[string]interface{}, key string) (map[string]interface{}, bool) { + v, ok := rawOption(options, key) + if !ok { + return nil, false + } + m, ok := v.(map[string]interface{}) + return m, ok +} + +func stringSliceOption(options map[string]interface{}, key string) ([]string, bool) { + v, ok := rawOption(options, key) + if !ok { + return nil, false + } + switch t := v.(type) { + case []string: + out := make([]string, 0, len(t)) + for _, item := range t { + if s := strings.TrimSpace(item); s != "" { + out = append(out, s) + } + } + return out, true + case []interface{}: + out := make([]string, 0, len(t)) + for _, item := range t { + s := strings.TrimSpace(fmt.Sprintf("%v", item)) + if s != "" { + out = append(out, s) + } + } + return out, true + } + return nil, false +} + +func mapSliceOption(options map[string]interface{}, key string) ([]map[string]interface{}, bool) { + v, ok := rawOption(options, key) + if !ok { + return nil, false + } + switch t := v.(type) { + case []map[string]interface{}: + return t, true + case []interface{}: + out := make([]map[string]interface{}, 0, len(t)) + for _, item := range t { + m, ok := item.(map[string]interface{}) + if ok { + out = append(out, m) + } + } + return out, true + } + return nil, false +} + +func previewResponseBody(body []byte) string { + preview := strings.TrimSpace(string(body)) + preview = strings.ReplaceAll(preview, "\n", " ") + preview = strings.ReplaceAll(preview, "\r", " ") + if preview == "" { + return "" + } + const maxLen = 600 + if len(preview) > maxLen { + return preview[:maxLen] + "..." + } + return preview +} + +func int64FromOption(options map[string]interface{}, key string) (int64, bool) { + if options == nil { + return 0, false + } + v, ok := options[key] + if !ok { + return 0, false + } + switch t := v.(type) { + case int: + return int64(t), true + case int64: + return t, true + case float64: + return int64(t), true + default: + return 0, false + } +} + +func float64FromOption(options map[string]interface{}, key string) (float64, bool) { + if options == nil { + return 0, false + } + v, ok := options[key] + if !ok { + return 0, false + } + switch t := v.(type) { + case float32: + return float64(t), true + case float64: + return t, true + case int: + return float64(t), true + default: + return 0, false + } +} + +func normalizeAPIBase(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + u, err := url.Parse(trimmed) + if err != nil { + return strings.TrimRight(trimmed, "/") + } + u.Path = strings.TrimRight(u.Path, "/") + return strings.TrimRight(u.String(), "/") +} + +func endpointFor(base, relative string) string { + b := strings.TrimRight(strings.TrimSpace(base), "/") + if b == "" { + return relative + } + if strings.HasSuffix(b, relative) { + return b + } + if relative == "/responses/compact" && strings.HasSuffix(b, "/responses") { + return b + "/compact" + } + if relative == "/responses" && strings.HasSuffix(b, "/responses/compact") { + return strings.TrimSuffix(b, "/compact") + } + return b + relative +} diff --git a/pkg/providers/provider_runtime.go b/pkg/providers/provider_runtime.go new file mode 100644 index 0000000..9811c15 --- /dev/null +++ b/pkg/providers/provider_runtime.go @@ -0,0 +1,1149 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "github.com/YspCoder/clawgo/pkg/config" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +type providerAPIRuntimeState struct { + TokenMasked string `json:"token_masked,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` + FailureCount int `json:"failure_count,omitempty"` + LastFailure string `json:"last_failure,omitempty"` + HealthScore int `json:"health_score,omitempty"` +} + +type providerRuntimeEvent struct { + When string `json:"when,omitempty"` + Kind string `json:"kind,omitempty"` + Target string `json:"target,omitempty"` + Reason string `json:"reason,omitempty"` + Detail string `json:"detail,omitempty"` +} + +func recordProviderRuntimeChange(providerName, kind, target, reason, detail string) { + name := strings.TrimSpace(providerName) + if name == "" || strings.TrimSpace(reason) == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: strings.TrimSpace(kind), + Target: strings.TrimSpace(target), + Reason: strings.TrimSpace(reason), + Detail: strings.TrimSpace(detail), + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +type providerRuntimeCandidate struct { + Kind string `json:"kind,omitempty"` + Target string `json:"target,omitempty"` + Available bool `json:"available"` + Status string `json:"status,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` + HealthScore int `json:"health_score,omitempty"` + FailureCount int `json:"failure_count,omitempty"` +} + +type providerRuntimePersistConfig struct { + Enabled bool + File string + MaxEvents int + Loaded bool + LoadAttempt bool +} + +type ProviderRuntimeQuery struct { + Provider string + Window time.Duration + EventKind string + Reason string + Target string + Limit int + Cursor int + HealthBelow int + CooldownBefore time.Time + Sort string + ChangesOnly bool +} + +type ProviderRefreshAccountResult struct { + Target string `json:"target,omitempty"` + Status string `json:"status,omitempty"` + Detail string `json:"detail,omitempty"` + Expire string `json:"expire,omitempty"` +} + +type ProviderRefreshResult struct { + Provider string `json:"provider,omitempty"` + Checked int `json:"checked,omitempty"` + Refreshed int `json:"refreshed,omitempty"` + Skipped int `json:"skipped,omitempty"` + Failed int `json:"failed,omitempty"` + Accounts []ProviderRefreshAccountResult `json:"accounts,omitempty"` +} + +type ProviderRuntimeSummaryItem struct { + Name string `json:"name,omitempty"` + Auth string `json:"auth,omitempty"` + Status string `json:"status,omitempty"` + APIState providerAPIRuntimeState `json:"api_state,omitempty"` + OAuthAccounts []OAuthAccountInfo `json:"oauth_accounts,omitempty"` + CandidateOrder []providerRuntimeCandidate `json:"candidate_order,omitempty"` + LastSuccess *providerRuntimeEvent `json:"last_success,omitempty"` + LastSuccessAt string `json:"last_success_at,omitempty"` + LastError *providerRuntimeEvent `json:"last_error,omitempty"` + LastErrorAt string `json:"last_error_at,omitempty"` + LastErrorReason string `json:"last_error_reason,omitempty"` + TopCandidateChangedAt string `json:"top_candidate_changed_at,omitempty"` + StaleForSec int64 `json:"stale_for_sec,omitempty"` + InCooldown bool `json:"in_cooldown"` + LowHealth bool `json:"low_health"` + HasRecentErrors bool `json:"has_recent_errors"` + TopCandidate *providerRuntimeCandidate `json:"top_candidate,omitempty"` +} + +type ProviderRuntimeSummary struct { + TotalProviders int `json:"total_providers"` + Healthy int `json:"healthy"` + Degraded int `json:"degraded"` + Critical int `json:"critical"` + InCooldown int `json:"in_cooldown"` + LowHealth int `json:"low_health"` + RecentErrors int `json:"recent_errors"` + Providers []ProviderRuntimeSummaryItem `json:"providers,omitempty"` +} + +type providerRuntimeState struct { + API providerAPIRuntimeState `json:"api_state,omitempty"` + RecentHits []providerRuntimeEvent `json:"recent_hits,omitempty"` + RecentErrors []providerRuntimeEvent `json:"recent_errors,omitempty"` + RecentChanges []providerRuntimeEvent `json:"recent_changes,omitempty"` + LastSuccess *providerRuntimeEvent `json:"last_success,omitempty"` + CandidateOrder []providerRuntimeCandidate `json:"candidate_order,omitempty"` + Persist providerRuntimePersistConfig `json:"-"` +} + +var providerRuntimeRegistry = struct { + mu sync.Mutex + api map[string]providerRuntimeState +}{api: map[string]providerRuntimeState{}} + +func ConfigureProviderRuntime(providerName string, pc config.ProviderConfig) { + name := strings.TrimSpace(providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.Persist = providerRuntimePersistConfig{ + Enabled: pc.RuntimePersist, + File: runtimeHistoryFile(name, pc), + MaxEvents: runtimeHistoryMax(pc), + Loaded: state.Persist.Loaded, + LoadAttempt: state.Persist.LoadAttempt, + } + if state.Persist.Enabled && !state.Persist.LoadAttempt { + state.Persist.LoadAttempt = true + loadPersistedProviderRuntimeLocked(name, &state) + } + providerRuntimeRegistry.api[name] = state +} + +type authAttempt struct { + session *oauthSession + token string + kind string +} + +func (p *HTTPProvider) authAttempts(ctx context.Context) ([]authAttempt, error) { + mode := strings.ToLower(strings.TrimSpace(p.authMode)) + if mode == "oauth" || mode == "hybrid" { + out := make([]authAttempt, 0, 1) + apiAttempt, apiReady := p.apiKeyAttempt() + if p.oauth == nil { + if mode == "hybrid" && apiReady { + return []authAttempt{apiAttempt}, nil + } + return nil, fmt.Errorf("oauth is enabled but provider session manager is not configured") + } + attempts, err := p.oauth.prepareAttemptsLocked(ctx) + if err != nil { + return nil, err + } + oauthAttempts := make([]authAttempt, 0, len(attempts)) + for _, attempt := range attempts { + oauthAttempts = append(oauthAttempts, authAttempt{session: attempt.Session, token: attempt.Token, kind: "oauth"}) + } + if mode == "hybrid" && apiReady { + out = append(out, apiAttempt) + } + if len(attempts) == 0 { + if len(out) > 0 { + p.updateCandidateOrder(out) + return out, nil + } + return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first") + } + out = append(out, oauthAttempts...) + p.updateCandidateOrder(out) + return out, nil + } + apiAttempt, apiReady := p.apiKeyAttempt() + if !apiReady { + return nil, fmt.Errorf("api key temporarily unavailable") + } + out := []authAttempt{apiAttempt} + p.updateCandidateOrder(out) + return out, nil +} + +func (p *HTTPProvider) updateCandidateOrder(attempts []authAttempt) { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + candidates := make([]providerRuntimeCandidate, 0, len(attempts)) + for _, attempt := range attempts { + candidate := providerRuntimeCandidate{ + Kind: attempt.kind, + Available: true, + Status: "ready", + } + if attempt.kind == "api_key" { + candidate.Target = maskToken(p.apiKey) + candidate.HealthScore = providerAPIHealth(name) + } else if attempt.session != nil { + candidate.Target = firstNonEmpty(attempt.session.Email, attempt.session.AccountID, attempt.session.FilePath) + candidate.HealthScore = sessionHealthScore(attempt.session) + candidate.FailureCount = attempt.session.FailureCount + candidate.CooldownUntil = attempt.session.CooldownUntil + } + candidates = append(candidates, candidate) + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if !providerCandidatesEqual(state.CandidateOrder, candidates) { + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "scheduler", + Target: name, + Reason: "candidate_order_changed", + Detail: candidateOrderChangeDetail(state.CandidateOrder, candidates), + }, runtimeEventLimit(state)) + } + state.CandidateOrder = candidates + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func (p *HTTPProvider) markAttemptSuccess(attempt authAttempt) { + if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { + p.oauth.markSuccess(attempt.session) + } + if attempt.kind == "api_key" { + p.markAPIKeySuccess() + } + p.recordProviderHit(attempt, "") +} + +func (p *HTTPProvider) markAPIKeyFailure(reason oauthFailureReason) { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.HealthScore <= 0 { + state.API.HealthScore = 100 + } + state.API.FailureCount++ + state.API.LastFailure = string(reason) + state.API.HealthScore = maxInt(1, state.API.HealthScore-healthPenaltyForReason(reason)) + cooldown := 15 * time.Minute + switch reason { + case oauthFailureQuota: + cooldown = 60 * time.Minute + case oauthFailureForbidden: + cooldown = 30 * time.Minute + } + state.API.CooldownUntil = time.Now().Add(cooldown).Format(time.RFC3339) + state.API.TokenMasked = maskToken(p.apiKey) + state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: maskToken(p.apiKey), + Reason: string(reason), + }, runtimeEventLimit(state)) + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: maskToken(p.apiKey), + Reason: "api_key_cooldown_" + string(reason), + Detail: "api key entered cooldown after request failure", + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func (p *HTTPProvider) markAPIKeySuccess() { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.HealthScore <= 0 { + state.API.HealthScore = 100 + } else { + state.API.HealthScore = minInt(100, state.API.HealthScore+3) + } + wasCooling := strings.TrimSpace(state.API.CooldownUntil) != "" + state.API.CooldownUntil = "" + state.API.TokenMasked = maskToken(p.apiKey) + if wasCooling { + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: maskToken(p.apiKey), + Reason: "api_key_recovered", + Detail: "api key cooldown cleared after successful request", + }, runtimeEventLimit(state)) + } + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func (p *HTTPProvider) apiKeyAttempt() (authAttempt, bool) { + token := strings.TrimSpace(p.apiKey) + if token == "" { + return authAttempt{}, false + } + name := strings.TrimSpace(p.providerName) + if name == "" { + return authAttempt{token: token, kind: "api_key"}, true + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.TokenMasked == "" { + state.API.TokenMasked = maskToken(token) + } + if state.API.HealthScore <= 0 { + state.API.HealthScore = 100 + } + if state.API.CooldownUntil != "" { + if until, err := time.Parse(time.RFC3339, state.API.CooldownUntil); err == nil { + if time.Now().Before(until) { + providerRuntimeRegistry.api[name] = state + return authAttempt{}, false + } + } + state.API.CooldownUntil = "" + } + providerRuntimeRegistry.api[name] = state + return authAttempt{token: token, kind: "api_key"}, true +} + +func providerAPIHealth(name string) int { + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.HealthScore <= 0 { + return 100 + } + return state.API.HealthScore +} + +func maskToken(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if len(value) <= 8 { + return value[:2] + "***" + } + return value[:4] + "***" + value[len(value)-4:] +} + +func appendRuntimeEvent(events []providerRuntimeEvent, event providerRuntimeEvent, limit int) []providerRuntimeEvent { + out := append([]providerRuntimeEvent{event}, events...) + if limit <= 0 { + limit = 8 + } + if len(out) > limit { + out = out[:limit] + } + return out +} + +func (p *HTTPProvider) recordProviderHit(attempt authAttempt, reason string) { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + target := "" + if attempt.kind == "api_key" { + target = maskToken(p.apiKey) + } else if attempt.session != nil { + target = firstNonEmpty(attempt.session.Email, attempt.session.AccountID, attempt.session.FilePath) + } + event := providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: attempt.kind, + Target: target, + Reason: reason, + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentHits = appendRuntimeEvent(state.RecentHits, event, runtimeEventLimit(state)) + state.LastSuccess = &event + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func recordProviderOAuthError(providerName string, session *oauthSession, reason oauthFailureReason) { + name := strings.TrimSpace(providerName) + if name == "" || session == nil { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "oauth", + Target: firstNonEmpty(session.Email, session.AccountID, session.FilePath), + Reason: string(reason), + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func ClearProviderAPICooldown(providerName string) { + name := strings.TrimSpace(providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + target := state.API.TokenMasked + state.API.CooldownUntil = "" + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: target, + Reason: "manual_clear_api_cooldown", + Detail: "api key cooldown cleared from runtime panel", + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func ClearProviderRuntimeHistory(providerName string) { + name := strings.TrimSpace(providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentHits = nil + state.RecentErrors = nil + state.RecentChanges = nil + state.LastSuccess = nil + if state.Persist.Enabled && strings.TrimSpace(state.Persist.File) != "" { + _ = os.Remove(state.Persist.File) + } + providerRuntimeRegistry.api[name] = state +} + +func runtimeEventLimit(state providerRuntimeState) int { + if state.Persist.MaxEvents > 0 { + return state.Persist.MaxEvents + } + return 8 +} + +func runtimeHistoryMax(pc config.ProviderConfig) int { + if pc.RuntimeHistoryMax > 0 { + return pc.RuntimeHistoryMax + } + return 24 +} + +func runtimeHistoryFile(name string, pc config.ProviderConfig) string { + if file := strings.TrimSpace(pc.RuntimeHistoryFile); file != "" { + return file + } + return filepath.Join(config.GetConfigDir(), "runtime", "providers", strings.TrimSpace(name)+".json") +} + +func loadPersistedProviderRuntimeLocked(name string, state *providerRuntimeState) { + if state == nil || !state.Persist.Enabled || strings.TrimSpace(state.Persist.File) == "" { + return + } + raw, err := os.ReadFile(state.Persist.File) + if err != nil { + if os.IsNotExist(err) { + state.Persist.Loaded = true + } + return + } + var persisted providerRuntimeState + if err := json.Unmarshal(raw, &persisted); err != nil { + state.Persist.Loaded = true + return + } + if state.API == (providerAPIRuntimeState{}) { + state.API = persisted.API + } + if len(state.RecentHits) == 0 { + state.RecentHits = persisted.RecentHits + } + if len(state.RecentErrors) == 0 { + state.RecentErrors = persisted.RecentErrors + } + if len(state.RecentChanges) == 0 { + state.RecentChanges = persisted.RecentChanges + } + if state.LastSuccess == nil && persisted.LastSuccess != nil { + last := *persisted.LastSuccess + state.LastSuccess = &last + } + if len(state.CandidateOrder) == 0 { + state.CandidateOrder = persisted.CandidateOrder + } + state.Persist.Loaded = true +} + +func persistProviderRuntimeLocked(name string, state providerRuntimeState) { + if !state.Persist.Enabled || strings.TrimSpace(state.Persist.File) == "" { + return + } + if err := os.MkdirAll(filepath.Dir(state.Persist.File), 0o700); err != nil { + return + } + payload := providerRuntimeState{ + API: state.API, + RecentHits: trimRuntimeEvents(state.RecentHits, runtimeEventLimit(state)), + RecentErrors: trimRuntimeEvents(state.RecentErrors, runtimeEventLimit(state)), + RecentChanges: trimRuntimeEvents(state.RecentChanges, runtimeEventLimit(state)), + LastSuccess: state.LastSuccess, + CandidateOrder: state.CandidateOrder, + } + raw, err := json.MarshalIndent(payload, "", " ") + if err != nil { + return + } + _ = os.WriteFile(state.Persist.File, raw, 0o600) +} + +func trimRuntimeEvents(events []providerRuntimeEvent, limit int) []providerRuntimeEvent { + if limit <= 0 || len(events) <= limit { + return events + } + return events[:limit] +} + +func eventTimeUnix(event providerRuntimeEvent) int64 { + when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) + if err != nil { + return 0 + } + return when.Unix() +} + +func filterRuntimeEvents(events []providerRuntimeEvent, query ProviderRuntimeQuery) []providerRuntimeEvent { + if len(events) == 0 { + return nil + } + kind := strings.TrimSpace(query.EventKind) + reason := strings.TrimSpace(query.Reason) + target := strings.ToLower(strings.TrimSpace(query.Target)) + var cutoff time.Time + if query.Window > 0 { + cutoff = time.Now().Add(-query.Window) + } + filtered := make([]providerRuntimeEvent, 0, len(events)) + for _, event := range events { + if !cutoff.IsZero() { + when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) + if err != nil || when.Before(cutoff) { + continue + } + } + if kind != "" && !strings.EqualFold(strings.TrimSpace(event.Kind), kind) { + continue + } + if reason != "" && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Reason)), strings.ToLower(reason)) { + continue + } + if target != "" && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Target)), target) && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Detail)), target) { + continue + } + filtered = append(filtered, event) + } + return filtered +} + +func mergeRuntimeEvents(item map[string]interface{}, query ProviderRuntimeQuery) ([]providerRuntimeEvent, int) { + hits, _ := item["recent_hits"].([]providerRuntimeEvent) + errors, _ := item["recent_errors"].([]providerRuntimeEvent) + changes, _ := item["recent_changes"].([]providerRuntimeEvent) + merged := make([]providerRuntimeEvent, 0, len(hits)+len(errors)+len(changes)) + if !query.ChangesOnly { + merged = append(merged, filterRuntimeEvents(hits, query)...) + merged = append(merged, filterRuntimeEvents(errors, query)...) + } + merged = append(merged, filterRuntimeEvents(changes, query)...) + desc := !strings.EqualFold(strings.TrimSpace(query.Sort), "asc") + for i := 0; i < len(merged); i++ { + for j := i + 1; j < len(merged); j++ { + left := eventTimeUnix(merged[i]) + right := eventTimeUnix(merged[j]) + swap := right > left + if !desc { + swap = right < left + } + if swap { + merged[i], merged[j] = merged[j], merged[i] + } + } + } + start := query.Cursor + if start < 0 { + start = 0 + } + if start > len(merged) { + start = len(merged) + } + limit := query.Limit + if limit <= 0 { + limit = 20 + } + end := start + limit + if end > len(merged) { + end = len(merged) + } + nextCursor := 0 + if end < len(merged) { + nextCursor = end + } + return merged[start:end], nextCursor +} + +func matchesProviderCandidateFilters(item map[string]interface{}, query ProviderRuntimeQuery) bool { + if query.HealthBelow <= 0 && query.CooldownBefore.IsZero() { + return true + } + apiState, _ := item["api_state"].(providerAPIRuntimeState) + candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) + if query.HealthBelow > 0 { + if runtimeHealthValue(apiState.HealthScore) < query.HealthBelow { + return true + } + for _, candidate := range candidates { + if runtimeHealthValue(candidate.HealthScore) < query.HealthBelow { + return true + } + } + } + if !query.CooldownBefore.IsZero() { + values := []string{apiState.CooldownUntil} + for _, candidate := range candidates { + values = append(values, candidate.CooldownUntil) + } + for _, value := range values { + if strings.TrimSpace(value) == "" { + continue + } + until, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) + if err == nil && until.Before(query.CooldownBefore) { + return true + } + } + } + return false +} + +func providerInCooldown(item map[string]interface{}) bool { + apiState, _ := item["api_state"].(providerAPIRuntimeState) + if cooldownActive(apiState.CooldownUntil) { + return true + } + candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) + for _, candidate := range candidates { + if cooldownActive(candidate.CooldownUntil) { + return true + } + } + return false +} + +func cooldownActive(value string) bool { + if strings.TrimSpace(value) == "" { + return false + } + until, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) + return err == nil && time.Now().Before(until) +} + +func buildProviderCandidateOrder(_ string, pc config.ProviderConfig, accounts []OAuthAccountInfo, api providerAPIRuntimeState) []providerRuntimeCandidate { + authMode := strings.ToLower(strings.TrimSpace(pc.Auth)) + apiCandidate := providerRuntimeCandidate{ + Kind: "api_key", + Target: maskToken(pc.APIKey), + Available: strings.TrimSpace(pc.APIKey) != "", + Status: "ready", + CooldownUntil: strings.TrimSpace(api.CooldownUntil), + HealthScore: runtimeHealthValue(api.HealthScore), + FailureCount: api.FailureCount, + } + if strings.TrimSpace(apiCandidate.CooldownUntil) != "" { + if until, err := time.Parse(time.RFC3339, apiCandidate.CooldownUntil); err == nil && time.Now().Before(until) { + apiCandidate.Available = false + apiCandidate.Status = "cooldown" + } + } + oauthAvailable := make([]providerRuntimeCandidate, 0, len(accounts)) + oauthUnavailable := make([]providerRuntimeCandidate, 0, len(accounts)) + for _, account := range accounts { + candidate := providerRuntimeCandidate{ + Kind: "oauth", + Target: firstNonEmpty(account.Email, account.AccountID, account.CredentialFile), + Available: true, + Status: "ready", + CooldownUntil: strings.TrimSpace(account.CooldownUntil), + HealthScore: runtimeHealthValue(account.HealthScore), + FailureCount: account.FailureCount, + } + if strings.TrimSpace(candidate.CooldownUntil) != "" { + if until, err := time.Parse(time.RFC3339, candidate.CooldownUntil); err == nil && time.Now().Before(until) { + candidate.Available = false + candidate.Status = "cooldown" + } + } + if candidate.Available { + oauthAvailable = append(oauthAvailable, candidate) + } else { + oauthUnavailable = append(oauthUnavailable, candidate) + } + } + sortRuntimeCandidates(oauthAvailable) + sortRuntimeCandidates(oauthUnavailable) + out := make([]providerRuntimeCandidate, 0, 1+len(accounts)) + switch authMode { + case "oauth": + out = append(out, oauthAvailable...) + case "hybrid": + if apiCandidate.Target != "" && apiCandidate.Available { + out = append(out, apiCandidate) + } + out = append(out, oauthAvailable...) + case "none": + default: + if apiCandidate.Target != "" { + out = append(out, apiCandidate) + } + } + if authMode == "hybrid" { + if apiCandidate.Target != "" && !apiCandidate.Available { + out = append(out, apiCandidate) + } + out = append(out, oauthUnavailable...) + } else if authMode == "oauth" { + out = append(out, oauthUnavailable...) + } + return out +} + +func runtimeHealthValue(value int) int { + if value <= 0 { + return 100 + } + return value +} + +func sortRuntimeCandidates(items []providerRuntimeCandidate) { + for i := 0; i < len(items); i++ { + for j := i + 1; j < len(items); j++ { + if items[j].HealthScore > items[i].HealthScore || (items[j].HealthScore == items[i].HealthScore && items[j].Target < items[i].Target) { + items[i], items[j] = items[j], items[i] + } + } + } +} + +func providerCandidatesEqual(left, right []providerRuntimeCandidate) bool { + if len(left) != len(right) { + return false + } + for i := range left { + if left[i].Kind != right[i].Kind || + left[i].Target != right[i].Target || + left[i].Available != right[i].Available || + left[i].Status != right[i].Status || + left[i].CooldownUntil != right[i].CooldownUntil || + left[i].HealthScore != right[i].HealthScore || + left[i].FailureCount != right[i].FailureCount { + return false + } + } + return true +} + +func summarizeCandidate(candidate providerRuntimeCandidate) string { + target := strings.TrimSpace(candidate.Target) + if target == "" { + target = "-" + } + return strings.TrimSpace(candidate.Kind) + ":" + target +} + +func candidateOrderChangeDetail(before, after []providerRuntimeCandidate) string { + if len(before) == 0 && len(after) == 0 { + return "" + } + beforeTop := "-" + afterTop := "-" + if len(before) > 0 { + beforeTop = summarizeCandidate(before[0]) + } + if len(after) > 0 { + afterTop = summarizeCandidate(after[0]) + } + beforeOrder := make([]string, 0, len(before)) + for _, item := range before { + beforeOrder = append(beforeOrder, summarizeCandidate(item)) + } + afterOrder := make([]string, 0, len(after)) + for _, item := range after { + afterOrder = append(afterOrder, summarizeCandidate(item)) + } + return fmt.Sprintf("top %s -> %s | order [%s] -> [%s]", beforeTop, afterTop, strings.Join(beforeOrder, " > "), strings.Join(afterOrder, " > ")) +} + +func GetProviderRuntimeSnapshot(cfg *config.Config) map[string]interface{} { + if cfg == nil { + return map[string]interface{}{"items": []interface{}{}} + } + items := make([]map[string]interface{}, 0) + configs := getAllProviderConfigs(cfg) + for name, pc := range configs { + ConfigureProviderRuntime(name, pc) + providerRuntimeRegistry.mu.Lock() + state := providerRuntimeRegistry.api[name] + providerRuntimeRegistry.mu.Unlock() + item := map[string]interface{}{ + "name": name, + "auth": strings.TrimSpace(pc.Auth), + "api_base": strings.TrimSpace(pc.APIBase), + "api_state": state.API, + "recent_hits": state.RecentHits, + "recent_errors": state.RecentErrors, + "recent_changes": state.RecentChanges, + "last_success": state.LastSuccess, + } + candidateOrder := state.CandidateOrder + if strings.EqualFold(name, "aistudio") { + if accounts := listAIStudioRelayAccounts(); len(accounts) > 0 { + item["oauth_accounts"] = accounts + } + } else if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + if mgr, err := NewOAuthLoginManager(pc, time.Duration(maxInt(pc.TimeoutSec, 90))*time.Second); err == nil { + if accounts, err := mgr.ListAccounts(); err == nil { + item["oauth_accounts"] = accounts + candidateOrder = buildProviderCandidateOrder(name, pc, accounts, state.API) + } + } + } else if len(candidateOrder) == 0 && strings.TrimSpace(pc.APIKey) != "" { + candidateOrder = buildProviderCandidateOrder(name, pc, nil, state.API) + } + if len(candidateOrder) > 0 { + providerRuntimeRegistry.mu.Lock() + state = providerRuntimeRegistry.api[name] + state.CandidateOrder = candidateOrder + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state + providerRuntimeRegistry.mu.Unlock() + } + item["candidate_order"] = candidateOrder + items = append(items, item) + } + return map[string]interface{}{"items": items} +} + +func GetProviderRuntimeView(cfg *config.Config, query ProviderRuntimeQuery) map[string]interface{} { + if cfg == nil { + return map[string]interface{}{"items": []interface{}{}} + } + snapshot := GetProviderRuntimeSnapshot(cfg) + rawItems, _ := snapshot["items"].([]map[string]interface{}) + if len(rawItems) == 0 { + return map[string]interface{}{"items": []interface{}{}} + } + filterName := strings.TrimSpace(query.Provider) + items := make([]map[string]interface{}, 0, len(rawItems)) + for _, item := range rawItems { + name := strings.TrimSpace(fmt.Sprintf("%v", item["name"])) + if filterName != "" && name != filterName { + continue + } + next := map[string]interface{}{} + for key, value := range item { + next[key] = value + } + hits, _ := item["recent_hits"].([]providerRuntimeEvent) + errors, _ := item["recent_errors"].([]providerRuntimeEvent) + changes, _ := item["recent_changes"].([]providerRuntimeEvent) + next["recent_hits"] = filterRuntimeEvents(hits, query) + next["recent_errors"] = filterRuntimeEvents(errors, query) + next["recent_changes"] = filterRuntimeEvents(changes, query) + if query.ChangesOnly { + next["recent_hits"] = []providerRuntimeEvent{} + next["recent_errors"] = []providerRuntimeEvent{} + } + events, nextCursor := mergeRuntimeEvents(next, query) + next["events"] = events + next["next_cursor"] = nextCursor + if !matchesProviderCandidateFilters(next, query) { + continue + } + items = append(items, next) + } + return map[string]interface{}{"items": items} +} + +func GetProviderRuntimeSummary(cfg *config.Config, query ProviderRuntimeQuery) ProviderRuntimeSummary { + snapshot := GetProviderRuntimeSnapshot(cfg) + rawItems, _ := snapshot["items"].([]map[string]interface{}) + summary := ProviderRuntimeSummary{Providers: make([]ProviderRuntimeSummaryItem, 0, len(rawItems))} + for _, item := range rawItems { + name := strings.TrimSpace(fmt.Sprintf("%v", item["name"])) + if strings.TrimSpace(query.Provider) != "" && name != strings.TrimSpace(query.Provider) { + continue + } + auth := strings.TrimSpace(fmt.Sprintf("%v", item["auth"])) + apiState, _ := item["api_state"].(providerAPIRuntimeState) + accounts, _ := item["oauth_accounts"].([]OAuthAccountInfo) + candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) + errors, _ := item["recent_errors"].([]providerRuntimeEvent) + changes, _ := item["recent_changes"].([]providerRuntimeEvent) + errors = filterRuntimeEvents(errors, query) + changes = filterRuntimeEvents(changes, query) + lastSuccess, _ := item["last_success"].(*providerRuntimeEvent) + inCooldown := providerInCooldown(item) + lowHealth := matchesProviderCandidateFilters(item, ProviderRuntimeQuery{HealthBelow: maxInt(query.HealthBelow, 1)}) + hasRecentErrors := len(errors) > 0 + lastError := latestProviderRuntimeEvent(errors) + topChangedAt := latestRuntimeChangeAt(changes, "candidate_order_changed") + status := providerRuntimeSummaryStatus(inCooldown, lowHealth, hasRecentErrors) + providerItem := ProviderRuntimeSummaryItem{ + Name: name, + Auth: auth, + Status: status, + APIState: apiState, + OAuthAccounts: accounts, + CandidateOrder: candidates, + LastSuccess: lastSuccess, + LastError: lastError, + TopCandidateChangedAt: topChangedAt, + InCooldown: inCooldown, + LowHealth: lowHealth, + HasRecentErrors: hasRecentErrors, + } + if lastSuccess != nil { + providerItem.LastSuccessAt = strings.TrimSpace(lastSuccess.When) + if when := parseRuntimeEventTime(*lastSuccess); !when.IsZero() { + providerItem.StaleForSec = int64(time.Since(when).Seconds()) + } + } else { + providerItem.StaleForSec = -1 + } + if lastError != nil { + providerItem.LastErrorAt = strings.TrimSpace(lastError.When) + providerItem.LastErrorReason = strings.TrimSpace(lastError.Reason) + } + if len(candidates) > 0 { + top := candidates[0] + providerItem.TopCandidate = &top + } + summary.TotalProviders++ + switch status { + case "critical": + summary.Critical++ + case "degraded": + summary.Degraded++ + default: + summary.Healthy++ + } + if inCooldown { + summary.InCooldown++ + } + if lowHealth { + summary.LowHealth++ + } + if hasRecentErrors { + summary.RecentErrors++ + } + if inCooldown || lowHealth || hasRecentErrors || strings.TrimSpace(query.Provider) != "" { + summary.Providers = append(summary.Providers, providerItem) + } + } + return summary +} + +func latestProviderRuntimeEvent(events []providerRuntimeEvent) *providerRuntimeEvent { + if len(events) == 0 { + return nil + } + best := events[0] + bestTime := eventTimeUnix(best) + for i := 1; i < len(events); i++ { + currentTime := eventTimeUnix(events[i]) + if currentTime > bestTime { + best = events[i] + bestTime = currentTime + } + } + copyEvent := best + return ©Event +} + +func latestRuntimeChangeAt(events []providerRuntimeEvent, reason string) string { + targetReason := strings.TrimSpace(reason) + if targetReason == "" || len(events) == 0 { + return "" + } + var latest *providerRuntimeEvent + var latestUnix int64 + for i := range events { + if !strings.EqualFold(strings.TrimSpace(events[i].Reason), targetReason) { + continue + } + currentUnix := eventTimeUnix(events[i]) + if latest == nil || currentUnix > latestUnix { + eventCopy := events[i] + latest = &eventCopy + latestUnix = currentUnix + } + } + if latest == nil { + return "" + } + return strings.TrimSpace(latest.When) +} + +func parseRuntimeEventTime(event providerRuntimeEvent) time.Time { + when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) + if err != nil { + return time.Time{} + } + return when +} + +func providerRuntimeSummaryStatus(inCooldown, lowHealth, hasRecentErrors bool) string { + if inCooldown || lowHealth { + return "critical" + } + if hasRecentErrors { + return "degraded" + } + return "healthy" +} + +func RefreshProviderRuntimeNow(cfg *config.Config, providerName string, onlyExpiring bool) (*ProviderRefreshResult, error) { + pc, err := getProviderConfigByName(cfg, providerName) + if err != nil { + return nil, err + } + if !strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") && !strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + return nil, fmt.Errorf("provider %q does not use oauth", providerName) + } + manager, err := newOAuthManager(pc, time.Duration(maxInt(pc.TimeoutSec, 90))*time.Second) + if err != nil { + return nil, err + } + defer manager.bgCancel() + manager.providerName = strings.TrimSpace(providerName) + lead := 365 * 24 * time.Hour + if onlyExpiring { + lead = manager.cfg.RefreshLead + if lead <= 0 { + lead = 30 * time.Minute + } + } + return manager.refreshExpiringSessions(context.Background(), lead) +} + +func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerRuntimeCandidate, error) { + provider, err := CreateProviderByName(cfg, providerName) + if err != nil { + return nil, err + } + httpProvider, ok := unwrapHTTPProvider(provider) + if !ok { + return nil, fmt.Errorf("provider %q does not support runtime rerank", providerName) + } + _, err = httpProvider.authAttempts(context.Background()) + if err != nil && !strings.Contains(strings.ToLower(err.Error()), "oauth session not found") { + return nil, err + } + providerRuntimeRegistry.mu.Lock() + order := append([]providerRuntimeCandidate(nil), providerRuntimeRegistry.api[strings.TrimSpace(providerName)].CandidateOrder...) + providerRuntimeRegistry.mu.Unlock() + return order, nil +} + +func unwrapHTTPProvider(provider LLMProvider) (*HTTPProvider, bool) { + switch typed := provider.(type) { + case *HTTPProvider: + return typed, true + case *CodexProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *AntigravityProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *ClaudeProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *QwenProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *KimiProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + default: + return nil, false + } +} diff --git a/pkg/providers/responses_adapter.go b/pkg/providers/responses_adapter.go new file mode 100644 index 0000000..2037b21 --- /dev/null +++ b/pkg/providers/responses_adapter.go @@ -0,0 +1,546 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) { + input := make([]map[string]interface{}, 0, len(messages)) + pendingCalls := map[string]struct{}{} + for _, msg := range messages { + input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) + } + requestBody := map[string]interface{}{ + "model": model, + "input": input, + } + responseTools := buildResponsesTools(tools, options) + if len(responseTools) > 0 { + requestBody["tools"] = responseTools + requestBody["tool_choice"] = "auto" + if tc, ok := rawOption(options, "tool_choice"); ok { + requestBody["tool_choice"] = tc + } + if tc, ok := rawOption(options, "responses_tool_choice"); ok { + requestBody["tool_choice"] = tc + } + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + requestBody["max_output_tokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + requestBody["temperature"] = temperature + } + if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 { + requestBody["include"] = include + } + if metadata, ok := mapOption(options, "responses_metadata"); ok && len(metadata) > 0 { + requestBody["metadata"] = metadata + } + if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" { + requestBody["previous_response_id"] = prevID + } + if p.useOpenAICompatChatUpstream() { + chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options) + return p.postJSON(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody) + } + if p.useCodexCompat() { + requestBody = p.codexCompatRequestBody(requestBody) + return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, nil) + } + return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody) +} + +func toResponsesInputItemsWithState(msg Message, pendingCalls map[string]struct{}) []map[string]interface{} { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "system", "developer", "user": + if content := responsesMessageContent(msg); len(content) > 0 { + return []map[string]interface{}{{ + "type": "message", + "role": role, + "content": content, + }} + } + return []map[string]interface{}{responsesMessageItem(role, msg.Content, "input_text")} + case "assistant": + items := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls)) + if msg.Content != "" || len(msg.ToolCalls) == 0 { + items = append(items, responsesMessageItem(role, msg.Content, "output_text")) + } + for _, tc := range msg.ToolCalls { + callID := tc.ID + if callID == "" { + continue + } + name := tc.Name + argsRaw := "" + if tc.Function != nil { + if tc.Function.Name != "" { + name = tc.Function.Name + } + argsRaw = tc.Function.Arguments + } + if name == "" { + continue + } + if argsRaw == "" { + argsJSON, err := json.Marshal(tc.Arguments) + if err != nil { + argsRaw = "{}" + } else { + argsRaw = string(argsJSON) + } + } + if pendingCalls != nil { + pendingCalls[callID] = struct{}{} + } + items = append(items, map[string]interface{}{ + "type": "function_call", + "call_id": callID, + "name": name, + "arguments": argsRaw, + }) + } + if len(items) == 0 { + return []map[string]interface{}{responsesMessageItem(role, msg.Content, "output_text")} + } + return items + case "tool": + callID := msg.ToolCallID + if callID == "" { + return nil + } + if pendingCalls != nil { + if _, ok := pendingCalls[callID]; !ok { + // Strict pairing: drop orphan/duplicate tool outputs instead of degrading role. + return nil + } + delete(pendingCalls, callID) + } + return []map[string]interface{}{map[string]interface{}{ + "type": "function_call_output", + "call_id": callID, + "output": msg.Content, + }} + default: + return []map[string]interface{}{responsesMessageItem("user", msg.Content, "input_text")} + } +} + +func responsesMessageContent(msg Message) []map[string]interface{} { + content := make([]map[string]interface{}, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "input_text", "text": + if part.Text == "" { + continue + } + content = append(content, map[string]interface{}{ + "type": "input_text", + "text": part.Text, + }) + case "input_image", "image": + entry := map[string]interface{}{ + "type": "input_image", + } + if part.ImageURL != "" { + entry["image_url"] = part.ImageURL + } + if part.FileID != "" { + entry["file_id"] = part.FileID + } + if detail := strings.TrimSpace(part.Detail); detail != "" { + entry["detail"] = detail + } + if _, ok := entry["image_url"]; !ok { + if _, ok := entry["file_id"]; !ok { + continue + } + } + content = append(content, entry) + case "input_file", "file": + entry := map[string]interface{}{ + "type": "input_file", + } + if part.FileData != "" { + entry["file_data"] = part.FileData + } + if part.FileID != "" { + entry["file_id"] = part.FileID + } + if part.FileURL != "" { + entry["file_url"] = part.FileURL + } + if part.Filename != "" { + entry["filename"] = part.Filename + } + if _, ok := entry["file_data"]; !ok { + if _, ok := entry["file_id"]; !ok { + if _, ok := entry["file_url"]; !ok { + continue + } + } + } + content = append(content, entry) + } + } + return content +} + +func buildResponsesTools(tools []ToolDefinition, options map[string]interface{}) []map[string]interface{} { + responseTools := make([]map[string]interface{}, 0, len(tools)+2) + for _, t := range tools { + typ := strings.ToLower(strings.TrimSpace(t.Type)) + if typ == "" { + typ = "function" + } + if typ == "function" { + name := strings.TrimSpace(t.Function.Name) + if name == "" { + name = strings.TrimSpace(t.Name) + } + if name == "" { + continue + } + entry := map[string]interface{}{ + "type": "function", + "name": name, + "parameters": map[string]interface{}{}, + } + if t.Function.Parameters != nil { + entry["parameters"] = t.Function.Parameters + } else if t.Parameters != nil { + entry["parameters"] = t.Parameters + } + desc := strings.TrimSpace(t.Function.Description) + if desc == "" { + desc = strings.TrimSpace(t.Description) + } + if desc != "" { + entry["description"] = desc + } + if t.Function.Strict != nil { + entry["strict"] = *t.Function.Strict + } else if t.Strict != nil { + entry["strict"] = *t.Strict + } + responseTools = append(responseTools, entry) + continue + } + + // Built-in tool types (web_search, file_search, code_interpreter, etc.). + entry := map[string]interface{}{ + "type": typ, + } + if name := strings.TrimSpace(t.Name); name != "" { + entry["name"] = name + } + if desc := strings.TrimSpace(t.Description); desc != "" { + entry["description"] = desc + } + if t.Strict != nil { + entry["strict"] = *t.Strict + } + for k, v := range t.Parameters { + entry[k] = v + } + responseTools = append(responseTools, entry) + } + + if extraTools, ok := mapSliceOption(options, "responses_tools"); ok { + responseTools = append(responseTools, extraTools...) + } + return responseTools +} + +func responsesMessageItem(role, text, contentType string) map[string]interface{} { + ct := contentType + if ct == "" { + ct = "input_text" + } + return map[string]interface{}{ + "type": "message", + "role": role, + "content": []map[string]interface{}{ + { + "type": ct, + "text": text, + }, + }, + } +} + +func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { + input := make([]map[string]interface{}, 0, len(messages)) + pendingCalls := map[string]struct{}{} + for _, msg := range messages { + input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) + } + requestBody := map[string]interface{}{ + "model": model, + "input": input, + "stream": true, + } + responseTools := buildResponsesTools(tools, options) + if len(responseTools) > 0 { + requestBody["tools"] = responseTools + requestBody["tool_choice"] = "auto" + if tc, ok := rawOption(options, "tool_choice"); ok { + requestBody["tool_choice"] = tc + } + if tc, ok := rawOption(options, "responses_tool_choice"); ok { + requestBody["tool_choice"] = tc + } + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + requestBody["max_output_tokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + requestBody["temperature"] = temperature + } + if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 { + requestBody["include"] = include + } + if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 { + requestBody["stream_options"] = streamOpts + } + if p.useOpenAICompatChatUpstream() { + chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options) + chatBody["stream"] = true + streamOptions := map[string]interface{}{"include_usage": true} + chatBody["stream_options"] = streamOptions + return p.postJSONStream(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + choices, _ := obj["choices"].([]interface{}) + for _, choice := range choices { + item, _ := choice.(map[string]interface{}) + delta, _ := item["delta"].(map[string]interface{}) + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" { + onDelta(txt) + } + } + }) + } + if p.useCodexCompat() { + requestBody = p.codexCompatRequestBody(requestBody) + return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" { + onDelta(d) + return + } + if delta, ok := obj["delta"].(map[string]interface{}); ok { + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" { + onDelta(txt) + } + } + }) + } + return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])) + if typ == "response.output_text.delta" { + if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" { + onDelta(d) + } + return + } + if delta, ok := obj["delta"].(map[string]interface{}); ok { + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" { + onDelta(txt) + } + } + }) +} + +func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) { + var resp struct { + Status string `json:"status"` + Output []struct { + ID string `json:"id"` + Type string `json:"type"` + CallID string `json:"call_id"` + Name string `json:"name"` + ArgsRaw string `json:"arguments"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"output"` + OutputText string `json:"output_text"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + toolCalls := make([]ToolCall, 0) + outputText := strings.TrimSpace(resp.OutputText) + for _, item := range resp.Output { + switch strings.TrimSpace(item.Type) { + case "function_call": + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + args := map[string]interface{}{} + if strings.TrimSpace(item.ArgsRaw) != "" { + if err := json.Unmarshal([]byte(item.ArgsRaw), &args); err != nil { + args["raw"] = item.ArgsRaw + } + } + id := strings.TrimSpace(item.CallID) + if id == "" { + id = strings.TrimSpace(item.ID) + } + if id == "" { + id = fmt.Sprintf("call_%d", len(toolCalls)+1) + } + toolCalls = append(toolCalls, ToolCall{ID: id, Name: name, Arguments: args}) + case "message": + if outputText == "" { + texts := make([]string, 0, len(item.Content)) + for _, c := range item.Content { + if strings.TrimSpace(c.Type) == "output_text" && strings.TrimSpace(c.Text) != "" { + texts = append(texts, c.Text) + } + } + if len(texts) > 0 { + outputText = strings.Join(texts, "\n") + } + } + } + } + + if len(toolCalls) == 0 { + compatCalls, cleanedContent := parseCompatFunctionCalls(outputText) + if len(compatCalls) > 0 { + toolCalls = compatCalls + outputText = cleanedContent + } + } + + finishReason := strings.TrimSpace(resp.Status) + if finishReason == "" || finishReason == "completed" { + finishReason = "stop" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 || resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 { + usage = &UsageInfo{PromptTokens: resp.Usage.InputTokens, CompletionTokens: resp.Usage.OutputTokens, TotalTokens: resp.Usage.TotalTokens} + } + return &LLMResponse{Content: strings.TrimSpace(outputText), ToolCalls: toolCalls, FinishReason: finishReason, Usage: usage}, nil +} + +func (p *HTTPProvider) BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error) { + if !p.SupportsResponsesCompact() { + return "", fmt.Errorf("responses compact is not enabled for this provider") + } + input := make([]map[string]interface{}, 0, len(messages)+1) + if strings.TrimSpace(existingSummary) != "" { + input = append(input, responsesMessageItem("system", "Existing summary:\n"+strings.TrimSpace(existingSummary), "input_text")) + } + pendingCalls := map[string]struct{}{} + for _, msg := range messages { + input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) + } + if len(input) == 0 { + return strings.TrimSpace(existingSummary), nil + } + + compactReq := map[string]interface{}{"model": model, "input": input} + compactBody, statusCode, contentType, err := p.postJSON(ctx, endpointFor(p.apiBase, "/responses/compact"), compactReq) + if err != nil { + return "", fmt.Errorf("responses compact request failed: %w", err) + } + if statusCode != http.StatusOK { + return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(compactBody)) + } + if !json.Valid(compactBody) { + return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(compactBody)) + } + + var compactResp struct { + Output interface{} `json:"output"` + CompactedInput interface{} `json:"compacted_input"` + Compacted interface{} `json:"compacted"` + } + if err := json.Unmarshal(compactBody, &compactResp); err != nil { + return "", fmt.Errorf("responses compact request failed: invalid JSON: %w", err) + } + compactPayload := compactResp.Output + if compactPayload == nil { + compactPayload = compactResp.CompactedInput + } + if compactPayload == nil { + compactPayload = compactResp.Compacted + } + payloadBytes, err := json.Marshal(compactPayload) + if err != nil { + return "", fmt.Errorf("failed to serialize compact output: %w", err) + } + compactedPayload := strings.TrimSpace(string(payloadBytes)) + if compactedPayload == "" || compactedPayload == "null" { + return "", fmt.Errorf("empty compact output") + } + if len(compactedPayload) > 12000 { + compactedPayload = compactedPayload[:12000] + "..." + } + + summaryPrompt := fmt.Sprintf( + "Compacted conversation JSON:\n%s\n\nReturn a concise markdown summary with sections: Key Facts, Decisions, Open Items, Next Steps.", + compactedPayload, + ) + summaryReq := map[string]interface{}{ + "model": model, + "input": summaryPrompt, + } + if maxSummaryChars > 0 { + estMaxTokens := maxSummaryChars / 3 + if estMaxTokens < 128 { + estMaxTokens = 128 + } + summaryReq["max_output_tokens"] = estMaxTokens + } + summaryBody, summaryStatus, summaryType, err := p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), summaryReq) + if err != nil { + return "", fmt.Errorf("responses summary request failed: %w", err) + } + if summaryStatus != http.StatusOK { + return "", fmt.Errorf("responses summary request failed (status %d, content-type %q): %s", summaryStatus, summaryType, previewResponseBody(summaryBody)) + } + if !json.Valid(summaryBody) { + return "", fmt.Errorf("responses summary request failed (status %d, content-type %q): non-JSON response: %s", summaryStatus, summaryType, previewResponseBody(summaryBody)) + } + summaryResp, err := parseResponsesAPIResponse(summaryBody) + if err != nil { + return "", fmt.Errorf("responses summary request failed: %w", err) + } + summary := strings.TrimSpace(summaryResp.Content) + if summary == "" { + return "", fmt.Errorf("empty summary after responses compact") + } + if maxSummaryChars > 0 && len(summary) > maxSummaryChars { + summary = summary[:maxSummaryChars] + } + return summary, nil +} diff --git a/pkg/sentinel/service.go b/pkg/sentinel/service.go index e18651c..d455ba0 100644 --- a/pkg/sentinel/service.go +++ b/pkg/sentinel/service.go @@ -16,6 +16,11 @@ import ( type AlertFunc func(msg string) +type channelHealthManager interface { + CheckHealth(ctx context.Context) map[string]error + RestartChannel(ctx context.Context, name string) error +} + type Service struct { cfgPath string workspace string @@ -25,7 +30,7 @@ type Service struct { runner *lifecycle.LoopRunner mu sync.RWMutex lastAlerts map[string]time.Time - mgr *channels.Manager + mgr channelHealthManager healingChannels map[string]bool } diff --git a/pkg/sentinel/service_test.go b/pkg/sentinel/service_test.go new file mode 100644 index 0000000..9366aef --- /dev/null +++ b/pkg/sentinel/service_test.go @@ -0,0 +1,168 @@ +package sentinel + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/config" +) + +func TestCheckConfigReportsMissingAndCorruptConfig(t *testing.T) { + s := NewService(filepath.Join(t.TempDir(), "missing.json"), t.TempDir(), 60, false, nil) + issues := s.checkConfig() + if len(issues) != 1 || !strings.Contains(issues[0], "config file missing") { + t.Fatalf("missing config issues = %+v", issues) + } + + cfgPath := filepath.Join(t.TempDir(), "config.json") + if err := os.WriteFile(cfgPath, []byte("{bad-json"), 0644); err != nil { + t.Fatalf("write corrupt config: %v", err) + } + s = NewService(cfgPath, t.TempDir(), 60, false, nil) + issues = s.checkConfig() + if len(issues) != 1 || !strings.Contains(issues[0], "config parse failed") { + t.Fatalf("corrupt config issues = %+v", issues) + } +} + +func TestCheckMemoryReportsMissingAndAutoHeals(t *testing.T) { + workspace := t.TempDir() + s := NewService(validConfigFile(t, t.TempDir()), workspace, 60, false, nil) + + issues := s.checkMemory() + if len(issues) != 1 || !strings.Contains(issues[0], "memory dir missing") { + t.Fatalf("missing memory issues = %+v", issues) + } + + s = NewService(validConfigFile(t, t.TempDir()), workspace, 60, true, nil) + issues = s.checkMemory() + if len(issues) != 1 || !strings.Contains(issues[0], "auto-healed") { + t.Fatalf("auto-heal memory issues = %+v", issues) + } + if _, err := os.Stat(filepath.Join(workspace, "memory")); err != nil { + t.Fatalf("memory dir was not created: %v", err) + } + + issues = s.checkMemory() + if len(issues) != 1 || !strings.Contains(issues[0], "MEMORY.md missing, auto-healed") { + t.Fatalf("auto-heal MEMORY.md issues = %+v", issues) + } + if _, err := os.Stat(filepath.Join(workspace, "memory", "MEMORY.md")); err != nil { + t.Fatalf("MEMORY.md was not created: %v", err) + } +} + +func TestCheckLogsReportsMissingLogDirAndAutoHeals(t *testing.T) { + root := t.TempDir() + logDir := filepath.Join(root, "logs") + cfgPath := validConfigFileWithLogDir(t, t.TempDir(), logDir) + + s := NewService(cfgPath, t.TempDir(), 60, false, nil) + issues := s.checkLogs() + if len(issues) != 1 || !strings.Contains(issues[0], "log dir missing") { + t.Fatalf("missing log dir issues = %+v", issues) + } + + s = NewService(cfgPath, t.TempDir(), 60, true, nil) + issues = s.checkLogs() + if len(issues) != 1 || !strings.Contains(issues[0], "auto-healed") { + t.Fatalf("auto-heal log dir issues = %+v", issues) + } + if _, err := os.Stat(logDir); err != nil { + t.Fatalf("log dir was not created: %v", err) + } +} + +func TestCheckChannelsReportsHealthFailuresAndRestartsWhenAutoHealEnabled(t *testing.T) { + mgr := &fakeHealthManager{health: map[string]error{"telegram": errors.New("offline")}} + s := NewService(validConfigFile(t, t.TempDir()), t.TempDir(), 60, true, nil) + s.mgr = mgr + + issues := s.checkChannels() + if len(issues) != 1 || !strings.Contains(issues[0], "telegram health check failed") { + t.Fatalf("channel issues = %+v", issues) + } + waitForRestarts(t, &mgr.restarts, 1) +} + +func TestRunChecksCallsAlertCallbackAndSuppressesDuplicates(t *testing.T) { + cfgPath := filepath.Join(t.TempDir(), "missing.json") + workspace := t.TempDir() + var alerts int64 + s := NewService(cfgPath, workspace, 60, false, func(msg string) { + atomic.AddInt64(&alerts, 1) + }) + + s.runChecks() + if got := atomic.LoadInt64(&alerts); got == 0 { + t.Fatal("runChecks did not call alert callback") + } + first := atomic.LoadInt64(&alerts) + s.runChecks() + if got := atomic.LoadInt64(&alerts); got != first { + t.Fatalf("duplicate alerts = %d, want suppressed at %d", got, first) + } +} + +func TestStartStopLifecycle(t *testing.T) { + s := NewService(validConfigFile(t, t.TempDir()), t.TempDir(), 3600, false, nil) + s.Start() + if !s.runner.Running() { + t.Fatal("Start did not mark service running") + } + s.Stop() + if s.runner.Running() { + t.Fatal("Stop left service running") + } +} + +type fakeHealthManager struct { + health map[string]error + restarts int64 +} + +func (m *fakeHealthManager) CheckHealth(ctx context.Context) map[string]error { + return m.health +} + +func (m *fakeHealthManager) RestartChannel(ctx context.Context, name string) error { + atomic.AddInt64(&m.restarts, 1) + return nil +} + +func waitForRestarts(t *testing.T, restarts *int64, want int64) { + t.Helper() + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if atomic.LoadInt64(restarts) >= want { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("restarts = %d, want at least %d", atomic.LoadInt64(restarts), want) +} + +func validConfigFile(t *testing.T, dir string) string { + t.Helper() + return validConfigFileWithLogDir(t, dir, filepath.Join(dir, "logs")) +} + +func validConfigFileWithLogDir(t *testing.T, dir, logDir string) string { + t.Helper() + cfg := config.DefaultConfig() + cfg.Logging.Enabled = true + cfg.Logging.Dir = logDir + cfg.Logging.Filename = "clawgo.log" + cfg.Agents.Defaults.Workspace = filepath.Join(dir, "workspace") + cfgPath := filepath.Join(dir, "config.json") + if err := config.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("SaveConfig: %v", err) + } + return cfgPath +} diff --git a/pkg/tools/bootstrap.go b/pkg/tools/bootstrap.go new file mode 100644 index 0000000..578f24e --- /dev/null +++ b/pkg/tools/bootstrap.go @@ -0,0 +1,157 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/config" +) + +func BootstrapDefaultTools(ctx context.Context, opts BootstrapOptions) (*BootstrapResult, error) { + if opts.Config == nil { + return nil, fmt.Errorf("config is required") + } + workspace := strings.TrimSpace(opts.Workspace) + if workspace == "" { + workspace = opts.Config.WorkspacePath() + } + if ctx == nil { + ctx = context.Background() + } + + registry := NewToolRegistry() + processManager := opts.ProcessManager + if processManager == nil { + processManager = NewProcessManager(workspace) + } + + registerFilesystemTools(registry, workspace) + registerShellTools(registry, opts, workspace, processManager) + registerCronTools(registry, opts) + + maxParallelCalls, parallelSafe := bootstrapParallelConfig(opts.Config) + registerWebTools(registry, opts, maxParallelCalls, parallelSafe) + registerMCPTools(ctx, registry, opts, workspace) + registerMessageTool(registry, opts.MessageBus) + + subagentManager, subagentRouter := registerSubagentTools(registry, opts, workspace) + registerSessionTools(registry, opts) + registerMemoryTools(registry, workspace) + + registry.Register(NewParallelTool(registry, maxParallelCalls, parallelSafe)) + registry.Register(NewBrowserTool()) + registry.Register(NewCameraTool(workspace)) + registry.Register(NewSystemInfoTool()) + + return &BootstrapResult{ + Registry: registry, + ProcessManager: processManager, + SubagentManager: subagentManager, + SubagentRouter: subagentRouter, + }, nil +} + +func registerFilesystemTools(registry *ToolRegistry, workspace string) { + registry.Register(NewReadFileTool(workspace)) + registry.Register(NewWriteFileTool(workspace)) + registry.Register(NewListDirTool(workspace)) + registry.Register(NewSkillExecTool(workspace)) + registry.Register(NewEditFileTool(workspace)) +} + +func registerShellTools(registry *ToolRegistry, opts BootstrapOptions, workspace string, processManager *ProcessManager) { + registry.Register(NewExecTool(opts.Config.Tools.Shell, workspace, processManager)) + registry.Register(NewProcessTool(processManager)) +} + +func registerCronTools(registry *ToolRegistry, opts BootstrapOptions) { + if opts.CronService == nil { + return + } + registry.Register(NewRemindTool(opts.CronService)) + registry.Register(NewCronTool(opts.CronService)) +} + +func bootstrapParallelConfig(cfg *config.Config) (int, map[string]struct{}) { + maxParallelCalls := cfg.Agents.Defaults.Execution.ToolMaxParallelCalls + if maxParallelCalls <= 0 { + maxParallelCalls = 4 + } + parallelSafe := make(map[string]struct{}) + for _, name := range cfg.Agents.Defaults.Execution.ToolParallelSafeNames { + trimmed := strings.TrimSpace(name) + if trimmed != "" { + parallelSafe[trimmed] = struct{}{} + } + } + return maxParallelCalls, parallelSafe +} + +func registerWebTools(registry *ToolRegistry, opts BootstrapOptions, maxParallelCalls int, parallelSafe map[string]struct{}) { + searchCfg := opts.Config.Tools.Web.Search + registry.Register(NewWebSearchTool(searchCfg.APIKey, searchCfg.MaxResults)) + webFetchTool := NewWebFetchTool(50000) + registry.Register(webFetchTool) + registry.Register(NewParallelFetchTool(webFetchTool, maxParallelCalls, parallelSafe)) +} + +func registerMCPTools(ctx context.Context, registry *ToolRegistry, opts BootstrapOptions, workspace string) { + mcpCfg := opts.Config.Tools.MCP + if !mcpCfg.Enabled { + return + } + mcpTool := NewMCPTool(workspace, mcpCfg) + registry.Register(mcpTool) + timeoutSec := mcpCfg.RequestTimeoutSec + if timeoutSec <= 0 { + timeoutSec = 20 + } + discoveryCtx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second) + defer cancel() + for _, remoteTool := range mcpTool.DiscoverTools(discoveryCtx) { + registry.Register(remoteTool) + } +} + +func registerMessageTool(registry *ToolRegistry, msgBus *bus.MessageBus) { + messageTool := NewMessageTool() + if msgBus != nil { + messageTool.SetSendCallback(func(channel, chatID, action, content, media, messageID, emoji string, buttons [][]bus.Button) error { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + Media: media, + Buttons: buttons, + Action: action, + MessageID: messageID, + Emoji: emoji, + }) + return nil + }) + } + registry.Register(messageTool) +} + +func registerSubagentTools(registry *ToolRegistry, opts BootstrapOptions, workspace string) (*SubagentManager, *SubagentRouter) { + subagentManager := NewSubagentManager(opts.Provider, workspace, opts.MessageBus) + subagentRouter := NewSubagentRouter(subagentManager) + registry.Register(NewSpawnTool(subagentManager)) + if store := subagentManager.ProfileStore(); store != nil { + registry.Register(NewSubagentProfileTool(store)) + } + return subagentManager, subagentRouter +} + +func registerSessionTools(registry *ToolRegistry, opts BootstrapOptions) { + registry.Register(NewSessionsTool(opts.SessionList, opts.SessionHistory)) +} + +func registerMemoryTools(registry *ToolRegistry, workspace string) { + registry.Register(NewMemorySearchTool(workspace)) + registry.Register(NewMemoryGetTool(workspace)) + registry.Register(NewMemoryWriteTool(workspace)) +} diff --git a/pkg/tools/bootstrap_options.go b/pkg/tools/bootstrap_options.go new file mode 100644 index 0000000..4456089 --- /dev/null +++ b/pkg/tools/bootstrap_options.go @@ -0,0 +1,30 @@ +package tools + +import ( + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/cron" + "github.com/YspCoder/clawgo/pkg/providers" +) + +type SessionListFunc func(limit int) []SessionInfo + +type SessionHistoryFunc func(key string, limit int) []providers.Message + +type BootstrapOptions struct { + Config *config.Config + Workspace string + MessageBus *bus.MessageBus + CronService *cron.CronService + Provider providers.LLMProvider + ProcessManager *ProcessManager + SessionList SessionListFunc + SessionHistory SessionHistoryFunc +} + +type BootstrapResult struct { + Registry *ToolRegistry + ProcessManager *ProcessManager + SubagentManager *SubagentManager + SubagentRouter *SubagentRouter +} diff --git a/pkg/tools/bootstrap_test.go b/pkg/tools/bootstrap_test.go new file mode 100644 index 0000000..002514f --- /dev/null +++ b/pkg/tools/bootstrap_test.go @@ -0,0 +1,61 @@ +package tools + +import ( + "reflect" + "sort" + "testing" + + "github.com/YspCoder/clawgo/pkg/config" +) + +func TestBootstrapDefaultToolsRegistersExpectedLocalTools(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = t.TempDir() + cfg.Tools.MCP.Enabled = false + + result, err := BootstrapDefaultTools(t.Context(), BootstrapOptions{ + Config: cfg, + Workspace: cfg.WorkspacePath(), + }) + if err != nil { + t.Fatalf("bootstrap default tools: %v", err) + } + if result == nil || result.Registry == nil { + t.Fatalf("expected registry in bootstrap result") + } + if result.ProcessManager == nil { + t.Fatalf("expected process manager in bootstrap result") + } + if result.SubagentManager == nil || result.SubagentRouter == nil { + t.Fatalf("expected subagent manager and router in bootstrap result") + } + + got := result.Registry.List() + sort.Strings(got) + want := []string{ + "browser", + "camera_snap", + "edit_file", + "exec", + "list_dir", + "memory_get", + "memory_search", + "memory_write", + "message", + "parallel", + "parallel_fetch", + "process", + "read_file", + "sessions", + "skill_exec", + "spawn", + "subagent_profile", + "system_info", + "web_fetch", + "web_search", + "write_file", + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("default tool names mismatch\n got: %v\nwant: %v", got, want) + } +} diff --git a/pkg/tools/mcp.go b/pkg/tools/mcp.go index d2bf0b4..9cfd00d 100644 --- a/pkg/tools/mcp.go +++ b/pkg/tools/mcp.go @@ -346,13 +346,30 @@ type mcpClient struct { cmd *exec.Cmd stdin io.WriteCloser reader *bufio.Reader - stderr bytes.Buffer + stderr mcpSafeBuffer writeMu sync.Mutex waiters sync.Map nextID atomic.Int64 } +type mcpSafeBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *mcpSafeBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *mcpSafeBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + type mcpInbound struct { JSONRPC string `json:"jsonrpc"` ID interface{} `json:"id,omitempty"` diff --git a/pkg/tools/remind_test.go b/pkg/tools/remind_test.go index 741a6ae..7b0cd82 100644 --- a/pkg/tools/remind_test.go +++ b/pkg/tools/remind_test.go @@ -15,7 +15,7 @@ func TestRemindTool_UsesToolContextForDeliveryTarget(t *testing.T) { tool.SetContext("telegram", "chat-123") _, err := tool.Execute(context.Background(), map[string]interface{}{ - "message": "鍠濇按", + "message": "喝水", "time_expr": "10m", }) if err != nil {