Files
clawgo/pkg/providers/oauth_test.go
2026-03-13 12:41:27 +08:00

2172 lines
70 KiB
Go

package providers
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/YspCoder/clawgo/pkg/config"
)
func TestHTTPProviderOAuthRefreshesExpiredSession(t *testing.T) {
t.Parallel()
var refreshCalls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/token":
atomic.AddInt32(&refreshCalls, 1)
if err := r.ParseForm(); err != nil {
t.Fatalf("parse token form failed: %v", err)
}
if got := r.Form.Get("grant_type"); got != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", got)
}
if got := r.Form.Get("refresh_token"); got != "refresh-token" {
t.Fatalf("unexpected refresh_token: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"fresh-token","refresh_token":"refresh-token","expires_in":3600}`))
case "/v1/responses":
if got := r.Header.Get("Authorization"); got != "Bearer fresh-token" {
t.Fatalf("unexpected authorization header: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
credFile := filepath.Join(t.TempDir(), "codex.json")
initial := oauthSession{
Provider: "codex",
AccessToken: "expired-token",
RefreshToken: "refresh-token",
Expire: time.Now().Add(-time.Hour).Format(time.RFC3339),
}
raw, err := json.Marshal(initial)
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write credential file failed: %v", err)
}
pc := config.ProviderConfig{
APIBase: server.URL + "/v1",
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
ClientID: "test-client",
TokenURL: server.URL + "/oauth/token",
AuthURL: server.URL + "/oauth/authorize",
},
}
oauth, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-oauth-refresh", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
}
if resp.Content != "ok" {
t.Fatalf("unexpected chat content: %q", resp.Content)
}
if got := atomic.LoadInt32(&refreshCalls); got != 1 {
t.Fatalf("expected exactly one refresh call, got %d", got)
}
savedRaw, err := os.ReadFile(credFile)
if err != nil {
t.Fatalf("read refreshed credential file failed: %v", err)
}
if !strings.Contains(string(savedRaw), "fresh-token") {
t.Fatalf("expected refreshed token to be persisted, got %s", string(savedRaw))
}
}
func TestOAuthLoginManualCallbackURLParse(t *testing.T) {
t.Parallel()
result, err := waitForOAuthCodeManual(
"https://example.com/auth?state=test-state",
bytes.NewBufferString("http://localhost:1455/auth/callback?code=auth-code&state=test-state\n"),
)
if err != nil {
t.Fatalf("manual callback parse failed: %v", err)
}
if result.Code != "auth-code" {
t.Fatalf("unexpected auth code: %s", result.Code)
}
if result.State != "test-state" {
t.Fatalf("unexpected state: %s", result.State)
}
}
func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) {
t.Parallel()
dir := t.TempDir()
firstFile := filepath.Join(dir, "first.json")
secondFile := filepath.Join(dir, "second.json")
writeSession := func(path, token, email string) {
t.Helper()
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: token,
Email: email,
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(path, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
}
writeSession(firstFile, "token-a", "a@example.com")
writeSession(secondFile, "token-b", "b@example.com")
var tokenAUsed int32
var tokenBUsed int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/responses" {
http.NotFound(w, r)
return
}
switch r.Header.Get("Authorization") {
case "Bearer token-a":
atomic.AddInt32(&tokenAUsed, 1)
w.WriteHeader(http.StatusTooManyRequests)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"quota exceeded"}}`))
case "Bearer token-b":
atomic.AddInt32(&tokenBUsed, 1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-second"}`))
default:
t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization"))
}
}))
defer server.Close()
pc := config.ProviderConfig{
APIBase: server.URL + "/v1",
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: firstFile,
CredentialFiles: []string{firstFile, secondFile},
ClientID: "test-client",
TokenURL: server.URL + "/oauth/token",
AuthURL: server.URL + "/oauth/authorize",
},
}
oauth, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-oauth-quota", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
}
if resp.Content != "ok-from-second" {
t.Fatalf("unexpected response content: %q", resp.Content)
}
if atomic.LoadInt32(&tokenAUsed) != 1 || atomic.LoadInt32(&tokenBUsed) != 1 {
t.Fatalf("expected one attempt per token, got token-a=%d token-b=%d", tokenAUsed, tokenBUsed)
}
}
func TestOAuthManagerPreRefreshesExpiringSession(t *testing.T) {
t.Parallel()
var refreshCalls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.NotFound(w, r)
return
}
atomic.AddInt32(&refreshCalls, 1)
if err := r.ParseForm(); err != nil {
t.Fatalf("parse token form failed: %v", err)
}
if got := r.Form.Get("grant_type"); got != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"prefreshed-token","refresh_token":"refresh-token","expires_in":3600}`))
}))
defer server.Close()
credFile := filepath.Join(t.TempDir(), "prefresh.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "old-token",
RefreshToken: "refresh-token",
Expire: time.Now().Add(2 * time.Minute).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
originalAPI := defaultAntigravityAPIEndpoint
originalUserInfo := defaultAntigravityUserInfoURL
defaultAntigravityAPIEndpoint = server.URL
defaultAntigravityUserInfoURL = server.URL + "/userinfo"
t.Cleanup(func() {
defaultAntigravityAPIEndpoint = originalAPI
defaultAntigravityUserInfoURL = originalUserInfo
})
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
ClientID: "test-client",
TokenURL: server.URL + "/oauth/token",
AuthURL: server.URL + "/oauth/authorize",
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer manager.bgCancel()
result, err := manager.refreshExpiringSessions(context.Background(), 10*time.Minute)
if err != nil {
t.Fatalf("pre-refresh failed: %v", err)
}
if result == nil || result.Refreshed != 1 {
t.Fatalf("expected one refreshed account, got %#v", result)
}
if atomic.LoadInt32(&refreshCalls) != 1 {
t.Fatalf("expected one refresh call, got %d", refreshCalls)
}
saved, err := os.ReadFile(credFile)
if err != nil {
t.Fatalf("read saved session failed: %v", err)
}
if !strings.Contains(string(saved), "prefreshed-token") {
t.Fatalf("expected prefreshed token in file, got %s", string(saved))
}
}
func TestResolveOAuthConfigSupportsAdditionalProviders(t *testing.T) {
t.Parallel()
cases := []struct {
name string
provider string
want string
flow string
}{
{name: "anthropic-alias", provider: "anthropic", want: "claude", flow: oauthFlowCallback},
{name: "antigravity", provider: "antigravity", want: "antigravity", flow: oauthFlowCallback},
{name: "gemini", provider: "gemini", want: "gemini", flow: oauthFlowCallback},
{name: "kimi", provider: "kimi", want: "kimi", flow: oauthFlowDevice},
{name: "qwen", provider: "qwen", want: "qwen", flow: oauthFlowDevice},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
cfg, err := resolveOAuthConfig(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: tc.provider,
},
})
if err != nil {
t.Fatalf("resolve oauth config failed: %v", err)
}
if cfg.Provider != tc.want {
t.Fatalf("unexpected provider: %s", cfg.Provider)
}
if cfg.FlowKind != tc.flow {
t.Fatalf("unexpected flow kind: %s", cfg.FlowKind)
}
})
}
}
func TestNewOAuthManagerUsesAnthropicTransportForClaude(t *testing.T) {
t.Parallel()
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: "anthropic",
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer manager.bgCancel()
if _, ok := manager.httpClient.Transport.(*anthropicOAuthRoundTripper); !ok {
t.Fatalf("expected anthropic oauth transport, got %T", manager.httpClient.Transport)
}
}
func TestNewOAuthManagerUsesDefaultTransportForNonClaude(t *testing.T) {
t.Parallel()
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer manager.bgCancel()
if manager.httpClient.Transport != nil {
t.Fatalf("expected default transport for non-claude provider, got %T", manager.httpClient.Transport)
}
}
func TestResolveOAuthConfigAppliesProviderRefreshLeadDefaults(t *testing.T) {
t.Parallel()
cases := []struct {
provider string
want time.Duration
}{
{provider: "codex", want: 5 * 24 * time.Hour},
{provider: "anthropic", want: 4 * time.Hour},
{provider: "antigravity", want: 5 * time.Minute},
{provider: "gemini", want: 30 * time.Minute},
{provider: "kimi", want: 5 * time.Minute},
{provider: "qwen", want: 3 * time.Hour},
}
for _, tc := range cases {
tc := tc
t.Run(tc.provider, func(t *testing.T) {
t.Parallel()
cfg, err := resolveOAuthConfig(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: tc.provider,
},
})
if err != nil {
t.Fatalf("resolve oauth config failed: %v", err)
}
if cfg.RefreshLead != tc.want {
t.Fatalf("unexpected refresh lead for %s: got %v want %v", tc.provider, cfg.RefreshLead, tc.want)
}
})
}
}
func TestResolveOAuthConfigUsesBuiltInGeminiClientDefaults(t *testing.T) {
t.Parallel()
cfg, err := resolveOAuthConfig(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: "gemini",
},
})
if err != nil {
t.Fatalf("resolve oauth config failed: %v", err)
}
if cfg.ClientID != defaultGeminiClientIDValue {
t.Fatalf("unexpected gemini client id: %q", cfg.ClientID)
}
if cfg.ClientSecret != defaultGeminiClientSecretValue {
t.Fatalf("unexpected gemini client secret: %q", cfg.ClientSecret)
}
}
func TestResolveOAuthConfigUsesBuiltInAntigravityClientDefaults(t *testing.T) {
t.Parallel()
cfg, err := resolveOAuthConfig(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: "antigravity",
},
})
if err != nil {
t.Fatalf("resolve oauth config failed: %v", err)
}
if cfg.ClientID != defaultAntigravityClientIDValue {
t.Fatalf("unexpected antigravity client id: %q", cfg.ClientID)
}
if cfg.ClientSecret != defaultAntigravityClientSecretValue {
t.Fatalf("unexpected antigravity client secret: %q", cfg.ClientSecret)
}
}
func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) {
t.Parallel()
var refreshCalls int32
var responseCalls int32
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/token":
atomic.AddInt32(&refreshCalls, 1)
if err := r.ParseForm(); err != nil {
t.Fatalf("parse token form failed: %v", err)
}
if got := r.Form.Get("grant_type"); got != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"proxied-fresh-token","refresh_token":"refresh-token","expires_in":3600}`))
case "/v1/responses":
atomic.AddInt32(&responseCalls, 1)
if got := r.Header.Get("Authorization"); got != "Bearer proxied-fresh-token" {
t.Fatalf("unexpected authorization header: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-via-proxy"}`))
default:
http.NotFound(w, r)
}
}))
defer target.Close()
var proxyCalls int32
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&proxyCalls, 1)
targetURL := r.URL.String()
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
targetURL = target.URL + r.URL.Path
if rawQuery := strings.TrimSpace(r.URL.RawQuery); rawQuery != "" {
targetURL += "?" + rawQuery
}
}
req, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
if err != nil {
t.Fatalf("create proxied request failed: %v", err)
}
req.Header = r.Header.Clone()
resp, err := http.DefaultTransport.RoundTrip(req)
if err != nil {
t.Fatalf("proxy round trip failed: %v", err)
}
defer resp.Body.Close()
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}))
defer proxyServer.Close()
credFile := filepath.Join(t.TempDir(), "proxied.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "expired-token",
RefreshToken: "refresh-token",
Expire: time.Now().Add(-time.Hour).Format(time.RFC3339),
NetworkProxy: proxyServer.URL,
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write credential file failed: %v", err)
}
pc := config.ProviderConfig{
APIBase: target.URL + "/v1",
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
ClientID: "test-client",
TokenURL: target.URL + "/oauth/token",
AuthURL: target.URL + "/oauth/authorize",
},
}
oauth, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer oauth.bgCancel()
provider := NewHTTPProvider("test-oauth-proxy", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
}
if resp.Content != "ok-via-proxy" {
t.Fatalf("unexpected response content: %q", resp.Content)
}
if atomic.LoadInt32(&refreshCalls) != 1 {
t.Fatalf("expected one refresh call, got %d", refreshCalls)
}
if atomic.LoadInt32(&responseCalls) != 1 {
t.Fatalf("expected one response call, got %d", responseCalls)
}
if got := atomic.LoadInt32(&proxyCalls); got < 2 {
t.Fatalf("expected proxy to receive refresh and response requests, got %d", got)
}
}
func TestOAuthImportGeminiNestedTokenRefreshesWithTokenMetadata(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/token":
if err := r.ParseForm(); err != nil {
t.Fatalf("parse form failed: %v", err)
}
if got := r.Form.Get("client_secret"); got != "secret-1" {
t.Fatalf("unexpected client_secret: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"gemini-fresh","refresh_token":"gemini-refresh","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"gemini@example.com"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: "gemini",
TokenURL: server.URL + "/oauth/token",
ClientID: "client-1",
ClientSecret: "secret-1",
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
manager.cfg.UserInfoURL = server.URL + "/userinfo"
raw := []byte(`{
"type": "gemini",
"email": "gemini@example.com",
"project_id": "demo-project",
"token": {
"refresh_token": "gemini-refresh",
"client_id": "client-1",
"client_secret": "secret-1",
"token_uri": "` + server.URL + `/oauth/token"
}
}`)
session, err := parseImportedOAuthSession("gemini", "gemini.json", raw)
if err != nil {
t.Fatalf("parse imported oauth session failed: %v", err)
}
refreshed, err := manager.refreshImportedSession(context.Background(), session)
if err != nil {
t.Fatalf("refresh imported session failed: %v", err)
}
if refreshed.AccessToken != "gemini-fresh" {
t.Fatalf("unexpected access token: %s", refreshed.AccessToken)
}
if refreshed.Email != "gemini@example.com" {
t.Fatalf("unexpected email: %s", refreshed.Email)
}
if refreshed.ProjectID != "demo-project" {
t.Fatalf("unexpected project id: %s", refreshed.ProjectID)
}
}
func TestAntigravityEnrichSessionAddsEmailAndProjectID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/userinfo":
if got := r.Header.Get("Authorization"); got != "Bearer antigravity-token" {
t.Fatalf("unexpected userinfo authorization: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"antigravity@example.com"}`))
case "/v1internal:loadCodeAssist":
if got := r.Header.Get("Authorization"); got != "Bearer antigravity-token" {
t.Fatalf("unexpected project authorization: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cloudaicompanionProject":"project-123"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
originalAPI := defaultAntigravityAPIEndpoint
originalUserInfo := defaultAntigravityUserInfoURL
defaultAntigravityAPIEndpoint = server.URL
defaultAntigravityUserInfoURL = server.URL + "/userinfo"
t.Cleanup(func() {
defaultAntigravityAPIEndpoint = originalAPI
defaultAntigravityUserInfoURL = originalUserInfo
})
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "antigravity",
ClientID: "client-id",
ClientSecret: "client-secret",
AuthURL: server.URL + "/oauth/authorize",
RedirectURL: "http://localhost:51121/oauth-callback",
RefreshLeadSec: 300,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer manager.bgCancel()
session, err := manager.enrichSession(context.Background(), &oauthSession{
Provider: "antigravity",
AccessToken: "antigravity-token",
})
if err != nil {
t.Fatalf("enrich session failed: %v", err)
}
if session.Email != "antigravity@example.com" {
t.Fatalf("unexpected email: %#v", session)
}
if session.ProjectID != "project-123" {
t.Fatalf("expected project id enrichment, got %#v", session)
}
}
func TestQwenDeviceFlowRequiresAccountLabelWhenEmailMissing(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/device":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"device_code":"dev-1","user_code":"user-1","verification_uri_complete":"https://chat.qwen.ai/device?code=user-1","interval":1,"expires_in":60}`))
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"qwen-token","refresh_token":"refresh-token","expires_in":3600}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "qwen",
AuthURL: server.URL + "/device",
TokenURL: server.URL + "/token",
CredentialFile: filepath.Join(t.TempDir(), "qwen.json"),
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer manager.bgCancel()
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
if err != nil {
t.Fatalf("start device flow failed: %v", err)
}
_, _, err = manager.completeDeviceFlow(context.Background(), "", flow, OAuthLoginOptions{})
if err == nil || !strings.Contains(err.Error(), "account_label") {
t.Fatalf("expected qwen account_label error, got %v", err)
}
session, _, err := manager.completeDeviceFlow(context.Background(), "", flow, OAuthLoginOptions{AccountLabel: "qwen-alias"})
if err != nil {
t.Fatalf("complete device flow with label failed: %v", err)
}
if session.Email != "qwen-alias" {
t.Fatalf("expected qwen alias persisted as email label, got %#v", session)
}
}
func TestImportSessionEnrichesAntigravityMetadata(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"imported@example.com"}`))
case "/v1internal:loadCodeAssist":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cloudaicompanionProject":"import-project"}`))
case "/v1/models":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":[{"id":"g-model"}]}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
originalAPI := defaultAntigravityAPIEndpoint
originalUserInfo := defaultAntigravityUserInfoURL
defaultAntigravityAPIEndpoint = server.URL
defaultAntigravityUserInfoURL = server.URL + "/userinfo"
t.Cleanup(func() {
defaultAntigravityAPIEndpoint = originalAPI
defaultAntigravityUserInfoURL = originalUserInfo
})
manager, err := newOAuthManager(config.ProviderConfig{
APIBase: server.URL + "/v1",
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "antigravity",
CredentialFile: filepath.Join(t.TempDir(), "antigravity.json"),
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer manager.bgCancel()
raw := []byte(`{"access_token":"import-token","refresh_token":"refresh-token","expired":"2030-01-01T00:00:00Z"}`)
session, models, err := manager.importSession(context.Background(), server.URL+"/v1", "auth.json", raw, OAuthLoginOptions{})
if err != nil {
t.Fatalf("import session failed: %v", err)
}
if session.Email != "imported@example.com" || session.ProjectID != "import-project" {
t.Fatalf("expected antigravity enrichment, got %#v", session)
}
if len(models) != 1 || models[0] != "g-model" {
t.Fatalf("unexpected models: %#v", models)
}
}
func TestPersistSessionAddsGeminiTokenMetadata(t *testing.T) {
t.Parallel()
credFile := filepath.Join(t.TempDir(), "gemini.json")
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "gemini",
CredentialFile: credFile,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer manager.bgCancel()
manager.mu.Lock()
err = manager.persistSessionLocked(&oauthSession{
Provider: "gemini",
AccessToken: "gem-access",
RefreshToken: "gem-refresh",
Expire: "2030-01-01T00:00:00Z",
Email: "gem@example.com",
FilePath: credFile,
})
manager.mu.Unlock()
if err != nil {
t.Fatalf("persist session failed: %v", err)
}
raw, err := os.ReadFile(credFile)
if err != nil {
t.Fatalf("read credential file failed: %v", err)
}
var payload map[string]any
if err := json.Unmarshal(raw, &payload); err != nil {
t.Fatalf("unmarshal credential file failed: %v", err)
}
tokenMap, _ := payload["token"].(map[string]any)
if tokenMap == nil {
t.Fatalf("expected token map in persisted gemini session, got %s", string(raw))
}
if tokenMap["token_uri"] != defaultGeminiTokenURL {
t.Fatalf("unexpected gemini token metadata: %#v", tokenMap)
}
if defaultGeminiClientID != "" && tokenMap["client_id"] != defaultGeminiClientID {
t.Fatalf("unexpected gemini client_id metadata: %#v", tokenMap)
}
if defaultGeminiClientSecret != "" && tokenMap["client_secret"] != defaultGeminiClientSecret {
t.Fatalf("unexpected gemini client_secret metadata: %#v", tokenMap)
}
}
func TestParseImportedOAuthSessionSupportsAliasProjectAndDeviceVariants(t *testing.T) {
t.Parallel()
session, err := parseImportedOAuthSession("qwen", "auth.json", []byte(`{
"refresh_token": "rt-1",
"token": {
"access_token": "at-1",
"account_label": "alias-qwen",
"projectId": "proj-1",
"deviceId": "dev-1",
"scopes": "openid profile"
}
}`))
if err != nil {
t.Fatalf("parse imported session failed: %v", err)
}
if session.Email != "alias-qwen" || session.ProjectID != "proj-1" || session.DeviceID != "dev-1" || session.Scope != "openid profile" {
t.Fatalf("unexpected parsed session: %#v", session)
}
}
func TestOAuthDeviceFlowQwenManualCompletes(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/device":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"device_code":"dev-1","user_code":"user-1","verification_uri_complete":"https://chat.qwen.ai/device?code=user-1","interval":1,"expires_in":60}`))
case "/token":
if err := r.ParseForm(); err != nil {
t.Fatalf("parse form failed: %v", err)
}
if got := r.Form.Get("device_code"); got != "dev-1" {
t.Fatalf("unexpected device_code: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"qwen-at","refresh_token":"qwen-rt","expires_in":3600}`))
case "/models":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":[{"id":"qwen-test"}]}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
dir := t.TempDir()
manager, err := newOAuthManager(config.ProviderConfig{
APIBase: server.URL,
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: "qwen",
CredentialFile: filepath.Join(dir, "qwen.json"),
AuthURL: server.URL + "/device",
TokenURL: server.URL + "/token",
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
if err != nil {
t.Fatalf("start device flow failed: %v", err)
}
if flow.Mode != oauthFlowDevice {
t.Fatalf("unexpected flow mode: %s", flow.Mode)
}
session, models, err := manager.completeDeviceFlow(context.Background(), server.URL, flow, OAuthLoginOptions{AccountLabel: "qwen-label"})
if err != nil {
t.Fatalf("complete device flow failed: %v", err)
}
if session.AccessToken != "qwen-at" {
t.Fatalf("unexpected access token: %s", session.AccessToken)
}
if session.FilePath == "" {
t.Fatalf("expected credential file path")
}
if session.Email != "qwen-label" {
t.Fatalf("expected qwen label, got %#v", session)
}
if len(models) != 1 || models[0] != "qwen-test" {
t.Fatalf("unexpected models: %#v", models)
}
}
func TestHTTPProviderHybridFallsBackFromAPIKeyToOAuth(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "oauth.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
Email: "oauth@example.com",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
var apiKeyCalls int32
var oauthCalls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/responses" {
http.NotFound(w, r)
return
}
switch r.Header.Get("Authorization") {
case "Bearer api-key-1":
atomic.AddInt32(&apiKeyCalls, 1)
w.WriteHeader(http.StatusTooManyRequests)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"quota exceeded"}}`))
case "Bearer oauth-token":
atomic.AddInt32(&oauthCalls, 1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-oauth"}`))
default:
t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization"))
}
}))
defer server.Close()
pc := config.ProviderConfig{
APIBase: server.URL + "/v1",
APIKey: "api-key-1",
Auth: "hybrid",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
TokenURL: server.URL + "/oauth/token",
AuthURL: server.URL + "/oauth/authorize",
},
}
oauth, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-hybrid-fallback", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
}
if resp.Content != "ok-from-oauth" {
t.Fatalf("unexpected response content: %q", resp.Content)
}
if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 1 {
t.Fatalf("expected one api-key and one oauth attempt, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
}
}
func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "oauth.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
Email: "oauth@example.com",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
var apiKeyCalls int32
var oauthCalls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/responses" {
http.NotFound(w, r)
return
}
switch r.Header.Get("Authorization") {
case "Bearer api-key-1":
atomic.AddInt32(&apiKeyCalls, 1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-api"}`))
case "Bearer oauth-token":
atomic.AddInt32(&oauthCalls, 1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-oauth"}`))
default:
t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization"))
}
}))
defer server.Close()
pc := config.ProviderConfig{
APIBase: server.URL + "/v1",
APIKey: "api-key-1",
Auth: "hybrid",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
TokenURL: server.URL + "/oauth/token",
AuthURL: server.URL + "/oauth/authorize",
},
}
oauth, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-hybrid-oauth-first", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
}
if resp.Content != "ok-from-api" {
t.Fatalf("unexpected response content: %q", resp.Content)
}
if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 0 {
t.Fatalf("expected api key first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
}
}
func TestOAuthManagerCooldownSkipsExhaustedAccount(t *testing.T) {
t.Parallel()
dir := t.TempDir()
firstFile := filepath.Join(dir, "first.json")
secondFile := filepath.Join(dir, "second.json")
writeSession := func(path, token string) {
t.Helper()
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: token,
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(path, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
}
writeSession(firstFile, "token-a")
writeSession(secondFile, "token-b")
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: firstFile,
CredentialFiles: []string{firstFile, secondFile},
CooldownSec: 3600,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
attempts, err := manager.prepareAttemptsLocked(context.Background())
if err != nil {
t.Fatalf("prepare attempts failed: %v", err)
}
if len(attempts) != 2 {
t.Fatalf("expected 2 attempts, got %d", len(attempts))
}
manager.markExhausted(attempts[0].Session, oauthFailureRateLimit)
nextAttempts, err := manager.prepareAttemptsLocked(context.Background())
if err != nil {
t.Fatalf("prepare attempts after cooldown failed: %v", err)
}
if len(nextAttempts) != 1 {
t.Fatalf("expected 1 available attempt after cooldown, got %d", len(nextAttempts))
}
if nextAttempts[0].Token != "token-b" {
t.Fatalf("unexpected token after cooldown: %s", nextAttempts[0].Token)
}
accounts, err := (&OAuthLoginManager{manager: manager}).ListAccounts()
if err != nil {
t.Fatalf("list accounts failed: %v", err)
}
foundCooldown := false
for _, account := range accounts {
if account.CredentialFile == firstFile && account.CooldownUntil != "" {
foundCooldown = true
}
}
if !foundCooldown {
t.Fatalf("expected cooldown metadata to be exposed in account list")
}
}
func TestOAuthManagerDisableSessionSkipsAccount(t *testing.T) {
t.Parallel()
dir := t.TempDir()
firstFile := filepath.Join(dir, "first.json")
secondFile := filepath.Join(dir, "second.json")
writeSession := func(path, token string) {
t.Helper()
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: token,
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(path, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
}
writeSession(firstFile, "token-a")
writeSession(secondFile, "token-b")
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: firstFile,
CredentialFiles: []string{firstFile, secondFile},
CooldownSec: 3600,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
attempts, err := manager.prepareAttemptsLocked(context.Background())
if err != nil {
t.Fatalf("prepare attempts failed: %v", err)
}
if len(attempts) != 2 {
t.Fatalf("expected 2 attempts, got %d", len(attempts))
}
manager.disableSession(attempts[0].Session, oauthFailureRevoked, "oauth token revoked")
nextAttempts, err := manager.prepareAttemptsLocked(context.Background())
if err != nil {
t.Fatalf("prepare attempts after disable failed: %v", err)
}
if len(nextAttempts) != 1 {
t.Fatalf("expected 1 available attempt after disable, got %d", len(nextAttempts))
}
if nextAttempts[0].Token != "token-b" {
t.Fatalf("unexpected token after disable: %s", nextAttempts[0].Token)
}
raw, err := os.ReadFile(firstFile)
if err != nil {
t.Fatalf("read disabled session failed: %v", err)
}
var saved oauthSession
if err := json.Unmarshal(raw, &saved); err != nil {
t.Fatalf("unmarshal disabled session failed: %v", err)
}
if !saved.Disabled || saved.DisableReason != string(oauthFailureRevoked) {
t.Fatalf("expected disabled session to persist, got %#v", saved)
}
}
func TestOAuthLoginManagerListAccountsSkipsInvalidCredentialFiles(t *testing.T) {
t.Parallel()
dir := t.TempDir()
invalidFile := filepath.Join(dir, "invalid.json")
validFile := filepath.Join(dir, "valid.json")
if err := os.WriteFile(invalidFile, []byte(`{"not":"a valid oauth session"}`), 0o600); err != nil {
t.Fatalf("write invalid credential failed: %v", err)
}
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
RefreshToken: "refresh-token",
Email: "user@example.com",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal valid session failed: %v", err)
}
if err := os.WriteFile(validFile, raw, 0o600); err != nil {
t.Fatalf("write valid credential failed: %v", err)
}
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: invalidFile,
CredentialFiles: []string{invalidFile, validFile},
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
accounts, err := (&OAuthLoginManager{manager: manager}).ListAccounts()
if err != nil {
t.Fatalf("list accounts failed: %v", err)
}
if len(accounts) != 1 {
t.Fatalf("expected 1 valid account, got %d (%#v)", len(accounts), accounts)
}
if accounts[0].Email != "user@example.com" {
t.Fatalf("unexpected account: %#v", accounts[0])
}
}
func TestOAuthManagerPrefersHealthierAccount(t *testing.T) {
t.Parallel()
dir := t.TempDir()
firstFile := filepath.Join(dir, "first.json")
secondFile := filepath.Join(dir, "second.json")
writeSession := func(path, token string) {
t.Helper()
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: token,
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(path, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
}
writeSession(firstFile, "token-a")
writeSession(secondFile, "token-b")
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: firstFile,
CredentialFiles: []string{firstFile, secondFile},
CooldownSec: 60,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
attempts, err := manager.prepareAttemptsLocked(context.Background())
if err != nil {
t.Fatalf("prepare attempts failed: %v", err)
}
manager.markExhausted(attempts[0].Session, oauthFailureQuota)
delete(manager.cooldowns, attempts[0].Session.FilePath)
attempts, err = manager.prepareAttemptsLocked(context.Background())
if err != nil {
t.Fatalf("prepare attempts after health drop failed: %v", err)
}
if len(attempts) != 2 {
t.Fatalf("expected 2 attempts, got %d", len(attempts))
}
if attempts[0].Token != "token-b" {
t.Fatalf("expected healthier token-b first, got %s", attempts[0].Token)
}
}
func TestOAuthLoginManagerListAccountsIncludesCodexPlanMetadata(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "codex-plan.json")
idToken := buildTestJWT(map[string]any{
"email": "plan@example.com",
"https://api.openai.com/auth": map[string]any{
"chatgpt_account_id": "acct-plan",
"chatgpt_plan_type": "pro",
"chatgpt_subscription_active_start": "2026-03-01T00:00:00Z",
"chatgpt_subscription_active_until": "2026-04-01T00:00:00Z",
},
})
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "token-plan",
RefreshToken: "refresh-plan",
IDToken: idToken,
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
manager, err := newOAuthManager(config.ProviderConfig{
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
accounts, err := (&OAuthLoginManager{manager: manager}).ListAccounts()
if err != nil {
t.Fatalf("list accounts failed: %v", err)
}
if len(accounts) != 1 {
t.Fatalf("expected one account, got %#v", accounts)
}
account := accounts[0]
if account.PlanType != "pro" {
t.Fatalf("expected plan type to be extracted, got %#v", account)
}
if account.BalanceLabel != "PRO" || account.SubActiveUntil != "2026-04-01T00:00:00Z" {
t.Fatalf("expected subscription metadata in account info, got %#v", account)
}
}
func buildTestJWT(claims map[string]any) string {
header, _ := json.Marshal(map[string]any{"alg": "none", "typ": "JWT"})
payload, _ := json.Marshal(claims)
return base64.RawURLEncoding.EncodeToString(header) + "." + base64.RawURLEncoding.EncodeToString(payload) + "."
}
func TestClassifyOAuthFailureDifferentiatesReasons(t *testing.T) {
t.Parallel()
reason, retry := classifyOAuthFailure(http.StatusTooManyRequests, []byte(`{"error":{"code":"insufficient_quota"}}`))
if !retry || reason != oauthFailureQuota {
t.Fatalf("expected quota classification, got retry=%v reason=%s", retry, reason)
}
reason, retry = classifyOAuthFailure(http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`))
if !retry || reason != oauthFailureRateLimit {
t.Fatalf("expected rate-limit classification, got retry=%v reason=%s", retry, reason)
}
reason, retry = classifyOAuthFailure(http.StatusForbidden, []byte(`{"error":"forbidden"}`))
if !retry || reason != oauthFailureForbidden {
t.Fatalf("expected forbidden classification, got retry=%v reason=%s", retry, reason)
}
}
func TestHTTPProviderHybridSkipsAPIKeyDuringCooldown(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "oauth.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api["cooldown-provider"] = providerRuntimeState{
API: providerAPIRuntimeState{
TokenMasked: "api-***",
HealthScore: 50,
CooldownUntil: time.Now().Add(10 * time.Minute).Format(time.RFC3339),
},
}
providerRuntimeRegistry.mu.Unlock()
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "hybrid",
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("cooldown-provider", "api-key-1", "https://example.com/v1", "gpt-test", false, "hybrid", 5*time.Second, manager)
attempts, err := provider.authAttempts(context.Background())
if err != nil {
t.Fatalf("auth attempts failed: %v", err)
}
if len(attempts) != 1 || attempts[0].kind != "oauth" {
t.Fatalf("expected only oauth attempt during api cooldown, got %#v", attempts)
}
}
func TestClearProviderAPICooldownRestoresAPIKeyAttempt(t *testing.T) {
t.Parallel()
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api["clear-api-provider"] = providerRuntimeState{
API: providerAPIRuntimeState{
TokenMasked: "api-***",
HealthScore: 50,
CooldownUntil: time.Now().Add(10 * time.Minute).Format(time.RFC3339),
},
}
providerRuntimeRegistry.mu.Unlock()
provider := NewHTTPProvider("clear-api-provider", "api-key-1", "https://example.com/v1", "gpt-test", false, "bearer", 5*time.Second, nil)
if _, err := provider.authAttempts(context.Background()); err == nil {
t.Fatalf("expected api key attempt to be blocked by cooldown")
}
ClearProviderAPICooldown("clear-api-provider")
attempts, err := provider.authAttempts(context.Background())
if err != nil {
t.Fatalf("expected api key attempt after clear cooldown, got %v", err)
}
if len(attempts) != 1 || attempts[0].kind != "api_key" {
t.Fatalf("unexpected attempts after clear cooldown: %#v", attempts)
}
}
func TestOAuthLoginManagerClearCooldown(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "oauth.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
manager, err := newOAuthManager(config.ProviderConfig{
Auth: "oauth",
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
CooldownSec: 3600,
},
}, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
attempts, err := manager.prepareAttemptsLocked(context.Background())
if err != nil || len(attempts) != 1 {
t.Fatalf("prepare attempts failed: %v %#v", err, attempts)
}
manager.markExhausted(attempts[0].Session, oauthFailureRateLimit)
loginMgr := &OAuthLoginManager{manager: manager}
if err := loginMgr.ClearCooldown(credFile); err != nil {
t.Fatalf("clear cooldown failed: %v", err)
}
next, err := manager.prepareAttemptsLocked(context.Background())
if err != nil || len(next) != 1 {
t.Fatalf("expected session available after cooldown clear, got err=%v attempts=%#v", err, next)
}
}
func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "oauth.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
RefreshToken: "refresh-token",
Email: "user@example.com",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
name := "runtime-snapshot-provider"
pc := config.ProviderConfig{
APIKey: "api-key-123456",
APIBase: "https://example.com/v1",
Auth: "hybrid",
TimeoutSec: 5,
RuntimePersist: true,
RuntimeHistoryFile: filepath.Join(dir, "runtime.json"),
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
},
}
ConfigureProviderRuntime(name, pc)
manager, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider(name, pc.APIKey, pc.APIBase, "gpt-test", false, pc.Auth, 5*time.Second, manager)
attempts, err := provider.authAttempts(context.Background())
if err != nil {
t.Fatalf("auth attempts failed: %v", err)
}
if len(attempts) != 2 || attempts[0].kind != "api_key" || attempts[1].kind != "oauth" {
t.Fatalf("unexpected attempts order: %#v", attempts)
}
provider.markAttemptSuccess(attempts[1])
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{name: pc},
},
}
snapshot := GetProviderRuntimeSnapshot(cfg)
items, _ := snapshot["items"].([]map[string]interface{})
if len(items) == 0 {
t.Fatalf("expected provider runtime items")
}
item := items[0]
candidates, _ := item["candidate_order"].([]providerRuntimeCandidate)
if len(candidates) < 2 {
t.Fatalf("expected candidate order, got %#v", item["candidate_order"])
}
if candidates[0].Kind != "api_key" || candidates[1].Kind != "oauth" {
t.Fatalf("unexpected candidate order: %#v", candidates)
}
lastSuccess, _ := item["last_success"].(*providerRuntimeEvent)
if lastSuccess == nil || lastSuccess.Kind != "oauth" || lastSuccess.Target != "user@example.com" {
t.Fatalf("unexpected last success: %#v", item["last_success"])
}
if _, err := os.Stat(pc.RuntimeHistoryFile); err != nil {
t.Fatalf("expected runtime history file, got %v", err)
}
}
func TestConfigureProviderRuntimeLoadsPersistedEvents(t *testing.T) {
t.Parallel()
dir := t.TempDir()
name := "persisted-runtime-provider"
historyFile := filepath.Join(dir, "runtime.json")
payload := providerRuntimeState{
RecentHits: []providerRuntimeEvent{{
When: time.Now().Add(-time.Minute).Format(time.RFC3339),
Kind: "oauth",
Target: "persisted@example.com",
Reason: "ok",
}},
LastSuccess: &providerRuntimeEvent{
When: time.Now().Add(-time.Minute).Format(time.RFC3339),
Kind: "oauth",
Target: "persisted@example.com",
},
}
raw, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal runtime payload failed: %v", err)
}
if err := os.WriteFile(historyFile, raw, 0o600); err != nil {
t.Fatalf("write history file failed: %v", err)
}
ConfigureProviderRuntime(name, config.ProviderConfig{
APIBase: "https://example.com/v1",
Auth: "bearer",
RuntimePersist: true,
RuntimeHistoryFile: historyFile,
})
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {
APIBase: "https://example.com/v1",
Auth: "bearer",
RuntimePersist: true,
RuntimeHistoryFile: historyFile,
},
},
},
}
snapshot := GetProviderRuntimeSnapshot(cfg)
items, _ := snapshot["items"].([]map[string]interface{})
if len(items) == 0 {
t.Fatalf("expected provider runtime item")
}
lastSuccess, _ := items[0]["last_success"].(*providerRuntimeEvent)
if lastSuccess == nil || lastSuccess.Target != "persisted@example.com" {
t.Fatalf("expected persisted last success, got %#v", items[0]["last_success"])
}
hits, _ := items[0]["recent_hits"].([]providerRuntimeEvent)
if len(hits) == 0 || hits[0].Target != "persisted@example.com" {
t.Fatalf("expected persisted recent hits, got %#v", items[0]["recent_hits"])
}
}
func TestClearProviderRuntimeHistoryRemovesPersistedFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
name := "clear-runtime-history-provider"
historyFile := filepath.Join(dir, "runtime.json")
ConfigureProviderRuntime(name, config.ProviderConfig{
APIBase: "https://example.com/v1",
Auth: "bearer",
RuntimePersist: true,
RuntimeHistoryFile: historyFile,
})
providerRuntimeRegistry.mu.Lock()
state := providerRuntimeRegistry.api[name]
state.RecentHits = []providerRuntimeEvent{{When: time.Now().Format(time.RFC3339), Kind: "api_key", Target: "api***"}}
state.LastSuccess = &providerRuntimeEvent{When: time.Now().Format(time.RFC3339), Kind: "api_key", Target: "api***"}
persistProviderRuntimeLocked(name, state)
providerRuntimeRegistry.api[name] = state
providerRuntimeRegistry.mu.Unlock()
if _, err := os.Stat(historyFile); err != nil {
t.Fatalf("expected runtime history file, got %v", err)
}
ClearProviderRuntimeHistory(name)
if _, err := os.Stat(historyFile); !os.IsNotExist(err) {
t.Fatalf("expected runtime history file removed, got %v", err)
}
providerRuntimeRegistry.mu.Lock()
cleared := providerRuntimeRegistry.api[name]
providerRuntimeRegistry.mu.Unlock()
if len(cleared.RecentHits) != 0 || len(cleared.RecentErrors) != 0 || cleared.LastSuccess != nil {
t.Fatalf("expected runtime history cleared, got %#v", cleared)
}
}
func TestUpdateCandidateOrderRecordsSchedulerChange(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "oauth.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
Email: "user@example.com",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
name := "candidate-change-provider"
pc := config.ProviderConfig{
APIKey: "api-key-123456",
APIBase: "https://example.com/v1",
Auth: "hybrid",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
},
}
manager, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider(name, pc.APIKey, pc.APIBase, "gpt-test", false, pc.Auth, 5*time.Second, manager)
attempts, err := provider.authAttempts(context.Background())
if err != nil {
t.Fatalf("auth attempts failed: %v", err)
}
if len(attempts) != 2 {
t.Fatalf("unexpected attempts: %#v", attempts)
}
provider.markAPIKeyFailure(oauthFailureRateLimit)
attempts, err = provider.authAttempts(context.Background())
if err != nil {
t.Fatalf("auth attempts after cooldown failed: %v", err)
}
if len(attempts) != 1 || attempts[0].kind != "oauth" {
t.Fatalf("unexpected attempts after cooldown: %#v", attempts)
}
providerRuntimeRegistry.mu.Lock()
state := providerRuntimeRegistry.api[name]
providerRuntimeRegistry.mu.Unlock()
if len(state.RecentChanges) == 0 || state.RecentChanges[0].Reason != "candidate_order_changed" {
t.Fatalf("expected scheduler change event, got %#v", state.RecentChanges)
}
if !strings.Contains(state.RecentChanges[0].Detail, "top ") {
t.Fatalf("expected candidate order detail, got %#v", state.RecentChanges[0])
}
}
func TestGetProviderRuntimeViewFiltersEvents(t *testing.T) {
t.Parallel()
name := "runtime-view-provider"
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api[name] = providerRuntimeState{
RecentHits: []providerRuntimeEvent{
{When: time.Now().Add(-30 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "ok"},
{When: time.Now().Add(-3 * time.Hour).Format(time.RFC3339), Kind: "api_key", Target: "api***", Reason: "ok"},
},
RecentErrors: []providerRuntimeEvent{
{When: time.Now().Add(-10 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "quota"},
},
RecentChanges: []providerRuntimeEvent{
{When: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), Kind: "scheduler", Target: name, Reason: "candidate_order_changed", Detail: "top api -> oauth"},
},
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "hybrid", APIKey: "api-key"},
},
},
}
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{
Provider: name,
Window: 2 * time.Hour,
EventKind: "oauth",
Limit: 1,
})
items, _ := view["items"].([]map[string]interface{})
if len(items) != 1 {
t.Fatalf("expected one runtime item, got %#v", view)
}
hits, _ := items[0]["recent_hits"].([]providerRuntimeEvent)
if len(hits) != 1 || hits[0].Kind != "oauth" {
t.Fatalf("expected filtered oauth hits, got %#v", items[0]["recent_hits"])
}
errors, _ := items[0]["recent_errors"].([]providerRuntimeEvent)
if len(errors) != 1 || errors[0].Reason != "quota" {
t.Fatalf("expected filtered oauth errors, got %#v", items[0]["recent_errors"])
}
changes, _ := items[0]["recent_changes"].([]providerRuntimeEvent)
if len(changes) != 0 {
t.Fatalf("expected no scheduler changes when filtering kind=oauth, got %#v", items[0]["recent_changes"])
}
events, _ := items[0]["events"].([]providerRuntimeEvent)
if len(events) != 1 {
t.Fatalf("expected merged paged events, got %#v", items[0]["events"])
}
}
func TestGetProviderRuntimeViewCursorPagination(t *testing.T) {
t.Parallel()
name := "runtime-cursor-provider"
now := time.Now()
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api[name] = providerRuntimeState{
RecentHits: []providerRuntimeEvent{
{When: now.Add(-1 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "a", Reason: "ok"},
{When: now.Add(-2 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "b", Reason: "ok"},
},
RecentErrors: []providerRuntimeEvent{
{When: now.Add(-3 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "c", Reason: "quota"},
},
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
},
},
}
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 2, Cursor: 0})
items, _ := view["items"].([]map[string]interface{})
if len(items) != 1 {
t.Fatalf("expected one item, got %#v", view)
}
page1, _ := items[0]["events"].([]providerRuntimeEvent)
if len(page1) != 2 || items[0]["next_cursor"].(int) != 2 {
t.Fatalf("unexpected first page %#v", items[0])
}
view = GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 2, Cursor: 2})
items, _ = view["items"].([]map[string]interface{})
page2, _ := items[0]["events"].([]providerRuntimeEvent)
if len(page2) != 1 || items[0]["next_cursor"].(int) != 0 {
t.Fatalf("unexpected second page %#v", items[0])
}
}
func TestGetProviderRuntimeViewSortAscending(t *testing.T) {
t.Parallel()
name := "runtime-sort-provider"
now := time.Now()
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api[name] = providerRuntimeState{
RecentHits: []providerRuntimeEvent{
{When: now.Add(-1 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "a", Reason: "ok"},
{When: now.Add(-3 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "b", Reason: "ok"},
},
RecentErrors: []providerRuntimeEvent{
{When: now.Add(-2 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "c", Reason: "quota"},
},
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
},
},
}
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 10, Sort: "asc"})
items, _ := view["items"].([]map[string]interface{})
if len(items) != 1 {
t.Fatalf("expected one item, got %#v", view)
}
events, _ := items[0]["events"].([]providerRuntimeEvent)
if len(events) != 3 {
t.Fatalf("expected three events, got %#v", items[0]["events"])
}
if events[0].Target != "b" || events[1].Target != "c" || events[2].Target != "a" {
t.Fatalf("expected ascending order oldest->newest, got %#v", events)
}
}
func TestGetProviderRuntimeViewFiltersByHealthAndCooldown(t *testing.T) {
t.Parallel()
name := "runtime-health-provider"
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api[name] = providerRuntimeState{
API: providerAPIRuntimeState{
HealthScore: 20,
CooldownUntil: time.Now().Add(-5 * time.Minute).Format(time.RFC3339),
},
CandidateOrder: []providerRuntimeCandidate{{
Kind: "api_key",
Target: "api***",
HealthScore: 20,
CooldownUntil: time.Now().Add(-5 * time.Minute).Format(time.RFC3339),
}},
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
},
},
}
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{
Provider: name,
HealthBelow: 30,
CooldownBefore: time.Now(),
})
items, _ := view["items"].([]map[string]interface{})
if len(items) != 1 {
t.Fatalf("expected one filtered runtime item, got %#v", view)
}
}
func TestGetProviderRuntimeSummaryFlagsUnhealthyProviders(t *testing.T) {
t.Parallel()
name := "runtime-summary-provider"
lastSuccessAt := time.Now().Add(-2 * time.Hour).Format(time.RFC3339)
topChangedAt := time.Now().Add(-3 * time.Minute).Format(time.RFC3339)
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api[name] = providerRuntimeState{
API: providerAPIRuntimeState{
HealthScore: 25,
CooldownUntil: time.Now().Add(15 * time.Minute).Format(time.RFC3339),
},
RecentErrors: []providerRuntimeEvent{
{When: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), Kind: "api_key", Target: "api***", Reason: "quota"},
},
RecentChanges: []providerRuntimeEvent{
{When: topChangedAt, Kind: "scheduler", Target: name, Reason: "candidate_order_changed", Detail: "top oauth -> api"},
},
LastSuccess: &providerRuntimeEvent{
When: lastSuccessAt,
Kind: "api_key",
Target: "api***",
Reason: "ok",
},
CandidateOrder: []providerRuntimeCandidate{{
Kind: "api_key",
Target: "api***",
Available: false,
Status: "cooldown",
HealthScore: 25,
CooldownUntil: time.Now().Add(15 * time.Minute).Format(time.RFC3339),
}},
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
},
},
}
summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{HealthBelow: 30, Window: time.Hour})
if summary.TotalProviders != 1 || summary.InCooldown != 1 || summary.LowHealth != 1 || summary.RecentErrors != 1 {
t.Fatalf("unexpected summary counts: %#v", summary)
}
if summary.Critical != 1 || summary.Degraded != 0 || summary.Healthy != 0 {
t.Fatalf("unexpected status counts: %#v", summary)
}
if len(summary.Providers) != 1 || summary.Providers[0].TopCandidate == nil || summary.Providers[0].TopCandidate.Kind != "api_key" {
t.Fatalf("unexpected provider summary items: %#v", summary.Providers)
}
if summary.Providers[0].Status != "critical" {
t.Fatalf("expected critical status, got %#v", summary.Providers[0])
}
if summary.Providers[0].LastError == nil || summary.Providers[0].LastErrorReason != "quota" || summary.Providers[0].LastErrorAt == "" {
t.Fatalf("expected last error details, got %#v", summary.Providers[0])
}
if summary.Providers[0].LastSuccessAt != lastSuccessAt || summary.Providers[0].TopCandidateChangedAt != topChangedAt {
t.Fatalf("expected last success and top candidate timestamps, got %#v", summary.Providers[0])
}
if summary.Providers[0].StaleForSec < 7100 || summary.Providers[0].StaleForSec > 7300 {
t.Fatalf("expected stale_for_sec around 2h, got %#v", summary.Providers[0].StaleForSec)
}
}
func TestGetProviderRuntimeSummaryMarksRecentErrorsAsDegraded(t *testing.T) {
t.Parallel()
name := "runtime-summary-degraded-provider"
providerRuntimeRegistry.mu.Lock()
providerRuntimeRegistry.api[name] = providerRuntimeState{
RecentErrors: []providerRuntimeEvent{
{When: time.Now().Add(-10 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "rate_limit"},
},
CandidateOrder: []providerRuntimeCandidate{{
Kind: "oauth",
Target: "user@example.com",
Available: true,
Status: "ready",
HealthScore: 90,
}},
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
},
},
}
summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{HealthBelow: 30, Window: time.Hour})
if summary.TotalProviders != 1 || summary.Degraded != 1 || summary.Critical != 0 || summary.Healthy != 0 {
t.Fatalf("unexpected summary counts: %#v", summary)
}
if len(summary.Providers) != 1 || summary.Providers[0].Status != "degraded" {
t.Fatalf("expected degraded provider item, got %#v", summary.Providers)
}
if summary.Providers[0].LastErrorReason != "rate_limit" {
t.Fatalf("expected last error reason, got %#v", summary.Providers[0])
}
if summary.Providers[0].StaleForSec != -1 {
t.Fatalf("expected stale_for_sec=-1 without success event, got %#v", summary.Providers[0])
}
}
func TestGetProviderRuntimeSummaryIncludesOAuthAccountMetadata(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "qwen.json")
raw, err := json.Marshal(oauthSession{
Provider: "qwen",
AccessToken: "qwen-token",
RefreshToken: "refresh-token",
Email: "qwen-label",
ProjectID: "proj-9",
DeviceID: "device-9",
ResourceURL: "https://chat.qwen.ai/api",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
FilePath: credFile,
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
"qwen-summary": {
APIBase: "https://example.com/v1",
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "qwen",
CredentialFile: credFile,
},
},
},
},
}
summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{Provider: "qwen-summary", HealthBelow: 50})
if len(summary.Providers) != 1 {
t.Fatalf("expected one provider, got %#v", summary)
}
if len(summary.Providers[0].OAuthAccounts) != 1 {
t.Fatalf("expected oauth account metadata, got %#v", summary.Providers[0])
}
account := summary.Providers[0].OAuthAccounts[0]
if account.AccountLabel != "qwen-label" || account.ProjectID != "proj-9" || account.DeviceID != "device-9" || account.ResourceURL != "https://chat.qwen.ai/api" {
t.Fatalf("unexpected oauth account metadata: %#v", account)
}
}
func TestRefreshProviderRuntimeNowSupportsOnlyExpiring(t *testing.T) {
t.Parallel()
var refreshCalls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.NotFound(w, r)
return
}
atomic.AddInt32(&refreshCalls, 1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"refreshed-token","refresh_token":"refresh-token","expires_in":3600}`))
}))
defer server.Close()
credFile := filepath.Join(t.TempDir(), "codex.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "old-token",
RefreshToken: "refresh-token",
Expire: time.Now().Add(24 * time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write credential file failed: %v", err)
}
name := "runtime-refresh-provider"
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {
APIBase: server.URL + "/v1",
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
ClientID: "test-client",
TokenURL: server.URL + "/oauth/token",
AuthURL: server.URL + "/oauth/authorize",
RefreshLeadSec: 1800,
},
},
},
},
}
result, err := RefreshProviderRuntimeNow(cfg, name, true)
if err != nil {
t.Fatalf("refresh only expiring failed: %v", err)
}
if result == nil || result.Refreshed != 0 || result.Skipped != 1 {
t.Fatalf("expected skip for non-expiring session, got %#v", result)
}
if atomic.LoadInt32(&refreshCalls) != 0 {
t.Fatalf("expected no refresh calls for only-expiring path, got %d", refreshCalls)
}
result, err = RefreshProviderRuntimeNow(cfg, name, false)
if err != nil {
t.Fatalf("refresh all failed: %v", err)
}
if result == nil || result.Refreshed != 1 {
t.Fatalf("expected forced refresh, got %#v", result)
}
if atomic.LoadInt32(&refreshCalls) != 1 {
t.Fatalf("expected one refresh call, got %d", refreshCalls)
}
}
func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
t.Parallel()
dir := t.TempDir()
credFile := filepath.Join(dir, "oauth.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "oauth-token",
Email: "rerank@example.com",
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write session failed: %v", err)
}
name := "rerank-runtime-provider"
cfg := &config.Config{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {
APIKey: "api-key",
APIBase: "https://example.com/v1",
Auth: "hybrid",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
},
},
},
},
}
order, err := RerankProviderRuntime(cfg, name)
if err != nil {
t.Fatalf("rerank provider runtime failed: %v", err)
}
if len(order) == 0 || order[0].Kind != "api_key" {
t.Fatalf("expected api-key-first rerank result, got %#v", order)
}
snapshot := GetProviderRuntimeSnapshot(cfg)
items, _ := snapshot["items"].([]map[string]interface{})
if len(items) != 1 {
t.Fatalf("expected one runtime item, got %#v", snapshot)
}
snapshotOrder, _ := items[0]["candidate_order"].([]providerRuntimeCandidate)
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "api_key" {
t.Fatalf("expected api-key-first candidate order, got %#v", items[0]["candidate_order"])
}
}