Files
clawgo/pkg/providers/aistudio_provider.go
2026-03-12 20:26:16 +08:00

272 lines
8.9 KiB
Go

package providers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/YspCoder/clawgo/pkg/wsrelay"
)
type AistudioProvider struct {
base *HTTPProvider
relay *wsrelay.Manager
}
func NewAistudioProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *AistudioProvider {
normalizedBase := normalizeAPIBase(apiBase)
if normalizedBase == "" {
normalizedBase = geminiBaseURL
}
return &AistudioProvider{
base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
relay: getAIStudioRelayManager(),
}
}
func (p *AistudioProvider) GetDefaultModel() string {
if p == nil || p.base == nil {
return ""
}
return p.base.GetDefaultModel()
}
func (p *AistudioProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseGeminiResponse(body)
}
func (p *AistudioProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseGeminiResponse(body)
}
func (p *AistudioProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) {
requestBody := p.buildRequestBody(messages, nil, model, options, false)
delete(requestBody, "tools")
delete(requestBody, "toolConfig")
delete(requestBody, "generationConfig")
body, status, ctype, err := p.perform(ctx, p.endpoint(model, "countTokens", false), requestBody, options, false, nil)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
var payload struct {
TotalTokens int `json:"totalTokens"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("invalid countTokens response: %w", err)
}
return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil
}
func (p *AistudioProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
requestBody := p.buildRequestBody(messages, tools, model, options, stream)
return p.perform(ctx, p.endpoint(model, "generateContent", stream), requestBody, options, stream, onDelta)
}
func (p *AistudioProvider) perform(ctx context.Context, endpoint string, payload map[string]any, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
if p == nil || p.base == nil {
return nil, 0, "", fmt.Errorf("provider not configured")
}
if p.relay == nil {
p.relay = getAIStudioRelayManager()
}
if p.relay == nil {
return nil, 0, "", fmt.Errorf("aistudio relay not configured")
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
}
req := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
Headers: http.Header{
"Content-Type": []string{"application/json"},
"Accept": []string{"application/json"},
},
Body: jsonData,
}
if stream {
req.Headers.Set("Accept", "text/event-stream")
}
channelIDs := aistudioChannelCandidates(p.base.providerName, options)
if len(channelIDs) == 0 {
return nil, 0, "", fmt.Errorf("aistudio relay channel not specified")
}
if !stream {
var lastErr error
for _, channelID := range channelIDs {
resp, err := p.relay.NonStream(ctx, channelID, req)
if err != nil {
recordAIStudioRelayFailure(channelID, err)
lastErr = err
continue
}
if resp.Status >= 200 && resp.Status < 300 {
recordAIStudioRelaySuccess(channelID)
return resp.Body, resp.Status, strings.TrimSpace(resp.Headers.Get("Content-Type")), nil
}
retryErr := fmt.Errorf("status=%d", resp.Status)
recordAIStudioRelayFailure(channelID, retryErr)
lastErr = retryErr
if resp.Status < 500 && resp.Status != http.StatusTooManyRequests {
return resp.Body, resp.Status, strings.TrimSpace(resp.Headers.Get("Content-Type")), nil
}
}
if lastErr == nil {
lastErr = fmt.Errorf("aistudio relay request failed")
}
return nil, 0, "", lastErr
}
if onDelta == nil {
onDelta = func(string) {}
}
var lastErr error
for _, channelID := range channelIDs {
streamCh, err := p.relay.Stream(ctx, channelID, req)
if err != nil {
recordAIStudioRelayFailure(channelID, err)
lastErr = err
continue
}
state := &antigravityStreamState{}
status := http.StatusOK
ctype := "text/event-stream"
var full bytes.Buffer
started := false
retryable := false
failed := false
for event := range streamCh {
if event.Err != nil {
recordAIStudioRelayFailure(channelID, event.Err)
lastErr = event.Err
retryable = !started
failed = true
break
}
switch event.Type {
case wsrelay.MessageTypeStreamStart:
if event.Status > 0 {
status = event.Status
}
if v := strings.TrimSpace(event.Headers.Get("Content-Type")); v != "" {
ctype = v
}
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) == 0 {
continue
}
started = true
full.Write(event.Payload)
filtered := filterGeminiSSEUsageMetadata(event.Payload)
if delta := state.consume(filtered); delta != "" {
onDelta(delta)
}
case wsrelay.MessageTypeHTTPResp:
if event.Status > 0 {
status = event.Status
}
if v := strings.TrimSpace(event.Headers.Get("Content-Type")); v != "" {
ctype = v
}
if len(event.Payload) > 0 {
if status >= 200 && status < 300 {
recordAIStudioRelaySuccess(channelID)
} else {
recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status))
}
if status >= 500 || status == http.StatusTooManyRequests {
lastErr = fmt.Errorf("status=%d", status)
retryable = !started
failed = true
break
}
return event.Payload, status, ctype, nil
}
if status >= 200 && status < 300 {
recordAIStudioRelaySuccess(channelID)
} else {
recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status))
}
if status >= 500 || status == http.StatusTooManyRequests {
lastErr = fmt.Errorf("status=%d", status)
retryable = !started
failed = true
break
}
return state.finalBody(), status, ctype, nil
case wsrelay.MessageTypeStreamEnd:
if status >= 200 && status < 300 {
recordAIStudioRelaySuccess(channelID)
} else {
recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status))
}
if status >= 500 || status == http.StatusTooManyRequests {
lastErr = fmt.Errorf("status=%d", status)
retryable = !started
failed = true
break
}
return state.finalBody(), status, ctype, nil
}
}
if failed && started {
break
}
if !failed && full.Len() > 0 {
recordAIStudioRelaySuccess(channelID)
return state.finalBody(), status, ctype, nil
}
if !retryable {
break
}
}
if lastErr == nil {
lastErr = fmt.Errorf("wsrelay: stream closed")
}
return nil, 0, "", lastErr
}
func (p *AistudioProvider) endpoint(model, action string, stream bool) string {
base := geminiBaseURL
if p != nil && p.base != nil && strings.TrimSpace(p.base.apiBase) != "" && !strings.Contains(strings.ToLower(p.base.apiBase), "api.openai.com") {
base = normalizeGeminiBaseURL(p.base.apiBase)
}
baseModel := strings.TrimSpace(qwenBaseModel(model))
if stream {
return fmt.Sprintf("%s/%s/models/%s:streamGenerateContent?alt=sse", base, geminiAPIVersion, baseModel)
}
return fmt.Sprintf("%s/%s/models/%s:%s", base, geminiAPIVersion, baseModel, action)
}
func (p *AistudioProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]any {
gemini := &GeminiProvider{base: p.base}
return gemini.buildRequestBody(messages, tools, model, options, stream)
}