mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-15 00:27:29 +08:00
244 lines
8.0 KiB
Go
244 lines
8.0 KiB
Go
package providers
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/YspCoder/clawgo/pkg/config"
|
|
)
|
|
|
|
func TestCreateProviderByNameRoutesVertexProvider(t *testing.T) {
|
|
cfg := &config.Config{
|
|
Models: config.ModelsConfig{
|
|
Providers: map[string]config.ProviderConfig{
|
|
"vertex": {
|
|
TimeoutSec: 30,
|
|
Models: []string{"gemini-2.5-pro"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
provider, err := CreateProviderByName(cfg, "vertex")
|
|
if err != nil {
|
|
t.Fatalf("CreateProviderByName error: %v", err)
|
|
}
|
|
if _, ok := provider.(*VertexProvider); !ok {
|
|
t.Fatalf("provider = %T, want *VertexProvider", provider)
|
|
}
|
|
}
|
|
|
|
func TestVertexEndpointWithProjectLocation(t *testing.T) {
|
|
p := NewVertexProvider("vertex", "", "", "gemini-2.5-pro", false, "oauth", 5*time.Second, nil)
|
|
endpoint := p.endpoint(authAttempt{
|
|
kind: "oauth",
|
|
session: &oauthSession{
|
|
ProjectID: "demo-project",
|
|
Token: map[string]any{
|
|
"location": "asia-east1",
|
|
},
|
|
},
|
|
}, "gemini-2.5-pro", "generateContent", false, nil)
|
|
want := "https://asia-east1-aiplatform.googleapis.com/v1/projects/demo-project/locations/asia-east1/publishers/google/models/gemini-2.5-pro:generateContent"
|
|
if endpoint != want {
|
|
t.Fatalf("endpoint = %q, want %q", endpoint, want)
|
|
}
|
|
}
|
|
|
|
func TestVertexEndpointUsesPredictForImagen(t *testing.T) {
|
|
p := NewVertexProvider("vertex", "", "", "imagen-4.0-generate-001", false, "oauth", 5*time.Second, nil)
|
|
endpoint := p.endpoint(authAttempt{
|
|
kind: "oauth",
|
|
session: &oauthSession{
|
|
ProjectID: "demo-project",
|
|
Token: map[string]any{
|
|
"location": "us-central1",
|
|
},
|
|
},
|
|
}, "imagen-4.0-generate-001", "generateContent", false, nil)
|
|
want := "https://us-central1-aiplatform.googleapis.com/v1/projects/demo-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-001:predict"
|
|
if endpoint != want {
|
|
t.Fatalf("endpoint = %q, want %q", endpoint, want)
|
|
}
|
|
}
|
|
|
|
func TestVertexEndpointUsesGlobalBaseForGlobalLocation(t *testing.T) {
|
|
p := NewVertexProvider("vertex", "", "", "gemini-2.5-pro", false, "oauth", 5*time.Second, nil)
|
|
endpoint := p.endpoint(authAttempt{
|
|
kind: "oauth",
|
|
session: &oauthSession{
|
|
ProjectID: "demo-project",
|
|
Token: map[string]any{
|
|
"location": "global",
|
|
},
|
|
},
|
|
}, "gemini-2.5-pro", "generateContent", false, nil)
|
|
want := "https://aiplatform.googleapis.com/v1/projects/demo-project/locations/global/publishers/google/models/gemini-2.5-pro:generateContent"
|
|
if endpoint != want {
|
|
t.Fatalf("endpoint = %q, want %q", endpoint, want)
|
|
}
|
|
}
|
|
|
|
func TestConvertVertexImagenRequest(t *testing.T) {
|
|
payload := map[string]any{
|
|
"contents": []map[string]any{
|
|
{"parts": []map[string]any{{"text": "draw a cat"}}},
|
|
},
|
|
"generationConfig": map[string]any{
|
|
"aspectRatio": "1:1",
|
|
"sampleCount": 2.0,
|
|
"negativePrompt": "blurry",
|
|
},
|
|
}
|
|
req, err := convertVertexImagenRequest(payload)
|
|
if err != nil {
|
|
t.Fatalf("convertVertexImagenRequest error: %v", err)
|
|
}
|
|
instances, _ := req["instances"].([]map[string]any)
|
|
if len(instances) != 1 || instances[0]["prompt"] != "draw a cat" {
|
|
t.Fatalf("unexpected instances: %+v", instances)
|
|
}
|
|
params, _ := req["parameters"].(map[string]any)
|
|
if params["aspectRatio"] != "1:1" || params["sampleCount"] != 2 {
|
|
t.Fatalf("unexpected parameters: %+v", params)
|
|
}
|
|
}
|
|
|
|
func TestConvertVertexImagenToGeminiResponse(t *testing.T) {
|
|
body := []byte(`{"predictions":[{"bytesBase64Encoded":"abcd","mimeType":"image/png"}]}`)
|
|
converted := convertVertexImagenToGeminiResponse(body, "imagen-4.0-generate-001")
|
|
resp, err := parseGeminiResponse(converted)
|
|
if err != nil {
|
|
t.Fatalf("parseGeminiResponse error: %v", err)
|
|
}
|
|
if resp.FinishReason != "STOP" {
|
|
t.Fatalf("finish reason = %q", resp.FinishReason)
|
|
}
|
|
}
|
|
|
|
func TestVertexProjectLocationFallsBackToProjectAlias(t *testing.T) {
|
|
projectID, location, ok := vertexProjectLocation(authAttempt{
|
|
kind: "oauth",
|
|
session: &oauthSession{
|
|
Token: map[string]any{
|
|
"project": "demo-project",
|
|
"location": "asia-east1",
|
|
},
|
|
},
|
|
}, nil)
|
|
if !ok || projectID != "demo-project" || location != "asia-east1" {
|
|
t.Fatalf("unexpected project/location: ok=%v project=%q location=%q", ok, projectID, location)
|
|
}
|
|
}
|
|
|
|
func TestVertexServiceAccountJSONBackfillsProjectID(t *testing.T) {
|
|
session := &oauthSession{
|
|
ProjectID: "demo-project",
|
|
Token: map[string]any{
|
|
"service_account": map[string]any{
|
|
"type": "service_account",
|
|
"client_email": "svc@example.com",
|
|
"private_key": "key",
|
|
"token_uri": "https://example.com/token",
|
|
},
|
|
},
|
|
}
|
|
data, err := vertexServiceAccountJSON(session)
|
|
if err != nil {
|
|
t.Fatalf("vertexServiceAccountJSON error: %v", err)
|
|
}
|
|
if !strings.Contains(string(data), `"project_id":"demo-project"`) {
|
|
t.Fatalf("expected project_id in service account json, got %s", string(data))
|
|
}
|
|
}
|
|
|
|
func TestVertexProviderChatWithAPIKey(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/publishers/google/models/gemini-2.5-pro:generateContent" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
if got := r.Header.Get("x-goog-api-key"); got != "token" {
|
|
t.Fatalf("x-goog-api-key = %q", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewVertexProvider("vertex", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil)
|
|
resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat error: %v", err)
|
|
}
|
|
if resp.Content != "ok" {
|
|
t.Fatalf("content = %q, want ok", resp.Content)
|
|
}
|
|
}
|
|
|
|
func TestVertexProviderChatWithServiceAccount(t *testing.T) {
|
|
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"sa-token","token_type":"Bearer","expires_in":3600}`))
|
|
}))
|
|
defer tokenServer.Close()
|
|
|
|
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/projects/demo-project/locations/us-central1/publishers/google/models/gemini-2.5-pro:generateContent" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Authorization"); got != "Bearer sa-token" {
|
|
t.Fatalf("Authorization = %q", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}]}`))
|
|
}))
|
|
defer apiServer.Close()
|
|
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("rsa.GenerateKey: %v", err)
|
|
}
|
|
pemKey := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
|
|
|
p := NewVertexProvider("vertex", "", apiServer.URL, "gemini-2.5-pro", false, "oauth", 5*time.Second, nil)
|
|
p.base.oauth = &oauthManager{}
|
|
p.base.authMode = "oauth"
|
|
p.base.oauth = &oauthManager{
|
|
cfg: oauthConfig{},
|
|
cached: []*oauthSession{{
|
|
ProjectID: "demo-project",
|
|
Token: map[string]any{
|
|
"location": "us-central1",
|
|
"service_account": map[string]any{
|
|
"type": "service_account",
|
|
"project_id": "demo-project",
|
|
"private_key_id": "key-1",
|
|
"private_key": string(pemKey),
|
|
"client_email": "svc@example.iam.gserviceaccount.com",
|
|
"client_id": "1234567890",
|
|
"token_uri": tokenServer.URL,
|
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
|
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
|
"client_x509_cert_url": "https://example.com/cert",
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
|
|
resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat error: %v", err)
|
|
}
|
|
if resp.Content != "ok" {
|
|
t.Fatalf("content = %q, want ok", resp.Content)
|
|
}
|
|
}
|