mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-05-17 23:57:30 +08:00
release: v0.2.0
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -386,6 +387,118 @@ func TestResolveOAuthConfigAppliesProviderRefreshLeadDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var refreshCalls int32
|
||||
var responseCalls int32
|
||||
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/oauth/token":
|
||||
atomic.AddInt32(&refreshCalls, 1)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("parse token form failed: %v", err)
|
||||
}
|
||||
if got := r.Form.Get("grant_type"); got != "refresh_token" {
|
||||
t.Fatalf("unexpected grant_type: %s", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"access_token":"proxied-fresh-token","refresh_token":"refresh-token","expires_in":3600}`))
|
||||
case "/v1/responses":
|
||||
atomic.AddInt32(&responseCalls, 1)
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer proxied-fresh-token" {
|
||||
t.Fatalf("unexpected authorization header: %s", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-via-proxy"}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer target.Close()
|
||||
|
||||
var proxyCalls int32
|
||||
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&proxyCalls, 1)
|
||||
targetURL := r.URL.String()
|
||||
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
||||
targetURL = target.URL + r.URL.Path
|
||||
if rawQuery := strings.TrimSpace(r.URL.RawQuery); rawQuery != "" {
|
||||
targetURL += "?" + rawQuery
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("create proxied request failed: %v", err)
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
resp, err := http.DefaultTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy round trip failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}))
|
||||
defer proxyServer.Close()
|
||||
|
||||
credFile := filepath.Join(t.TempDir(), "proxied.json")
|
||||
raw, err := json.Marshal(oauthSession{
|
||||
Provider: "codex",
|
||||
AccessToken: "expired-token",
|
||||
RefreshToken: "refresh-token",
|
||||
Expire: time.Now().Add(-time.Hour).Format(time.RFC3339),
|
||||
NetworkProxy: proxyServer.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal session failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
||||
t.Fatalf("write credential file failed: %v", err)
|
||||
}
|
||||
|
||||
pc := config.ProviderConfig{
|
||||
APIBase: target.URL + "/v1",
|
||||
Auth: "oauth",
|
||||
TimeoutSec: 5,
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
ClientID: "test-client",
|
||||
TokenURL: target.URL + "/oauth/token",
|
||||
AuthURL: target.URL + "/oauth/authorize",
|
||||
},
|
||||
}
|
||||
oauth, err := newOAuthManager(pc, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
defer oauth.bgCancel()
|
||||
|
||||
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
if resp.Content != "ok-via-proxy" {
|
||||
t.Fatalf("unexpected response content: %q", resp.Content)
|
||||
}
|
||||
if atomic.LoadInt32(&refreshCalls) != 1 {
|
||||
t.Fatalf("expected one refresh call, got %d", refreshCalls)
|
||||
}
|
||||
if atomic.LoadInt32(&responseCalls) != 1 {
|
||||
t.Fatalf("expected one response call, got %d", responseCalls)
|
||||
}
|
||||
if got := atomic.LoadInt32(&proxyCalls); got < 2 {
|
||||
t.Fatalf("expected proxy to receive refresh and response requests, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthImportGeminiNestedTokenRefreshesWithTokenMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -547,7 +660,7 @@ func TestQwenDeviceFlowRequiresAccountLabelWhenEmailMissing(t *testing.T) {
|
||||
}
|
||||
defer manager.bgCancel()
|
||||
|
||||
flow, err := manager.startDeviceFlow(context.Background())
|
||||
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("start device flow failed: %v", err)
|
||||
}
|
||||
@@ -734,7 +847,7 @@ func TestOAuthDeviceFlowQwenManualCompletes(t *testing.T) {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
|
||||
flow, err := manager.startDeviceFlow(context.Background())
|
||||
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("start device flow failed: %v", err)
|
||||
}
|
||||
@@ -879,7 +992,6 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
|
||||
CredentialFile: credFile,
|
||||
TokenURL: server.URL + "/oauth/token",
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
HybridPriority: "oauth_first",
|
||||
},
|
||||
}
|
||||
oauth, err := newOAuthManager(pc, 5*time.Second)
|
||||
@@ -891,11 +1003,11 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
if resp.Content != "ok-from-oauth" {
|
||||
if resp.Content != "ok-from-api" {
|
||||
t.Fatalf("unexpected response content: %q", resp.Content)
|
||||
}
|
||||
if atomic.LoadInt32(&oauthCalls) != 1 || atomic.LoadInt32(&apiKeyCalls) != 0 {
|
||||
t.Fatalf("expected oauth first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
|
||||
if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 0 {
|
||||
t.Fatalf("expected api key first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1187,7 +1299,6 @@ func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
HybridPriority: "api_first",
|
||||
},
|
||||
}
|
||||
ConfigureProviderRuntime(name, pc)
|
||||
@@ -1206,8 +1317,8 @@ func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.
|
||||
provider.markAttemptSuccess(attempts[1])
|
||||
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{name: pc},
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{name: pc},
|
||||
},
|
||||
}
|
||||
snapshot := GetProviderRuntimeSnapshot(cfg)
|
||||
@@ -1267,8 +1378,8 @@ func TestConfigureProviderRuntimeLoadsPersistedEvents(t *testing.T) {
|
||||
})
|
||||
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {
|
||||
APIBase: "https://example.com/v1",
|
||||
Auth: "bearer",
|
||||
@@ -1357,7 +1468,6 @@ func TestUpdateCandidateOrderRecordsSchedulerChange(t *testing.T) {
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
HybridPriority: "api_first",
|
||||
},
|
||||
}
|
||||
manager, err := newOAuthManager(pc, 5*time.Second)
|
||||
@@ -1412,8 +1522,8 @@ func TestGetProviderRuntimeViewFiltersEvents(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "hybrid", APIKey: "api-key"},
|
||||
},
|
||||
},
|
||||
@@ -1463,8 +1573,8 @@ func TestGetProviderRuntimeViewCursorPagination(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
||||
},
|
||||
},
|
||||
@@ -1503,8 +1613,8 @@ func TestGetProviderRuntimeViewSortAscending(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
||||
},
|
||||
},
|
||||
@@ -1542,8 +1652,8 @@ func TestGetProviderRuntimeViewFiltersByHealthAndCooldown(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
|
||||
},
|
||||
},
|
||||
@@ -1594,8 +1704,8 @@ func TestGetProviderRuntimeSummaryFlagsUnhealthyProviders(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
|
||||
},
|
||||
},
|
||||
@@ -1643,8 +1753,8 @@ func TestGetProviderRuntimeSummaryMarksRecentErrorsAsDegraded(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
||||
},
|
||||
},
|
||||
@@ -1688,8 +1798,8 @@ func TestGetProviderRuntimeSummaryIncludesOAuthAccountMetadata(t *testing.T) {
|
||||
t.Fatalf("write session failed: %v", err)
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
"qwen-summary": {
|
||||
APIBase: "https://example.com/v1",
|
||||
Auth: "oauth",
|
||||
@@ -1746,8 +1856,8 @@ func TestRefreshProviderRuntimeNowSupportsOnlyExpiring(t *testing.T) {
|
||||
|
||||
name := "runtime-refresh-provider"
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {
|
||||
APIBase: server.URL + "/v1",
|
||||
Auth: "oauth",
|
||||
@@ -1807,8 +1917,8 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
}
|
||||
name := "rerank-runtime-provider"
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {
|
||||
APIKey: "api-key",
|
||||
APIBase: "https://example.com/v1",
|
||||
@@ -1817,7 +1927,6 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
HybridPriority: "oauth_first",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1827,8 +1936,8 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("rerank provider runtime failed: %v", err)
|
||||
}
|
||||
if len(order) == 0 || order[0].Kind != "oauth" {
|
||||
t.Fatalf("expected oauth-first rerank result, got %#v", order)
|
||||
if len(order) == 0 || order[0].Kind != "api_key" {
|
||||
t.Fatalf("expected api-key-first rerank result, got %#v", order)
|
||||
}
|
||||
snapshot := GetProviderRuntimeSnapshot(cfg)
|
||||
items, _ := snapshot["items"].([]map[string]interface{})
|
||||
@@ -1836,7 +1945,7 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
t.Fatalf("expected one runtime item, got %#v", snapshot)
|
||||
}
|
||||
snapshotOrder, _ := items[0]["candidate_order"].([]providerRuntimeCandidate)
|
||||
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "oauth" {
|
||||
t.Fatalf("expected oauth-first candidate order, got %#v", items[0]["candidate_order"])
|
||||
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "api_key" {
|
||||
t.Fatalf("expected api-key-first candidate order, got %#v", items[0]["candidate_order"])
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user