mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-15 01:37:31 +08:00
272 lines
8.9 KiB
Go
272 lines
8.9 KiB
Go
package providers
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/YspCoder/clawgo/pkg/wsrelay"
|
|
)
|
|
|
|
type AistudioProvider struct {
|
|
base *HTTPProvider
|
|
relay *wsrelay.Manager
|
|
}
|
|
|
|
func NewAistudioProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *AistudioProvider {
|
|
normalizedBase := normalizeAPIBase(apiBase)
|
|
if normalizedBase == "" {
|
|
normalizedBase = geminiBaseURL
|
|
}
|
|
return &AistudioProvider{
|
|
base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
|
|
relay: getAIStudioRelayManager(),
|
|
}
|
|
}
|
|
|
|
func (p *AistudioProvider) GetDefaultModel() string {
|
|
if p == nil || p.base == nil {
|
|
return ""
|
|
}
|
|
return p.base.GetDefaultModel()
|
|
}
|
|
|
|
func (p *AistudioProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
|
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if !json.Valid(body) {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
return parseGeminiResponse(body)
|
|
}
|
|
|
|
func (p *AistudioProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
|
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if !json.Valid(body) {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
return parseGeminiResponse(body)
|
|
}
|
|
|
|
func (p *AistudioProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) {
|
|
requestBody := p.buildRequestBody(messages, nil, model, options, false)
|
|
delete(requestBody, "tools")
|
|
delete(requestBody, "toolConfig")
|
|
delete(requestBody, "generationConfig")
|
|
body, status, ctype, err := p.perform(ctx, p.endpoint(model, "countTokens", false), requestBody, options, false, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
var payload struct {
|
|
TotalTokens int `json:"totalTokens"`
|
|
}
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return nil, fmt.Errorf("invalid countTokens response: %w", err)
|
|
}
|
|
return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil
|
|
}
|
|
|
|
func (p *AistudioProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
|
requestBody := p.buildRequestBody(messages, tools, model, options, stream)
|
|
return p.perform(ctx, p.endpoint(model, "generateContent", stream), requestBody, options, stream, onDelta)
|
|
}
|
|
|
|
func (p *AistudioProvider) perform(ctx context.Context, endpoint string, payload map[string]any, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
|
if p == nil || p.base == nil {
|
|
return nil, 0, "", fmt.Errorf("provider not configured")
|
|
}
|
|
if p.relay == nil {
|
|
p.relay = getAIStudioRelayManager()
|
|
}
|
|
if p.relay == nil {
|
|
return nil, 0, "", fmt.Errorf("aistudio relay not configured")
|
|
}
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
req := &wsrelay.HTTPRequest{
|
|
Method: http.MethodPost,
|
|
URL: endpoint,
|
|
Headers: http.Header{
|
|
"Content-Type": []string{"application/json"},
|
|
"Accept": []string{"application/json"},
|
|
},
|
|
Body: jsonData,
|
|
}
|
|
if stream {
|
|
req.Headers.Set("Accept", "text/event-stream")
|
|
}
|
|
channelIDs := aistudioChannelCandidates(p.base.providerName, options)
|
|
if len(channelIDs) == 0 {
|
|
return nil, 0, "", fmt.Errorf("aistudio relay channel not specified")
|
|
}
|
|
if !stream {
|
|
var lastErr error
|
|
for _, channelID := range channelIDs {
|
|
resp, err := p.relay.NonStream(ctx, channelID, req)
|
|
if err != nil {
|
|
recordAIStudioRelayFailure(channelID, err)
|
|
lastErr = err
|
|
continue
|
|
}
|
|
if resp.Status >= 200 && resp.Status < 300 {
|
|
recordAIStudioRelaySuccess(channelID)
|
|
return resp.Body, resp.Status, strings.TrimSpace(resp.Headers.Get("Content-Type")), nil
|
|
}
|
|
retryErr := fmt.Errorf("status=%d", resp.Status)
|
|
recordAIStudioRelayFailure(channelID, retryErr)
|
|
lastErr = retryErr
|
|
if resp.Status < 500 && resp.Status != http.StatusTooManyRequests {
|
|
return resp.Body, resp.Status, strings.TrimSpace(resp.Headers.Get("Content-Type")), nil
|
|
}
|
|
}
|
|
if lastErr == nil {
|
|
lastErr = fmt.Errorf("aistudio relay request failed")
|
|
}
|
|
return nil, 0, "", lastErr
|
|
}
|
|
if onDelta == nil {
|
|
onDelta = func(string) {}
|
|
}
|
|
var lastErr error
|
|
for _, channelID := range channelIDs {
|
|
streamCh, err := p.relay.Stream(ctx, channelID, req)
|
|
if err != nil {
|
|
recordAIStudioRelayFailure(channelID, err)
|
|
lastErr = err
|
|
continue
|
|
}
|
|
state := &antigravityStreamState{}
|
|
status := http.StatusOK
|
|
ctype := "text/event-stream"
|
|
var full bytes.Buffer
|
|
started := false
|
|
retryable := false
|
|
failed := false
|
|
for event := range streamCh {
|
|
if event.Err != nil {
|
|
recordAIStudioRelayFailure(channelID, event.Err)
|
|
lastErr = event.Err
|
|
retryable = !started
|
|
failed = true
|
|
break
|
|
}
|
|
switch event.Type {
|
|
case wsrelay.MessageTypeStreamStart:
|
|
if event.Status > 0 {
|
|
status = event.Status
|
|
}
|
|
if v := strings.TrimSpace(event.Headers.Get("Content-Type")); v != "" {
|
|
ctype = v
|
|
}
|
|
case wsrelay.MessageTypeStreamChunk:
|
|
if len(event.Payload) == 0 {
|
|
continue
|
|
}
|
|
started = true
|
|
full.Write(event.Payload)
|
|
filtered := filterGeminiSSEUsageMetadata(event.Payload)
|
|
if delta := state.consume(filtered); delta != "" {
|
|
onDelta(delta)
|
|
}
|
|
case wsrelay.MessageTypeHTTPResp:
|
|
if event.Status > 0 {
|
|
status = event.Status
|
|
}
|
|
if v := strings.TrimSpace(event.Headers.Get("Content-Type")); v != "" {
|
|
ctype = v
|
|
}
|
|
if len(event.Payload) > 0 {
|
|
if status >= 200 && status < 300 {
|
|
recordAIStudioRelaySuccess(channelID)
|
|
} else {
|
|
recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status))
|
|
}
|
|
if status >= 500 || status == http.StatusTooManyRequests {
|
|
lastErr = fmt.Errorf("status=%d", status)
|
|
retryable = !started
|
|
failed = true
|
|
break
|
|
}
|
|
return event.Payload, status, ctype, nil
|
|
}
|
|
if status >= 200 && status < 300 {
|
|
recordAIStudioRelaySuccess(channelID)
|
|
} else {
|
|
recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status))
|
|
}
|
|
if status >= 500 || status == http.StatusTooManyRequests {
|
|
lastErr = fmt.Errorf("status=%d", status)
|
|
retryable = !started
|
|
failed = true
|
|
break
|
|
}
|
|
return state.finalBody(), status, ctype, nil
|
|
case wsrelay.MessageTypeStreamEnd:
|
|
if status >= 200 && status < 300 {
|
|
recordAIStudioRelaySuccess(channelID)
|
|
} else {
|
|
recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status))
|
|
}
|
|
if status >= 500 || status == http.StatusTooManyRequests {
|
|
lastErr = fmt.Errorf("status=%d", status)
|
|
retryable = !started
|
|
failed = true
|
|
break
|
|
}
|
|
return state.finalBody(), status, ctype, nil
|
|
}
|
|
}
|
|
if failed && started {
|
|
break
|
|
}
|
|
if !failed && full.Len() > 0 {
|
|
recordAIStudioRelaySuccess(channelID)
|
|
return state.finalBody(), status, ctype, nil
|
|
}
|
|
if !retryable {
|
|
break
|
|
}
|
|
}
|
|
if lastErr == nil {
|
|
lastErr = fmt.Errorf("wsrelay: stream closed")
|
|
}
|
|
return nil, 0, "", lastErr
|
|
}
|
|
|
|
func (p *AistudioProvider) endpoint(model, action string, stream bool) string {
|
|
base := geminiBaseURL
|
|
if p != nil && p.base != nil && strings.TrimSpace(p.base.apiBase) != "" && !strings.Contains(strings.ToLower(p.base.apiBase), "api.openai.com") {
|
|
base = normalizeGeminiBaseURL(p.base.apiBase)
|
|
}
|
|
baseModel := strings.TrimSpace(qwenBaseModel(model))
|
|
if stream {
|
|
return fmt.Sprintf("%s/%s/models/%s:streamGenerateContent?alt=sse", base, geminiAPIVersion, baseModel)
|
|
}
|
|
return fmt.Sprintf("%s/%s/models/%s:%s", base, geminiAPIVersion, baseModel, action)
|
|
}
|
|
|
|
func (p *AistudioProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]any {
|
|
gemini := &GeminiProvider{base: p.base}
|
|
return gemini.buildRequestBody(messages, tools, model, options, stream)
|
|
}
|