mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-20 13:47:28 +08:00
feat: align cliproxyapi providers and auto fallback
This commit is contained in:
@@ -265,37 +265,6 @@ func providerCmd() {
|
||||
cfg.Agents.Defaults.Model.Primary = providerName + "/" + targetModel
|
||||
}
|
||||
|
||||
currentFallbacks := strings.Join(cfg.Agents.Defaults.Model.Fallbacks, ",")
|
||||
fallbackRaw := promptLine(reader, "agents.defaults.model.fallbacks (comma-separated provider/model refs)", currentFallbacks)
|
||||
fallbacks := parseCSV(fallbackRaw)
|
||||
valid := map[string]struct{}{}
|
||||
for _, name := range providerNames(cfg) {
|
||||
valid[name] = struct{}{}
|
||||
}
|
||||
filteredFallbacks := make([]string, 0, len(fallbacks))
|
||||
seen := map[string]struct{}{}
|
||||
defaultRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary)
|
||||
for _, fb := range fallbacks {
|
||||
if fb == "" || fb == defaultRef {
|
||||
continue
|
||||
}
|
||||
fbProvider, fbModel := config.ParseProviderModelRef(fb)
|
||||
if fbProvider == "" || fbModel == "" {
|
||||
fmt.Printf("Skip invalid fallback provider/model ref: %s\n", fb)
|
||||
continue
|
||||
}
|
||||
if _, ok := valid[fbProvider]; !ok {
|
||||
fmt.Printf("Skip unknown fallback provider: %s\n", fb)
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[fb]; ok {
|
||||
continue
|
||||
}
|
||||
seen[fb] = struct{}{}
|
||||
filteredFallbacks = append(filteredFallbacks, fb)
|
||||
}
|
||||
cfg.Agents.Defaults.Model.Fallbacks = filteredFallbacks
|
||||
|
||||
if err := config.SaveConfig(getConfigPath(), cfg); err != nil {
|
||||
fmt.Printf("Error saving config: %v\n", err)
|
||||
os.Exit(1)
|
||||
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
|
||||
type AgentLoop struct {
|
||||
bus *bus.MessageBus
|
||||
cfg *config.Config
|
||||
provider providers.LLMProvider
|
||||
workspace string
|
||||
model string
|
||||
@@ -54,6 +55,7 @@ type AgentLoop struct {
|
||||
audit *triggerAudit
|
||||
running bool
|
||||
sessionScheduler *SessionScheduler
|
||||
providerChain []providerCandidate
|
||||
providerNames []string
|
||||
providerPool map[string]providers.LLMProvider
|
||||
providerResponses map[string]config.ProviderResponsesConfig
|
||||
@@ -73,6 +75,12 @@ type AgentLoop struct {
|
||||
subagentDigests map[string]*subagentDigestState
|
||||
}
|
||||
|
||||
type providerCandidate struct {
|
||||
ref string
|
||||
name string
|
||||
model string
|
||||
}
|
||||
|
||||
type subagentDigestItem struct {
|
||||
agentID string
|
||||
reason string
|
||||
@@ -315,6 +323,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
|
||||
loop := &AgentLoop{
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
provider: provider,
|
||||
workspace: workspace,
|
||||
model: provider.GetDefaultModel(),
|
||||
@@ -346,36 +355,75 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
loop.model = strings.TrimSpace(primaryModel)
|
||||
}
|
||||
go loop.runSubagentDigestTicker()
|
||||
// Initialize provider fallback chain (primary + model fallbacks).
|
||||
// Initialize provider fallback chain (primary + inferred providers).
|
||||
loop.providerChain = []providerCandidate{}
|
||||
loop.providerPool = map[string]providers.LLMProvider{}
|
||||
loop.providerNames = []string{}
|
||||
primaryName := config.PrimaryProviderName(cfg)
|
||||
primaryRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary)
|
||||
if primaryRef == "" {
|
||||
primaryRef = primaryName + "/" + loop.model
|
||||
}
|
||||
loop.providerPool[primaryName] = provider
|
||||
loop.providerChain = append(loop.providerChain, providerCandidate{
|
||||
ref: primaryRef,
|
||||
name: primaryName,
|
||||
model: loop.model,
|
||||
})
|
||||
loop.providerNames = append(loop.providerNames, primaryName)
|
||||
if pc, ok := config.ProviderConfigByName(cfg, primaryName); ok {
|
||||
loop.providerResponses[primaryName] = pc.Responses
|
||||
}
|
||||
for _, name := range cfg.Agents.Defaults.Model.Fallbacks {
|
||||
if name == "" {
|
||||
seenProviders := map[string]struct{}{primaryName: {}}
|
||||
providerConfigs := config.AllProviderConfigs(cfg)
|
||||
providerOrder := make([]string, 0, len(providerConfigs))
|
||||
for name := range providerConfigs {
|
||||
normalized := strings.TrimSpace(name)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
dup := false
|
||||
for _, existing := range loop.providerNames {
|
||||
if existing == name {
|
||||
dup = true
|
||||
break
|
||||
}
|
||||
providerOrder = append(providerOrder, normalized)
|
||||
}
|
||||
sort.SliceStable(providerOrder, func(i, j int) bool {
|
||||
ni := normalizeFallbackProviderName(providerOrder[i])
|
||||
nj := normalizeFallbackProviderName(providerOrder[j])
|
||||
pi := automaticFallbackPriority(ni)
|
||||
pj := automaticFallbackPriority(nj)
|
||||
if pi == pj {
|
||||
return ni < nj
|
||||
}
|
||||
if dup {
|
||||
return pi < pj
|
||||
})
|
||||
for _, rawName := range providerOrder {
|
||||
providerName := strings.TrimSpace(rawName)
|
||||
if providerName == "" {
|
||||
continue
|
||||
}
|
||||
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
|
||||
loop.providerPool[name] = p2
|
||||
loop.providerNames = append(loop.providerNames, name)
|
||||
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
|
||||
loop.providerResponses[name] = pc.Responses
|
||||
}
|
||||
providerName, _ = config.ParseProviderModelRef(providerName + "/_")
|
||||
if providerName == "" {
|
||||
continue
|
||||
}
|
||||
if _, dup := seenProviders[providerName]; dup {
|
||||
continue
|
||||
}
|
||||
modelName := ""
|
||||
if pc, ok := config.ProviderConfigByName(cfg, providerName); ok {
|
||||
if len(pc.Models) > 0 {
|
||||
modelName = strings.TrimSpace(pc.Models[0])
|
||||
}
|
||||
loop.providerResponses[providerName] = pc.Responses
|
||||
}
|
||||
seenProviders[providerName] = struct{}{}
|
||||
loop.providerNames = append(loop.providerNames, providerName)
|
||||
ref := providerName
|
||||
if modelName != "" {
|
||||
ref += "/" + modelName
|
||||
}
|
||||
loop.providerChain = append(loop.providerChain, providerCandidate{
|
||||
ref: ref,
|
||||
name: providerName,
|
||||
model: modelName,
|
||||
})
|
||||
}
|
||||
|
||||
// Inject recursive run logic so subagents can use full tool-calling flows.
|
||||
@@ -581,24 +629,44 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM
|
||||
}
|
||||
|
||||
func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, options map[string]interface{}, primaryErr error) (*providers.LLMResponse, string, error) {
|
||||
if len(al.providerNames) <= 1 {
|
||||
if len(al.providerChain) <= 1 {
|
||||
return nil, "", primaryErr
|
||||
}
|
||||
lastErr := primaryErr
|
||||
candidates := append([]string(nil), al.providerNames[1:]...)
|
||||
candidateNames := make([]string, 0, len(al.providerChain)-1)
|
||||
for _, candidate := range al.providerChain[1:] {
|
||||
candidateNames = append(candidateNames, candidate.name)
|
||||
}
|
||||
if al.ekg != nil {
|
||||
errSig := ""
|
||||
if primaryErr != nil {
|
||||
errSig = primaryErr.Error()
|
||||
}
|
||||
candidates = al.ekg.RankProvidersForError(candidates, errSig)
|
||||
candidateNames = al.ekg.RankProvidersForError(candidateNames, errSig)
|
||||
}
|
||||
for _, name := range candidates {
|
||||
p, ok := al.providerPool[name]
|
||||
if !ok || p == nil {
|
||||
ranked := make([]providerCandidate, 0, len(al.providerChain)-1)
|
||||
used := make([]bool, len(al.providerChain)-1)
|
||||
for _, name := range candidateNames {
|
||||
for idx, candidate := range al.providerChain[1:] {
|
||||
if used[idx] || candidate.name != name {
|
||||
continue
|
||||
}
|
||||
used[idx] = true
|
||||
ranked = append(ranked, candidate)
|
||||
}
|
||||
}
|
||||
for idx, candidate := range al.providerChain[1:] {
|
||||
if !used[idx] {
|
||||
ranked = append(ranked, candidate)
|
||||
}
|
||||
}
|
||||
for _, candidate := range ranked {
|
||||
p, candidateModel, err := al.ensureProviderCandidate(candidate)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
resp, err := p.Chat(ctx, messages, toolDefs, al.model, options)
|
||||
resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, options)
|
||||
if al.ekg != nil {
|
||||
st := "success"
|
||||
lg := "fallback provider success"
|
||||
@@ -608,17 +676,96 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe
|
||||
lg = err.Error()
|
||||
errSig = err.Error()
|
||||
}
|
||||
al.ekg.Record(ekg.Event{Session: msg.SessionKey, Channel: msg.Channel, Source: "provider_fallback", Status: st, Provider: name, Model: al.model, ErrSig: errSig, Log: lg})
|
||||
al.ekg.Record(ekg.Event{Session: msg.SessionKey, Channel: msg.Channel, Source: "provider_fallback", Status: st, Provider: candidate.name, Model: candidateModel, ErrSig: errSig, Log: lg})
|
||||
}
|
||||
if err == nil {
|
||||
logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": name})
|
||||
return resp, name, nil
|
||||
logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": candidate.name, "model": candidateModel, "ref": candidate.ref})
|
||||
return resp, candidate.name, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return nil, "", lastErr
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (providers.LLMProvider, string, error) {
|
||||
if al == nil {
|
||||
return nil, "", fmt.Errorf("agent loop is nil")
|
||||
}
|
||||
name := strings.TrimSpace(candidate.name)
|
||||
if name == "" {
|
||||
return nil, "", fmt.Errorf("fallback provider name is empty")
|
||||
}
|
||||
al.providerMu.RLock()
|
||||
existing := al.providerPool[name]
|
||||
al.providerMu.RUnlock()
|
||||
if existing != nil {
|
||||
model := strings.TrimSpace(candidate.model)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(existing.GetDefaultModel())
|
||||
}
|
||||
if model == "" {
|
||||
return nil, "", fmt.Errorf("fallback provider %q has no model configured", name)
|
||||
}
|
||||
return existing, model, nil
|
||||
}
|
||||
if al.cfg == nil {
|
||||
return nil, "", fmt.Errorf("config not available for fallback provider %q", name)
|
||||
}
|
||||
created, err := providers.CreateProviderByName(al.cfg, name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
model := strings.TrimSpace(candidate.model)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(created.GetDefaultModel())
|
||||
}
|
||||
if model == "" {
|
||||
return nil, "", fmt.Errorf("fallback provider %q has no model configured", name)
|
||||
}
|
||||
al.providerMu.Lock()
|
||||
if existing := al.providerPool[name]; existing != nil {
|
||||
al.providerMu.Unlock()
|
||||
return existing, model, nil
|
||||
}
|
||||
al.providerPool[name] = created
|
||||
al.providerMu.Unlock()
|
||||
return created, model, nil
|
||||
}
|
||||
|
||||
func automaticFallbackPriority(name string) int {
|
||||
switch normalizeFallbackProviderName(name) {
|
||||
case "claude":
|
||||
return 10
|
||||
case "codex":
|
||||
return 20
|
||||
case "gemini":
|
||||
return 30
|
||||
case "gemini-cli":
|
||||
return 40
|
||||
case "aistudio":
|
||||
return 50
|
||||
case "vertex":
|
||||
return 60
|
||||
case "antigravity":
|
||||
return 70
|
||||
case "qwen":
|
||||
return 80
|
||||
case "kimi":
|
||||
return 90
|
||||
case "iflow":
|
||||
return 100
|
||||
case "openai-compatibility":
|
||||
return 110
|
||||
default:
|
||||
return 1000
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeFallbackProviderName(name string) string {
|
||||
normalized, _ := config.ParseProviderModelRef(strings.TrimSpace(name) + "/_")
|
||||
return strings.TrimSpace(normalized)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) setSessionProvider(sessionKey, provider string) {
|
||||
key := strings.TrimSpace(sessionKey)
|
||||
if key == "" {
|
||||
@@ -1188,12 +1335,9 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if fb, fbProvider, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
|
||||
if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
|
||||
response = fb
|
||||
err = nil
|
||||
if fbProvider != "" {
|
||||
al.setSessionProvider(msg.SessionKey, fbProvider)
|
||||
}
|
||||
} else {
|
||||
err = ferr
|
||||
}
|
||||
@@ -1542,12 +1686,9 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
|
||||
|
||||
if err != nil {
|
||||
if fb, fbProvider, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
|
||||
if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
|
||||
response = fb
|
||||
err = nil
|
||||
if fbProvider != "" {
|
||||
al.setSessionProvider(msg.SessionKey, fbProvider)
|
||||
}
|
||||
} else {
|
||||
err = ferr
|
||||
}
|
||||
|
||||
@@ -132,8 +132,7 @@ type AgentDefaults struct {
|
||||
}
|
||||
|
||||
type AgentModelDefaults struct {
|
||||
Primary string `json:"primary,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_PRIMARY"`
|
||||
Fallbacks []string `json:"fallbacks,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_FALLBACKS"`
|
||||
Primary string `json:"primary,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_PRIMARY"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
@@ -445,7 +444,7 @@ func DefaultConfig() *Config {
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: filepath.Join(configDir, "workspace"),
|
||||
Model: AgentModelDefaults{Primary: "openai/gpt-5.4", Fallbacks: []string{}},
|
||||
Model: AgentModelDefaults{Primary: "openai/gpt-5.4"},
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
@@ -660,13 +659,36 @@ func DefaultConfig() *Config {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeProviderNameAlias(name string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(name)) {
|
||||
case "geminicli", "gemini_cli":
|
||||
return "gemini-cli"
|
||||
case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio":
|
||||
return "aistudio"
|
||||
case "google", "gemini-api-key", "gemini_api_key":
|
||||
return "gemini"
|
||||
case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key":
|
||||
return "claude"
|
||||
case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat":
|
||||
return "openai-compatibility"
|
||||
case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility":
|
||||
return "vertex"
|
||||
case "codex-api-key", "codex_api_key":
|
||||
return "codex"
|
||||
case "i-flow", "i_flow":
|
||||
return "iflow"
|
||||
default:
|
||||
return strings.TrimSpace(name)
|
||||
}
|
||||
}
|
||||
|
||||
func ParseProviderModelRef(raw string) (provider string, model string) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", ""
|
||||
}
|
||||
if idx := strings.Index(trimmed, "/"); idx > 0 {
|
||||
return strings.TrimSpace(trimmed[:idx]), strings.TrimSpace(trimmed[idx+1:])
|
||||
return normalizeProviderNameAlias(trimmed[:idx]), strings.TrimSpace(trimmed[idx+1:])
|
||||
}
|
||||
return "", trimmed
|
||||
}
|
||||
@@ -690,7 +712,12 @@ func ProviderConfigByName(cfg *Config, name string) (ProviderConfig, bool) {
|
||||
if cfg == nil {
|
||||
return ProviderConfig{}, false
|
||||
}
|
||||
pc, ok := AllProviderConfigs(cfg)[strings.TrimSpace(name)]
|
||||
configs := AllProviderConfigs(cfg)
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if pc, ok := configs[trimmed]; ok {
|
||||
return pc, true
|
||||
}
|
||||
pc, ok := configs[normalizeProviderNameAlias(trimmed)]
|
||||
return pc, ok
|
||||
}
|
||||
|
||||
@@ -704,10 +731,10 @@ func PrimaryProviderName(cfg *Config) string {
|
||||
return "openai"
|
||||
}
|
||||
if provider, _ := ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); provider != "" {
|
||||
return provider
|
||||
return normalizeProviderNameAlias(provider)
|
||||
}
|
||||
for name := range cfg.Models.Providers {
|
||||
if trimmed := strings.TrimSpace(name); trimmed != "" {
|
||||
if trimmed := normalizeProviderNameAlias(name); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,11 +93,6 @@ func Validate(cfg *Config) []error {
|
||||
if len(cfg.Models.Providers) == 0 {
|
||||
errs = append(errs, fmt.Errorf("models.providers must contain at least one provider"))
|
||||
}
|
||||
for _, name := range cfg.Agents.Defaults.Model.Fallbacks {
|
||||
if !ProviderExists(cfg, name) {
|
||||
errs = append(errs, fmt.Errorf("agents.defaults.model.fallbacks contains unknown provider %q", name))
|
||||
}
|
||||
}
|
||||
if primaryRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary); primaryRef != "" {
|
||||
providerName, modelName := ParseProviderModelRef(primaryRef)
|
||||
if providerName == "" {
|
||||
|
||||
@@ -44,11 +44,31 @@ func aistudioChannelID(providerName string, options map[string]interface{}) stri
|
||||
}
|
||||
|
||||
func aistudioChannelCandidates(providerName string, options map[string]interface{}) []string {
|
||||
for _, key := range []string{"aistudio_channel", "aistudio_provider", "relay_provider"} {
|
||||
for _, key := range []string{"aistudio_channel", "aistudio_provider", "relay_provider", "channel_id", "provider_id"} {
|
||||
if value, ok := stringOption(options, key); ok && strings.TrimSpace(value) != "" {
|
||||
return []string{strings.ToLower(strings.TrimSpace(value))}
|
||||
}
|
||||
}
|
||||
for _, key := range []string{"aistudio_channels", "channel_ids", "relay_providers"} {
|
||||
if values, ok := stringSliceOption(options, key); ok && len(values) > 0 {
|
||||
out := make([]string, 0, len(values))
|
||||
seen := map[string]struct{}{}
|
||||
for _, value := range values {
|
||||
channelID := strings.ToLower(strings.TrimSpace(value))
|
||||
if channelID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[channelID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[channelID] = struct{}{}
|
||||
out = append(out, channelID)
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
if runtimeSelected := preferredAIStudioRelayChannels(); len(runtimeSelected) > 0 {
|
||||
return runtimeSelected
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -252,6 +253,7 @@ func (p *AntigravityProvider) baseURLs() []string {
|
||||
|
||||
func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any {
|
||||
request := map[string]any{}
|
||||
baseModel := strings.TrimSpace(qwenBaseModel(model))
|
||||
systemParts := make([]map[string]any, 0)
|
||||
contents := make([]map[string]any, 0, len(messages))
|
||||
callNames := map[string]string{}
|
||||
@@ -299,6 +301,16 @@ func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolD
|
||||
if gen := antigravityGenerationConfig(options); len(gen) > 0 {
|
||||
request["generationConfig"] = gen
|
||||
}
|
||||
if extra, ok := mapOption(options, "gemini_generation_config"); ok && len(extra) > 0 {
|
||||
gen := mapFromAny(request["generationConfig"])
|
||||
if gen == nil {
|
||||
gen = map[string]any{}
|
||||
}
|
||||
for k, v := range extra {
|
||||
gen[k] = v
|
||||
}
|
||||
request["generationConfig"] = gen
|
||||
}
|
||||
if toolDecls := antigravityToolDeclarations(tools); len(toolDecls) > 0 {
|
||||
request["tools"] = []map[string]any{{"function_declarations": toolDecls}}
|
||||
request["toolConfig"] = map[string]any{
|
||||
@@ -307,18 +319,19 @@ func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolD
|
||||
}
|
||||
projectID := ""
|
||||
if session != nil {
|
||||
projectID = firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"]))
|
||||
projectID = firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["projectId"]), asString(session.Token["project"]))
|
||||
}
|
||||
if projectID == "" {
|
||||
projectID = "default-project"
|
||||
}
|
||||
applyAntigravityThinkingSuffix(request, model)
|
||||
requestType := "agent"
|
||||
if strings.Contains(strings.ToLower(model), "image") {
|
||||
if strings.Contains(strings.ToLower(baseModel), "image") {
|
||||
requestType = "image_gen"
|
||||
}
|
||||
return map[string]any{
|
||||
"project": projectID,
|
||||
"model": strings.TrimSpace(model),
|
||||
"model": baseModel,
|
||||
"userAgent": "antigravity",
|
||||
"requestType": requestType,
|
||||
"requestId": "agent-" + randomSessionID(),
|
||||
@@ -454,6 +467,79 @@ func antigravityGenerationConfig(options map[string]any) map[string]any {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applyAntigravityThinkingSuffix(request map[string]any, model string) {
|
||||
suffix := qwenModelSuffix(model)
|
||||
if suffix == "" {
|
||||
return
|
||||
}
|
||||
baseModel := strings.TrimSpace(qwenBaseModel(model))
|
||||
gen := mapFromAny(request["generationConfig"])
|
||||
if gen == nil {
|
||||
gen = map[string]any{}
|
||||
}
|
||||
thinkingConfig := mapFromAny(gen["thinkingConfig"])
|
||||
if thinkingConfig == nil {
|
||||
thinkingConfig = map[string]any{}
|
||||
}
|
||||
includeThoughts, userSetIncludeThoughts := geminiExistingIncludeThoughts(thinkingConfig)
|
||||
delete(thinkingConfig, "thinkingBudget")
|
||||
delete(thinkingConfig, "thinking_budget")
|
||||
delete(thinkingConfig, "thinkingLevel")
|
||||
delete(thinkingConfig, "thinking_level")
|
||||
delete(thinkingConfig, "include_thoughts")
|
||||
|
||||
setIncludeThoughts := func(defaultValue bool, force bool) {
|
||||
if force || !userSetIncludeThoughts {
|
||||
includeThoughts = defaultValue
|
||||
}
|
||||
thinkingConfig["includeThoughts"] = includeThoughts
|
||||
}
|
||||
|
||||
lower := strings.ToLower(strings.TrimSpace(suffix))
|
||||
switch {
|
||||
case lower == "auto" || lower == "-1":
|
||||
thinkingConfig["thinkingBudget"] = -1
|
||||
setIncludeThoughts(true, false)
|
||||
case lower == "none" || lower == "0":
|
||||
if geminiUsesThinkingLevels(baseModel) {
|
||||
thinkingConfig["thinkingLevel"] = "low"
|
||||
} else {
|
||||
thinkingConfig["thinkingBudget"] = 128
|
||||
}
|
||||
setIncludeThoughts(false, true)
|
||||
case isGeminiThinkingLevel(lower):
|
||||
if geminiUsesThinkingLevels(baseModel) {
|
||||
thinkingConfig["thinkingLevel"] = normalizeGeminiThinkingLevel(lower)
|
||||
} else {
|
||||
thinkingConfig["thinkingBudget"] = geminiThinkingBudgetForLevel(lower)
|
||||
}
|
||||
setIncludeThoughts(true, false)
|
||||
default:
|
||||
if budget, err := strconv.Atoi(lower); err == nil {
|
||||
switch {
|
||||
case budget < 0:
|
||||
thinkingConfig["thinkingBudget"] = -1
|
||||
setIncludeThoughts(true, false)
|
||||
case budget == 0:
|
||||
if geminiUsesThinkingLevels(baseModel) {
|
||||
thinkingConfig["thinkingLevel"] = "low"
|
||||
} else {
|
||||
thinkingConfig["thinkingBudget"] = 128
|
||||
}
|
||||
setIncludeThoughts(false, true)
|
||||
default:
|
||||
thinkingConfig["thinkingBudget"] = budget
|
||||
setIncludeThoughts(true, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(thinkingConfig) == 0 {
|
||||
return
|
||||
}
|
||||
gen["thinkingConfig"] = thinkingConfig
|
||||
request["generationConfig"] = gen
|
||||
}
|
||||
|
||||
func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) {
|
||||
if onDelta == nil {
|
||||
onDelta = func(string) {}
|
||||
|
||||
@@ -2,6 +2,7 @@ package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
@@ -166,3 +167,54 @@ func TestAntigravityBaseURLsIncludeProdFallback(t *testing.T) {
|
||||
t.Fatalf("unexpected fallback order: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequestBodyAppliesThinkingSuffix(t *testing.T) {
|
||||
p := NewAntigravityProvider("antigravity", "", "", "gemini-3-pro", false, "oauth", 0, nil)
|
||||
body := p.buildRequestBody([]Message{{Role: "user", Content: "hello"}}, nil, "gemini-3-pro(high)", nil, &oauthSession{ProjectID: "demo-project"}, false)
|
||||
if got := body["model"]; got != "gemini-3-pro" {
|
||||
t.Fatalf("model = %#v, want gemini-3-pro", got)
|
||||
}
|
||||
request := mapFromAny(body["request"])
|
||||
gen := mapFromAny(request["generationConfig"])
|
||||
thinking := mapFromAny(gen["thinkingConfig"])
|
||||
if got := asString(thinking["thinkingLevel"]); got != "high" {
|
||||
t.Fatalf("thinkingLevel = %q, want high", got)
|
||||
}
|
||||
if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "true" {
|
||||
t.Fatalf("includeThoughts = %v, want true", thinking["includeThoughts"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequestBodyDisablesThinkingOutput(t *testing.T) {
|
||||
p := NewAntigravityProvider("antigravity", "", "", "gemini-2.5-pro", false, "oauth", 0, nil)
|
||||
body := p.buildRequestBody([]Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro(0)", nil, &oauthSession{ProjectID: "demo-project"}, false)
|
||||
request := mapFromAny(body["request"])
|
||||
gen := mapFromAny(request["generationConfig"])
|
||||
thinking := mapFromAny(gen["thinkingConfig"])
|
||||
if got := intValue(thinking["thinkingBudget"]); got != 128 {
|
||||
t.Fatalf("thinkingBudget = %v, want 128", thinking["thinkingBudget"])
|
||||
}
|
||||
if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" {
|
||||
t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityThinkingSuffixPreservesExplicitIncludeThoughts(t *testing.T) {
|
||||
p := NewAntigravityProvider("antigravity", "", "", "gemini-3-pro", false, "oauth", 0, nil)
|
||||
body := p.buildRequestBody([]Message{{Role: "user", Content: "hello"}}, nil, "gemini-3-pro(high)", map[string]interface{}{
|
||||
"gemini_generation_config": map[string]interface{}{
|
||||
"thinkingConfig": map[string]interface{}{
|
||||
"includeThoughts": false,
|
||||
},
|
||||
},
|
||||
}, &oauthSession{ProjectID: "demo-project"}, false)
|
||||
request := mapFromAny(body["request"])
|
||||
gen := mapFromAny(request["generationConfig"])
|
||||
thinking := mapFromAny(gen["thinkingConfig"])
|
||||
if got := asString(thinking["thinkingLevel"]); got != "high" {
|
||||
t.Fatalf("thinkingLevel = %q, want high", got)
|
||||
}
|
||||
if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" {
|
||||
t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -820,8 +820,15 @@ func applyCodexWebsocketHeaders(headers http.Header, attempt authAttempt, option
|
||||
headers.Set("User-Agent", codexCompatUserAgent)
|
||||
if attempt.kind != "api_key" {
|
||||
headers.Set("Originator", "codex_cli_rs")
|
||||
if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" {
|
||||
headers.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID))
|
||||
if attempt.session != nil {
|
||||
accountID := firstNonEmpty(
|
||||
strings.TrimSpace(attempt.session.AccountID),
|
||||
strings.TrimSpace(asString(attempt.session.Token["account_id"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["account-id"])),
|
||||
)
|
||||
if accountID != "" {
|
||||
headers.Set("Chatgpt-Account-Id", accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
|
||||
@@ -9,16 +9,18 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
geminiCLIBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
geminiCLIDefaultAlt = "sse"
|
||||
geminiCLIApiClient = "genai-cli/0 gl-go/1.0"
|
||||
geminiCLIBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
geminiCLIDefaultAlt = "sse"
|
||||
geminiCLIClientVersion = "0.31.0"
|
||||
geminiCLIApiClient = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
|
||||
)
|
||||
|
||||
type GeminiCLIProvider struct {
|
||||
@@ -220,7 +222,7 @@ func applyGeminiCLIAttemptAuth(req *http.Request, attempt authAttempt) error {
|
||||
}
|
||||
token := strings.TrimSpace(attempt.token)
|
||||
if attempt.session != nil {
|
||||
token = firstNonEmpty(strings.TrimSpace(attempt.session.AccessToken), token, asString(attempt.session.Token["access_token"]))
|
||||
token = firstNonEmpty(strings.TrimSpace(attempt.session.AccessToken), token, asString(attempt.session.Token["access_token"]), asString(attempt.session.Token["access-token"]))
|
||||
}
|
||||
if token == "" {
|
||||
return fmt.Errorf("missing access token for gemini-cli")
|
||||
@@ -263,26 +265,53 @@ func consumeGeminiCLIStream(resp *http.Response, onDelta func(string)) ([]byte,
|
||||
}
|
||||
|
||||
func geminiCLIProjectID(options map[string]interface{}, session *oauthSession) string {
|
||||
if value, ok := stringOption(options, "gemini_project_id"); ok {
|
||||
return value
|
||||
}
|
||||
if value, ok := stringOption(options, "project_id"); ok {
|
||||
return value
|
||||
for _, key := range []string{"gemini_project_id", "project_id", "project"} {
|
||||
if value, ok := stringOption(options, key); ok {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
if session == nil {
|
||||
return ""
|
||||
}
|
||||
return firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"]), asString(session.Token["project"]))
|
||||
return firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["projectId"]), asString(session.Token["project"]))
|
||||
}
|
||||
|
||||
func geminiCLIRuntimeOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return "win32"
|
||||
default:
|
||||
return runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
func geminiCLIRuntimeArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func geminiCLIUserAgent(model string) string {
|
||||
trimmedModel := strings.TrimSpace(model)
|
||||
if trimmedModel == "" {
|
||||
trimmedModel = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", geminiCLIClientVersion, trimmedModel, geminiCLIRuntimeOS(), geminiCLIRuntimeArch())
|
||||
}
|
||||
|
||||
func applyGeminiCLIHeaders(req *http.Request, model string) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(model) == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
req.Header.Set("User-Agent", "GeminiCLI/"+model)
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent(model))
|
||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package providers
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiBaseURL = "https://generativelanguage.googleapis.com"
|
||||
geminiAPIVersion = "v1beta"
|
||||
geminiBaseURL = "https://generativelanguage.googleapis.com"
|
||||
geminiAPIVersion = "v1beta"
|
||||
geminiImagePreviewModel = "gemini-2.5-flash-image-preview"
|
||||
)
|
||||
|
||||
@@ -277,43 +277,50 @@ func applyGeminiThinkingSuffix(request map[string]any, model string) {
|
||||
if thinkingConfig == nil {
|
||||
thinkingConfig = map[string]any{}
|
||||
}
|
||||
includeThoughts, userSetIncludeThoughts := geminiExistingIncludeThoughts(thinkingConfig)
|
||||
delete(thinkingConfig, "thinkingBudget")
|
||||
delete(thinkingConfig, "thinking_budget")
|
||||
delete(thinkingConfig, "thinkingLevel")
|
||||
delete(thinkingConfig, "thinking_level")
|
||||
delete(thinkingConfig, "include_thoughts")
|
||||
|
||||
setIncludeThoughts := func(defaultValue bool, force bool) {
|
||||
if force || !userSetIncludeThoughts {
|
||||
includeThoughts = defaultValue
|
||||
}
|
||||
thinkingConfig["includeThoughts"] = includeThoughts
|
||||
}
|
||||
|
||||
lower := strings.ToLower(strings.TrimSpace(suffix))
|
||||
switch {
|
||||
case lower == "auto" || lower == "-1":
|
||||
thinkingConfig["thinkingBudget"] = -1
|
||||
thinkingConfig["includeThoughts"] = true
|
||||
setIncludeThoughts(true, false)
|
||||
case lower == "none":
|
||||
if geminiUsesThinkingLevels(baseModel) {
|
||||
thinkingConfig["thinkingLevel"] = "low"
|
||||
} else {
|
||||
thinkingConfig["thinkingBudget"] = 128
|
||||
}
|
||||
thinkingConfig["includeThoughts"] = false
|
||||
setIncludeThoughts(false, true)
|
||||
case isGeminiThinkingLevel(lower):
|
||||
if geminiUsesThinkingLevels(baseModel) {
|
||||
thinkingConfig["thinkingLevel"] = normalizeGeminiThinkingLevel(lower)
|
||||
thinkingConfig["includeThoughts"] = true
|
||||
} else {
|
||||
thinkingConfig["thinkingBudget"] = geminiThinkingBudgetForLevel(lower)
|
||||
thinkingConfig["includeThoughts"] = true
|
||||
}
|
||||
setIncludeThoughts(true, false)
|
||||
default:
|
||||
if budget, err := strconv.Atoi(lower); err == nil {
|
||||
if budget < 0 {
|
||||
thinkingConfig["thinkingBudget"] = -1
|
||||
thinkingConfig["includeThoughts"] = true
|
||||
setIncludeThoughts(true, false)
|
||||
} else if budget == 0 {
|
||||
thinkingConfig["thinkingBudget"] = 128
|
||||
thinkingConfig["includeThoughts"] = false
|
||||
setIncludeThoughts(false, true)
|
||||
} else {
|
||||
thinkingConfig["thinkingBudget"] = budget
|
||||
thinkingConfig["includeThoughts"] = true
|
||||
setIncludeThoughts(true, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -324,6 +331,38 @@ func applyGeminiThinkingSuffix(request map[string]any, model string) {
|
||||
request["generationConfig"] = gen
|
||||
}
|
||||
|
||||
func geminiExistingIncludeThoughts(thinkingConfig map[string]any) (bool, bool) {
|
||||
if thinkingConfig == nil {
|
||||
return false, false
|
||||
}
|
||||
if value, ok := thinkingConfig["includeThoughts"]; ok {
|
||||
return geminiBoolValue(value), true
|
||||
}
|
||||
if value, ok := thinkingConfig["include_thoughts"]; ok {
|
||||
return geminiBoolValue(value), true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func geminiBoolValue(value any) bool {
|
||||
switch typed := value.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(typed)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
}
|
||||
case int:
|
||||
return typed != 0
|
||||
case int64:
|
||||
return typed != 0
|
||||
case float64:
|
||||
return typed != 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func geminiUsesThinkingLevels(model string) bool {
|
||||
trimmed := strings.ToLower(strings.TrimSpace(model))
|
||||
return strings.Contains(trimmed, "gemini-3")
|
||||
@@ -490,10 +529,16 @@ func geminiBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string {
|
||||
return normalizeGeminiBaseURL(raw)
|
||||
}
|
||||
if attempt.session.Token != nil {
|
||||
if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" {
|
||||
if raw := firstNonEmpty(
|
||||
strings.TrimSpace(asString(attempt.session.Token["base_url"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["base-url"])),
|
||||
); raw != "" {
|
||||
return normalizeGeminiBaseURL(raw)
|
||||
}
|
||||
if raw := strings.TrimSpace(asString(attempt.session.Token["resource_url"])); raw != "" {
|
||||
if raw := firstNonEmpty(
|
||||
strings.TrimSpace(asString(attempt.session.Token["resource_url"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["resource-url"])),
|
||||
); raw != "" {
|
||||
return normalizeGeminiBaseURL(raw)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,3 +292,22 @@ func TestCreateProviderByNameRoutesAIStudioProviderViaGeminiTests(t *testing.T)
|
||||
t.Fatalf("provider = %T, want *AistudioProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiThinkingSuffixPreservesExplicitIncludeThoughts(t *testing.T) {
|
||||
p := NewGeminiProvider("gemini", "", "", "gemini-3-pro-preview", false, "api_key", 5*time.Second, nil)
|
||||
body := p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-3-pro-preview(high)", map[string]interface{}{
|
||||
"gemini_generation_config": map[string]interface{}{
|
||||
"thinkingConfig": map[string]interface{}{
|
||||
"includeThoughts": false,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
gen := mapFromAny(body["generationConfig"])
|
||||
thinking := mapFromAny(gen["thinkingConfig"])
|
||||
if got := asString(thinking["thinkingLevel"]); got != "high" {
|
||||
t.Fatalf("thinkingLevel = %q, want high", got)
|
||||
}
|
||||
if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" {
|
||||
t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2344,7 +2344,7 @@ func (p *HTTPProvider) compatBase() string {
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) compatModel(model string) string {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
trimmed := strings.TrimSpace(qwenBaseModel(model))
|
||||
if p.oauthProvider() == defaultKimiOAuthProvider && strings.HasPrefix(strings.ToLower(trimmed), "kimi-") {
|
||||
return trimmed[5:]
|
||||
}
|
||||
@@ -2356,6 +2356,9 @@ func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools []
|
||||
"model": p.compatModel(model),
|
||||
"messages": openAICompatMessages(messages),
|
||||
}
|
||||
if suffix := qwenModelSuffix(model); suffix != "" {
|
||||
applyOpenAICompatThinkingSuffix(requestBody, suffix)
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = openAICompatTools(tools)
|
||||
requestBody["tool_choice"] = "auto"
|
||||
@@ -2680,6 +2683,29 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(ctx context.Context, mode
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func normalizeProviderRouteName(name string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(name)) {
|
||||
case "geminicli", "gemini_cli":
|
||||
return "gemini-cli"
|
||||
case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio":
|
||||
return "aistudio"
|
||||
case "google", "gemini-api-key", "gemini_api_key":
|
||||
return "gemini"
|
||||
case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key":
|
||||
return "claude"
|
||||
case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat":
|
||||
return "openai-compatibility"
|
||||
case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility":
|
||||
return "vertex"
|
||||
case "codex-api-key", "codex_api_key":
|
||||
return "codex"
|
||||
case "i-flow", "i_flow":
|
||||
return "iflow"
|
||||
default:
|
||||
return strings.TrimSpace(name)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
name := config.PrimaryProviderName(cfg)
|
||||
provider, err := CreateProviderByName(cfg, name)
|
||||
@@ -2694,18 +2720,32 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
}
|
||||
|
||||
func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) {
|
||||
pc, err := getProviderConfigByName(cfg, name)
|
||||
routeName := normalizeProviderRouteName(name)
|
||||
pc, err := getProviderConfigByName(cfg, routeName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ConfigureProviderRuntime(name, pc)
|
||||
oauthProvider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider))
|
||||
ConfigureProviderRuntime(routeName, pc)
|
||||
oauthProvider := normalizeOAuthProvider(pc.OAuth.Provider)
|
||||
if pc.APIBase == "" &&
|
||||
oauthProvider != defaultAntigravityOAuthProvider &&
|
||||
oauthProvider != defaultGeminiOAuthProvider &&
|
||||
!strings.EqualFold(name, "gemini-cli") &&
|
||||
!strings.EqualFold(name, "aistudio") &&
|
||||
!strings.EqualFold(name, "vertex") {
|
||||
oauthProvider != "aistudio" &&
|
||||
oauthProvider != defaultCodexOAuthProvider &&
|
||||
oauthProvider != defaultClaudeOAuthProvider &&
|
||||
oauthProvider != defaultQwenOAuthProvider &&
|
||||
oauthProvider != defaultKimiOAuthProvider &&
|
||||
oauthProvider != defaultIFlowOAuthProvider &&
|
||||
!strings.EqualFold(routeName, "gemini-cli") &&
|
||||
!strings.EqualFold(routeName, "aistudio") &&
|
||||
!strings.EqualFold(routeName, "vertex") &&
|
||||
!strings.EqualFold(routeName, defaultAntigravityOAuthProvider) &&
|
||||
!strings.EqualFold(routeName, defaultGeminiOAuthProvider) &&
|
||||
!strings.EqualFold(routeName, defaultCodexOAuthProvider) &&
|
||||
!strings.EqualFold(routeName, defaultClaudeOAuthProvider) &&
|
||||
!strings.EqualFold(routeName, defaultQwenOAuthProvider) &&
|
||||
!strings.EqualFold(routeName, defaultKimiOAuthProvider) &&
|
||||
!strings.EqualFold(routeName, defaultIFlowOAuthProvider) {
|
||||
return nil, fmt.Errorf("no API base configured for provider %q", name)
|
||||
}
|
||||
if pc.TimeoutSec <= 0 {
|
||||
@@ -2722,37 +2762,37 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if oauthProvider == defaultAntigravityOAuthProvider {
|
||||
return NewAntigravityProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == defaultAntigravityOAuthProvider || strings.EqualFold(routeName, defaultAntigravityOAuthProvider) {
|
||||
return NewAntigravityProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if strings.EqualFold(name, "aistudio") {
|
||||
return NewAistudioProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == "aistudio" || strings.EqualFold(routeName, "aistudio") {
|
||||
return NewAistudioProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if strings.EqualFold(name, "gemini-cli") {
|
||||
return NewGeminiCLIProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if strings.EqualFold(routeName, "gemini-cli") {
|
||||
return NewGeminiCLIProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultGeminiOAuthProvider || strings.EqualFold(name, defaultGeminiOAuthProvider) || strings.EqualFold(name, "aistudio") {
|
||||
return NewGeminiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == defaultGeminiOAuthProvider || strings.EqualFold(routeName, defaultGeminiOAuthProvider) || strings.EqualFold(routeName, "aistudio") {
|
||||
return NewGeminiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if strings.EqualFold(name, "vertex") {
|
||||
return NewVertexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if strings.EqualFold(routeName, "vertex") {
|
||||
return NewVertexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultCodexOAuthProvider {
|
||||
return NewCodexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == defaultCodexOAuthProvider || strings.EqualFold(routeName, defaultCodexOAuthProvider) {
|
||||
return NewCodexProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultClaudeOAuthProvider {
|
||||
return NewClaudeProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == defaultClaudeOAuthProvider || strings.EqualFold(routeName, defaultClaudeOAuthProvider) {
|
||||
return NewClaudeProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultQwenOAuthProvider {
|
||||
return NewQwenProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == defaultQwenOAuthProvider || strings.EqualFold(routeName, defaultQwenOAuthProvider) {
|
||||
return NewQwenProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultKimiOAuthProvider {
|
||||
return NewKimiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == defaultKimiOAuthProvider || strings.EqualFold(routeName, defaultKimiOAuthProvider) {
|
||||
return NewKimiProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(name, defaultIFlowOAuthProvider) {
|
||||
return NewIFlowProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(routeName, defaultIFlowOAuthProvider) {
|
||||
return NewIFlowProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
return NewHTTPProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
return NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
|
||||
func CreateProviders(cfg *config.Config) (map[string]LLMProvider, error) {
|
||||
|
||||
@@ -275,10 +275,16 @@ func iflowBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string {
|
||||
return normalizeIFlowBaseURL(raw)
|
||||
}
|
||||
if attempt.session.Token != nil {
|
||||
if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" {
|
||||
if raw := firstNonEmpty(
|
||||
strings.TrimSpace(asString(attempt.session.Token["base_url"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["base-url"])),
|
||||
); raw != "" {
|
||||
return normalizeIFlowBaseURL(raw)
|
||||
}
|
||||
if raw := strings.TrimSpace(asString(attempt.session.Token["resource_url"])); raw != "" {
|
||||
if raw := firstNonEmpty(
|
||||
strings.TrimSpace(asString(attempt.session.Token["resource_url"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["resource-url"])),
|
||||
); raw != "" {
|
||||
return normalizeIFlowBaseURL(raw)
|
||||
}
|
||||
}
|
||||
@@ -309,10 +315,11 @@ func normalizeIFlowBaseURL(raw string) string {
|
||||
|
||||
func iflowAttemptAPIKey(attempt authAttempt) string {
|
||||
if attempt.session != nil && attempt.session.Token != nil {
|
||||
if v := strings.TrimSpace(asString(attempt.session.Token["api_key"])); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(asString(attempt.session.Token["apiKey"])); v != "" {
|
||||
if v := firstNonEmpty(
|
||||
strings.TrimSpace(asString(attempt.session.Token["api_key"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["api-key"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["apiKey"])),
|
||||
); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,37 +119,7 @@ func applyKimiThinking(body map[string]interface{}, model string) {
|
||||
if suffix == "" {
|
||||
return
|
||||
}
|
||||
suffix = strings.ToLower(strings.TrimSpace(suffix))
|
||||
switch suffix {
|
||||
case "low", "medium", "high", "auto":
|
||||
body["reasoning_effort"] = suffix
|
||||
delete(body, "thinking")
|
||||
case "none":
|
||||
delete(body, "reasoning_effort")
|
||||
body["thinking"] = map[string]interface{}{"type": "disabled"}
|
||||
default:
|
||||
if budget, err := parsePositiveInt(suffix); err == nil && budget > 0 {
|
||||
delete(body, "reasoning_effort")
|
||||
body["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parsePositiveInt(raw string) (int, error) {
|
||||
var out int
|
||||
for _, ch := range raw {
|
||||
if ch < '0' || ch > '9' {
|
||||
return 0, fmt.Errorf("non-digit")
|
||||
}
|
||||
out = out*10 + int(ch-'0')
|
||||
}
|
||||
if out <= 0 {
|
||||
return 0, fmt.Errorf("not positive")
|
||||
}
|
||||
return out, nil
|
||||
_ = applyOpenAICompatThinkingSuffix(body, suffix)
|
||||
}
|
||||
|
||||
func normalizeKimiToolMessages(body map[string]interface{}) {
|
||||
|
||||
@@ -143,3 +143,16 @@ func TestKimiProviderCountTokens(t *testing.T) {
|
||||
t.Fatalf("usage = %#v, want positive prompt-only count", usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKimiChatRequestSupportsNumericAutoAndDisable(t *testing.T) {
|
||||
base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil)
|
||||
autoBody := buildKimiChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(-1)", nil, false)
|
||||
if got := autoBody["reasoning_effort"]; got != "auto" {
|
||||
t.Fatalf("reasoning_effort = %#v, want auto", got)
|
||||
}
|
||||
disableBody := buildKimiChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(0)", nil, false)
|
||||
thinking, _ := disableBody["thinking"].(map[string]interface{})
|
||||
if got := thinking["type"]; got != "disabled" {
|
||||
t.Fatalf("thinking.type = %#v, want disabled", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -692,10 +692,20 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) {
|
||||
|
||||
func normalizeOAuthProvider(provider string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic":
|
||||
case "anthropic", "claude-code", "claude_code", "claude-api-key", "claude_api_key":
|
||||
return defaultClaudeOAuthProvider
|
||||
case "gemini-cli":
|
||||
case "gemini-cli", "geminicli", "gemini_cli", "google", "gemini-api-key", "gemini_api_key":
|
||||
return defaultGeminiOAuthProvider
|
||||
case "aistudio", "ai-studio", "ai_studio", "google-ai-studio", "google_ai_studio", "googleaistudio":
|
||||
return "aistudio"
|
||||
case "openai-compatibility", "openai_compatibility", "openai-compat", "openai_compat":
|
||||
return "openai-compatibility"
|
||||
case "vertex-api-key", "vertex_api_key", "vertex-compat", "vertex_compat", "vertex-compatibility", "vertex_compatibility":
|
||||
return "vertex"
|
||||
case "codex-api-key", "codex_api_key":
|
||||
return defaultCodexOAuthProvider
|
||||
case "i-flow", "i_flow":
|
||||
return defaultIFlowOAuthProvider
|
||||
default:
|
||||
return strings.ToLower(strings.TrimSpace(provider))
|
||||
}
|
||||
|
||||
@@ -157,3 +157,26 @@ func TestOpenAICompatMessagesPreserveMultimodalContentParts(t *testing.T) {
|
||||
t.Fatalf("image detail = %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatChatRequestAppliesThinkingSuffix(t *testing.T) {
|
||||
base := NewHTTPProvider("openai", "token", "https://example.com/v1", "gpt-5", false, "api_key", 5*time.Second, nil)
|
||||
body := base.buildOpenAICompatChatRequest([]Message{{Role: "user", Content: "hi"}}, nil, "gpt-5(high)", nil)
|
||||
if got := body["model"]; got != "gpt-5" {
|
||||
t.Fatalf("model = %#v, want gpt-5", got)
|
||||
}
|
||||
if got := body["reasoning_effort"]; got != "high" {
|
||||
t.Fatalf("reasoning_effort = %#v, want high", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatChatRequestStripsKimiPrefixAndSuffix(t *testing.T) {
|
||||
base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil)
|
||||
base.oauth = &oauthManager{cfg: oauthConfig{Provider: defaultKimiOAuthProvider}}
|
||||
body := base.buildOpenAICompatChatRequest([]Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(-1)", nil)
|
||||
if got := body["model"]; got != "k2.5" {
|
||||
t.Fatalf("model = %#v, want k2.5", got)
|
||||
}
|
||||
if got := body["reasoning_effort"]; got != "auto" {
|
||||
t.Fatalf("reasoning_effort = %#v, want auto", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,19 +164,58 @@ func applyQwenThinkingSuffix(body map[string]interface{}, suffix string) {
|
||||
if suffix == "" {
|
||||
return
|
||||
}
|
||||
switch suffix {
|
||||
case "low", "medium", "high", "auto":
|
||||
body["reasoning_effort"] = suffix
|
||||
case "none":
|
||||
if applyOpenAICompatThinkingSuffix(body, suffix) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func applyOpenAICompatThinkingSuffix(body map[string]interface{}, suffix string) bool {
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
normalizedLevel, isLevel := normalizeOpenAICompatThinkingLevel(suffix)
|
||||
switch {
|
||||
case isLevel:
|
||||
delete(body, "thinking")
|
||||
body["reasoning_effort"] = normalizedLevel
|
||||
return true
|
||||
case strings.EqualFold(strings.TrimSpace(suffix), "none"):
|
||||
delete(body, "reasoning_effort")
|
||||
body["thinking"] = map[string]interface{}{"type": "disabled"}
|
||||
return true
|
||||
default:
|
||||
if n, err := strconv.Atoi(suffix); err == nil && n > 0 {
|
||||
n, err := strconv.Atoi(strings.TrimSpace(suffix))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case n < 0:
|
||||
delete(body, "thinking")
|
||||
body["reasoning_effort"] = "auto"
|
||||
case n == 0:
|
||||
delete(body, "reasoning_effort")
|
||||
body["thinking"] = map[string]interface{}{"type": "disabled"}
|
||||
default:
|
||||
delete(body, "reasoning_effort")
|
||||
body["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
"budget_tokens": n,
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAICompatThinkingLevel(raw string) (string, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "minimal":
|
||||
return "low", true
|
||||
case "low", "medium", "high", "auto":
|
||||
return strings.ToLower(strings.TrimSpace(raw)), true
|
||||
case "xhigh", "max":
|
||||
return "high", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
23
pkg/providers/qwen_provider_test.go
Normal file
23
pkg/providers/qwen_provider_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBuildQwenChatRequestSupportsExtendedThinkingSuffixes(t *testing.T) {
|
||||
base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil)
|
||||
autoBody := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(-1)", nil, false)
|
||||
if got := autoBody["reasoning_effort"]; got != "auto" {
|
||||
t.Fatalf("reasoning_effort = %#v, want auto", got)
|
||||
}
|
||||
disableBody := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(0)", nil, false)
|
||||
thinking, _ := disableBody["thinking"].(map[string]interface{})
|
||||
if got := thinking["type"]; got != "disabled" {
|
||||
t.Fatalf("thinking.type = %#v, want disabled", got)
|
||||
}
|
||||
minimalBody := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(minimal)", nil, false)
|
||||
if got := minimalBody["reasoning_effort"]; got != "low" {
|
||||
t.Fatalf("reasoning_effort = %#v, want low", got)
|
||||
}
|
||||
}
|
||||
@@ -226,7 +226,10 @@ func (p *VertexProvider) endpoint(attempt authAttempt, model, action string, str
|
||||
func vertexBaseURLForAttempt(base *HTTPProvider, attempt authAttempt, options map[string]interface{}) string {
|
||||
customBase := ""
|
||||
if attempt.session != nil && attempt.session.Token != nil {
|
||||
if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" {
|
||||
if raw := firstNonEmpty(
|
||||
strings.TrimSpace(asString(attempt.session.Token["base_url"])),
|
||||
strings.TrimSpace(asString(attempt.session.Token["base-url"])),
|
||||
); raw != "" {
|
||||
customBase = normalizeVertexBaseURL(raw)
|
||||
}
|
||||
}
|
||||
@@ -256,11 +259,16 @@ func normalizeVertexBaseURL(raw string) string {
|
||||
|
||||
func vertexProjectLocation(attempt authAttempt, options map[string]interface{}) (string, string, bool) {
|
||||
projectID := ""
|
||||
if value, ok := stringOption(options, "vertex_project_id"); ok {
|
||||
projectID = strings.TrimSpace(value)
|
||||
for _, key := range []string{"vertex_project_id", "project_id", "project"} {
|
||||
if value, ok := stringOption(options, key); ok {
|
||||
projectID = strings.TrimSpace(value)
|
||||
if projectID != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if attempt.session != nil {
|
||||
projectID = firstNonEmpty(projectID, strings.TrimSpace(attempt.session.ProjectID), asString(attempt.session.Token["project_id"]), asString(attempt.session.Token["projectId"]), asString(attempt.session.Token["project"]))
|
||||
projectID = firstNonEmpty(projectID, strings.TrimSpace(attempt.session.ProjectID), asString(attempt.session.Token["project_id"]), asString(attempt.session.Token["project-id"]), asString(attempt.session.Token["projectId"]), asString(attempt.session.Token["project"]))
|
||||
if projectID == "" {
|
||||
projectID = strings.TrimSpace(asString(mapFromAny(attempt.session.Token["service_account"])["project_id"]))
|
||||
}
|
||||
@@ -274,11 +282,16 @@ func vertexProjectLocation(attempt authAttempt, options map[string]interface{})
|
||||
|
||||
func vertexLocationForAttempt(attempt authAttempt, options map[string]interface{}) string {
|
||||
location := ""
|
||||
if value, ok := stringOption(options, "vertex_location"); ok {
|
||||
location = strings.TrimSpace(value)
|
||||
for _, key := range []string{"vertex_location", "location", "region"} {
|
||||
if value, ok := stringOption(options, key); ok {
|
||||
location = strings.TrimSpace(value)
|
||||
if location != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if attempt.session != nil {
|
||||
location = firstNonEmpty(location, asString(attempt.session.Token["location"]), asString(mapFromAny(attempt.session.Token["service_account"])["location"]))
|
||||
location = firstNonEmpty(location, asString(attempt.session.Token["location"]), asString(attempt.session.Token["region"]), asString(mapFromAny(attempt.session.Token["service_account"])["location"]))
|
||||
}
|
||||
if strings.TrimSpace(location) == "" {
|
||||
location = vertexDefaultRegion
|
||||
@@ -501,7 +514,7 @@ func vertexServiceAccountJSON(session *oauthSession) ([]byte, error) {
|
||||
return nil, fmt.Errorf("vertex service account missing")
|
||||
}
|
||||
raw := mapFromAny(session.Token["service_account"])
|
||||
if projectID := firstNonEmpty(asString(raw["project_id"]), strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project"])); projectID != "" {
|
||||
if projectID := firstNonEmpty(asString(raw["project_id"]), asString(raw["project-id"]), strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["project"])); projectID != "" {
|
||||
raw["project_id"] = projectID
|
||||
}
|
||||
data, err := json.Marshal(raw)
|
||||
|
||||
Reference in New Issue
Block a user