From 06e3599e457f92fcdc062aa5e753b0d1936d14ab Mon Sep 17 00:00:00 2001 From: LPF Date: Fri, 13 Mar 2026 09:45:09 +0800 Subject: [PATCH] fix: refresh provider runtime summary and balance detail --- pkg/api/server.go | 76 ++++++++++++++++++++++++++++++++---------- pkg/providers/oauth.go | 52 ++++++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 18 deletions(-) diff --git a/pkg/api/server.go b/pkg/api/server.go index 00a58d5..5138108 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -1429,35 +1429,56 @@ func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Reque } switch strings.ToLower(strings.TrimSpace(body.Action)) { case "clear_api_cooldown": - providers.ClearProviderAPICooldown(strings.TrimSpace(body.Provider)) + cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = cfg + providers.ClearProviderAPICooldown(providerName) _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) case "clear_history": - providers.ClearProviderRuntimeHistory(strings.TrimSpace(body.Provider)) + cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = cfg + providers.ClearProviderRuntimeHistory(providerName) _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) case "refresh_now": - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - result, err := providers.RefreshProviderRuntimeNow(cfg, strings.TrimSpace(body.Provider), body.OnlyExpiring) + cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "refreshed": true, "result": result}) + result, err := providers.RefreshProviderRuntimeNow(cfg, providerName, body.OnlyExpiring) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + order, _ := providers.RerankProviderRuntime(cfg, providerName) + summary := providers.GetProviderRuntimeSummary(cfg, providers.ProviderRuntimeQuery{Provider: providerName, HealthBelow: 50}) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "ok": true, + "provider": providerName, + "refreshed": true, + "result": result, + "candidate_order": order, + "summary": summary, + }) case "rerank": - cfg, err := cfgpkg.LoadConfig(s.configPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - order, err := providers.RerankProviderRuntime(cfg, strings.TrimSpace(body.Provider)) + cfg, providerName, err := s.loadRuntimeProviderName(strings.TrimSpace(body.Provider)) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "reranked": true, "candidate_order": order}) + order, err := providers.RerankProviderRuntime(cfg, providerName) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "provider": providerName, "reranked": true, "candidate_order": order}) default: http.Error(w, "unsupported action", http.StatusBadRequest) } @@ -1509,13 +1530,34 @@ func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.Provide return nil, cfgpkg.ProviderConfig{}, err } providerName := strings.TrimSpace(name) - pc, ok := cfg.Models.Providers[providerName] + if providerName == "" { + providerName = cfgpkg.PrimaryProviderName(cfg) + } + pc, ok := cfgpkg.ProviderConfigByName(cfg, providerName) if !ok { return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName) } return cfg, pc, nil } +func (s *Server) loadRuntimeProviderName(name string) (*cfgpkg.Config, string, error) { + if strings.TrimSpace(s.configPath) == "" { + return nil, "", fmt.Errorf("config path not set") + } + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + return nil, "", err + } + providerName := strings.TrimSpace(name) + if providerName == "" { + providerName = cfgpkg.PrimaryProviderName(cfg) + } + if !cfgpkg.ProviderExists(cfg, providerName) { + return nil, "", fmt.Errorf("provider %q not found", providerName) + } + return cfg, providerName, nil +} + func (s *Server) resolveProviderConfig(name string, inline cfgpkg.ProviderConfig) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { if hasInlineProviderConfig(inline) { cfg, err := cfgpkg.LoadConfig(s.configPath) diff --git a/pkg/providers/oauth.go b/pkg/providers/oauth.go index 3010902..da6657c 100644 --- a/pkg/providers/oauth.go +++ b/pkg/providers/oauth.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "io" + "math" "net" "net/http" "net/url" @@ -1883,7 +1884,7 @@ func extractOAuthBalanceMetadata(session *oauthSession) (planType, quotaSource, } switch { case subActiveStart != "" && subActiveUntil != "": - balanceDetail = fmt.Sprintf("%s ~ %s", subActiveStart, subActiveUntil) + balanceDetail = formatSubscriptionWindow(subActiveStart, subActiveUntil) case subActiveUntil != "": balanceDetail = fmt.Sprintf("until %s", subActiveUntil) case subActiveStart != "": @@ -1919,6 +1920,55 @@ func normalizeBalanceTime(value string) string { return trimmed } +func formatSubscriptionWindow(startRaw, untilRaw string) string { + startRaw = strings.TrimSpace(startRaw) + untilRaw = strings.TrimSpace(untilRaw) + if startRaw == "" || untilRaw == "" { + return strings.TrimSpace(strings.TrimSpace(startRaw) + " ~ " + strings.TrimSpace(untilRaw)) + } + startAt, okStart := parseBalanceTime(startRaw) + untilAt, okUntil := parseBalanceTime(untilRaw) + if !okStart || !okUntil || !untilAt.After(startAt) { + return fmt.Sprintf("%s ~ %s", startRaw, untilRaw) + } + now := time.Now().UTC() + total := untilAt.Sub(startAt) + elapsed := now.Sub(startAt) + if elapsed < 0 { + elapsed = 0 + } + if elapsed > total { + elapsed = total + } + usedPct := int(math.Round(float64(elapsed) * 100 / float64(total))) + if usedPct < 0 { + usedPct = 0 + } + if usedPct > 100 { + usedPct = 100 + } + return fmt.Sprintf("%s ~ %s (%d%% used, %d%% left)", startRaw, untilRaw, usedPct, 100-usedPct) +} + +func parseBalanceTime(value string) (time.Time, bool) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return time.Time{}, false + } + layouts := []string{ + time.RFC3339Nano, + time.RFC3339, + "2006-01-02 15:04:05", + "2006-01-02", + } + for _, layout := range layouts { + if parsed, err := time.Parse(layout, trimmed); err == nil { + return parsed.UTC(), true + } + } + return time.Time{}, false +} + func firstNonEmpty(values ...string) string { for _, value := range values { if trimmed := strings.TrimSpace(value); trimmed != "" {