Files
clawgo/pkg/providers/gemini_cli_provider.go

361 lines
12 KiB
Go

package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"runtime"
"strconv"
"strings"
"time"
)
const (
geminiCLIBaseURL = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
geminiCLIDefaultAlt = "sse"
geminiCLIClientVersion = "0.31.0"
geminiCLIApiClient = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
)
type GeminiCLIProvider struct {
base *HTTPProvider
}
func NewGeminiCLIProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *GeminiCLIProvider {
normalizedBase := normalizeAPIBase(apiBase)
if normalizedBase == "" {
normalizedBase = geminiCLIBaseURL
}
return &GeminiCLIProvider{
base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
}
}
func (p *GeminiCLIProvider) GetDefaultModel() string {
if p == nil || p.base == nil {
return ""
}
return p.base.GetDefaultModel()
}
func (p *GeminiCLIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseGeminiResponse(body)
}
func (p *GeminiCLIProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseGeminiResponse(body)
}
func (p *GeminiCLIProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) {
if p == nil || p.base == nil {
return nil, fmt.Errorf("provider not configured")
}
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, err
}
var lastBody []byte
var lastStatus int
var lastType string
for _, attempt := range attempts {
requestBody := p.buildRequestBody(messages, nil, model, options, false, attempt.session)
delete(requestBody, "project")
delete(requestBody, "model")
request := mapFromAny(requestBody["request"])
delete(request, "safetySettings")
requestBody["request"] = request
body, status, ctype, reqErr := p.performAttempt(ctx, p.endpoint("countTokens", false), requestBody, attempt, false, nil)
if reqErr != nil {
return nil, reqErr
}
lastBody, lastStatus, lastType = body, status, ctype
reason, retry := classifyOAuthFailure(status, body)
if retry {
applyAttemptFailure(p.base, attempt, reason, geminiRetryAfter(body))
continue
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
var payload struct {
TotalTokens int `json:"totalTokens"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("invalid countTokens response: %w", err)
}
p.base.markAttemptSuccess(attempt)
return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil
}
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody))
}
func (p *GeminiCLIProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, 0, "", err
}
var lastBody []byte
var lastStatus int
var lastType string
action := "generateContent"
if stream {
action = "streamGenerateContent"
}
for _, attempt := range attempts {
requestBody := p.buildRequestBody(messages, tools, model, options, stream, attempt.session)
body, status, ctype, reqErr := p.performAttempt(ctx, p.endpoint(action, stream), requestBody, attempt, stream, onDelta)
if reqErr != nil {
return nil, 0, "", reqErr
}
lastBody, lastStatus, lastType = body, status, ctype
reason, retry := classifyOAuthFailure(status, body)
if retry {
applyAttemptFailure(p.base, attempt, reason, geminiRetryAfter(body))
continue
}
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
return lastBody, lastStatus, lastType, nil
}
func (p *GeminiCLIProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, session *oauthSession) map[string]any {
request := map[string]any{
"request": p.buildInnerRequest(messages, tools, model, options, stream),
"model": strings.TrimSpace(qwenBaseModel(model)),
}
if projectID := geminiCLIProjectID(options, session); projectID != "" {
request["project"] = projectID
}
return request
}
func (p *GeminiCLIProvider) buildInnerRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]any {
request := NewGeminiProvider(p.base.providerName, p.base.apiKey, p.base.apiBase, p.base.defaultModel, p.base.supportsResponsesCompact, p.base.authMode, p.base.timeout, p.base.oauth).
buildRequestBody(messages, tools, model, options, stream)
if _, ok := request["safetySettings"]; !ok {
request["safetySettings"] = []map[string]any{}
}
return request
}
func (p *GeminiCLIProvider) endpoint(action string, stream bool) string {
base := normalizeAPIBase(p.base.apiBase)
if base == "" {
base = geminiCLIBaseURL
}
url := fmt.Sprintf("%s/%s:%s", base, geminiCLIVersion, action)
if stream {
return url + "?alt=" + geminiCLIDefaultAlt
}
return url
}
func (p *GeminiCLIProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) {
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if stream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
if err := applyGeminiCLIAttemptAuth(req, attempt); err != nil {
return nil, 0, "", err
}
applyGeminiCLIHeaders(req, strings.TrimSpace(asString(payload["model"])))
client, err := p.base.httpClientForAttempt(attempt)
if err != nil {
return nil, 0, "", err
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") {
return consumeGeminiCLIStream(resp, onDelta)
}
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr)
}
return body, resp.StatusCode, ctype, nil
}
func applyGeminiCLIAttemptAuth(req *http.Request, attempt authAttempt) error {
if req == nil {
return nil
}
token := strings.TrimSpace(attempt.token)
if attempt.session != nil {
token = firstNonEmpty(strings.TrimSpace(attempt.session.AccessToken), token, asString(attempt.session.Token["access_token"]), asString(attempt.session.Token["access-token"]))
}
if token == "" {
return fmt.Errorf("missing access token for gemini-cli")
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Del("x-goog-api-key")
return nil
}
func consumeGeminiCLIStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) {
if onDelta == nil {
onDelta = func(string) {}
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
var dataLines []string
state := &antigravityStreamState{}
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
if len(dataLines) > 0 {
payload := strings.Join(dataLines, "\n")
dataLines = dataLines[:0]
if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" {
if delta := state.consume([]byte(payload)); delta != "" {
onDelta(delta)
}
}
}
continue
}
if strings.HasPrefix(line, "data:") {
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
}
}
if err := scanner.Err(); err != nil {
return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read stream: %w", err)
}
return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
}
func geminiCLIProjectID(options map[string]interface{}, session *oauthSession) string {
for _, key := range []string{"gemini_project_id", "project_id", "project"} {
if value, ok := stringOption(options, key); ok {
trimmed := strings.TrimSpace(value)
if trimmed != "" {
return trimmed
}
}
}
if session == nil {
return ""
}
return firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["projectId"]), asString(session.Token["project"]))
}
func geminiCLIRuntimeOS() string {
switch runtime.GOOS {
case "windows":
return "win32"
default:
return runtime.GOOS
}
}
func geminiCLIRuntimeArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "386":
return "x86"
default:
return runtime.GOARCH
}
}
func geminiCLIUserAgent(model string) string {
trimmedModel := strings.TrimSpace(model)
if trimmedModel == "" {
trimmedModel = "unknown"
}
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", geminiCLIClientVersion, trimmedModel, geminiCLIRuntimeOS(), geminiCLIRuntimeArch())
}
func applyGeminiCLIHeaders(req *http.Request, model string) {
if req == nil {
return
}
req.Header.Set("User-Agent", geminiCLIUserAgent(model))
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
}
func geminiRetryAfter(body []byte) *time.Duration {
if len(body) == 0 {
return nil
}
var root map[string]any
if err := json.Unmarshal(body, &root); err != nil {
return retryDelayFromMessage(string(body))
}
errRoot := mapFromAny(root["error"])
details, _ := errRoot["details"].([]any)
for _, raw := range details {
detail := mapFromAny(raw)
if asString(detail["@type"]) == "type.googleapis.com/google.rpc.RetryInfo" {
if d, err := time.ParseDuration(strings.TrimSpace(asString(detail["retryDelay"]))); err == nil {
return &d
}
}
}
for _, raw := range details {
detail := mapFromAny(raw)
if asString(detail["@type"]) == "type.googleapis.com/google.rpc.ErrorInfo" {
metadata := mapFromAny(detail["metadata"])
if d, err := time.ParseDuration(strings.TrimSpace(asString(metadata["quotaResetDelay"]))); err == nil {
return &d
}
}
}
return retryDelayFromMessage(asString(errRoot["message"]))
}
func retryDelayFromMessage(message string) *time.Duration {
re := regexp.MustCompile(`after\s+(\d+)s\.?`)
matches := re.FindStringSubmatch(strings.TrimSpace(message))
if len(matches) < 2 {
return nil
}
seconds, err := strconv.Atoi(matches[1])
if err != nil {
return nil
}
d := time.Duration(seconds) * time.Second
return &d
}