feat: align google and relay providers

This commit is contained in:
lpf
2026-03-12 20:26:16 +08:00
parent e405d410c9
commit 1e9e4d8459
29 changed files with 6208 additions and 229 deletions

View File

@@ -2,7 +2,11 @@ package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)
func TestAntigravityBuildRequestBody(t *testing.T) {
@@ -99,3 +103,66 @@ func TestParseAntigravityResponse(t *testing.T) {
t.Fatalf("expected tool args, got %#v", args)
}
}
func TestAntigravityProviderCountTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1internal:countTokens" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"totalTokens":42}`))
}))
defer server.Close()
p := NewAntigravityProvider("antigravity", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil)
usage, err := p.CountTokens(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil)
if err != nil {
t.Fatalf("CountTokens error: %v", err)
}
if usage == nil || usage.PromptTokens != 42 || usage.TotalTokens != 42 {
t.Fatalf("usage = %#v, want 42", usage)
}
}
func TestAntigravityProviderRetriesNoCapacity(t *testing.T) {
var hits int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1internal:generateContent" {
http.NotFound(w, r)
return
}
if atomic.AddInt32(&hits, 1) == 1 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte(`{"error":{"message":"no capacity available"}}`))
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}}`))
}))
defer server.Close()
p := NewAntigravityProvider("antigravity", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil)
resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil)
if err != nil {
t.Fatalf("Chat error: %v", err)
}
if resp.Content != "ok" {
t.Fatalf("content = %q, want ok", resp.Content)
}
if got := atomic.LoadInt32(&hits); got != 2 {
t.Fatalf("hits = %d, want 2", got)
}
}
func TestAntigravityBaseURLsIncludeProdFallback(t *testing.T) {
p := NewAntigravityProvider("antigravity", "", "", "gemini-2.5-pro", false, "oauth", 0, nil)
got := p.baseURLs()
if len(got) < 3 {
t.Fatalf("baseURLs = %#v", got)
}
if got[0] != antigravityDailyBaseURL || got[1] != antigravitySandboxBaseURL || got[2] != antigravityProdBaseURL {
t.Fatalf("unexpected fallback order: %#v", got)
}
}