mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-14 05:08:58 +08:00
361 lines
12 KiB
Go
361 lines
12 KiB
Go
package providers
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"regexp"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
geminiCLIBaseURL = "https://cloudcode-pa.googleapis.com"
|
|
geminiCLIVersion = "v1internal"
|
|
geminiCLIDefaultAlt = "sse"
|
|
geminiCLIClientVersion = "0.31.0"
|
|
geminiCLIApiClient = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
|
|
)
|
|
|
|
type GeminiCLIProvider struct {
|
|
base *HTTPProvider
|
|
}
|
|
|
|
func NewGeminiCLIProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *GeminiCLIProvider {
|
|
normalizedBase := normalizeAPIBase(apiBase)
|
|
if normalizedBase == "" {
|
|
normalizedBase = geminiCLIBaseURL
|
|
}
|
|
return &GeminiCLIProvider{
|
|
base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
|
|
}
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) GetDefaultModel() string {
|
|
if p == nil || p.base == nil {
|
|
return ""
|
|
}
|
|
return p.base.GetDefaultModel()
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
|
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if !json.Valid(body) {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
return parseGeminiResponse(body)
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
|
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if !json.Valid(body) {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
return parseGeminiResponse(body)
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) {
|
|
if p == nil || p.base == nil {
|
|
return nil, fmt.Errorf("provider not configured")
|
|
}
|
|
attempts, err := p.base.authAttempts(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var lastBody []byte
|
|
var lastStatus int
|
|
var lastType string
|
|
for _, attempt := range attempts {
|
|
requestBody := p.buildRequestBody(messages, nil, model, options, false, attempt.session)
|
|
delete(requestBody, "project")
|
|
delete(requestBody, "model")
|
|
request := mapFromAny(requestBody["request"])
|
|
delete(request, "safetySettings")
|
|
requestBody["request"] = request
|
|
body, status, ctype, reqErr := p.performAttempt(ctx, p.endpoint("countTokens", false), requestBody, attempt, false, nil)
|
|
if reqErr != nil {
|
|
return nil, reqErr
|
|
}
|
|
lastBody, lastStatus, lastType = body, status, ctype
|
|
reason, retry := classifyOAuthFailure(status, body)
|
|
if retry {
|
|
applyAttemptFailure(p.base, attempt, reason, geminiRetryAfter(body))
|
|
continue
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
var payload struct {
|
|
TotalTokens int `json:"totalTokens"`
|
|
}
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return nil, fmt.Errorf("invalid countTokens response: %w", err)
|
|
}
|
|
p.base.markAttemptSuccess(attempt)
|
|
return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil
|
|
}
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody))
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
|
attempts, err := p.base.authAttempts(ctx)
|
|
if err != nil {
|
|
return nil, 0, "", err
|
|
}
|
|
var lastBody []byte
|
|
var lastStatus int
|
|
var lastType string
|
|
action := "generateContent"
|
|
if stream {
|
|
action = "streamGenerateContent"
|
|
}
|
|
for _, attempt := range attempts {
|
|
requestBody := p.buildRequestBody(messages, tools, model, options, stream, attempt.session)
|
|
body, status, ctype, reqErr := p.performAttempt(ctx, p.endpoint(action, stream), requestBody, attempt, stream, onDelta)
|
|
if reqErr != nil {
|
|
return nil, 0, "", reqErr
|
|
}
|
|
lastBody, lastStatus, lastType = body, status, ctype
|
|
reason, retry := classifyOAuthFailure(status, body)
|
|
if retry {
|
|
applyAttemptFailure(p.base, attempt, reason, geminiRetryAfter(body))
|
|
continue
|
|
}
|
|
p.base.markAttemptSuccess(attempt)
|
|
return body, status, ctype, nil
|
|
}
|
|
return lastBody, lastStatus, lastType, nil
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, session *oauthSession) map[string]any {
|
|
request := map[string]any{
|
|
"request": p.buildInnerRequest(messages, tools, model, options, stream),
|
|
"model": strings.TrimSpace(qwenBaseModel(model)),
|
|
}
|
|
if projectID := geminiCLIProjectID(options, session); projectID != "" {
|
|
request["project"] = projectID
|
|
}
|
|
return request
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) buildInnerRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]any {
|
|
request := NewGeminiProvider(p.base.providerName, p.base.apiKey, p.base.apiBase, p.base.defaultModel, p.base.supportsResponsesCompact, p.base.authMode, p.base.timeout, p.base.oauth).
|
|
buildRequestBody(messages, tools, model, options, stream)
|
|
if _, ok := request["safetySettings"]; !ok {
|
|
request["safetySettings"] = []map[string]any{}
|
|
}
|
|
return request
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) endpoint(action string, stream bool) string {
|
|
base := normalizeAPIBase(p.base.apiBase)
|
|
if base == "" {
|
|
base = geminiCLIBaseURL
|
|
}
|
|
url := fmt.Sprintf("%s/%s:%s", base, geminiCLIVersion, action)
|
|
if stream {
|
|
return url + "?alt=" + geminiCLIDefaultAlt
|
|
}
|
|
return url
|
|
}
|
|
|
|
func (p *GeminiCLIProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
|
|
if err != nil {
|
|
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if stream {
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
} else {
|
|
req.Header.Set("Accept", "application/json")
|
|
}
|
|
if err := applyGeminiCLIAttemptAuth(req, attempt); err != nil {
|
|
return nil, 0, "", err
|
|
}
|
|
applyGeminiCLIHeaders(req, strings.TrimSpace(asString(payload["model"])))
|
|
client, err := p.base.httpClientForAttempt(attempt)
|
|
if err != nil {
|
|
return nil, 0, "", err
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
|
if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") {
|
|
return consumeGeminiCLIStream(resp, onDelta)
|
|
}
|
|
body, readErr := io.ReadAll(resp.Body)
|
|
if readErr != nil {
|
|
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr)
|
|
}
|
|
return body, resp.StatusCode, ctype, nil
|
|
}
|
|
|
|
func applyGeminiCLIAttemptAuth(req *http.Request, attempt authAttempt) error {
|
|
if req == nil {
|
|
return nil
|
|
}
|
|
token := strings.TrimSpace(attempt.token)
|
|
if attempt.session != nil {
|
|
token = firstNonEmpty(strings.TrimSpace(attempt.session.AccessToken), token, asString(attempt.session.Token["access_token"]), asString(attempt.session.Token["access-token"]))
|
|
}
|
|
if token == "" {
|
|
return fmt.Errorf("missing access token for gemini-cli")
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
req.Header.Del("x-goog-api-key")
|
|
return nil
|
|
}
|
|
|
|
func consumeGeminiCLIStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) {
|
|
if onDelta == nil {
|
|
onDelta = func(string) {}
|
|
}
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
|
var dataLines []string
|
|
state := &antigravityStreamState{}
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if strings.TrimSpace(line) == "" {
|
|
if len(dataLines) > 0 {
|
|
payload := strings.Join(dataLines, "\n")
|
|
dataLines = dataLines[:0]
|
|
if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" {
|
|
if delta := state.consume([]byte(payload)); delta != "" {
|
|
onDelta(delta)
|
|
}
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
if strings.HasPrefix(line, "data:") {
|
|
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read stream: %w", err)
|
|
}
|
|
return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
|
|
}
|
|
|
|
func geminiCLIProjectID(options map[string]interface{}, session *oauthSession) string {
|
|
for _, key := range []string{"gemini_project_id", "project_id", "project"} {
|
|
if value, ok := stringOption(options, key); ok {
|
|
trimmed := strings.TrimSpace(value)
|
|
if trimmed != "" {
|
|
return trimmed
|
|
}
|
|
}
|
|
}
|
|
if session == nil {
|
|
return ""
|
|
}
|
|
return firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["projectId"]), asString(session.Token["project"]))
|
|
}
|
|
|
|
func geminiCLIRuntimeOS() string {
|
|
switch runtime.GOOS {
|
|
case "windows":
|
|
return "win32"
|
|
default:
|
|
return runtime.GOOS
|
|
}
|
|
}
|
|
|
|
func geminiCLIRuntimeArch() string {
|
|
switch runtime.GOARCH {
|
|
case "amd64":
|
|
return "x64"
|
|
case "386":
|
|
return "x86"
|
|
default:
|
|
return runtime.GOARCH
|
|
}
|
|
}
|
|
|
|
func geminiCLIUserAgent(model string) string {
|
|
trimmedModel := strings.TrimSpace(model)
|
|
if trimmedModel == "" {
|
|
trimmedModel = "unknown"
|
|
}
|
|
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", geminiCLIClientVersion, trimmedModel, geminiCLIRuntimeOS(), geminiCLIRuntimeArch())
|
|
}
|
|
|
|
func applyGeminiCLIHeaders(req *http.Request, model string) {
|
|
if req == nil {
|
|
return
|
|
}
|
|
req.Header.Set("User-Agent", geminiCLIUserAgent(model))
|
|
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
|
}
|
|
|
|
func geminiRetryAfter(body []byte) *time.Duration {
|
|
if len(body) == 0 {
|
|
return nil
|
|
}
|
|
var root map[string]any
|
|
if err := json.Unmarshal(body, &root); err != nil {
|
|
return retryDelayFromMessage(string(body))
|
|
}
|
|
errRoot := mapFromAny(root["error"])
|
|
details, _ := errRoot["details"].([]any)
|
|
for _, raw := range details {
|
|
detail := mapFromAny(raw)
|
|
if asString(detail["@type"]) == "type.googleapis.com/google.rpc.RetryInfo" {
|
|
if d, err := time.ParseDuration(strings.TrimSpace(asString(detail["retryDelay"]))); err == nil {
|
|
return &d
|
|
}
|
|
}
|
|
}
|
|
for _, raw := range details {
|
|
detail := mapFromAny(raw)
|
|
if asString(detail["@type"]) == "type.googleapis.com/google.rpc.ErrorInfo" {
|
|
metadata := mapFromAny(detail["metadata"])
|
|
if d, err := time.ParseDuration(strings.TrimSpace(asString(metadata["quotaResetDelay"]))); err == nil {
|
|
return &d
|
|
}
|
|
}
|
|
}
|
|
return retryDelayFromMessage(asString(errRoot["message"]))
|
|
}
|
|
|
|
func retryDelayFromMessage(message string) *time.Duration {
|
|
re := regexp.MustCompile(`after\s+(\d+)s\.?`)
|
|
matches := re.FindStringSubmatch(strings.TrimSpace(message))
|
|
if len(matches) < 2 {
|
|
return nil
|
|
}
|
|
seconds, err := strconv.Atoi(matches[1])
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
d := time.Duration(seconds) * time.Second
|
|
return &d
|
|
}
|