feat: align google and relay providers

This commit is contained in:
lpf
2026-03-12 20:26:16 +08:00
parent e405d410c9
commit 1e9e4d8459
29 changed files with 6208 additions and 229 deletions

View File

@@ -15,6 +15,7 @@ import (
const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com"
)
type AntigravityProvider struct {
@@ -38,6 +39,64 @@ func (p *AntigravityProvider) GetDefaultModel() string {
return p.base.GetDefaultModel()
}
func (p *AntigravityProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) {
if p == nil || p.base == nil {
return nil, fmt.Errorf("provider not configured")
}
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, err
}
var lastBody []byte
var lastStatus int
var lastType string
for _, attempt := range attempts {
for _, baseURL := range p.baseURLs() {
requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, false)
delete(requestBody, "project")
delete(requestBody, "model")
request := mapFromAny(requestBody["request"])
delete(request, "safetySettings")
requestBody["request"] = request
body, status, ctype, reqErr := p.performCountTokensAttempt(ctx, p.countTokensEndpoint(baseURL), requestBody, attempt)
if reqErr != nil {
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
return nil, reqErr
}
lastBody, lastStatus, lastType = nil, 0, ""
continue
}
lastBody, lastStatus, lastType = body, status, ctype
if status == http.StatusTooManyRequests {
continue
}
reason, retry := classifyOAuthFailure(status, body)
if retry {
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" {
p.base.markAPIKeyFailure(reason)
}
break
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
var payload struct {
TotalTokens int `json:"totalTokens"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("invalid countTokens response: %w", err)
}
p.base.markAttemptSuccess(attempt)
return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil
}
}
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody))
}
func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
if err != nil {
@@ -81,31 +140,39 @@ func (p *AntigravityProvider) doRequest(ctx context.Context, messages []Message,
for _, baseURL := range p.baseURLs() {
requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, stream)
endpoint := p.endpoint(baseURL, stream)
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta)
if reqErr != nil {
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
return nil, 0, "", reqErr
for retryAttempt := 0; retryAttempt < 3; retryAttempt++ {
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta)
if reqErr != nil {
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
return nil, 0, "", reqErr
}
lastBody, lastStatus, lastType = nil, 0, ""
break
}
lastBody, lastStatus, lastType = nil, 0, ""
continue
}
lastBody, lastStatus, lastType = body, status, ctype
if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway {
continue
}
reason, retry := classifyOAuthFailure(status, body)
if retry {
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
lastBody, lastStatus, lastType = body, status, ctype
if antigravityShouldRetryNoCapacity(status, body) && retryAttempt < 2 {
if err := antigravityWait(ctx, antigravityNoCapacityRetryDelay(retryAttempt)); err != nil {
return nil, 0, "", err
}
continue
}
if attempt.kind == "api_key" {
p.base.markAPIKeyFailure(reason)
if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway {
break
}
break
reason, retry := classifyOAuthFailure(status, body)
if retry {
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" {
p.base.markAPIKeyFailure(reason)
}
break
}
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
}
return lastBody, lastStatus, lastType, nil
@@ -163,14 +230,24 @@ func (p *AntigravityProvider) endpoint(baseURL string, stream bool) string {
return base + path
}
func (p *AntigravityProvider) countTokensEndpoint(baseURL string) string {
base := normalizeAPIBase(baseURL)
if base == "" {
base = antigravityDailyBaseURL
}
return base + "/" + defaultAntigravityAPIVersion + ":countTokens"
}
func (p *AntigravityProvider) baseURLs() []string {
if p == nil || p.base == nil {
return []string{antigravityDailyBaseURL}
}
if custom := normalizeAPIBase(p.base.apiBase); custom != "" && !strings.Contains(strings.ToLower(custom), "api.openai.com") {
if custom := normalizeAPIBase(p.base.apiBase); custom != "" &&
!strings.Contains(strings.ToLower(custom), "api.openai.com") &&
custom != antigravityDailyBaseURL {
return []string{custom}
}
return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, defaultAntigravityAPIEndpoint}
return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, antigravityProdBaseURL, defaultAntigravityAPIEndpoint}
}
func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any {
@@ -409,6 +486,70 @@ func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte
return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
}
func (p *AntigravityProvider) performCountTokensAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt) ([]byte, int, string, error) {
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
}
req.Close = true
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient)
req.Header.Set("Client-Metadata", defaultAntigravityClientMeta)
applyAttemptAuth(req, attempt)
client, err := p.base.httpClientForAttempt(attempt)
if err != nil {
return nil, 0, "", err
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read response: %w", readErr)
}
return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
}
func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
if statusCode != http.StatusServiceUnavailable {
return false
}
return strings.Contains(strings.ToLower(string(body)), "no capacity available")
}
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := time.Duration(attempt+1) * 250 * time.Millisecond
if delay > 2*time.Second {
delay = 2 * time.Second
}
return delay
}
func antigravityWait(ctx context.Context, wait time.Duration) error {
if wait <= 0 {
return nil
}
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
type antigravityStreamState struct {
Text string
ToolCalls []ToolCall