mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-14 20:47:30 +08:00
526 lines
18 KiB
Go
526 lines
18 KiB
Go
package providers
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/jwt"
|
|
)
|
|
|
|
const (
|
|
vertexDefaultBaseURL = "https://aiplatform.googleapis.com"
|
|
vertexAPIVersion = "v1"
|
|
vertexDefaultRegion = "us-central1"
|
|
)
|
|
|
|
type VertexProvider struct {
|
|
base *HTTPProvider
|
|
}
|
|
|
|
func NewVertexProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *VertexProvider {
|
|
normalizedBase := normalizeAPIBase(apiBase)
|
|
if normalizedBase == "" {
|
|
normalizedBase = vertexDefaultBaseURL
|
|
}
|
|
return &VertexProvider{
|
|
base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
|
|
}
|
|
}
|
|
|
|
func (p *VertexProvider) GetDefaultModel() string {
|
|
if p == nil || p.base == nil {
|
|
return ""
|
|
}
|
|
return p.base.GetDefaultModel()
|
|
}
|
|
|
|
func (p *VertexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
|
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if !json.Valid(body) {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if isVertexImagenModel(model) {
|
|
body = convertVertexImagenToGeminiResponse(body, strings.TrimSpace(qwenBaseModel(model)))
|
|
}
|
|
return parseGeminiResponse(body)
|
|
}
|
|
|
|
func (p *VertexProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
|
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if !json.Valid(body) {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
if isVertexImagenModel(model) {
|
|
body = convertVertexImagenToGeminiResponse(body, strings.TrimSpace(qwenBaseModel(model)))
|
|
}
|
|
return parseGeminiResponse(body)
|
|
}
|
|
|
|
func (p *VertexProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) {
|
|
if p == nil || p.base == nil {
|
|
return nil, fmt.Errorf("provider not configured")
|
|
}
|
|
attempts, err := p.vertexAttempts(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var lastBody []byte
|
|
var lastStatus int
|
|
var lastType string
|
|
for _, attempt := range attempts {
|
|
requestBody := p.buildRequestBody(messages, nil, model, options)
|
|
if isVertexImagenModel(model) {
|
|
requestBody, err = convertVertexImagenRequest(requestBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
delete(requestBody, "tools")
|
|
delete(requestBody, "toolConfig")
|
|
delete(requestBody, "generationConfig")
|
|
endpoint := p.endpoint(attempt, model, "countTokens", false, options)
|
|
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, false, nil)
|
|
if reqErr != nil {
|
|
return nil, reqErr
|
|
}
|
|
lastBody, lastStatus, lastType = body, status, ctype
|
|
reason, retry := classifyOAuthFailure(status, body)
|
|
if retry {
|
|
applyAttemptFailure(p.base, attempt, reason, nil)
|
|
continue
|
|
}
|
|
if status != http.StatusOK {
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
|
}
|
|
var payload struct {
|
|
TotalTokens int `json:"totalTokens"`
|
|
}
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return nil, fmt.Errorf("invalid countTokens response: %w", err)
|
|
}
|
|
p.base.markAttemptSuccess(attempt)
|
|
return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil
|
|
}
|
|
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody))
|
|
}
|
|
|
|
func (p *VertexProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
|
if p == nil || p.base == nil {
|
|
return nil, 0, "", fmt.Errorf("provider not configured")
|
|
}
|
|
attempts, err := p.vertexAttempts(ctx)
|
|
if err != nil {
|
|
return nil, 0, "", err
|
|
}
|
|
var lastBody []byte
|
|
var lastStatus int
|
|
var lastType string
|
|
for _, attempt := range attempts {
|
|
requestBody := p.buildRequestBody(messages, tools, model, options)
|
|
if isVertexImagenModel(model) {
|
|
requestBody, err = convertVertexImagenRequest(requestBody)
|
|
if err != nil {
|
|
return nil, 0, "", err
|
|
}
|
|
}
|
|
endpoint := p.endpoint(attempt, model, "generateContent", stream, options)
|
|
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta)
|
|
if reqErr != nil {
|
|
return nil, 0, "", reqErr
|
|
}
|
|
lastBody, lastStatus, lastType = body, status, ctype
|
|
reason, retry := classifyOAuthFailure(status, body)
|
|
if retry {
|
|
applyAttemptFailure(p.base, attempt, reason, nil)
|
|
continue
|
|
}
|
|
p.base.markAttemptSuccess(attempt)
|
|
return body, status, ctype, nil
|
|
}
|
|
return lastBody, lastStatus, lastType, nil
|
|
}
|
|
|
|
func (p *VertexProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
|
reqBody, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
|
|
if err != nil {
|
|
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if stream {
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
} else {
|
|
req.Header.Set("Accept", "application/json")
|
|
}
|
|
if err := p.applyVertexAttemptAuth(ctx, req, attempt); err != nil {
|
|
return nil, 0, "", err
|
|
}
|
|
client, err := p.base.httpClientForAttempt(attempt)
|
|
if err != nil {
|
|
return nil, 0, "", err
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
|
if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") {
|
|
return consumeGeminiStream(resp, onDelta)
|
|
}
|
|
body, readErr := ioReadAll(resp)
|
|
if readErr != nil {
|
|
return nil, resp.StatusCode, ctype, readErr
|
|
}
|
|
return body, resp.StatusCode, ctype, nil
|
|
}
|
|
|
|
func (p *VertexProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]any {
|
|
return (&GeminiProvider{base: p.base}).buildRequestBody(messages, tools, model, options, false)
|
|
}
|
|
|
|
func (p *VertexProvider) endpoint(attempt authAttempt, model, action string, stream bool, options map[string]interface{}) string {
|
|
base := vertexBaseURLForAttempt(p.base, attempt, options)
|
|
baseModel := strings.TrimSpace(qwenBaseModel(model))
|
|
vertexAction := vertexAction(baseModel, stream)
|
|
if projectID, location, ok := vertexProjectLocation(attempt, options); ok {
|
|
if stream {
|
|
if isVertexImagenModel(baseModel) {
|
|
return fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, projectID, location, baseModel, vertexAction)
|
|
}
|
|
return fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s?alt=sse", base, vertexAPIVersion, projectID, location, baseModel, vertexAction)
|
|
}
|
|
return fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, projectID, location, baseModel, vertexAction)
|
|
}
|
|
if stream {
|
|
if isVertexImagenModel(baseModel) {
|
|
return fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, baseModel, vertexAction)
|
|
}
|
|
return fmt.Sprintf("%s/%s/publishers/google/models/%s:%s?alt=sse", base, vertexAPIVersion, baseModel, vertexAction)
|
|
}
|
|
return fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, baseModel, vertexAction)
|
|
}
|
|
|
|
func vertexBaseURLForAttempt(base *HTTPProvider, attempt authAttempt, options map[string]interface{}) string {
|
|
customBase := ""
|
|
if attempt.session != nil && attempt.session.Token != nil {
|
|
if raw := firstNonEmpty(
|
|
strings.TrimSpace(asString(attempt.session.Token["base_url"])),
|
|
strings.TrimSpace(asString(attempt.session.Token["base-url"])),
|
|
); raw != "" {
|
|
customBase = normalizeVertexBaseURL(raw)
|
|
}
|
|
}
|
|
if customBase == "" && base != nil && strings.TrimSpace(base.apiBase) != "" && !strings.Contains(strings.ToLower(base.apiBase), "api.openai.com") {
|
|
customBase = normalizeVertexBaseURL(base.apiBase)
|
|
}
|
|
if customBase != "" && !strings.EqualFold(customBase, vertexDefaultBaseURL) {
|
|
return customBase
|
|
}
|
|
location := vertexLocationForAttempt(attempt, options)
|
|
if location != "" && !strings.EqualFold(location, "global") {
|
|
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", location)
|
|
}
|
|
return vertexDefaultBaseURL
|
|
}
|
|
|
|
func normalizeVertexBaseURL(raw string) string {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
return vertexDefaultBaseURL
|
|
}
|
|
if !strings.Contains(trimmed, "://") {
|
|
trimmed = "https://" + trimmed
|
|
}
|
|
return normalizeAPIBase(trimmed)
|
|
}
|
|
|
|
func vertexProjectLocation(attempt authAttempt, options map[string]interface{}) (string, string, bool) {
|
|
projectID := ""
|
|
for _, key := range []string{"vertex_project_id", "project_id", "project"} {
|
|
if value, ok := stringOption(options, key); ok {
|
|
projectID = strings.TrimSpace(value)
|
|
if projectID != "" {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if attempt.session != nil {
|
|
projectID = firstNonEmpty(projectID, strings.TrimSpace(attempt.session.ProjectID), asString(attempt.session.Token["project_id"]), asString(attempt.session.Token["project-id"]), asString(attempt.session.Token["projectId"]), asString(attempt.session.Token["project"]))
|
|
if projectID == "" {
|
|
projectID = strings.TrimSpace(asString(mapFromAny(attempt.session.Token["service_account"])["project_id"]))
|
|
}
|
|
}
|
|
location := vertexLocationForAttempt(attempt, options)
|
|
if strings.TrimSpace(projectID) == "" {
|
|
return "", "", false
|
|
}
|
|
return projectID, location, true
|
|
}
|
|
|
|
func vertexLocationForAttempt(attempt authAttempt, options map[string]interface{}) string {
|
|
location := ""
|
|
for _, key := range []string{"vertex_location", "location", "region"} {
|
|
if value, ok := stringOption(options, key); ok {
|
|
location = strings.TrimSpace(value)
|
|
if location != "" {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if attempt.session != nil {
|
|
location = firstNonEmpty(location, asString(attempt.session.Token["location"]), asString(attempt.session.Token["region"]), asString(mapFromAny(attempt.session.Token["service_account"])["location"]))
|
|
}
|
|
if strings.TrimSpace(location) == "" {
|
|
location = vertexDefaultRegion
|
|
}
|
|
return location
|
|
}
|
|
|
|
func isVertexImagenModel(model string) bool {
|
|
return strings.Contains(strings.ToLower(strings.TrimSpace(model)), "imagen")
|
|
}
|
|
|
|
func vertexAction(model string, stream bool) string {
|
|
if isVertexImagenModel(model) {
|
|
return "predict"
|
|
}
|
|
if stream {
|
|
return "streamGenerateContent"
|
|
}
|
|
return "generateContent"
|
|
}
|
|
|
|
func convertVertexImagenRequest(payload map[string]any) (map[string]any, error) {
|
|
prompt := strings.TrimSpace(geminiPromptTextFromPayload(payload))
|
|
if prompt == "" {
|
|
return nil, fmt.Errorf("imagen: no prompt found in request")
|
|
}
|
|
request := map[string]any{
|
|
"instances": []map[string]any{{"prompt": prompt}},
|
|
"parameters": map[string]any{
|
|
"sampleCount": 1,
|
|
},
|
|
}
|
|
if generationConfig := mapFromAny(payload["generationConfig"]); len(generationConfig) > 0 {
|
|
if aspectRatio := strings.TrimSpace(asString(generationConfig["aspectRatio"])); aspectRatio != "" {
|
|
request["parameters"].(map[string]any)["aspectRatio"] = aspectRatio
|
|
}
|
|
switch count := generationConfig["sampleCount"].(type) {
|
|
case int:
|
|
if count > 0 {
|
|
request["parameters"].(map[string]any)["sampleCount"] = count
|
|
}
|
|
case float64:
|
|
if int(count) > 0 {
|
|
request["parameters"].(map[string]any)["sampleCount"] = int(count)
|
|
}
|
|
}
|
|
if negativePrompt := strings.TrimSpace(asString(generationConfig["negativePrompt"])); negativePrompt != "" {
|
|
request["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt
|
|
}
|
|
}
|
|
return request, nil
|
|
}
|
|
|
|
func geminiPromptTextFromPayload(payload map[string]any) string {
|
|
contents, _ := payload["contents"].([]map[string]any)
|
|
for _, content := range contents {
|
|
parts, _ := content["parts"].([]map[string]any)
|
|
for _, part := range parts {
|
|
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
|
return text
|
|
}
|
|
}
|
|
}
|
|
return strings.TrimSpace(asString(payload["prompt"]))
|
|
}
|
|
|
|
func convertVertexImagenToGeminiResponse(data []byte, model string) []byte {
|
|
var raw struct {
|
|
Predictions []struct {
|
|
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
|
MIMEType string `json:"mimeType"`
|
|
} `json:"predictions"`
|
|
}
|
|
if err := json.Unmarshal(data, &raw); err != nil || len(raw.Predictions) == 0 {
|
|
return data
|
|
}
|
|
parts := make([]map[string]any, 0, len(raw.Predictions))
|
|
for _, prediction := range raw.Predictions {
|
|
if strings.TrimSpace(prediction.BytesBase64Encoded) == "" {
|
|
continue
|
|
}
|
|
parts = append(parts, map[string]any{
|
|
"inlineData": map[string]any{
|
|
"mimeType": firstNonEmpty(strings.TrimSpace(prediction.MIMEType), "image/png"),
|
|
"data": prediction.BytesBase64Encoded,
|
|
},
|
|
})
|
|
}
|
|
if len(parts) == 0 {
|
|
return data
|
|
}
|
|
converted := map[string]any{
|
|
"candidates": []map[string]any{{
|
|
"content": map[string]any{
|
|
"parts": parts,
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}},
|
|
"responseId": fmt.Sprintf("imagen-%d", time.Now().UnixNano()),
|
|
"modelVersion": model,
|
|
"usageMetadata": map[string]any{
|
|
"promptTokenCount": 0,
|
|
"candidatesTokenCount": 0,
|
|
"totalTokenCount": 0,
|
|
},
|
|
}
|
|
out, err := json.Marshal(converted)
|
|
if err != nil {
|
|
return data
|
|
}
|
|
return out
|
|
}
|
|
|
|
func ioReadAll(resp *http.Response) ([]byte, error) {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (p *VertexProvider) vertexAttempts(ctx context.Context) ([]authAttempt, error) {
|
|
if p == nil || p.base == nil {
|
|
return nil, fmt.Errorf("provider not configured")
|
|
}
|
|
mode := strings.ToLower(strings.TrimSpace(p.base.authMode))
|
|
if mode != "oauth" && mode != "hybrid" {
|
|
return p.base.authAttempts(ctx)
|
|
}
|
|
out := make([]authAttempt, 0, 4)
|
|
if mode == "hybrid" {
|
|
if apiAttempt, apiReady := p.base.apiKeyAttempt(); apiReady {
|
|
out = append(out, apiAttempt)
|
|
}
|
|
}
|
|
if p.base.oauth == nil {
|
|
if len(out) > 0 {
|
|
p.base.updateCandidateOrder(out)
|
|
return out, nil
|
|
}
|
|
return nil, fmt.Errorf("oauth is enabled but provider session manager is not configured")
|
|
}
|
|
manager := p.base.oauth
|
|
manager.mu.Lock()
|
|
sessions, err := manager.loadAllLocked()
|
|
manager.mu.Unlock()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, session := range sessions {
|
|
if session == nil || manager.sessionOnCooldown(session) {
|
|
continue
|
|
}
|
|
token := strings.TrimSpace(session.AccessToken)
|
|
if token == "" && !vertexHasServiceAccount(session) {
|
|
continue
|
|
}
|
|
out = append(out, authAttempt{session: session, token: token, kind: "oauth"})
|
|
}
|
|
if len(out) == 0 {
|
|
return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first")
|
|
}
|
|
p.base.updateCandidateOrder(out)
|
|
return out, nil
|
|
}
|
|
|
|
func (p *VertexProvider) applyVertexAttemptAuth(ctx context.Context, req *http.Request, attempt authAttempt) error {
|
|
if req == nil {
|
|
return nil
|
|
}
|
|
if attempt.kind == "api_key" {
|
|
applyGeminiAttemptAuth(req, attempt)
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(attempt.token) != "" {
|
|
applyGeminiAttemptAuth(req, attempt)
|
|
return nil
|
|
}
|
|
saJSON, err := vertexServiceAccountJSON(attempt.session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
client, err := p.base.httpClientForAttempt(attempt)
|
|
if err == nil && client != nil {
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
|
}
|
|
var serviceAccount struct {
|
|
ClientEmail string `json:"client_email"`
|
|
PrivateKey string `json:"private_key"`
|
|
TokenURI string `json:"token_uri"`
|
|
}
|
|
if err := json.Unmarshal(saJSON, &serviceAccount); err != nil {
|
|
return fmt.Errorf("vertex service account parse failed: %w", err)
|
|
}
|
|
cfg := &jwt.Config{
|
|
Email: strings.TrimSpace(serviceAccount.ClientEmail),
|
|
PrivateKey: []byte(serviceAccount.PrivateKey),
|
|
TokenURL: strings.TrimSpace(serviceAccount.TokenURI),
|
|
Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"},
|
|
}
|
|
tok, err := cfg.TokenSource(ctx).Token()
|
|
if err != nil {
|
|
return fmt.Errorf("vertex service account token failed: %w", err)
|
|
}
|
|
req.Header.Del("x-goog-api-key")
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(tok.AccessToken))
|
|
return nil
|
|
}
|
|
|
|
func vertexHasServiceAccount(session *oauthSession) bool {
|
|
if session == nil || session.Token == nil {
|
|
return false
|
|
}
|
|
return len(mapFromAny(session.Token["service_account"])) > 0
|
|
}
|
|
|
|
func vertexServiceAccountJSON(session *oauthSession) ([]byte, error) {
|
|
if !vertexHasServiceAccount(session) {
|
|
return nil, fmt.Errorf("vertex service account missing")
|
|
}
|
|
raw := mapFromAny(session.Token["service_account"])
|
|
if projectID := firstNonEmpty(asString(raw["project_id"]), asString(raw["project-id"]), strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project-id"]), asString(session.Token["project"])); projectID != "" {
|
|
raw["project_id"] = projectID
|
|
}
|
|
data, err := json.Marshal(raw)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("vertex service account marshal failed: %w", err)
|
|
}
|
|
return data, nil
|
|
}
|