mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-14 22:09:37 +08:00
2172 lines
70 KiB
Go
2172 lines
70 KiB
Go
package providers
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/YspCoder/clawgo/pkg/config"
|
|
)
|
|
|
|
func TestHTTPProviderOAuthRefreshesExpiredSession(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var refreshCalls int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
atomic.AddInt32(&refreshCalls, 1)
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("parse token form failed: %v", err)
|
|
}
|
|
if got := r.Form.Get("grant_type"); got != "refresh_token" {
|
|
t.Fatalf("unexpected grant_type: %s", got)
|
|
}
|
|
if got := r.Form.Get("refresh_token"); got != "refresh-token" {
|
|
t.Fatalf("unexpected refresh_token: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"fresh-token","refresh_token":"refresh-token","expires_in":3600}`))
|
|
case "/v1/responses":
|
|
if got := r.Header.Get("Authorization"); got != "Bearer fresh-token" {
|
|
t.Fatalf("unexpected authorization header: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
credFile := filepath.Join(t.TempDir(), "codex.json")
|
|
initial := oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "expired-token",
|
|
RefreshToken: "refresh-token",
|
|
Expire: time.Now().Add(-time.Hour).Format(time.RFC3339),
|
|
}
|
|
raw, err := json.Marshal(initial)
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write credential file failed: %v", err)
|
|
}
|
|
|
|
pc := config.ProviderConfig{
|
|
APIBase: server.URL + "/v1",
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
ClientID: "test-client",
|
|
TokenURL: server.URL + "/oauth/token",
|
|
AuthURL: server.URL + "/oauth/authorize",
|
|
},
|
|
}
|
|
oauth, err := newOAuthManager(pc, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
provider := NewHTTPProvider("test-oauth-refresh", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
|
|
|
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
|
if err != nil {
|
|
t.Fatalf("chat failed: %v", err)
|
|
}
|
|
if resp.Content != "ok" {
|
|
t.Fatalf("unexpected chat content: %q", resp.Content)
|
|
}
|
|
if got := atomic.LoadInt32(&refreshCalls); got != 1 {
|
|
t.Fatalf("expected exactly one refresh call, got %d", got)
|
|
}
|
|
|
|
savedRaw, err := os.ReadFile(credFile)
|
|
if err != nil {
|
|
t.Fatalf("read refreshed credential file failed: %v", err)
|
|
}
|
|
if !strings.Contains(string(savedRaw), "fresh-token") {
|
|
t.Fatalf("expected refreshed token to be persisted, got %s", string(savedRaw))
|
|
}
|
|
}
|
|
|
|
func TestOAuthLoginManualCallbackURLParse(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result, err := waitForOAuthCodeManual(
|
|
"https://example.com/auth?state=test-state",
|
|
bytes.NewBufferString("http://localhost:1455/auth/callback?code=auth-code&state=test-state\n"),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("manual callback parse failed: %v", err)
|
|
}
|
|
if result.Code != "auth-code" {
|
|
t.Fatalf("unexpected auth code: %s", result.Code)
|
|
}
|
|
if result.State != "test-state" {
|
|
t.Fatalf("unexpected state: %s", result.State)
|
|
}
|
|
}
|
|
|
|
func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
firstFile := filepath.Join(dir, "first.json")
|
|
secondFile := filepath.Join(dir, "second.json")
|
|
writeSession := func(path, token, email string) {
|
|
t.Helper()
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: token,
|
|
Email: email,
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(path, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
}
|
|
writeSession(firstFile, "token-a", "a@example.com")
|
|
writeSession(secondFile, "token-b", "b@example.com")
|
|
|
|
var tokenAUsed int32
|
|
var tokenBUsed int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/responses" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
switch r.Header.Get("Authorization") {
|
|
case "Bearer token-a":
|
|
atomic.AddInt32(&tokenAUsed, 1)
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"quota exceeded"}}`))
|
|
case "Bearer token-b":
|
|
atomic.AddInt32(&tokenBUsed, 1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-second"}`))
|
|
default:
|
|
t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization"))
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
pc := config.ProviderConfig{
|
|
APIBase: server.URL + "/v1",
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: firstFile,
|
|
CredentialFiles: []string{firstFile, secondFile},
|
|
ClientID: "test-client",
|
|
TokenURL: server.URL + "/oauth/token",
|
|
AuthURL: server.URL + "/oauth/authorize",
|
|
},
|
|
}
|
|
oauth, err := newOAuthManager(pc, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
provider := NewHTTPProvider("test-oauth-quota", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
|
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
|
if err != nil {
|
|
t.Fatalf("chat failed: %v", err)
|
|
}
|
|
if resp.Content != "ok-from-second" {
|
|
t.Fatalf("unexpected response content: %q", resp.Content)
|
|
}
|
|
if atomic.LoadInt32(&tokenAUsed) != 1 || atomic.LoadInt32(&tokenBUsed) != 1 {
|
|
t.Fatalf("expected one attempt per token, got token-a=%d token-b=%d", tokenAUsed, tokenBUsed)
|
|
}
|
|
}
|
|
|
|
func TestOAuthManagerPreRefreshesExpiringSession(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var refreshCalls int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/oauth/token" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
atomic.AddInt32(&refreshCalls, 1)
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("parse token form failed: %v", err)
|
|
}
|
|
if got := r.Form.Get("grant_type"); got != "refresh_token" {
|
|
t.Fatalf("unexpected grant_type: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"prefreshed-token","refresh_token":"refresh-token","expires_in":3600}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
credFile := filepath.Join(t.TempDir(), "prefresh.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "old-token",
|
|
RefreshToken: "refresh-token",
|
|
Expire: time.Now().Add(2 * time.Minute).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
|
|
originalAPI := defaultAntigravityAPIEndpoint
|
|
originalUserInfo := defaultAntigravityUserInfoURL
|
|
defaultAntigravityAPIEndpoint = server.URL
|
|
defaultAntigravityUserInfoURL = server.URL + "/userinfo"
|
|
t.Cleanup(func() {
|
|
defaultAntigravityAPIEndpoint = originalAPI
|
|
defaultAntigravityUserInfoURL = originalUserInfo
|
|
})
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
ClientID: "test-client",
|
|
TokenURL: server.URL + "/oauth/token",
|
|
AuthURL: server.URL + "/oauth/authorize",
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer manager.bgCancel()
|
|
|
|
result, err := manager.refreshExpiringSessions(context.Background(), 10*time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("pre-refresh failed: %v", err)
|
|
}
|
|
if result == nil || result.Refreshed != 1 {
|
|
t.Fatalf("expected one refreshed account, got %#v", result)
|
|
}
|
|
if atomic.LoadInt32(&refreshCalls) != 1 {
|
|
t.Fatalf("expected one refresh call, got %d", refreshCalls)
|
|
}
|
|
saved, err := os.ReadFile(credFile)
|
|
if err != nil {
|
|
t.Fatalf("read saved session failed: %v", err)
|
|
}
|
|
if !strings.Contains(string(saved), "prefreshed-token") {
|
|
t.Fatalf("expected prefreshed token in file, got %s", string(saved))
|
|
}
|
|
}
|
|
|
|
func TestResolveOAuthConfigSupportsAdditionalProviders(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
provider string
|
|
want string
|
|
flow string
|
|
}{
|
|
{name: "anthropic-alias", provider: "anthropic", want: "claude", flow: oauthFlowCallback},
|
|
{name: "antigravity", provider: "antigravity", want: "antigravity", flow: oauthFlowCallback},
|
|
{name: "gemini", provider: "gemini", want: "gemini", flow: oauthFlowCallback},
|
|
{name: "kimi", provider: "kimi", want: "kimi", flow: oauthFlowDevice},
|
|
{name: "qwen", provider: "qwen", want: "qwen", flow: oauthFlowDevice},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg, err := resolveOAuthConfig(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: tc.provider,
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("resolve oauth config failed: %v", err)
|
|
}
|
|
if cfg.Provider != tc.want {
|
|
t.Fatalf("unexpected provider: %s", cfg.Provider)
|
|
}
|
|
if cfg.FlowKind != tc.flow {
|
|
t.Fatalf("unexpected flow kind: %s", cfg.FlowKind)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewOAuthManagerUsesAnthropicTransportForClaude(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "anthropic",
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer manager.bgCancel()
|
|
|
|
if _, ok := manager.httpClient.Transport.(*anthropicOAuthRoundTripper); !ok {
|
|
t.Fatalf("expected anthropic oauth transport, got %T", manager.httpClient.Transport)
|
|
}
|
|
}
|
|
|
|
func TestNewOAuthManagerUsesDefaultTransportForNonClaude(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer manager.bgCancel()
|
|
|
|
if manager.httpClient.Transport != nil {
|
|
t.Fatalf("expected default transport for non-claude provider, got %T", manager.httpClient.Transport)
|
|
}
|
|
}
|
|
|
|
func TestResolveOAuthConfigAppliesProviderRefreshLeadDefaults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
provider string
|
|
want time.Duration
|
|
}{
|
|
{provider: "codex", want: 5 * 24 * time.Hour},
|
|
{provider: "anthropic", want: 4 * time.Hour},
|
|
{provider: "antigravity", want: 5 * time.Minute},
|
|
{provider: "gemini", want: 30 * time.Minute},
|
|
{provider: "kimi", want: 5 * time.Minute},
|
|
{provider: "qwen", want: 3 * time.Hour},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
tc := tc
|
|
t.Run(tc.provider, func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg, err := resolveOAuthConfig(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: tc.provider,
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("resolve oauth config failed: %v", err)
|
|
}
|
|
if cfg.RefreshLead != tc.want {
|
|
t.Fatalf("unexpected refresh lead for %s: got %v want %v", tc.provider, cfg.RefreshLead, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolveOAuthConfigUsesBuiltInGeminiClientDefaults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cfg, err := resolveOAuthConfig(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "gemini",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("resolve oauth config failed: %v", err)
|
|
}
|
|
if cfg.ClientID != defaultGeminiClientIDValue {
|
|
t.Fatalf("unexpected gemini client id: %q", cfg.ClientID)
|
|
}
|
|
if cfg.ClientSecret != defaultGeminiClientSecretValue {
|
|
t.Fatalf("unexpected gemini client secret: %q", cfg.ClientSecret)
|
|
}
|
|
}
|
|
|
|
func TestResolveOAuthConfigUsesBuiltInAntigravityClientDefaults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cfg, err := resolveOAuthConfig(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "antigravity",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("resolve oauth config failed: %v", err)
|
|
}
|
|
if cfg.ClientID != defaultAntigravityClientIDValue {
|
|
t.Fatalf("unexpected antigravity client id: %q", cfg.ClientID)
|
|
}
|
|
if cfg.ClientSecret != defaultAntigravityClientSecretValue {
|
|
t.Fatalf("unexpected antigravity client secret: %q", cfg.ClientSecret)
|
|
}
|
|
}
|
|
|
|
func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var refreshCalls int32
|
|
var responseCalls int32
|
|
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
atomic.AddInt32(&refreshCalls, 1)
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("parse token form failed: %v", err)
|
|
}
|
|
if got := r.Form.Get("grant_type"); got != "refresh_token" {
|
|
t.Fatalf("unexpected grant_type: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"proxied-fresh-token","refresh_token":"refresh-token","expires_in":3600}`))
|
|
case "/v1/responses":
|
|
atomic.AddInt32(&responseCalls, 1)
|
|
if got := r.Header.Get("Authorization"); got != "Bearer proxied-fresh-token" {
|
|
t.Fatalf("unexpected authorization header: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-via-proxy"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer target.Close()
|
|
|
|
var proxyCalls int32
|
|
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddInt32(&proxyCalls, 1)
|
|
targetURL := r.URL.String()
|
|
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
|
targetURL = target.URL + r.URL.Path
|
|
if rawQuery := strings.TrimSpace(r.URL.RawQuery); rawQuery != "" {
|
|
targetURL += "?" + rawQuery
|
|
}
|
|
}
|
|
req, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
|
|
if err != nil {
|
|
t.Fatalf("create proxied request failed: %v", err)
|
|
}
|
|
req.Header = r.Header.Clone()
|
|
resp, err := http.DefaultTransport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("proxy round trip failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
for key, values := range resp.Header {
|
|
for _, value := range values {
|
|
w.Header().Add(key, value)
|
|
}
|
|
}
|
|
w.WriteHeader(resp.StatusCode)
|
|
_, _ = io.Copy(w, resp.Body)
|
|
}))
|
|
defer proxyServer.Close()
|
|
|
|
credFile := filepath.Join(t.TempDir(), "proxied.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "expired-token",
|
|
RefreshToken: "refresh-token",
|
|
Expire: time.Now().Add(-time.Hour).Format(time.RFC3339),
|
|
NetworkProxy: proxyServer.URL,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write credential file failed: %v", err)
|
|
}
|
|
|
|
pc := config.ProviderConfig{
|
|
APIBase: target.URL + "/v1",
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
ClientID: "test-client",
|
|
TokenURL: target.URL + "/oauth/token",
|
|
AuthURL: target.URL + "/oauth/authorize",
|
|
},
|
|
}
|
|
oauth, err := newOAuthManager(pc, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer oauth.bgCancel()
|
|
|
|
provider := NewHTTPProvider("test-oauth-proxy", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
|
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
|
if err != nil {
|
|
t.Fatalf("chat failed: %v", err)
|
|
}
|
|
if resp.Content != "ok-via-proxy" {
|
|
t.Fatalf("unexpected response content: %q", resp.Content)
|
|
}
|
|
if atomic.LoadInt32(&refreshCalls) != 1 {
|
|
t.Fatalf("expected one refresh call, got %d", refreshCalls)
|
|
}
|
|
if atomic.LoadInt32(&responseCalls) != 1 {
|
|
t.Fatalf("expected one response call, got %d", responseCalls)
|
|
}
|
|
if got := atomic.LoadInt32(&proxyCalls); got < 2 {
|
|
t.Fatalf("expected proxy to receive refresh and response requests, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestOAuthImportGeminiNestedTokenRefreshesWithTokenMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("parse form failed: %v", err)
|
|
}
|
|
if got := r.Form.Get("client_secret"); got != "secret-1" {
|
|
t.Fatalf("unexpected client_secret: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"gemini-fresh","refresh_token":"gemini-refresh","expires_in":3600}`))
|
|
case "/userinfo":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"email":"gemini@example.com"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "gemini",
|
|
TokenURL: server.URL + "/oauth/token",
|
|
ClientID: "client-1",
|
|
ClientSecret: "secret-1",
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
manager.cfg.UserInfoURL = server.URL + "/userinfo"
|
|
|
|
raw := []byte(`{
|
|
"type": "gemini",
|
|
"email": "gemini@example.com",
|
|
"project_id": "demo-project",
|
|
"token": {
|
|
"refresh_token": "gemini-refresh",
|
|
"client_id": "client-1",
|
|
"client_secret": "secret-1",
|
|
"token_uri": "` + server.URL + `/oauth/token"
|
|
}
|
|
}`)
|
|
session, err := parseImportedOAuthSession("gemini", "gemini.json", raw)
|
|
if err != nil {
|
|
t.Fatalf("parse imported oauth session failed: %v", err)
|
|
}
|
|
refreshed, err := manager.refreshImportedSession(context.Background(), session)
|
|
if err != nil {
|
|
t.Fatalf("refresh imported session failed: %v", err)
|
|
}
|
|
if refreshed.AccessToken != "gemini-fresh" {
|
|
t.Fatalf("unexpected access token: %s", refreshed.AccessToken)
|
|
}
|
|
if refreshed.Email != "gemini@example.com" {
|
|
t.Fatalf("unexpected email: %s", refreshed.Email)
|
|
}
|
|
if refreshed.ProjectID != "demo-project" {
|
|
t.Fatalf("unexpected project id: %s", refreshed.ProjectID)
|
|
}
|
|
}
|
|
|
|
func TestAntigravityEnrichSessionAddsEmailAndProjectID(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/userinfo":
|
|
if got := r.Header.Get("Authorization"); got != "Bearer antigravity-token" {
|
|
t.Fatalf("unexpected userinfo authorization: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"email":"antigravity@example.com"}`))
|
|
case "/v1internal:loadCodeAssist":
|
|
if got := r.Header.Get("Authorization"); got != "Bearer antigravity-token" {
|
|
t.Fatalf("unexpected project authorization: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"cloudaicompanionProject":"project-123"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
originalAPI := defaultAntigravityAPIEndpoint
|
|
originalUserInfo := defaultAntigravityUserInfoURL
|
|
defaultAntigravityAPIEndpoint = server.URL
|
|
defaultAntigravityUserInfoURL = server.URL + "/userinfo"
|
|
t.Cleanup(func() {
|
|
defaultAntigravityAPIEndpoint = originalAPI
|
|
defaultAntigravityUserInfoURL = originalUserInfo
|
|
})
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "antigravity",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
AuthURL: server.URL + "/oauth/authorize",
|
|
RedirectURL: "http://localhost:51121/oauth-callback",
|
|
RefreshLeadSec: 300,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer manager.bgCancel()
|
|
|
|
session, err := manager.enrichSession(context.Background(), &oauthSession{
|
|
Provider: "antigravity",
|
|
AccessToken: "antigravity-token",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("enrich session failed: %v", err)
|
|
}
|
|
if session.Email != "antigravity@example.com" {
|
|
t.Fatalf("unexpected email: %#v", session)
|
|
}
|
|
if session.ProjectID != "project-123" {
|
|
t.Fatalf("expected project id enrichment, got %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestQwenDeviceFlowRequiresAccountLabelWhenEmailMissing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/device":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"device_code":"dev-1","user_code":"user-1","verification_uri_complete":"https://chat.qwen.ai/device?code=user-1","interval":1,"expires_in":60}`))
|
|
case "/token":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"qwen-token","refresh_token":"refresh-token","expires_in":3600}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "qwen",
|
|
AuthURL: server.URL + "/device",
|
|
TokenURL: server.URL + "/token",
|
|
CredentialFile: filepath.Join(t.TempDir(), "qwen.json"),
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer manager.bgCancel()
|
|
|
|
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
|
|
if err != nil {
|
|
t.Fatalf("start device flow failed: %v", err)
|
|
}
|
|
_, _, err = manager.completeDeviceFlow(context.Background(), "", flow, OAuthLoginOptions{})
|
|
if err == nil || !strings.Contains(err.Error(), "account_label") {
|
|
t.Fatalf("expected qwen account_label error, got %v", err)
|
|
}
|
|
|
|
session, _, err := manager.completeDeviceFlow(context.Background(), "", flow, OAuthLoginOptions{AccountLabel: "qwen-alias"})
|
|
if err != nil {
|
|
t.Fatalf("complete device flow with label failed: %v", err)
|
|
}
|
|
if session.Email != "qwen-alias" {
|
|
t.Fatalf("expected qwen alias persisted as email label, got %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestImportSessionEnrichesAntigravityMetadata(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/userinfo":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"email":"imported@example.com"}`))
|
|
case "/v1internal:loadCodeAssist":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"cloudaicompanionProject":"import-project"}`))
|
|
case "/v1/models":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"data":[{"id":"g-model"}]}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
originalAPI := defaultAntigravityAPIEndpoint
|
|
originalUserInfo := defaultAntigravityUserInfoURL
|
|
defaultAntigravityAPIEndpoint = server.URL
|
|
defaultAntigravityUserInfoURL = server.URL + "/userinfo"
|
|
t.Cleanup(func() {
|
|
defaultAntigravityAPIEndpoint = originalAPI
|
|
defaultAntigravityUserInfoURL = originalUserInfo
|
|
})
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
APIBase: server.URL + "/v1",
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "antigravity",
|
|
CredentialFile: filepath.Join(t.TempDir(), "antigravity.json"),
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer manager.bgCancel()
|
|
|
|
raw := []byte(`{"access_token":"import-token","refresh_token":"refresh-token","expired":"2030-01-01T00:00:00Z"}`)
|
|
session, models, err := manager.importSession(context.Background(), server.URL+"/v1", "auth.json", raw, OAuthLoginOptions{})
|
|
if err != nil {
|
|
t.Fatalf("import session failed: %v", err)
|
|
}
|
|
if session.Email != "imported@example.com" || session.ProjectID != "import-project" {
|
|
t.Fatalf("expected antigravity enrichment, got %#v", session)
|
|
}
|
|
if len(models) != 1 || models[0] != "g-model" {
|
|
t.Fatalf("unexpected models: %#v", models)
|
|
}
|
|
}
|
|
|
|
func TestPersistSessionAddsGeminiTokenMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
credFile := filepath.Join(t.TempDir(), "gemini.json")
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "gemini",
|
|
CredentialFile: credFile,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
defer manager.bgCancel()
|
|
|
|
manager.mu.Lock()
|
|
err = manager.persistSessionLocked(&oauthSession{
|
|
Provider: "gemini",
|
|
AccessToken: "gem-access",
|
|
RefreshToken: "gem-refresh",
|
|
Expire: "2030-01-01T00:00:00Z",
|
|
Email: "gem@example.com",
|
|
FilePath: credFile,
|
|
})
|
|
manager.mu.Unlock()
|
|
if err != nil {
|
|
t.Fatalf("persist session failed: %v", err)
|
|
}
|
|
raw, err := os.ReadFile(credFile)
|
|
if err != nil {
|
|
t.Fatalf("read credential file failed: %v", err)
|
|
}
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(raw, &payload); err != nil {
|
|
t.Fatalf("unmarshal credential file failed: %v", err)
|
|
}
|
|
tokenMap, _ := payload["token"].(map[string]any)
|
|
if tokenMap == nil {
|
|
t.Fatalf("expected token map in persisted gemini session, got %s", string(raw))
|
|
}
|
|
if tokenMap["token_uri"] != defaultGeminiTokenURL {
|
|
t.Fatalf("unexpected gemini token metadata: %#v", tokenMap)
|
|
}
|
|
if defaultGeminiClientID != "" && tokenMap["client_id"] != defaultGeminiClientID {
|
|
t.Fatalf("unexpected gemini client_id metadata: %#v", tokenMap)
|
|
}
|
|
if defaultGeminiClientSecret != "" && tokenMap["client_secret"] != defaultGeminiClientSecret {
|
|
t.Fatalf("unexpected gemini client_secret metadata: %#v", tokenMap)
|
|
}
|
|
}
|
|
|
|
func TestParseImportedOAuthSessionSupportsAliasProjectAndDeviceVariants(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
session, err := parseImportedOAuthSession("qwen", "auth.json", []byte(`{
|
|
"refresh_token": "rt-1",
|
|
"token": {
|
|
"access_token": "at-1",
|
|
"account_label": "alias-qwen",
|
|
"projectId": "proj-1",
|
|
"deviceId": "dev-1",
|
|
"scopes": "openid profile"
|
|
}
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("parse imported session failed: %v", err)
|
|
}
|
|
if session.Email != "alias-qwen" || session.ProjectID != "proj-1" || session.DeviceID != "dev-1" || session.Scope != "openid profile" {
|
|
t.Fatalf("unexpected parsed session: %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestOAuthDeviceFlowQwenManualCompletes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/device":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"device_code":"dev-1","user_code":"user-1","verification_uri_complete":"https://chat.qwen.ai/device?code=user-1","interval":1,"expires_in":60}`))
|
|
case "/token":
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("parse form failed: %v", err)
|
|
}
|
|
if got := r.Form.Get("device_code"); got != "dev-1" {
|
|
t.Fatalf("unexpected device_code: %s", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"qwen-at","refresh_token":"qwen-rt","expires_in":3600}`))
|
|
case "/models":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"data":[{"id":"qwen-test"}]}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
dir := t.TempDir()
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
APIBase: server.URL,
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "qwen",
|
|
CredentialFile: filepath.Join(dir, "qwen.json"),
|
|
AuthURL: server.URL + "/device",
|
|
TokenURL: server.URL + "/token",
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
|
|
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
|
|
if err != nil {
|
|
t.Fatalf("start device flow failed: %v", err)
|
|
}
|
|
if flow.Mode != oauthFlowDevice {
|
|
t.Fatalf("unexpected flow mode: %s", flow.Mode)
|
|
}
|
|
session, models, err := manager.completeDeviceFlow(context.Background(), server.URL, flow, OAuthLoginOptions{AccountLabel: "qwen-label"})
|
|
if err != nil {
|
|
t.Fatalf("complete device flow failed: %v", err)
|
|
}
|
|
if session.AccessToken != "qwen-at" {
|
|
t.Fatalf("unexpected access token: %s", session.AccessToken)
|
|
}
|
|
if session.FilePath == "" {
|
|
t.Fatalf("expected credential file path")
|
|
}
|
|
if session.Email != "qwen-label" {
|
|
t.Fatalf("expected qwen label, got %#v", session)
|
|
}
|
|
if len(models) != 1 || models[0] != "qwen-test" {
|
|
t.Fatalf("unexpected models: %#v", models)
|
|
}
|
|
}
|
|
|
|
func TestHTTPProviderHybridFallsBackFromAPIKeyToOAuth(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "oauth.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
Email: "oauth@example.com",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
|
|
var apiKeyCalls int32
|
|
var oauthCalls int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/responses" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
switch r.Header.Get("Authorization") {
|
|
case "Bearer api-key-1":
|
|
atomic.AddInt32(&apiKeyCalls, 1)
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"quota exceeded"}}`))
|
|
case "Bearer oauth-token":
|
|
atomic.AddInt32(&oauthCalls, 1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-oauth"}`))
|
|
default:
|
|
t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization"))
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
pc := config.ProviderConfig{
|
|
APIBase: server.URL + "/v1",
|
|
APIKey: "api-key-1",
|
|
Auth: "hybrid",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
TokenURL: server.URL + "/oauth/token",
|
|
AuthURL: server.URL + "/oauth/authorize",
|
|
},
|
|
}
|
|
oauth, err := newOAuthManager(pc, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
provider := NewHTTPProvider("test-hybrid-fallback", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
|
|
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
|
if err != nil {
|
|
t.Fatalf("chat failed: %v", err)
|
|
}
|
|
if resp.Content != "ok-from-oauth" {
|
|
t.Fatalf("unexpected response content: %q", resp.Content)
|
|
}
|
|
if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 1 {
|
|
t.Fatalf("expected one api-key and one oauth attempt, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
|
|
}
|
|
}
|
|
|
|
func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "oauth.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
Email: "oauth@example.com",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
|
|
var apiKeyCalls int32
|
|
var oauthCalls int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/responses" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
switch r.Header.Get("Authorization") {
|
|
case "Bearer api-key-1":
|
|
atomic.AddInt32(&apiKeyCalls, 1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-api"}`))
|
|
case "Bearer oauth-token":
|
|
atomic.AddInt32(&oauthCalls, 1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-oauth"}`))
|
|
default:
|
|
t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization"))
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
pc := config.ProviderConfig{
|
|
APIBase: server.URL + "/v1",
|
|
APIKey: "api-key-1",
|
|
Auth: "hybrid",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
TokenURL: server.URL + "/oauth/token",
|
|
AuthURL: server.URL + "/oauth/authorize",
|
|
},
|
|
}
|
|
oauth, err := newOAuthManager(pc, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
provider := NewHTTPProvider("test-hybrid-oauth-first", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
|
|
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
|
if err != nil {
|
|
t.Fatalf("chat failed: %v", err)
|
|
}
|
|
if resp.Content != "ok-from-api" {
|
|
t.Fatalf("unexpected response content: %q", resp.Content)
|
|
}
|
|
if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 0 {
|
|
t.Fatalf("expected api key first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
|
|
}
|
|
}
|
|
|
|
func TestOAuthManagerCooldownSkipsExhaustedAccount(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
firstFile := filepath.Join(dir, "first.json")
|
|
secondFile := filepath.Join(dir, "second.json")
|
|
writeSession := func(path, token string) {
|
|
t.Helper()
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: token,
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(path, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
}
|
|
writeSession(firstFile, "token-a")
|
|
writeSession(secondFile, "token-b")
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: firstFile,
|
|
CredentialFiles: []string{firstFile, secondFile},
|
|
CooldownSec: 3600,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
|
|
attempts, err := manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("prepare attempts failed: %v", err)
|
|
}
|
|
if len(attempts) != 2 {
|
|
t.Fatalf("expected 2 attempts, got %d", len(attempts))
|
|
}
|
|
manager.markExhausted(attempts[0].Session, oauthFailureRateLimit)
|
|
nextAttempts, err := manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("prepare attempts after cooldown failed: %v", err)
|
|
}
|
|
if len(nextAttempts) != 1 {
|
|
t.Fatalf("expected 1 available attempt after cooldown, got %d", len(nextAttempts))
|
|
}
|
|
if nextAttempts[0].Token != "token-b" {
|
|
t.Fatalf("unexpected token after cooldown: %s", nextAttempts[0].Token)
|
|
}
|
|
accounts, err := (&OAuthLoginManager{manager: manager}).ListAccounts()
|
|
if err != nil {
|
|
t.Fatalf("list accounts failed: %v", err)
|
|
}
|
|
foundCooldown := false
|
|
for _, account := range accounts {
|
|
if account.CredentialFile == firstFile && account.CooldownUntil != "" {
|
|
foundCooldown = true
|
|
}
|
|
}
|
|
if !foundCooldown {
|
|
t.Fatalf("expected cooldown metadata to be exposed in account list")
|
|
}
|
|
}
|
|
|
|
func TestOAuthManagerDisableSessionSkipsAccount(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
firstFile := filepath.Join(dir, "first.json")
|
|
secondFile := filepath.Join(dir, "second.json")
|
|
writeSession := func(path, token string) {
|
|
t.Helper()
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: token,
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(path, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
}
|
|
writeSession(firstFile, "token-a")
|
|
writeSession(secondFile, "token-b")
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: firstFile,
|
|
CredentialFiles: []string{firstFile, secondFile},
|
|
CooldownSec: 3600,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
|
|
attempts, err := manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("prepare attempts failed: %v", err)
|
|
}
|
|
if len(attempts) != 2 {
|
|
t.Fatalf("expected 2 attempts, got %d", len(attempts))
|
|
}
|
|
manager.disableSession(attempts[0].Session, oauthFailureRevoked, "oauth token revoked")
|
|
|
|
nextAttempts, err := manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("prepare attempts after disable failed: %v", err)
|
|
}
|
|
if len(nextAttempts) != 1 {
|
|
t.Fatalf("expected 1 available attempt after disable, got %d", len(nextAttempts))
|
|
}
|
|
if nextAttempts[0].Token != "token-b" {
|
|
t.Fatalf("unexpected token after disable: %s", nextAttempts[0].Token)
|
|
}
|
|
|
|
raw, err := os.ReadFile(firstFile)
|
|
if err != nil {
|
|
t.Fatalf("read disabled session failed: %v", err)
|
|
}
|
|
var saved oauthSession
|
|
if err := json.Unmarshal(raw, &saved); err != nil {
|
|
t.Fatalf("unmarshal disabled session failed: %v", err)
|
|
}
|
|
if !saved.Disabled || saved.DisableReason != string(oauthFailureRevoked) {
|
|
t.Fatalf("expected disabled session to persist, got %#v", saved)
|
|
}
|
|
}
|
|
|
|
func TestOAuthLoginManagerListAccountsSkipsInvalidCredentialFiles(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
invalidFile := filepath.Join(dir, "invalid.json")
|
|
validFile := filepath.Join(dir, "valid.json")
|
|
if err := os.WriteFile(invalidFile, []byte(`{"not":"a valid oauth session"}`), 0o600); err != nil {
|
|
t.Fatalf("write invalid credential failed: %v", err)
|
|
}
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
RefreshToken: "refresh-token",
|
|
Email: "user@example.com",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal valid session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(validFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write valid credential failed: %v", err)
|
|
}
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: invalidFile,
|
|
CredentialFiles: []string{invalidFile, validFile},
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
|
|
accounts, err := (&OAuthLoginManager{manager: manager}).ListAccounts()
|
|
if err != nil {
|
|
t.Fatalf("list accounts failed: %v", err)
|
|
}
|
|
if len(accounts) != 1 {
|
|
t.Fatalf("expected 1 valid account, got %d (%#v)", len(accounts), accounts)
|
|
}
|
|
if accounts[0].Email != "user@example.com" {
|
|
t.Fatalf("unexpected account: %#v", accounts[0])
|
|
}
|
|
}
|
|
|
|
func TestOAuthManagerPrefersHealthierAccount(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
firstFile := filepath.Join(dir, "first.json")
|
|
secondFile := filepath.Join(dir, "second.json")
|
|
writeSession := func(path, token string) {
|
|
t.Helper()
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: token,
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(path, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
}
|
|
writeSession(firstFile, "token-a")
|
|
writeSession(secondFile, "token-b")
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: firstFile,
|
|
CredentialFiles: []string{firstFile, secondFile},
|
|
CooldownSec: 60,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
|
|
attempts, err := manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("prepare attempts failed: %v", err)
|
|
}
|
|
manager.markExhausted(attempts[0].Session, oauthFailureQuota)
|
|
delete(manager.cooldowns, attempts[0].Session.FilePath)
|
|
attempts, err = manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("prepare attempts after health drop failed: %v", err)
|
|
}
|
|
if len(attempts) != 2 {
|
|
t.Fatalf("expected 2 attempts, got %d", len(attempts))
|
|
}
|
|
if attempts[0].Token != "token-b" {
|
|
t.Fatalf("expected healthier token-b first, got %s", attempts[0].Token)
|
|
}
|
|
}
|
|
|
|
func TestOAuthLoginManagerListAccountsIncludesCodexPlanMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "codex-plan.json")
|
|
idToken := buildTestJWT(map[string]any{
|
|
"email": "plan@example.com",
|
|
"https://api.openai.com/auth": map[string]any{
|
|
"chatgpt_account_id": "acct-plan",
|
|
"chatgpt_plan_type": "pro",
|
|
"chatgpt_subscription_active_start": "2026-03-01T00:00:00Z",
|
|
"chatgpt_subscription_active_until": "2026-04-01T00:00:00Z",
|
|
},
|
|
})
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "token-plan",
|
|
RefreshToken: "refresh-plan",
|
|
IDToken: idToken,
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
|
|
accounts, err := (&OAuthLoginManager{manager: manager}).ListAccounts()
|
|
if err != nil {
|
|
t.Fatalf("list accounts failed: %v", err)
|
|
}
|
|
if len(accounts) != 1 {
|
|
t.Fatalf("expected one account, got %#v", accounts)
|
|
}
|
|
account := accounts[0]
|
|
if account.PlanType != "pro" {
|
|
t.Fatalf("expected plan type to be extracted, got %#v", account)
|
|
}
|
|
if account.BalanceLabel != "PRO" || account.SubActiveUntil != "2026-04-01T00:00:00Z" {
|
|
t.Fatalf("expected subscription metadata in account info, got %#v", account)
|
|
}
|
|
}
|
|
|
|
func buildTestJWT(claims map[string]any) string {
|
|
header, _ := json.Marshal(map[string]any{"alg": "none", "typ": "JWT"})
|
|
payload, _ := json.Marshal(claims)
|
|
return base64.RawURLEncoding.EncodeToString(header) + "." + base64.RawURLEncoding.EncodeToString(payload) + "."
|
|
}
|
|
|
|
func TestClassifyOAuthFailureDifferentiatesReasons(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
reason, retry := classifyOAuthFailure(http.StatusTooManyRequests, []byte(`{"error":{"code":"insufficient_quota"}}`))
|
|
if !retry || reason != oauthFailureQuota {
|
|
t.Fatalf("expected quota classification, got retry=%v reason=%s", retry, reason)
|
|
}
|
|
reason, retry = classifyOAuthFailure(http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`))
|
|
if !retry || reason != oauthFailureRateLimit {
|
|
t.Fatalf("expected rate-limit classification, got retry=%v reason=%s", retry, reason)
|
|
}
|
|
reason, retry = classifyOAuthFailure(http.StatusForbidden, []byte(`{"error":"forbidden"}`))
|
|
if !retry || reason != oauthFailureForbidden {
|
|
t.Fatalf("expected forbidden classification, got retry=%v reason=%s", retry, reason)
|
|
}
|
|
}
|
|
|
|
func TestHTTPProviderHybridSkipsAPIKeyDuringCooldown(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "oauth.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["cooldown-provider"] = providerRuntimeState{
|
|
API: providerAPIRuntimeState{
|
|
TokenMasked: "api-***",
|
|
HealthScore: 50,
|
|
CooldownUntil: time.Now().Add(10 * time.Minute).Format(time.RFC3339),
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "hybrid",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
provider := NewHTTPProvider("cooldown-provider", "api-key-1", "https://example.com/v1", "gpt-test", false, "hybrid", 5*time.Second, manager)
|
|
attempts, err := provider.authAttempts(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("auth attempts failed: %v", err)
|
|
}
|
|
if len(attempts) != 1 || attempts[0].kind != "oauth" {
|
|
t.Fatalf("expected only oauth attempt during api cooldown, got %#v", attempts)
|
|
}
|
|
}
|
|
|
|
func TestClearProviderAPICooldownRestoresAPIKeyAttempt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["clear-api-provider"] = providerRuntimeState{
|
|
API: providerAPIRuntimeState{
|
|
TokenMasked: "api-***",
|
|
HealthScore: 50,
|
|
CooldownUntil: time.Now().Add(10 * time.Minute).Format(time.RFC3339),
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
provider := NewHTTPProvider("clear-api-provider", "api-key-1", "https://example.com/v1", "gpt-test", false, "bearer", 5*time.Second, nil)
|
|
if _, err := provider.authAttempts(context.Background()); err == nil {
|
|
t.Fatalf("expected api key attempt to be blocked by cooldown")
|
|
}
|
|
ClearProviderAPICooldown("clear-api-provider")
|
|
attempts, err := provider.authAttempts(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("expected api key attempt after clear cooldown, got %v", err)
|
|
}
|
|
if len(attempts) != 1 || attempts[0].kind != "api_key" {
|
|
t.Fatalf("unexpected attempts after clear cooldown: %#v", attempts)
|
|
}
|
|
}
|
|
|
|
func TestOAuthLoginManagerClearCooldown(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "oauth.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
manager, err := newOAuthManager(config.ProviderConfig{
|
|
Auth: "oauth",
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
CooldownSec: 3600,
|
|
},
|
|
}, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
attempts, err := manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil || len(attempts) != 1 {
|
|
t.Fatalf("prepare attempts failed: %v %#v", err, attempts)
|
|
}
|
|
manager.markExhausted(attempts[0].Session, oauthFailureRateLimit)
|
|
loginMgr := &OAuthLoginManager{manager: manager}
|
|
if err := loginMgr.ClearCooldown(credFile); err != nil {
|
|
t.Fatalf("clear cooldown failed: %v", err)
|
|
}
|
|
next, err := manager.prepareAttemptsLocked(context.Background())
|
|
if err != nil || len(next) != 1 {
|
|
t.Fatalf("expected session available after cooldown clear, got err=%v attempts=%#v", err, next)
|
|
}
|
|
}
|
|
|
|
func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "oauth.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
RefreshToken: "refresh-token",
|
|
Email: "user@example.com",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
|
|
name := "runtime-snapshot-provider"
|
|
pc := config.ProviderConfig{
|
|
APIKey: "api-key-123456",
|
|
APIBase: "https://example.com/v1",
|
|
Auth: "hybrid",
|
|
TimeoutSec: 5,
|
|
RuntimePersist: true,
|
|
RuntimeHistoryFile: filepath.Join(dir, "runtime.json"),
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
},
|
|
}
|
|
ConfigureProviderRuntime(name, pc)
|
|
manager, err := newOAuthManager(pc, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
provider := NewHTTPProvider(name, pc.APIKey, pc.APIBase, "gpt-test", false, pc.Auth, 5*time.Second, manager)
|
|
attempts, err := provider.authAttempts(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("auth attempts failed: %v", err)
|
|
}
|
|
if len(attempts) != 2 || attempts[0].kind != "api_key" || attempts[1].kind != "oauth" {
|
|
t.Fatalf("unexpected attempts order: %#v", attempts)
|
|
}
|
|
provider.markAttemptSuccess(attempts[1])
|
|
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{name: pc},
|
|
},
|
|
}
|
|
snapshot := GetProviderRuntimeSnapshot(cfg)
|
|
items, _ := snapshot["items"].([]map[string]interface{})
|
|
if len(items) == 0 {
|
|
t.Fatalf("expected provider runtime items")
|
|
}
|
|
item := items[0]
|
|
candidates, _ := item["candidate_order"].([]providerRuntimeCandidate)
|
|
if len(candidates) < 2 {
|
|
t.Fatalf("expected candidate order, got %#v", item["candidate_order"])
|
|
}
|
|
if candidates[0].Kind != "api_key" || candidates[1].Kind != "oauth" {
|
|
t.Fatalf("unexpected candidate order: %#v", candidates)
|
|
}
|
|
lastSuccess, _ := item["last_success"].(*providerRuntimeEvent)
|
|
if lastSuccess == nil || lastSuccess.Kind != "oauth" || lastSuccess.Target != "user@example.com" {
|
|
t.Fatalf("unexpected last success: %#v", item["last_success"])
|
|
}
|
|
if _, err := os.Stat(pc.RuntimeHistoryFile); err != nil {
|
|
t.Fatalf("expected runtime history file, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConfigureProviderRuntimeLoadsPersistedEvents(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
name := "persisted-runtime-provider"
|
|
historyFile := filepath.Join(dir, "runtime.json")
|
|
payload := providerRuntimeState{
|
|
RecentHits: []providerRuntimeEvent{{
|
|
When: time.Now().Add(-time.Minute).Format(time.RFC3339),
|
|
Kind: "oauth",
|
|
Target: "persisted@example.com",
|
|
Reason: "ok",
|
|
}},
|
|
LastSuccess: &providerRuntimeEvent{
|
|
When: time.Now().Add(-time.Minute).Format(time.RFC3339),
|
|
Kind: "oauth",
|
|
Target: "persisted@example.com",
|
|
},
|
|
}
|
|
raw, err := json.Marshal(payload)
|
|
if err != nil {
|
|
t.Fatalf("marshal runtime payload failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(historyFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write history file failed: %v", err)
|
|
}
|
|
|
|
ConfigureProviderRuntime(name, config.ProviderConfig{
|
|
APIBase: "https://example.com/v1",
|
|
Auth: "bearer",
|
|
RuntimePersist: true,
|
|
RuntimeHistoryFile: historyFile,
|
|
})
|
|
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {
|
|
APIBase: "https://example.com/v1",
|
|
Auth: "bearer",
|
|
RuntimePersist: true,
|
|
RuntimeHistoryFile: historyFile,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
snapshot := GetProviderRuntimeSnapshot(cfg)
|
|
items, _ := snapshot["items"].([]map[string]interface{})
|
|
if len(items) == 0 {
|
|
t.Fatalf("expected provider runtime item")
|
|
}
|
|
lastSuccess, _ := items[0]["last_success"].(*providerRuntimeEvent)
|
|
if lastSuccess == nil || lastSuccess.Target != "persisted@example.com" {
|
|
t.Fatalf("expected persisted last success, got %#v", items[0]["last_success"])
|
|
}
|
|
hits, _ := items[0]["recent_hits"].([]providerRuntimeEvent)
|
|
if len(hits) == 0 || hits[0].Target != "persisted@example.com" {
|
|
t.Fatalf("expected persisted recent hits, got %#v", items[0]["recent_hits"])
|
|
}
|
|
}
|
|
|
|
func TestClearProviderRuntimeHistoryRemovesPersistedFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
name := "clear-runtime-history-provider"
|
|
historyFile := filepath.Join(dir, "runtime.json")
|
|
ConfigureProviderRuntime(name, config.ProviderConfig{
|
|
APIBase: "https://example.com/v1",
|
|
Auth: "bearer",
|
|
RuntimePersist: true,
|
|
RuntimeHistoryFile: historyFile,
|
|
})
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
state := providerRuntimeRegistry.api[name]
|
|
state.RecentHits = []providerRuntimeEvent{{When: time.Now().Format(time.RFC3339), Kind: "api_key", Target: "api***"}}
|
|
state.LastSuccess = &providerRuntimeEvent{When: time.Now().Format(time.RFC3339), Kind: "api_key", Target: "api***"}
|
|
persistProviderRuntimeLocked(name, state)
|
|
providerRuntimeRegistry.api[name] = state
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
if _, err := os.Stat(historyFile); err != nil {
|
|
t.Fatalf("expected runtime history file, got %v", err)
|
|
}
|
|
ClearProviderRuntimeHistory(name)
|
|
if _, err := os.Stat(historyFile); !os.IsNotExist(err) {
|
|
t.Fatalf("expected runtime history file removed, got %v", err)
|
|
}
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
cleared := providerRuntimeRegistry.api[name]
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
if len(cleared.RecentHits) != 0 || len(cleared.RecentErrors) != 0 || cleared.LastSuccess != nil {
|
|
t.Fatalf("expected runtime history cleared, got %#v", cleared)
|
|
}
|
|
}
|
|
|
|
func TestUpdateCandidateOrderRecordsSchedulerChange(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "oauth.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
Email: "user@example.com",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
|
|
name := "candidate-change-provider"
|
|
pc := config.ProviderConfig{
|
|
APIKey: "api-key-123456",
|
|
APIBase: "https://example.com/v1",
|
|
Auth: "hybrid",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
},
|
|
}
|
|
manager, err := newOAuthManager(pc, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("new oauth manager failed: %v", err)
|
|
}
|
|
provider := NewHTTPProvider(name, pc.APIKey, pc.APIBase, "gpt-test", false, pc.Auth, 5*time.Second, manager)
|
|
attempts, err := provider.authAttempts(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("auth attempts failed: %v", err)
|
|
}
|
|
if len(attempts) != 2 {
|
|
t.Fatalf("unexpected attempts: %#v", attempts)
|
|
}
|
|
provider.markAPIKeyFailure(oauthFailureRateLimit)
|
|
attempts, err = provider.authAttempts(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("auth attempts after cooldown failed: %v", err)
|
|
}
|
|
if len(attempts) != 1 || attempts[0].kind != "oauth" {
|
|
t.Fatalf("unexpected attempts after cooldown: %#v", attempts)
|
|
}
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
state := providerRuntimeRegistry.api[name]
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
if len(state.RecentChanges) == 0 || state.RecentChanges[0].Reason != "candidate_order_changed" {
|
|
t.Fatalf("expected scheduler change event, got %#v", state.RecentChanges)
|
|
}
|
|
if !strings.Contains(state.RecentChanges[0].Detail, "top ") {
|
|
t.Fatalf("expected candidate order detail, got %#v", state.RecentChanges[0])
|
|
}
|
|
}
|
|
|
|
func TestGetProviderRuntimeViewFiltersEvents(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
name := "runtime-view-provider"
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api[name] = providerRuntimeState{
|
|
RecentHits: []providerRuntimeEvent{
|
|
{When: time.Now().Add(-30 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "ok"},
|
|
{When: time.Now().Add(-3 * time.Hour).Format(time.RFC3339), Kind: "api_key", Target: "api***", Reason: "ok"},
|
|
},
|
|
RecentErrors: []providerRuntimeEvent{
|
|
{When: time.Now().Add(-10 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "quota"},
|
|
},
|
|
RecentChanges: []providerRuntimeEvent{
|
|
{When: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), Kind: "scheduler", Target: name, Reason: "candidate_order_changed", Detail: "top api -> oauth"},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {APIBase: "https://example.com/v1", Auth: "hybrid", APIKey: "api-key"},
|
|
},
|
|
},
|
|
}
|
|
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{
|
|
Provider: name,
|
|
Window: 2 * time.Hour,
|
|
EventKind: "oauth",
|
|
Limit: 1,
|
|
})
|
|
items, _ := view["items"].([]map[string]interface{})
|
|
if len(items) != 1 {
|
|
t.Fatalf("expected one runtime item, got %#v", view)
|
|
}
|
|
hits, _ := items[0]["recent_hits"].([]providerRuntimeEvent)
|
|
if len(hits) != 1 || hits[0].Kind != "oauth" {
|
|
t.Fatalf("expected filtered oauth hits, got %#v", items[0]["recent_hits"])
|
|
}
|
|
errors, _ := items[0]["recent_errors"].([]providerRuntimeEvent)
|
|
if len(errors) != 1 || errors[0].Reason != "quota" {
|
|
t.Fatalf("expected filtered oauth errors, got %#v", items[0]["recent_errors"])
|
|
}
|
|
changes, _ := items[0]["recent_changes"].([]providerRuntimeEvent)
|
|
if len(changes) != 0 {
|
|
t.Fatalf("expected no scheduler changes when filtering kind=oauth, got %#v", items[0]["recent_changes"])
|
|
}
|
|
events, _ := items[0]["events"].([]providerRuntimeEvent)
|
|
if len(events) != 1 {
|
|
t.Fatalf("expected merged paged events, got %#v", items[0]["events"])
|
|
}
|
|
}
|
|
|
|
func TestGetProviderRuntimeViewCursorPagination(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
name := "runtime-cursor-provider"
|
|
now := time.Now()
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api[name] = providerRuntimeState{
|
|
RecentHits: []providerRuntimeEvent{
|
|
{When: now.Add(-1 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "a", Reason: "ok"},
|
|
{When: now.Add(-2 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "b", Reason: "ok"},
|
|
},
|
|
RecentErrors: []providerRuntimeEvent{
|
|
{When: now.Add(-3 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "c", Reason: "quota"},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
|
},
|
|
},
|
|
}
|
|
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 2, Cursor: 0})
|
|
items, _ := view["items"].([]map[string]interface{})
|
|
if len(items) != 1 {
|
|
t.Fatalf("expected one item, got %#v", view)
|
|
}
|
|
page1, _ := items[0]["events"].([]providerRuntimeEvent)
|
|
if len(page1) != 2 || items[0]["next_cursor"].(int) != 2 {
|
|
t.Fatalf("unexpected first page %#v", items[0])
|
|
}
|
|
view = GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 2, Cursor: 2})
|
|
items, _ = view["items"].([]map[string]interface{})
|
|
page2, _ := items[0]["events"].([]providerRuntimeEvent)
|
|
if len(page2) != 1 || items[0]["next_cursor"].(int) != 0 {
|
|
t.Fatalf("unexpected second page %#v", items[0])
|
|
}
|
|
}
|
|
|
|
func TestGetProviderRuntimeViewSortAscending(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
name := "runtime-sort-provider"
|
|
now := time.Now()
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api[name] = providerRuntimeState{
|
|
RecentHits: []providerRuntimeEvent{
|
|
{When: now.Add(-1 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "a", Reason: "ok"},
|
|
{When: now.Add(-3 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "b", Reason: "ok"},
|
|
},
|
|
RecentErrors: []providerRuntimeEvent{
|
|
{When: now.Add(-2 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "c", Reason: "quota"},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
|
},
|
|
},
|
|
}
|
|
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 10, Sort: "asc"})
|
|
items, _ := view["items"].([]map[string]interface{})
|
|
if len(items) != 1 {
|
|
t.Fatalf("expected one item, got %#v", view)
|
|
}
|
|
events, _ := items[0]["events"].([]providerRuntimeEvent)
|
|
if len(events) != 3 {
|
|
t.Fatalf("expected three events, got %#v", items[0]["events"])
|
|
}
|
|
if events[0].Target != "b" || events[1].Target != "c" || events[2].Target != "a" {
|
|
t.Fatalf("expected ascending order oldest->newest, got %#v", events)
|
|
}
|
|
}
|
|
|
|
func TestGetProviderRuntimeViewFiltersByHealthAndCooldown(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
name := "runtime-health-provider"
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api[name] = providerRuntimeState{
|
|
API: providerAPIRuntimeState{
|
|
HealthScore: 20,
|
|
CooldownUntil: time.Now().Add(-5 * time.Minute).Format(time.RFC3339),
|
|
},
|
|
CandidateOrder: []providerRuntimeCandidate{{
|
|
Kind: "api_key",
|
|
Target: "api***",
|
|
HealthScore: 20,
|
|
CooldownUntil: time.Now().Add(-5 * time.Minute).Format(time.RFC3339),
|
|
}},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
|
|
},
|
|
},
|
|
}
|
|
view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{
|
|
Provider: name,
|
|
HealthBelow: 30,
|
|
CooldownBefore: time.Now(),
|
|
})
|
|
items, _ := view["items"].([]map[string]interface{})
|
|
if len(items) != 1 {
|
|
t.Fatalf("expected one filtered runtime item, got %#v", view)
|
|
}
|
|
}
|
|
|
|
func TestGetProviderRuntimeSummaryFlagsUnhealthyProviders(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
name := "runtime-summary-provider"
|
|
lastSuccessAt := time.Now().Add(-2 * time.Hour).Format(time.RFC3339)
|
|
topChangedAt := time.Now().Add(-3 * time.Minute).Format(time.RFC3339)
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api[name] = providerRuntimeState{
|
|
API: providerAPIRuntimeState{
|
|
HealthScore: 25,
|
|
CooldownUntil: time.Now().Add(15 * time.Minute).Format(time.RFC3339),
|
|
},
|
|
RecentErrors: []providerRuntimeEvent{
|
|
{When: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), Kind: "api_key", Target: "api***", Reason: "quota"},
|
|
},
|
|
RecentChanges: []providerRuntimeEvent{
|
|
{When: topChangedAt, Kind: "scheduler", Target: name, Reason: "candidate_order_changed", Detail: "top oauth -> api"},
|
|
},
|
|
LastSuccess: &providerRuntimeEvent{
|
|
When: lastSuccessAt,
|
|
Kind: "api_key",
|
|
Target: "api***",
|
|
Reason: "ok",
|
|
},
|
|
CandidateOrder: []providerRuntimeCandidate{{
|
|
Kind: "api_key",
|
|
Target: "api***",
|
|
Available: false,
|
|
Status: "cooldown",
|
|
HealthScore: 25,
|
|
CooldownUntil: time.Now().Add(15 * time.Minute).Format(time.RFC3339),
|
|
}},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
|
|
},
|
|
},
|
|
}
|
|
summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{HealthBelow: 30, Window: time.Hour})
|
|
if summary.TotalProviders != 1 || summary.InCooldown != 1 || summary.LowHealth != 1 || summary.RecentErrors != 1 {
|
|
t.Fatalf("unexpected summary counts: %#v", summary)
|
|
}
|
|
if summary.Critical != 1 || summary.Degraded != 0 || summary.Healthy != 0 {
|
|
t.Fatalf("unexpected status counts: %#v", summary)
|
|
}
|
|
if len(summary.Providers) != 1 || summary.Providers[0].TopCandidate == nil || summary.Providers[0].TopCandidate.Kind != "api_key" {
|
|
t.Fatalf("unexpected provider summary items: %#v", summary.Providers)
|
|
}
|
|
if summary.Providers[0].Status != "critical" {
|
|
t.Fatalf("expected critical status, got %#v", summary.Providers[0])
|
|
}
|
|
if summary.Providers[0].LastError == nil || summary.Providers[0].LastErrorReason != "quota" || summary.Providers[0].LastErrorAt == "" {
|
|
t.Fatalf("expected last error details, got %#v", summary.Providers[0])
|
|
}
|
|
if summary.Providers[0].LastSuccessAt != lastSuccessAt || summary.Providers[0].TopCandidateChangedAt != topChangedAt {
|
|
t.Fatalf("expected last success and top candidate timestamps, got %#v", summary.Providers[0])
|
|
}
|
|
if summary.Providers[0].StaleForSec < 7100 || summary.Providers[0].StaleForSec > 7300 {
|
|
t.Fatalf("expected stale_for_sec around 2h, got %#v", summary.Providers[0].StaleForSec)
|
|
}
|
|
}
|
|
|
|
func TestGetProviderRuntimeSummaryMarksRecentErrorsAsDegraded(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
name := "runtime-summary-degraded-provider"
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api[name] = providerRuntimeState{
|
|
RecentErrors: []providerRuntimeEvent{
|
|
{When: time.Now().Add(-10 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "rate_limit"},
|
|
},
|
|
CandidateOrder: []providerRuntimeCandidate{{
|
|
Kind: "oauth",
|
|
Target: "user@example.com",
|
|
Available: true,
|
|
Status: "ready",
|
|
HealthScore: 90,
|
|
}},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
|
},
|
|
},
|
|
}
|
|
|
|
summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{HealthBelow: 30, Window: time.Hour})
|
|
if summary.TotalProviders != 1 || summary.Degraded != 1 || summary.Critical != 0 || summary.Healthy != 0 {
|
|
t.Fatalf("unexpected summary counts: %#v", summary)
|
|
}
|
|
if len(summary.Providers) != 1 || summary.Providers[0].Status != "degraded" {
|
|
t.Fatalf("expected degraded provider item, got %#v", summary.Providers)
|
|
}
|
|
if summary.Providers[0].LastErrorReason != "rate_limit" {
|
|
t.Fatalf("expected last error reason, got %#v", summary.Providers[0])
|
|
}
|
|
if summary.Providers[0].StaleForSec != -1 {
|
|
t.Fatalf("expected stale_for_sec=-1 without success event, got %#v", summary.Providers[0])
|
|
}
|
|
}
|
|
|
|
func TestGetProviderRuntimeSummaryIncludesOAuthAccountMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "qwen.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "qwen",
|
|
AccessToken: "qwen-token",
|
|
RefreshToken: "refresh-token",
|
|
Email: "qwen-label",
|
|
ProjectID: "proj-9",
|
|
DeviceID: "device-9",
|
|
ResourceURL: "https://chat.qwen.ai/api",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
FilePath: credFile,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
"qwen-summary": {
|
|
APIBase: "https://example.com/v1",
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "qwen",
|
|
CredentialFile: credFile,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{Provider: "qwen-summary", HealthBelow: 50})
|
|
if len(summary.Providers) != 1 {
|
|
t.Fatalf("expected one provider, got %#v", summary)
|
|
}
|
|
if len(summary.Providers[0].OAuthAccounts) != 1 {
|
|
t.Fatalf("expected oauth account metadata, got %#v", summary.Providers[0])
|
|
}
|
|
account := summary.Providers[0].OAuthAccounts[0]
|
|
if account.AccountLabel != "qwen-label" || account.ProjectID != "proj-9" || account.DeviceID != "device-9" || account.ResourceURL != "https://chat.qwen.ai/api" {
|
|
t.Fatalf("unexpected oauth account metadata: %#v", account)
|
|
}
|
|
}
|
|
|
|
func TestRefreshProviderRuntimeNowSupportsOnlyExpiring(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var refreshCalls int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/oauth/token" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
atomic.AddInt32(&refreshCalls, 1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"refreshed-token","refresh_token":"refresh-token","expires_in":3600}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
credFile := filepath.Join(t.TempDir(), "codex.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "old-token",
|
|
RefreshToken: "refresh-token",
|
|
Expire: time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write credential file failed: %v", err)
|
|
}
|
|
|
|
name := "runtime-refresh-provider"
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {
|
|
APIBase: server.URL + "/v1",
|
|
Auth: "oauth",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
ClientID: "test-client",
|
|
TokenURL: server.URL + "/oauth/token",
|
|
AuthURL: server.URL + "/oauth/authorize",
|
|
RefreshLeadSec: 1800,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
result, err := RefreshProviderRuntimeNow(cfg, name, true)
|
|
if err != nil {
|
|
t.Fatalf("refresh only expiring failed: %v", err)
|
|
}
|
|
if result == nil || result.Refreshed != 0 || result.Skipped != 1 {
|
|
t.Fatalf("expected skip for non-expiring session, got %#v", result)
|
|
}
|
|
if atomic.LoadInt32(&refreshCalls) != 0 {
|
|
t.Fatalf("expected no refresh calls for only-expiring path, got %d", refreshCalls)
|
|
}
|
|
|
|
result, err = RefreshProviderRuntimeNow(cfg, name, false)
|
|
if err != nil {
|
|
t.Fatalf("refresh all failed: %v", err)
|
|
}
|
|
if result == nil || result.Refreshed != 1 {
|
|
t.Fatalf("expected forced refresh, got %#v", result)
|
|
}
|
|
if atomic.LoadInt32(&refreshCalls) != 1 {
|
|
t.Fatalf("expected one refresh call, got %d", refreshCalls)
|
|
}
|
|
}
|
|
|
|
func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
credFile := filepath.Join(dir, "oauth.json")
|
|
raw, err := json.Marshal(oauthSession{
|
|
Provider: "codex",
|
|
AccessToken: "oauth-token",
|
|
Email: "rerank@example.com",
|
|
Expire: time.Now().Add(time.Hour).Format(time.RFC3339),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal session failed: %v", err)
|
|
}
|
|
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
|
t.Fatalf("write session failed: %v", err)
|
|
}
|
|
name := "rerank-runtime-provider"
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
name: {
|
|
APIKey: "api-key",
|
|
APIBase: "https://example.com/v1",
|
|
Auth: "hybrid",
|
|
TimeoutSec: 5,
|
|
OAuth: config.ProviderOAuthConfig{
|
|
Provider: "codex",
|
|
CredentialFile: credFile,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
order, err := RerankProviderRuntime(cfg, name)
|
|
if err != nil {
|
|
t.Fatalf("rerank provider runtime failed: %v", err)
|
|
}
|
|
if len(order) == 0 || order[0].Kind != "api_key" {
|
|
t.Fatalf("expected api-key-first rerank result, got %#v", order)
|
|
}
|
|
snapshot := GetProviderRuntimeSnapshot(cfg)
|
|
items, _ := snapshot["items"].([]map[string]interface{})
|
|
if len(items) != 1 {
|
|
t.Fatalf("expected one runtime item, got %#v", snapshot)
|
|
}
|
|
snapshotOrder, _ := items[0]["candidate_order"].([]providerRuntimeCandidate)
|
|
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "api_key" {
|
|
t.Fatalf("expected api-key-first candidate order, got %#v", items[0]["candidate_order"])
|
|
}
|
|
}
|