From d23142bce2bd818fe856198c77ef92a6921e19d7 Mon Sep 17 00:00:00 2001 From: LPF Date: Thu, 12 Mar 2026 23:02:44 +0800 Subject: [PATCH] feat: align cliproxyapi providers and auto fallback --- cmd/cmd_config.go | 31 --- pkg/agent/loop.go | 209 ++++++++++++++++--- pkg/config/config.go | 41 +++- pkg/config/validate.go | 5 - pkg/providers/aistudio_relay.go | 22 +- pkg/providers/antigravity_provider.go | 92 +++++++- pkg/providers/antigravity_provider_test.go | 52 +++++ pkg/providers/codex_provider.go | 11 +- pkg/providers/gemini_cli_provider.go | 59 ++++-- pkg/providers/gemini_provider.go | 69 ++++-- pkg/providers/gemini_provider_test.go | 19 ++ pkg/providers/http_provider.go | 96 ++++++--- pkg/providers/iflow_provider.go | 19 +- pkg/providers/kimi_provider.go | 32 +-- pkg/providers/kimi_provider_test.go | 13 ++ pkg/providers/oauth.go | 14 +- pkg/providers/openai_compat_provider_test.go | 23 ++ pkg/providers/qwen_provider.go | 49 ++++- pkg/providers/qwen_provider_test.go | 23 ++ pkg/providers/vertex_provider.go | 29 ++- 20 files changed, 718 insertions(+), 190 deletions(-) create mode 100644 pkg/providers/qwen_provider_test.go diff --git a/cmd/cmd_config.go b/cmd/cmd_config.go index c287644..c76fef4 100644 --- a/cmd/cmd_config.go +++ b/cmd/cmd_config.go @@ -265,37 +265,6 @@ func providerCmd() { cfg.Agents.Defaults.Model.Primary = providerName + "/" + targetModel } - currentFallbacks := strings.Join(cfg.Agents.Defaults.Model.Fallbacks, ",") - fallbackRaw := promptLine(reader, "agents.defaults.model.fallbacks (comma-separated provider/model refs)", currentFallbacks) - fallbacks := parseCSV(fallbackRaw) - valid := map[string]struct{}{} - for _, name := range providerNames(cfg) { - valid[name] = struct{}{} - } - filteredFallbacks := make([]string, 0, len(fallbacks)) - seen := map[string]struct{}{} - defaultRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary) - for _, fb := range fallbacks { - if fb == "" || fb == defaultRef { - continue - } - fbProvider, fbModel := config.ParseProviderModelRef(fb) - if fbProvider == "" || fbModel == "" { - fmt.Printf("Skip invalid fallback provider/model ref: %s\n", fb) - continue - } - if _, ok := valid[fbProvider]; !ok { - fmt.Printf("Skip unknown fallback provider: %s\n", fb) - continue - } - if _, ok := seen[fb]; ok { - continue - } - seen[fb] = struct{}{} - filteredFallbacks = append(filteredFallbacks, fb) - } - cfg.Agents.Defaults.Model.Fallbacks = filteredFallbacks - if err := config.SaveConfig(getConfigPath(), cfg); err != nil { fmt.Printf("Error saving config: %v\n", err) os.Exit(1) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index e9a75c6..172136b 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -39,6 +39,7 @@ import ( type AgentLoop struct { bus *bus.MessageBus + cfg *config.Config provider providers.LLMProvider workspace string model string @@ -54,6 +55,7 @@ type AgentLoop struct { audit *triggerAudit running bool sessionScheduler *SessionScheduler + providerChain []providerCandidate providerNames []string providerPool map[string]providers.LLMProvider providerResponses map[string]config.ProviderResponsesConfig @@ -73,6 +75,12 @@ type AgentLoop struct { subagentDigests map[string]*subagentDigestState } +type providerCandidate struct { + ref string + name string + model string +} + type subagentDigestItem struct { agentID string reason string @@ -315,6 +323,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers loop := &AgentLoop{ bus: msgBus, + cfg: cfg, provider: provider, workspace: workspace, model: provider.GetDefaultModel(), @@ -346,36 +355,75 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers loop.model = strings.TrimSpace(primaryModel) } go loop.runSubagentDigestTicker() - // Initialize provider fallback chain (primary + model fallbacks). + // Initialize provider fallback chain (primary + inferred providers). + loop.providerChain = []providerCandidate{} loop.providerPool = map[string]providers.LLMProvider{} loop.providerNames = []string{} primaryName := config.PrimaryProviderName(cfg) + primaryRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary) + if primaryRef == "" { + primaryRef = primaryName + "/" + loop.model + } loop.providerPool[primaryName] = provider + loop.providerChain = append(loop.providerChain, providerCandidate{ + ref: primaryRef, + name: primaryName, + model: loop.model, + }) loop.providerNames = append(loop.providerNames, primaryName) if pc, ok := config.ProviderConfigByName(cfg, primaryName); ok { loop.providerResponses[primaryName] = pc.Responses } - for _, name := range cfg.Agents.Defaults.Model.Fallbacks { - if name == "" { + seenProviders := map[string]struct{}{primaryName: {}} + providerConfigs := config.AllProviderConfigs(cfg) + providerOrder := make([]string, 0, len(providerConfigs)) + for name := range providerConfigs { + normalized := strings.TrimSpace(name) + if normalized == "" { continue } - dup := false - for _, existing := range loop.providerNames { - if existing == name { - dup = true - break - } + providerOrder = append(providerOrder, normalized) + } + sort.SliceStable(providerOrder, func(i, j int) bool { + ni := normalizeFallbackProviderName(providerOrder[i]) + nj := normalizeFallbackProviderName(providerOrder[j]) + pi := automaticFallbackPriority(ni) + pj := automaticFallbackPriority(nj) + if pi == pj { + return ni < nj } - if dup { + return pi < pj + }) + for _, rawName := range providerOrder { + providerName := strings.TrimSpace(rawName) + if providerName == "" { continue } - if p2, err := providers.CreateProviderByName(cfg, name); err == nil { - loop.providerPool[name] = p2 - loop.providerNames = append(loop.providerNames, name) - if pc, ok := config.ProviderConfigByName(cfg, name); ok { - loop.providerResponses[name] = pc.Responses - } + providerName, _ = config.ParseProviderModelRef(providerName + "/_") + if providerName == "" { + continue } + if _, dup := seenProviders[providerName]; dup { + continue + } + modelName := "" + if pc, ok := config.ProviderConfigByName(cfg, providerName); ok { + if len(pc.Models) > 0 { + modelName = strings.TrimSpace(pc.Models[0]) + } + loop.providerResponses[providerName] = pc.Responses + } + seenProviders[providerName] = struct{}{} + loop.providerNames = append(loop.providerNames, providerName) + ref := providerName + if modelName != "" { + ref += "/" + modelName + } + loop.providerChain = append(loop.providerChain, providerCandidate{ + ref: ref, + name: providerName, + model: modelName, + }) } // Inject recursive run logic so subagents can use full tool-calling flows. @@ -581,24 +629,44 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM } func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, options map[string]interface{}, primaryErr error) (*providers.LLMResponse, string, error) { - if len(al.providerNames) <= 1 { + if len(al.providerChain) <= 1 { return nil, "", primaryErr } lastErr := primaryErr - candidates := append([]string(nil), al.providerNames[1:]...) + candidateNames := make([]string, 0, len(al.providerChain)-1) + for _, candidate := range al.providerChain[1:] { + candidateNames = append(candidateNames, candidate.name) + } if al.ekg != nil { errSig := "" if primaryErr != nil { errSig = primaryErr.Error() } - candidates = al.ekg.RankProvidersForError(candidates, errSig) + candidateNames = al.ekg.RankProvidersForError(candidateNames, errSig) } - for _, name := range candidates { - p, ok := al.providerPool[name] - if !ok || p == nil { + ranked := make([]providerCandidate, 0, len(al.providerChain)-1) + used := make([]bool, len(al.providerChain)-1) + for _, name := range candidateNames { + for idx, candidate := range al.providerChain[1:] { + if used[idx] || candidate.name != name { + continue + } + used[idx] = true + ranked = append(ranked, candidate) + } + } + for idx, candidate := range al.providerChain[1:] { + if !used[idx] { + ranked = append(ranked, candidate) + } + } + for _, candidate := range ranked { + p, candidateModel, err := al.ensureProviderCandidate(candidate) + if err != nil { + lastErr = err continue } - resp, err := p.Chat(ctx, messages, toolDefs, al.model, options) + resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, options) if al.ekg != nil { st := "success" lg := "fallback provider success" @@ -608,17 +676,96 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe lg = err.Error() errSig = err.Error() } - al.ekg.Record(ekg.Event{Session: msg.SessionKey, Channel: msg.Channel, Source: "provider_fallback", Status: st, Provider: name, Model: al.model, ErrSig: errSig, Log: lg}) + al.ekg.Record(ekg.Event{Session: msg.SessionKey, Channel: msg.Channel, Source: "provider_fallback", Status: st, Provider: candidate.name, Model: candidateModel, ErrSig: errSig, Log: lg}) } if err == nil { - logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": name}) - return resp, name, nil + logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": candidate.name, "model": candidateModel, "ref": candidate.ref}) + return resp, candidate.name, nil } lastErr = err } return nil, "", lastErr } +func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (providers.LLMProvider, string, error) { + if al == nil { + return nil, "", fmt.Errorf("agent loop is nil") + } + name := strings.TrimSpace(candidate.name) + if name == "" { + return nil, "", fmt.Errorf("fallback provider name is empty") + } + al.providerMu.RLock() + existing := al.providerPool[name] + al.providerMu.RUnlock() + if existing != nil { + model := strings.TrimSpace(candidate.model) + if model == "" { + model = strings.TrimSpace(existing.GetDefaultModel()) + } + if model == "" { + return nil, "", fmt.Errorf("fallback provider %q has no model configured", name) + } + return existing, model, nil + } + if al.cfg == nil { + return nil, "", fmt.Errorf("config not available for fallback provider %q", name) + } + created, err := providers.CreateProviderByName(al.cfg, name) + if err != nil { + return nil, "", err + } + model := strings.TrimSpace(candidate.model) + if model == "" { + model = strings.TrimSpace(created.GetDefaultModel()) + } + if model == "" { + return nil, "", fmt.Errorf("fallback provider %q has no model configured", name) + } + al.providerMu.Lock() + if existing := al.providerPool[name]; existing != nil { + al.providerMu.Unlock() + return existing, model, nil + } + al.providerPool[name] = created + al.providerMu.Unlock() + return created, model, nil +} + +func automaticFallbackPriority(name string) int { + switch normalizeFallbackProviderName(name) { + case "claude": + return 10 + case "codex": + return 20 + case "gemini": + return 30 + case "gemini-cli": + return 40 + case "aistudio": + return 50 + case "vertex": + return 60 + case "antigravity": + return 70 + case "qwen": + return 80 + case "kimi": + return 90 + case "iflow": + return 100 + case "openai-compatibility": + return 110 + default: + return 1000 + } +} + +func normalizeFallbackProviderName(name string) string { + normalized, _ := config.ParseProviderModelRef(strings.TrimSpace(name) + "/_") + return strings.TrimSpace(normalized) +} + func (al *AgentLoop) setSessionProvider(sessionKey, provider string) { key := strings.TrimSpace(sessionKey) if key == "" { @@ -1188,12 +1335,9 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } if err != nil { - if fb, fbProvider, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { + if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { response = fb err = nil - if fbProvider != "" { - al.setSessionProvider(msg.SessionKey, fbProvider) - } } else { err = ferr } @@ -1542,12 +1686,9 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) if err != nil { - if fb, fbProvider, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { + if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { response = fb err = nil - if fbProvider != "" { - al.setSessionProvider(msg.SessionKey, fbProvider) - } } else { err = ferr } diff --git a/pkg/config/config.go b/pkg/config/config.go index 3c5efae..468e881 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -132,8 +132,7 @@ type AgentDefaults struct { } type AgentModelDefaults struct { - Primary string `json:"primary,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_PRIMARY"` - Fallbacks []string `json:"fallbacks,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_FALLBACKS"` + Primary string `json:"primary,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_PRIMARY"` } type HeartbeatConfig struct { @@ -445,7 +444,7 @@ func DefaultConfig() *Config { Agents: AgentsConfig{ Defaults: AgentDefaults{ Workspace: filepath.Join(configDir, "workspace"), - Model: AgentModelDefaults{Primary: "openai/gpt-5.4", Fallbacks: []string{}}, + Model: AgentModelDefaults{Primary: "openai/gpt-5.4"}, MaxTokens: 8192, Temperature: 0.7, MaxToolIterations: 20, @@ -660,13 +659,36 @@ func DefaultConfig() *Config { } } +func normalizeProviderNameAlias(name string) string { + switch strings.ToLower(strings.TrimSpace(name)) { + case "geminicli", "gemini_cli": + return "gemini-cli" + case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio": + return "aistudio" + case "google", "gemini-api-key", "gemini_api_key": + return "gemini" + case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key": + return "claude" + case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat": + return "openai-compatibility" + case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility": + return "vertex" + case "codex-api-key", "codex_api_key": + return "codex" + case "i-flow", "i_flow": + return "iflow" + default: + return strings.TrimSpace(name) + } +} + func ParseProviderModelRef(raw string) (provider string, model string) { trimmed := strings.TrimSpace(raw) if trimmed == "" { return "", "" } if idx := strings.Index(trimmed, "/"); idx > 0 { - return strings.TrimSpace(trimmed[:idx]), strings.TrimSpace(trimmed[idx+1:]) + return normalizeProviderNameAlias(trimmed[:idx]), strings.TrimSpace(trimmed[idx+1:]) } return "", trimmed } @@ -690,7 +712,12 @@ func ProviderConfigByName(cfg *Config, name string) (ProviderConfig, bool) { if cfg == nil { return ProviderConfig{}, false } - pc, ok := AllProviderConfigs(cfg)[strings.TrimSpace(name)] + configs := AllProviderConfigs(cfg) + trimmed := strings.TrimSpace(name) + if pc, ok := configs[trimmed]; ok { + return pc, true + } + pc, ok := configs[normalizeProviderNameAlias(trimmed)] return pc, ok } @@ -704,10 +731,10 @@ func PrimaryProviderName(cfg *Config) string { return "openai" } if provider, _ := ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); provider != "" { - return provider + return normalizeProviderNameAlias(provider) } for name := range cfg.Models.Providers { - if trimmed := strings.TrimSpace(name); trimmed != "" { + if trimmed := normalizeProviderNameAlias(name); trimmed != "" { return trimmed } } diff --git a/pkg/config/validate.go b/pkg/config/validate.go index b144871..58d1073 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -93,11 +93,6 @@ func Validate(cfg *Config) []error { if len(cfg.Models.Providers) == 0 { errs = append(errs, fmt.Errorf("models.providers must contain at least one provider")) } - for _, name := range cfg.Agents.Defaults.Model.Fallbacks { - if !ProviderExists(cfg, name) { - errs = append(errs, fmt.Errorf("agents.defaults.model.fallbacks contains unknown provider %q", name)) - } - } if primaryRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary); primaryRef != "" { providerName, modelName := ParseProviderModelRef(primaryRef) if providerName == "" { diff --git a/pkg/providers/aistudio_relay.go b/pkg/providers/aistudio_relay.go index ee9947e..25e7122 100644 --- a/pkg/providers/aistudio_relay.go +++ b/pkg/providers/aistudio_relay.go @@ -44,11 +44,31 @@ func aistudioChannelID(providerName string, options map[string]interface{}) stri } func aistudioChannelCandidates(providerName string, options map[string]interface{}) []string { - for _, key := range []string{"aistudio_channel", "aistudio_provider", "relay_provider"} { + for _, key := range []string{"aistudio_channel", "aistudio_provider", "relay_provider", "channel_id", "provider_id"} { if value, ok := stringOption(options, key); ok && strings.TrimSpace(value) != "" { return []string{strings.ToLower(strings.TrimSpace(value))} } } + for _, key := range []string{"aistudio_channels", "channel_ids", "relay_providers"} { + if values, ok := stringSliceOption(options, key); ok && len(values) > 0 { + out := make([]string, 0, len(values)) + seen := map[string]struct{}{} + for _, value := range values { + channelID := strings.ToLower(strings.TrimSpace(value)) + if channelID == "" { + continue + } + if _, exists := seen[channelID]; exists { + continue + } + seen[channelID] = struct{}{} + out = append(out, channelID) + } + if len(out) > 0 { + return out + } + } + } if runtimeSelected := preferredAIStudioRelayChannels(); len(runtimeSelected) > 0 { return runtimeSelected } diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go index 22fb52c..f918ea1 100644 --- a/pkg/providers/antigravity_provider.go +++ b/pkg/providers/antigravity_provider.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "time" ) @@ -252,6 +253,7 @@ func (p *AntigravityProvider) baseURLs() []string { func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any { request := map[string]any{} + baseModel := strings.TrimSpace(qwenBaseModel(model)) systemParts := make([]map[string]any, 0) contents := make([]map[string]any, 0, len(messages)) callNames := map[string]string{} @@ -299,6 +301,16 @@ func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolD if gen := antigravityGenerationConfig(options); len(gen) > 0 { request["generationConfig"] = gen } + if extra, ok := mapOption(options, "gemini_generation_config"); ok && len(extra) > 0 { + gen := mapFromAny(request["generationConfig"]) + if gen == nil { + gen = map[string]any{} + } + for k, v := range extra { + gen[k] = v + } + request["generationConfig"] = gen + } if toolDecls := antigravityToolDeclarations(tools); len(toolDecls) > 0 { request["tools"] = []map[string]any{{"function_declarations": toolDecls}} request["toolConfig"] = map[string]any{ @@ -307,18 +319,19 @@ func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolD } projectID := "" if session != nil { - projectID = firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"])) + projectID = firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["projectId"]), asString(session.Token["project"])) } if projectID == "" { projectID = "default-project" } + applyAntigravityThinkingSuffix(request, model) requestType := "agent" - if strings.Contains(strings.ToLower(model), "image") { + if strings.Contains(strings.ToLower(baseModel), "image") { requestType = "image_gen" } return map[string]any{ "project": projectID, - "model": strings.TrimSpace(model), + "model": baseModel, "userAgent": "antigravity", "requestType": requestType, "requestId": "agent-" + randomSessionID(), @@ -454,6 +467,79 @@ func antigravityGenerationConfig(options map[string]any) map[string]any { return cfg } +func applyAntigravityThinkingSuffix(request map[string]any, model string) { + suffix := qwenModelSuffix(model) + if suffix == "" { + return + } + baseModel := strings.TrimSpace(qwenBaseModel(model)) + gen := mapFromAny(request["generationConfig"]) + if gen == nil { + gen = map[string]any{} + } + thinkingConfig := mapFromAny(gen["thinkingConfig"]) + if thinkingConfig == nil { + thinkingConfig = map[string]any{} + } + includeThoughts, userSetIncludeThoughts := geminiExistingIncludeThoughts(thinkingConfig) + delete(thinkingConfig, "thinkingBudget") + delete(thinkingConfig, "thinking_budget") + delete(thinkingConfig, "thinkingLevel") + delete(thinkingConfig, "thinking_level") + delete(thinkingConfig, "include_thoughts") + + setIncludeThoughts := func(defaultValue bool, force bool) { + if force || !userSetIncludeThoughts { + includeThoughts = defaultValue + } + thinkingConfig["includeThoughts"] = includeThoughts + } + + lower := strings.ToLower(strings.TrimSpace(suffix)) + switch { + case lower == "auto" || lower == "-1": + thinkingConfig["thinkingBudget"] = -1 + setIncludeThoughts(true, false) + case lower == "none" || lower == "0": + if geminiUsesThinkingLevels(baseModel) { + thinkingConfig["thinkingLevel"] = "low" + } else { + thinkingConfig["thinkingBudget"] = 128 + } + setIncludeThoughts(false, true) + case isGeminiThinkingLevel(lower): + if geminiUsesThinkingLevels(baseModel) { + thinkingConfig["thinkingLevel"] = normalizeGeminiThinkingLevel(lower) + } else { + thinkingConfig["thinkingBudget"] = geminiThinkingBudgetForLevel(lower) + } + setIncludeThoughts(true, false) + default: + if budget, err := strconv.Atoi(lower); err == nil { + switch { + case budget < 0: + thinkingConfig["thinkingBudget"] = -1 + setIncludeThoughts(true, false) + case budget == 0: + if geminiUsesThinkingLevels(baseModel) { + thinkingConfig["thinkingLevel"] = "low" + } else { + thinkingConfig["thinkingBudget"] = 128 + } + setIncludeThoughts(false, true) + default: + thinkingConfig["thinkingBudget"] = budget + setIncludeThoughts(true, false) + } + } + } + if len(thinkingConfig) == 0 { + return + } + gen["thinkingConfig"] = thinkingConfig + request["generationConfig"] = gen +} + func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) { if onDelta == nil { onDelta = func(string) {} diff --git a/pkg/providers/antigravity_provider_test.go b/pkg/providers/antigravity_provider_test.go index 8d91fe4..74ad1b7 100644 --- a/pkg/providers/antigravity_provider_test.go +++ b/pkg/providers/antigravity_provider_test.go @@ -2,6 +2,7 @@ package providers import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "sync/atomic" @@ -166,3 +167,54 @@ func TestAntigravityBaseURLsIncludeProdFallback(t *testing.T) { t.Fatalf("unexpected fallback order: %#v", got) } } + +func TestAntigravityBuildRequestBodyAppliesThinkingSuffix(t *testing.T) { + p := NewAntigravityProvider("antigravity", "", "", "gemini-3-pro", false, "oauth", 0, nil) + body := p.buildRequestBody([]Message{{Role: "user", Content: "hello"}}, nil, "gemini-3-pro(high)", nil, &oauthSession{ProjectID: "demo-project"}, false) + if got := body["model"]; got != "gemini-3-pro" { + t.Fatalf("model = %#v, want gemini-3-pro", got) + } + request := mapFromAny(body["request"]) + gen := mapFromAny(request["generationConfig"]) + thinking := mapFromAny(gen["thinkingConfig"]) + if got := asString(thinking["thinkingLevel"]); got != "high" { + t.Fatalf("thinkingLevel = %q, want high", got) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "true" { + t.Fatalf("includeThoughts = %v, want true", thinking["includeThoughts"]) + } +} + +func TestAntigravityBuildRequestBodyDisablesThinkingOutput(t *testing.T) { + p := NewAntigravityProvider("antigravity", "", "", "gemini-2.5-pro", false, "oauth", 0, nil) + body := p.buildRequestBody([]Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro(0)", nil, &oauthSession{ProjectID: "demo-project"}, false) + request := mapFromAny(body["request"]) + gen := mapFromAny(request["generationConfig"]) + thinking := mapFromAny(gen["thinkingConfig"]) + if got := intValue(thinking["thinkingBudget"]); got != 128 { + t.Fatalf("thinkingBudget = %v, want 128", thinking["thinkingBudget"]) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" { + t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"]) + } +} + +func TestAntigravityThinkingSuffixPreservesExplicitIncludeThoughts(t *testing.T) { + p := NewAntigravityProvider("antigravity", "", "", "gemini-3-pro", false, "oauth", 0, nil) + body := p.buildRequestBody([]Message{{Role: "user", Content: "hello"}}, nil, "gemini-3-pro(high)", map[string]interface{}{ + "gemini_generation_config": map[string]interface{}{ + "thinkingConfig": map[string]interface{}{ + "includeThoughts": false, + }, + }, + }, &oauthSession{ProjectID: "demo-project"}, false) + request := mapFromAny(body["request"]) + gen := mapFromAny(request["generationConfig"]) + thinking := mapFromAny(gen["thinkingConfig"]) + if got := asString(thinking["thinkingLevel"]); got != "high" { + t.Fatalf("thinkingLevel = %q, want high", got) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" { + t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"]) + } +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index dce5a1f..7609f9e 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -820,8 +820,15 @@ func applyCodexWebsocketHeaders(headers http.Header, attempt authAttempt, option headers.Set("User-Agent", codexCompatUserAgent) if attempt.kind != "api_key" { headers.Set("Originator", "codex_cli_rs") - if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" { - headers.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID)) + if attempt.session != nil { + accountID := firstNonEmpty( + strings.TrimSpace(attempt.session.AccountID), + strings.TrimSpace(asString(attempt.session.Token["account_id"])), + strings.TrimSpace(asString(attempt.session.Token["account-id"])), + ) + if accountID != "" { + headers.Set("Chatgpt-Account-Id", accountID) + } } } return headers diff --git a/pkg/providers/gemini_cli_provider.go b/pkg/providers/gemini_cli_provider.go index 619af46..72b97fd 100644 --- a/pkg/providers/gemini_cli_provider.go +++ b/pkg/providers/gemini_cli_provider.go @@ -9,16 +9,18 @@ import ( "io" "net/http" "regexp" + "runtime" "strconv" "strings" "time" ) const ( - geminiCLIBaseURL = "https://cloudcode-pa.googleapis.com" - geminiCLIVersion = "v1internal" - geminiCLIDefaultAlt = "sse" - geminiCLIApiClient = "genai-cli/0 gl-go/1.0" + geminiCLIBaseURL = "https://cloudcode-pa.googleapis.com" + geminiCLIVersion = "v1internal" + geminiCLIDefaultAlt = "sse" + geminiCLIClientVersion = "0.31.0" + geminiCLIApiClient = "google-genai-sdk/1.41.0 gl-node/v22.19.0" ) type GeminiCLIProvider struct { @@ -220,7 +222,7 @@ func applyGeminiCLIAttemptAuth(req *http.Request, attempt authAttempt) error { } token := strings.TrimSpace(attempt.token) if attempt.session != nil { - token = firstNonEmpty(strings.TrimSpace(attempt.session.AccessToken), token, asString(attempt.session.Token["access_token"])) + token = firstNonEmpty(strings.TrimSpace(attempt.session.AccessToken), token, asString(attempt.session.Token["access_token"]), asString(attempt.session.Token["access-token"])) } if token == "" { return fmt.Errorf("missing access token for gemini-cli") @@ -263,26 +265,53 @@ func consumeGeminiCLIStream(resp *http.Response, onDelta func(string)) ([]byte, } func geminiCLIProjectID(options map[string]interface{}, session *oauthSession) string { - if value, ok := stringOption(options, "gemini_project_id"); ok { - return value - } - if value, ok := stringOption(options, "project_id"); ok { - return value + for _, key := range []string{"gemini_project_id", "project_id", "project"} { + if value, ok := stringOption(options, key); ok { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + return trimmed + } + } } if session == nil { return "" } - return firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"]), asString(session.Token["project"])) + return firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["projectId"]), asString(session.Token["project"])) +} + +func geminiCLIRuntimeOS() string { + switch runtime.GOOS { + case "windows": + return "win32" + default: + return runtime.GOOS + } +} + +func geminiCLIRuntimeArch() string { + switch runtime.GOARCH { + case "amd64": + return "x64" + case "386": + return "x86" + default: + return runtime.GOARCH + } +} + +func geminiCLIUserAgent(model string) string { + trimmedModel := strings.TrimSpace(model) + if trimmedModel == "" { + trimmedModel = "unknown" + } + return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", geminiCLIClientVersion, trimmedModel, geminiCLIRuntimeOS(), geminiCLIRuntimeArch()) } func applyGeminiCLIHeaders(req *http.Request, model string) { if req == nil { return } - if strings.TrimSpace(model) == "" { - model = "unknown" - } - req.Header.Set("User-Agent", "GeminiCLI/"+model) + req.Header.Set("User-Agent", geminiCLIUserAgent(model)) req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient) } diff --git a/pkg/providers/gemini_provider.go b/pkg/providers/gemini_provider.go index 5102872..77b39dd 100644 --- a/pkg/providers/gemini_provider.go +++ b/pkg/providers/gemini_provider.go @@ -3,8 +3,8 @@ package providers import ( "bufio" "bytes" - "encoding/base64" "context" + "encoding/base64" "encoding/json" "fmt" "io" @@ -15,8 +15,8 @@ import ( ) const ( - geminiBaseURL = "https://generativelanguage.googleapis.com" - geminiAPIVersion = "v1beta" + geminiBaseURL = "https://generativelanguage.googleapis.com" + geminiAPIVersion = "v1beta" geminiImagePreviewModel = "gemini-2.5-flash-image-preview" ) @@ -277,43 +277,50 @@ func applyGeminiThinkingSuffix(request map[string]any, model string) { if thinkingConfig == nil { thinkingConfig = map[string]any{} } + includeThoughts, userSetIncludeThoughts := geminiExistingIncludeThoughts(thinkingConfig) delete(thinkingConfig, "thinkingBudget") delete(thinkingConfig, "thinking_budget") delete(thinkingConfig, "thinkingLevel") delete(thinkingConfig, "thinking_level") delete(thinkingConfig, "include_thoughts") + setIncludeThoughts := func(defaultValue bool, force bool) { + if force || !userSetIncludeThoughts { + includeThoughts = defaultValue + } + thinkingConfig["includeThoughts"] = includeThoughts + } + lower := strings.ToLower(strings.TrimSpace(suffix)) switch { case lower == "auto" || lower == "-1": thinkingConfig["thinkingBudget"] = -1 - thinkingConfig["includeThoughts"] = true + setIncludeThoughts(true, false) case lower == "none": if geminiUsesThinkingLevels(baseModel) { thinkingConfig["thinkingLevel"] = "low" } else { thinkingConfig["thinkingBudget"] = 128 } - thinkingConfig["includeThoughts"] = false + setIncludeThoughts(false, true) case isGeminiThinkingLevel(lower): if geminiUsesThinkingLevels(baseModel) { thinkingConfig["thinkingLevel"] = normalizeGeminiThinkingLevel(lower) - thinkingConfig["includeThoughts"] = true } else { thinkingConfig["thinkingBudget"] = geminiThinkingBudgetForLevel(lower) - thinkingConfig["includeThoughts"] = true } + setIncludeThoughts(true, false) default: if budget, err := strconv.Atoi(lower); err == nil { if budget < 0 { thinkingConfig["thinkingBudget"] = -1 - thinkingConfig["includeThoughts"] = true + setIncludeThoughts(true, false) } else if budget == 0 { thinkingConfig["thinkingBudget"] = 128 - thinkingConfig["includeThoughts"] = false + setIncludeThoughts(false, true) } else { thinkingConfig["thinkingBudget"] = budget - thinkingConfig["includeThoughts"] = true + setIncludeThoughts(true, false) } } } @@ -324,6 +331,38 @@ func applyGeminiThinkingSuffix(request map[string]any, model string) { request["generationConfig"] = gen } +func geminiExistingIncludeThoughts(thinkingConfig map[string]any) (bool, bool) { + if thinkingConfig == nil { + return false, false + } + if value, ok := thinkingConfig["includeThoughts"]; ok { + return geminiBoolValue(value), true + } + if value, ok := thinkingConfig["include_thoughts"]; ok { + return geminiBoolValue(value), true + } + return false, false +} + +func geminiBoolValue(value any) bool { + switch typed := value.(type) { + case bool: + return typed + case string: + switch strings.ToLower(strings.TrimSpace(typed)) { + case "1", "true", "yes", "on": + return true + } + case int: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + } + return false +} + func geminiUsesThinkingLevels(model string) bool { trimmed := strings.ToLower(strings.TrimSpace(model)) return strings.Contains(trimmed, "gemini-3") @@ -490,10 +529,16 @@ func geminiBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string { return normalizeGeminiBaseURL(raw) } if attempt.session.Token != nil { - if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" { + if raw := firstNonEmpty( + strings.TrimSpace(asString(attempt.session.Token["base_url"])), + strings.TrimSpace(asString(attempt.session.Token["base-url"])), + ); raw != "" { return normalizeGeminiBaseURL(raw) } - if raw := strings.TrimSpace(asString(attempt.session.Token["resource_url"])); raw != "" { + if raw := firstNonEmpty( + strings.TrimSpace(asString(attempt.session.Token["resource_url"])), + strings.TrimSpace(asString(attempt.session.Token["resource-url"])), + ); raw != "" { return normalizeGeminiBaseURL(raw) } } diff --git a/pkg/providers/gemini_provider_test.go b/pkg/providers/gemini_provider_test.go index 5d1d72c..90761e8 100644 --- a/pkg/providers/gemini_provider_test.go +++ b/pkg/providers/gemini_provider_test.go @@ -292,3 +292,22 @@ func TestCreateProviderByNameRoutesAIStudioProviderViaGeminiTests(t *testing.T) t.Fatalf("provider = %T, want *AistudioProvider", provider) } } + +func TestGeminiThinkingSuffixPreservesExplicitIncludeThoughts(t *testing.T) { + p := NewGeminiProvider("gemini", "", "", "gemini-3-pro-preview", false, "api_key", 5*time.Second, nil) + body := p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-3-pro-preview(high)", map[string]interface{}{ + "gemini_generation_config": map[string]interface{}{ + "thinkingConfig": map[string]interface{}{ + "includeThoughts": false, + }, + }, + }, false) + gen := mapFromAny(body["generationConfig"]) + thinking := mapFromAny(gen["thinkingConfig"]) + if got := asString(thinking["thinkingLevel"]); got != "high" { + t.Fatalf("thinkingLevel = %q, want high", got) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" { + t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"]) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 5b0c803..cc48d49 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -2344,7 +2344,7 @@ func (p *HTTPProvider) compatBase() string { } func (p *HTTPProvider) compatModel(model string) string { - trimmed := strings.TrimSpace(model) + trimmed := strings.TrimSpace(qwenBaseModel(model)) if p.oauthProvider() == defaultKimiOAuthProvider && strings.HasPrefix(strings.ToLower(trimmed), "kimi-") { return trimmed[5:] } @@ -2356,6 +2356,9 @@ func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools [] "model": p.compatModel(model), "messages": openAICompatMessages(messages), } + if suffix := qwenModelSuffix(model); suffix != "" { + applyOpenAICompatThinkingSuffix(requestBody, suffix) + } if len(tools) > 0 { requestBody["tools"] = openAICompatTools(tools) requestBody["tool_choice"] = "auto" @@ -2680,6 +2683,29 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(ctx context.Context, mode return summary, nil } +func normalizeProviderRouteName(name string) string { + switch strings.ToLower(strings.TrimSpace(name)) { + case "geminicli", "gemini_cli": + return "gemini-cli" + case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio": + return "aistudio" + case "google", "gemini-api-key", "gemini_api_key": + return "gemini" + case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key": + return "claude" + case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat": + return "openai-compatibility" + case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility": + return "vertex" + case "codex-api-key", "codex_api_key": + return "codex" + case "i-flow", "i_flow": + return "iflow" + default: + return strings.TrimSpace(name) + } +} + func CreateProvider(cfg *config.Config) (LLMProvider, error) { name := config.PrimaryProviderName(cfg) provider, err := CreateProviderByName(cfg, name) @@ -2694,18 +2720,32 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { } func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) { - pc, err := getProviderConfigByName(cfg, name) + routeName := normalizeProviderRouteName(name) + pc, err := getProviderConfigByName(cfg, routeName) if err != nil { return nil, err } - ConfigureProviderRuntime(name, pc) - oauthProvider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider)) + ConfigureProviderRuntime(routeName, pc) + oauthProvider := normalizeOAuthProvider(pc.OAuth.Provider) if pc.APIBase == "" && oauthProvider != defaultAntigravityOAuthProvider && oauthProvider != defaultGeminiOAuthProvider && - !strings.EqualFold(name, "gemini-cli") && - !strings.EqualFold(name, "aistudio") && - !strings.EqualFold(name, "vertex") { + oauthProvider != "aistudio" && + oauthProvider != defaultCodexOAuthProvider && + oauthProvider != defaultClaudeOAuthProvider && + oauthProvider != defaultQwenOAuthProvider && + oauthProvider != defaultKimiOAuthProvider && + oauthProvider != defaultIFlowOAuthProvider && + !strings.EqualFold(routeName, "gemini-cli") && + !strings.EqualFold(routeName, "aistudio") && + !strings.EqualFold(routeName, "vertex") && + !strings.EqualFold(routeName, defaultAntigravityOAuthProvider) && + !strings.EqualFold(routeName, defaultGeminiOAuthProvider) && + !strings.EqualFold(routeName, defaultCodexOAuthProvider) && + !strings.EqualFold(routeName, defaultClaudeOAuthProvider) && + !strings.EqualFold(routeName, defaultQwenOAuthProvider) && + !strings.EqualFold(routeName, defaultKimiOAuthProvider) && + !strings.EqualFold(routeName, defaultIFlowOAuthProvider) { return nil, fmt.Errorf("no API base configured for provider %q", name) } if pc.TimeoutSec <= 0 { @@ -2722,37 +2762,37 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) return nil, err } } - if oauthProvider == defaultAntigravityOAuthProvider { - return NewAntigravityProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == defaultAntigravityOAuthProvider || strings.EqualFold(routeName, defaultAntigravityOAuthProvider) { + return NewAntigravityProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if strings.EqualFold(name, "aistudio") { - return NewAistudioProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == "aistudio" || strings.EqualFold(routeName, "aistudio") { + return NewAistudioProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if strings.EqualFold(name, "gemini-cli") { - return NewGeminiCLIProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if strings.EqualFold(routeName, "gemini-cli") { + return NewGeminiCLIProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if oauthProvider == defaultGeminiOAuthProvider || strings.EqualFold(name, defaultGeminiOAuthProvider) || strings.EqualFold(name, "aistudio") { - return NewGeminiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == defaultGeminiOAuthProvider || strings.EqualFold(routeName, defaultGeminiOAuthProvider) || strings.EqualFold(routeName, "aistudio") { + return NewGeminiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if strings.EqualFold(name, "vertex") { - return NewVertexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if strings.EqualFold(routeName, "vertex") { + return NewVertexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if oauthProvider == defaultCodexOAuthProvider { - return NewCodexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == defaultCodexOAuthProvider || strings.EqualFold(routeName, defaultCodexOAuthProvider) { + return NewCodexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if oauthProvider == defaultClaudeOAuthProvider { - return NewClaudeProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == defaultClaudeOAuthProvider || strings.EqualFold(routeName, defaultClaudeOAuthProvider) { + return NewClaudeProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if oauthProvider == defaultQwenOAuthProvider { - return NewQwenProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == defaultQwenOAuthProvider || strings.EqualFold(routeName, defaultQwenOAuthProvider) { + return NewQwenProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if oauthProvider == defaultKimiOAuthProvider { - return NewKimiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == defaultKimiOAuthProvider || strings.EqualFold(routeName, defaultKimiOAuthProvider) { + return NewKimiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(name, defaultIFlowOAuthProvider) { - return NewIFlowProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(routeName, defaultIFlowOAuthProvider) { + return NewIFlowProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } - return NewHTTPProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + return NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } func CreateProviders(cfg *config.Config) (map[string]LLMProvider, error) { diff --git a/pkg/providers/iflow_provider.go b/pkg/providers/iflow_provider.go index 8d3f77f..fd9de26 100644 --- a/pkg/providers/iflow_provider.go +++ b/pkg/providers/iflow_provider.go @@ -275,10 +275,16 @@ func iflowBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string { return normalizeIFlowBaseURL(raw) } if attempt.session.Token != nil { - if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" { + if raw := firstNonEmpty( + strings.TrimSpace(asString(attempt.session.Token["base_url"])), + strings.TrimSpace(asString(attempt.session.Token["base-url"])), + ); raw != "" { return normalizeIFlowBaseURL(raw) } - if raw := strings.TrimSpace(asString(attempt.session.Token["resource_url"])); raw != "" { + if raw := firstNonEmpty( + strings.TrimSpace(asString(attempt.session.Token["resource_url"])), + strings.TrimSpace(asString(attempt.session.Token["resource-url"])), + ); raw != "" { return normalizeIFlowBaseURL(raw) } } @@ -309,10 +315,11 @@ func normalizeIFlowBaseURL(raw string) string { func iflowAttemptAPIKey(attempt authAttempt) string { if attempt.session != nil && attempt.session.Token != nil { - if v := strings.TrimSpace(asString(attempt.session.Token["api_key"])); v != "" { - return v - } - if v := strings.TrimSpace(asString(attempt.session.Token["apiKey"])); v != "" { + if v := firstNonEmpty( + strings.TrimSpace(asString(attempt.session.Token["api_key"])), + strings.TrimSpace(asString(attempt.session.Token["api-key"])), + strings.TrimSpace(asString(attempt.session.Token["apiKey"])), + ); v != "" { return v } } diff --git a/pkg/providers/kimi_provider.go b/pkg/providers/kimi_provider.go index 029536a..1108912 100644 --- a/pkg/providers/kimi_provider.go +++ b/pkg/providers/kimi_provider.go @@ -119,37 +119,7 @@ func applyKimiThinking(body map[string]interface{}, model string) { if suffix == "" { return } - suffix = strings.ToLower(strings.TrimSpace(suffix)) - switch suffix { - case "low", "medium", "high", "auto": - body["reasoning_effort"] = suffix - delete(body, "thinking") - case "none": - delete(body, "reasoning_effort") - body["thinking"] = map[string]interface{}{"type": "disabled"} - default: - if budget, err := parsePositiveInt(suffix); err == nil && budget > 0 { - delete(body, "reasoning_effort") - body["thinking"] = map[string]interface{}{ - "type": "enabled", - "budget_tokens": budget, - } - } - } -} - -func parsePositiveInt(raw string) (int, error) { - var out int - for _, ch := range raw { - if ch < '0' || ch > '9' { - return 0, fmt.Errorf("non-digit") - } - out = out*10 + int(ch-'0') - } - if out <= 0 { - return 0, fmt.Errorf("not positive") - } - return out, nil + _ = applyOpenAICompatThinkingSuffix(body, suffix) } func normalizeKimiToolMessages(body map[string]interface{}) { diff --git a/pkg/providers/kimi_provider_test.go b/pkg/providers/kimi_provider_test.go index 04b59dc..51afc39 100644 --- a/pkg/providers/kimi_provider_test.go +++ b/pkg/providers/kimi_provider_test.go @@ -143,3 +143,16 @@ func TestKimiProviderCountTokens(t *testing.T) { t.Fatalf("usage = %#v, want positive prompt-only count", usage) } } + +func TestBuildKimiChatRequestSupportsNumericAutoAndDisable(t *testing.T) { + base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil) + autoBody := buildKimiChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(-1)", nil, false) + if got := autoBody["reasoning_effort"]; got != "auto" { + t.Fatalf("reasoning_effort = %#v, want auto", got) + } + disableBody := buildKimiChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(0)", nil, false) + thinking, _ := disableBody["thinking"].(map[string]interface{}) + if got := thinking["type"]; got != "disabled" { + t.Fatalf("thinking.type = %#v, want disabled", got) + } +} diff --git a/pkg/providers/oauth.go b/pkg/providers/oauth.go index 344b87a..3010902 100644 --- a/pkg/providers/oauth.go +++ b/pkg/providers/oauth.go @@ -692,10 +692,20 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) { func normalizeOAuthProvider(provider string) string { switch strings.ToLower(strings.TrimSpace(provider)) { - case "anthropic": + case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key": return defaultClaudeOAuthProvider - case "gemini-cli": + case "gemini-cli", "geminicli", "gemini_cli", "google", "gemini-api-key", "gemini_api_key": return defaultGeminiOAuthProvider + case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio": + return "aistudio" + case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat": + return "openai-compatibility" + case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility": + return "vertex" + case "codex-api-key", "codex_api_key": + return defaultCodexOAuthProvider + case "i-flow", "i_flow": + return defaultIFlowOAuthProvider default: return strings.ToLower(strings.TrimSpace(provider)) } diff --git a/pkg/providers/openai_compat_provider_test.go b/pkg/providers/openai_compat_provider_test.go index 054f579..4a72a2d 100644 --- a/pkg/providers/openai_compat_provider_test.go +++ b/pkg/providers/openai_compat_provider_test.go @@ -157,3 +157,26 @@ func TestOpenAICompatMessagesPreserveMultimodalContentParts(t *testing.T) { t.Fatalf("image detail = %#v", got) } } + +func TestBuildOpenAICompatChatRequestAppliesThinkingSuffix(t *testing.T) { + base := NewHTTPProvider("openai", "token", "https://example.com/v1", "gpt-5", false, "api_key", 5*time.Second, nil) + body := base.buildOpenAICompatChatRequest([]Message{{Role: "user", Content: "hi"}}, nil, "gpt-5(high)", nil) + if got := body["model"]; got != "gpt-5" { + t.Fatalf("model = %#v, want gpt-5", got) + } + if got := body["reasoning_effort"]; got != "high" { + t.Fatalf("reasoning_effort = %#v, want high", got) + } +} + +func TestBuildOpenAICompatChatRequestStripsKimiPrefixAndSuffix(t *testing.T) { + base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil) + base.oauth = &oauthManager{cfg: oauthConfig{Provider: defaultKimiOAuthProvider}} + body := base.buildOpenAICompatChatRequest([]Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(-1)", nil) + if got := body["model"]; got != "k2.5" { + t.Fatalf("model = %#v, want k2.5", got) + } + if got := body["reasoning_effort"]; got != "auto" { + t.Fatalf("reasoning_effort = %#v, want auto", got) + } +} diff --git a/pkg/providers/qwen_provider.go b/pkg/providers/qwen_provider.go index 87494f8..015566f 100644 --- a/pkg/providers/qwen_provider.go +++ b/pkg/providers/qwen_provider.go @@ -164,19 +164,58 @@ func applyQwenThinkingSuffix(body map[string]interface{}, suffix string) { if suffix == "" { return } - switch suffix { - case "low", "medium", "high", "auto": - body["reasoning_effort"] = suffix - case "none": + if applyOpenAICompatThinkingSuffix(body, suffix) { + return + } +} + +func applyOpenAICompatThinkingSuffix(body map[string]interface{}, suffix string) bool { + if body == nil { + return false + } + normalizedLevel, isLevel := normalizeOpenAICompatThinkingLevel(suffix) + switch { + case isLevel: + delete(body, "thinking") + body["reasoning_effort"] = normalizedLevel + return true + case strings.EqualFold(strings.TrimSpace(suffix), "none"): delete(body, "reasoning_effort") body["thinking"] = map[string]interface{}{"type": "disabled"} + return true default: - if n, err := strconv.Atoi(suffix); err == nil && n > 0 { + n, err := strconv.Atoi(strings.TrimSpace(suffix)) + if err != nil { + return false + } + switch { + case n < 0: + delete(body, "thinking") + body["reasoning_effort"] = "auto" + case n == 0: + delete(body, "reasoning_effort") + body["thinking"] = map[string]interface{}{"type": "disabled"} + default: + delete(body, "reasoning_effort") body["thinking"] = map[string]interface{}{ "type": "enabled", "budget_tokens": n, } } + return true + } +} + +func normalizeOpenAICompatThinkingLevel(raw string) (string, bool) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "minimal": + return "low", true + case "low", "medium", "high", "auto": + return strings.ToLower(strings.TrimSpace(raw)), true + case "xhigh", "max": + return "high", true + default: + return "", false } } diff --git a/pkg/providers/qwen_provider_test.go b/pkg/providers/qwen_provider_test.go new file mode 100644 index 0000000..e218bce --- /dev/null +++ b/pkg/providers/qwen_provider_test.go @@ -0,0 +1,23 @@ +package providers + +import ( + "testing" + "time" +) + +func TestBuildQwenChatRequestSupportsExtendedThinkingSuffixes(t *testing.T) { + base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil) + autoBody := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(-1)", nil, false) + if got := autoBody["reasoning_effort"]; got != "auto" { + t.Fatalf("reasoning_effort = %#v, want auto", got) + } + disableBody := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(0)", nil, false) + thinking, _ := disableBody["thinking"].(map[string]interface{}) + if got := thinking["type"]; got != "disabled" { + t.Fatalf("thinking.type = %#v, want disabled", got) + } + minimalBody := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(minimal)", nil, false) + if got := minimalBody["reasoning_effort"]; got != "low" { + t.Fatalf("reasoning_effort = %#v, want low", got) + } +} diff --git a/pkg/providers/vertex_provider.go b/pkg/providers/vertex_provider.go index 624259a..071771c 100644 --- a/pkg/providers/vertex_provider.go +++ b/pkg/providers/vertex_provider.go @@ -226,7 +226,10 @@ func (p *VertexProvider) endpoint(attempt authAttempt, model, action string, str func vertexBaseURLForAttempt(base *HTTPProvider, attempt authAttempt, options map[string]interface{}) string { customBase := "" if attempt.session != nil && attempt.session.Token != nil { - if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" { + if raw := firstNonEmpty( + strings.TrimSpace(asString(attempt.session.Token["base_url"])), + strings.TrimSpace(asString(attempt.session.Token["base-url"])), + ); raw != "" { customBase = normalizeVertexBaseURL(raw) } } @@ -256,11 +259,16 @@ func normalizeVertexBaseURL(raw string) string { func vertexProjectLocation(attempt authAttempt, options map[string]interface{}) (string, string, bool) { projectID := "" - if value, ok := stringOption(options, "vertex_project_id"); ok { - projectID = strings.TrimSpace(value) + for _, key := range []string{"vertex_project_id", "project_id", "project"} { + if value, ok := stringOption(options, key); ok { + projectID = strings.TrimSpace(value) + if projectID != "" { + break + } + } } if attempt.session != nil { - projectID = firstNonEmpty(projectID, strings.TrimSpace(attempt.session.ProjectID), asString(attempt.session.Token["project_id"]), asString(attempt.session.Token["projectId"]), asString(attempt.session.Token["project"])) + projectID = firstNonEmpty(projectID, strings.TrimSpace(attempt.session.ProjectID), asString(attempt.session.Token["project_id"]), asString(attempt.session.Token["project-id"]), asString(attempt.session.Token["projectId"]), asString(attempt.session.Token["project"])) if projectID == "" { projectID = strings.TrimSpace(asString(mapFromAny(attempt.session.Token["service_account"])["project_id"])) } @@ -274,11 +282,16 @@ func vertexProjectLocation(attempt authAttempt, options map[string]interface{}) func vertexLocationForAttempt(attempt authAttempt, options map[string]interface{}) string { location := "" - if value, ok := stringOption(options, "vertex_location"); ok { - location = strings.TrimSpace(value) + for _, key := range []string{"vertex_location", "location", "region"} { + if value, ok := stringOption(options, key); ok { + location = strings.TrimSpace(value) + if location != "" { + break + } + } } if attempt.session != nil { - location = firstNonEmpty(location, asString(attempt.session.Token["location"]), asString(mapFromAny(attempt.session.Token["service_account"])["location"])) + location = firstNonEmpty(location, asString(attempt.session.Token["location"]), asString(attempt.session.Token["region"]), asString(mapFromAny(attempt.session.Token["service_account"])["location"])) } if strings.TrimSpace(location) == "" { location = vertexDefaultRegion @@ -501,7 +514,7 @@ func vertexServiceAccountJSON(session *oauthSession) ([]byte, error) { return nil, fmt.Errorf("vertex service account missing") } raw := mapFromAny(session.Token["service_account"]) - if projectID := firstNonEmpty(asString(raw["project_id"]), strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project"])); projectID != "" { + if projectID := firstNonEmpty(asString(raw["project_id"]), asString(raw["project-id"]), strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["project"])); projectID != "" { raw["project_id"] = projectID } data, err := json.Marshal(raw)