mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-05-20 07:17:29 +08:00
feat: align google and relay providers
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com"
|
||||
antigravitySandboxBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
)
|
||||
|
||||
type AntigravityProvider struct {
|
||||
@@ -38,6 +39,64 @@ func (p *AntigravityProvider) GetDefaultModel() string {
|
||||
return p.base.GetDefaultModel()
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) {
|
||||
if p == nil || p.base == nil {
|
||||
return nil, fmt.Errorf("provider not configured")
|
||||
}
|
||||
attempts, err := p.base.authAttempts(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lastBody []byte
|
||||
var lastStatus int
|
||||
var lastType string
|
||||
for _, attempt := range attempts {
|
||||
for _, baseURL := range p.baseURLs() {
|
||||
requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, false)
|
||||
delete(requestBody, "project")
|
||||
delete(requestBody, "model")
|
||||
request := mapFromAny(requestBody["request"])
|
||||
delete(request, "safetySettings")
|
||||
requestBody["request"] = request
|
||||
body, status, ctype, reqErr := p.performCountTokensAttempt(ctx, p.countTokensEndpoint(baseURL), requestBody, attempt)
|
||||
if reqErr != nil {
|
||||
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
|
||||
return nil, reqErr
|
||||
}
|
||||
lastBody, lastStatus, lastType = nil, 0, ""
|
||||
continue
|
||||
}
|
||||
lastBody, lastStatus, lastType = body, status, ctype
|
||||
if status == http.StatusTooManyRequests {
|
||||
continue
|
||||
}
|
||||
reason, retry := classifyOAuthFailure(status, body)
|
||||
if retry {
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
|
||||
p.base.oauth.markExhausted(attempt.session, reason)
|
||||
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
|
||||
}
|
||||
if attempt.kind == "api_key" {
|
||||
p.base.markAPIKeyFailure(reason)
|
||||
}
|
||||
break
|
||||
}
|
||||
if status != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
var payload struct {
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("invalid countTokens response: %w", err)
|
||||
}
|
||||
p.base.markAttemptSuccess(attempt)
|
||||
return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody))
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
|
||||
if err != nil {
|
||||
@@ -81,31 +140,39 @@ func (p *AntigravityProvider) doRequest(ctx context.Context, messages []Message,
|
||||
for _, baseURL := range p.baseURLs() {
|
||||
requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, stream)
|
||||
endpoint := p.endpoint(baseURL, stream)
|
||||
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta)
|
||||
if reqErr != nil {
|
||||
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
|
||||
return nil, 0, "", reqErr
|
||||
for retryAttempt := 0; retryAttempt < 3; retryAttempt++ {
|
||||
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta)
|
||||
if reqErr != nil {
|
||||
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
|
||||
return nil, 0, "", reqErr
|
||||
}
|
||||
lastBody, lastStatus, lastType = nil, 0, ""
|
||||
break
|
||||
}
|
||||
lastBody, lastStatus, lastType = nil, 0, ""
|
||||
continue
|
||||
}
|
||||
lastBody, lastStatus, lastType = body, status, ctype
|
||||
if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway {
|
||||
continue
|
||||
}
|
||||
reason, retry := classifyOAuthFailure(status, body)
|
||||
if retry {
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
|
||||
p.base.oauth.markExhausted(attempt.session, reason)
|
||||
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
|
||||
lastBody, lastStatus, lastType = body, status, ctype
|
||||
if antigravityShouldRetryNoCapacity(status, body) && retryAttempt < 2 {
|
||||
if err := antigravityWait(ctx, antigravityNoCapacityRetryDelay(retryAttempt)); err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attempt.kind == "api_key" {
|
||||
p.base.markAPIKeyFailure(reason)
|
||||
if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway {
|
||||
break
|
||||
}
|
||||
break
|
||||
reason, retry := classifyOAuthFailure(status, body)
|
||||
if retry {
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
|
||||
p.base.oauth.markExhausted(attempt.session, reason)
|
||||
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
|
||||
}
|
||||
if attempt.kind == "api_key" {
|
||||
p.base.markAPIKeyFailure(reason)
|
||||
}
|
||||
break
|
||||
}
|
||||
p.base.markAttemptSuccess(attempt)
|
||||
return body, status, ctype, nil
|
||||
}
|
||||
p.base.markAttemptSuccess(attempt)
|
||||
return body, status, ctype, nil
|
||||
}
|
||||
}
|
||||
return lastBody, lastStatus, lastType, nil
|
||||
@@ -163,14 +230,24 @@ func (p *AntigravityProvider) endpoint(baseURL string, stream bool) string {
|
||||
return base + path
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) countTokensEndpoint(baseURL string) string {
|
||||
base := normalizeAPIBase(baseURL)
|
||||
if base == "" {
|
||||
base = antigravityDailyBaseURL
|
||||
}
|
||||
return base + "/" + defaultAntigravityAPIVersion + ":countTokens"
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) baseURLs() []string {
|
||||
if p == nil || p.base == nil {
|
||||
return []string{antigravityDailyBaseURL}
|
||||
}
|
||||
if custom := normalizeAPIBase(p.base.apiBase); custom != "" && !strings.Contains(strings.ToLower(custom), "api.openai.com") {
|
||||
if custom := normalizeAPIBase(p.base.apiBase); custom != "" &&
|
||||
!strings.Contains(strings.ToLower(custom), "api.openai.com") &&
|
||||
custom != antigravityDailyBaseURL {
|
||||
return []string{custom}
|
||||
}
|
||||
return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, defaultAntigravityAPIEndpoint}
|
||||
return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, antigravityProdBaseURL, defaultAntigravityAPIEndpoint}
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any {
|
||||
@@ -409,6 +486,70 @@ func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte
|
||||
return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) performCountTokensAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt) ([]byte, int, string, error) {
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Close = true
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", defaultAntigravityClientMeta)
|
||||
applyAttemptAuth(req, attempt)
|
||||
client, err := p.base.httpClientForAttempt(attempt)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
|
||||
}
|
||||
|
||||
func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusServiceUnavailable {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(string(body)), "no capacity available")
|
||||
}
|
||||
|
||||
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
delay := time.Duration(attempt+1) * 250 * time.Millisecond
|
||||
if delay > 2*time.Second {
|
||||
delay = 2 * time.Second
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func antigravityWait(ctx context.Context, wait time.Duration) error {
|
||||
if wait <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(wait)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type antigravityStreamState struct {
|
||||
Text string
|
||||
ToolCalls []ToolCall
|
||||
|
||||
Reference in New Issue
Block a user