mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-06-09 14:43:08 +08:00
feat: align google and relay providers
This commit is contained in:
592
pkg/providers/aistudio_provider_test.go
Normal file
592
pkg/providers/aistudio_provider_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/YspCoder/clawgo/pkg/config"
|
||||
"github.com/YspCoder/clawgo/pkg/wsrelay"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestCreateProviderByNameRoutesAIStudioProvider(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Models.Providers["aistudio"] = config.ProviderConfig{
|
||||
Auth: "none",
|
||||
TimeoutSec: 90,
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
}
|
||||
provider, err := CreateProviderByName(cfg, "aistudio")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderByName() error = %v", err)
|
||||
}
|
||||
if _, ok := provider.(*AistudioProvider); !ok {
|
||||
t.Fatalf("expected *AistudioProvider, got %T", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAistudioProviderChatUsesRelay(t *testing.T) {
|
||||
manager, serverURL, cleanup := startRelayTestServer(t)
|
||||
defer cleanup()
|
||||
SetAIStudioRelayManager(manager)
|
||||
|
||||
reqCh := make(chan string, 1)
|
||||
stopClient := connectRelayClient(t, serverURL, "aistudio", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
reqCh <- fmt.Sprintf("%v", msg.Payload["url"])
|
||||
return []wsrelay.Message{{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeHTTPResp,
|
||||
Payload: map[string]any{
|
||||
"status": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
"body": `{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}`,
|
||||
},
|
||||
}}
|
||||
})
|
||||
defer stopClient()
|
||||
|
||||
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if resp.Content != "ok" {
|
||||
t.Fatalf("expected content ok, got %q", resp.Content)
|
||||
}
|
||||
select {
|
||||
case raw := <-reqCh:
|
||||
if !strings.Contains(raw, ":generateContent") {
|
||||
t.Fatalf("expected generateContent endpoint, got %q", raw)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for relay request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAistudioProviderCountTokensUsesRelay(t *testing.T) {
|
||||
manager, serverURL, cleanup := startRelayTestServer(t)
|
||||
defer cleanup()
|
||||
SetAIStudioRelayManager(manager)
|
||||
|
||||
stopClient := connectRelayClient(t, serverURL, "aistudio-count", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeHTTPResp,
|
||||
Payload: map[string]any{
|
||||
"status": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
"body": `{"totalTokens":42}`,
|
||||
},
|
||||
}}
|
||||
})
|
||||
defer stopClient()
|
||||
|
||||
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
||||
usage, err := provider.CountTokens(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", map[string]interface{}{"aistudio_channel": "aistudio-count"})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens() error = %v", err)
|
||||
}
|
||||
if usage.TotalTokens != 42 {
|
||||
t.Fatalf("expected total tokens 42, got %d", usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIStudioRelayRuntimeTracksConnectedChannels(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
delete(providerRuntimeRegistry.api, "aistudio")
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
NotifyAIStudioRelayConnected("aistudio")
|
||||
NotifyAIStudioRelayConnected("aistudio-alt")
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
state := providerRuntimeRegistry.api["aistudio"]
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
if len(state.CandidateOrder) != 2 {
|
||||
t.Fatalf("expected 2 relay candidates, got %d", len(state.CandidateOrder))
|
||||
}
|
||||
if state.CandidateOrder[0].Kind != "relay" {
|
||||
t.Fatalf("expected relay candidate kind, got %q", state.CandidateOrder[0].Kind)
|
||||
}
|
||||
|
||||
NotifyAIStudioRelayDisconnected("aistudio-alt", nil)
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
state = providerRuntimeRegistry.api["aistudio"]
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
if len(state.CandidateOrder) != 1 || state.CandidateOrder[0].Target != "aistudio" {
|
||||
t.Fatalf("unexpected candidate order after disconnect: %+v", state.CandidateOrder)
|
||||
}
|
||||
|
||||
NotifyAIStudioRelayDisconnected("aistudio", nil)
|
||||
}
|
||||
|
||||
func TestGetProviderRuntimeSnapshotIncludesAIStudioRelayAccounts(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
delete(providerRuntimeRegistry.api, "aistudio")
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
NotifyAIStudioRelayConnected("aistudio")
|
||||
defer NotifyAIStudioRelayDisconnected("aistudio", nil)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Models.Providers["aistudio"] = config.ProviderConfig{
|
||||
TimeoutSec: 30,
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
}
|
||||
|
||||
snapshot := GetProviderRuntimeSnapshot(cfg)
|
||||
items, _ := snapshot["items"].([]map[string]interface{})
|
||||
if len(items) == 0 {
|
||||
t.Fatal("expected snapshot items")
|
||||
}
|
||||
|
||||
var found map[string]interface{}
|
||||
for _, item := range items {
|
||||
if strings.TrimSpace(fmt.Sprintf("%v", item["name"])) == "aistudio" {
|
||||
found = item
|
||||
break
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
t.Fatal("expected aistudio snapshot item")
|
||||
}
|
||||
|
||||
accounts, _ := found["oauth_accounts"].([]OAuthAccountInfo)
|
||||
if len(accounts) != 1 || accounts[0].AccountLabel != "aistudio" {
|
||||
t.Fatalf("unexpected relay accounts: %+v", accounts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIStudioRelayAccountsUseLastSuccessAsRefresh(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
aistudioRelayRegistry.mu.Lock()
|
||||
aistudioRelayRegistry.connected = map[string]time.Time{"aistudio": now.Add(-10 * time.Minute)}
|
||||
aistudioRelayRegistry.succeeded = map[string]time.Time{"aistudio": now}
|
||||
aistudioRelayRegistry.mu.Unlock()
|
||||
|
||||
accounts := listAIStudioRelayAccounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected one account, got %d", len(accounts))
|
||||
}
|
||||
if accounts[0].LastRefresh != now.Format(time.RFC3339) {
|
||||
t.Fatalf("expected last refresh %s, got %s", now.Format(time.RFC3339), accounts[0].LastRefresh)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIStudioRelayRuntimeRecordsFailureAndRecovery(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
delete(providerRuntimeRegistry.api, "aistudio")
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
NotifyAIStudioRelayConnected("aistudio")
|
||||
defer NotifyAIStudioRelayDisconnected("aistudio", nil)
|
||||
|
||||
recordAIStudioRelayFailure("aistudio", fmt.Errorf("boom"))
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
state := providerRuntimeRegistry.api["aistudio"]
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
if state.API.FailureCount == 0 {
|
||||
t.Fatal("expected api failure count to increase")
|
||||
}
|
||||
if len(state.RecentErrors) == 0 {
|
||||
t.Fatal("expected recent errors to be recorded")
|
||||
}
|
||||
if len(state.CandidateOrder) == 0 || state.CandidateOrder[0].Status != "cooldown" {
|
||||
t.Fatalf("expected relay candidate cooldown, got %+v", state.CandidateOrder)
|
||||
}
|
||||
|
||||
recordAIStudioRelaySuccess("aistudio")
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
state = providerRuntimeRegistry.api["aistudio"]
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
if state.LastSuccess == nil {
|
||||
t.Fatal("expected last success event")
|
||||
}
|
||||
if len(state.CandidateOrder) == 0 || state.CandidateOrder[0].Status != "ready" {
|
||||
t.Fatalf("expected relay candidate recovery, got %+v", state.CandidateOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIStudioChannelIDPrefersHealthyAvailableRelay(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
||||
CandidateOrder: []providerRuntimeCandidate{
|
||||
{Kind: "relay", Target: "aistudio-bad", Available: false, Status: "cooldown", CooldownUntil: time.Now().Add(5 * time.Minute).Format(time.RFC3339), HealthScore: 90},
|
||||
{Kind: "relay", Target: "aistudio-good", Available: true, Status: "ready", HealthScore: 100},
|
||||
},
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
got := aistudioChannelID("aistudio", nil)
|
||||
if got != "aistudio-good" {
|
||||
t.Fatalf("expected aistudio-good, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIStudioChannelIDExplicitOptionWins(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
||||
CandidateOrder: []providerRuntimeCandidate{
|
||||
{Kind: "relay", Target: "aistudio-good", Available: true, Status: "ready", HealthScore: 100},
|
||||
},
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
got := aistudioChannelID("aistudio", map[string]interface{}{"aistudio_channel": "manual"})
|
||||
if got != "manual" {
|
||||
t.Fatalf("expected explicit channel manual, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIStudioChannelIDPrefersMostRecentSuccessfulRelay(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
||||
CandidateOrder: []providerRuntimeCandidate{
|
||||
{Kind: "relay", Target: "aistudio-a", Available: true, Status: "ready", HealthScore: 100},
|
||||
{Kind: "relay", Target: "aistudio-b", Available: true, Status: "ready", HealthScore: 100},
|
||||
},
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
aistudioRelayRegistry.mu.Lock()
|
||||
if aistudioRelayRegistry.succeeded == nil {
|
||||
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
||||
}
|
||||
aistudioRelayRegistry.succeeded["aistudio-a"] = time.Now().Add(-1 * time.Minute)
|
||||
aistudioRelayRegistry.succeeded["aistudio-b"] = time.Now()
|
||||
aistudioRelayRegistry.mu.Unlock()
|
||||
|
||||
got := aistudioChannelID("aistudio", nil)
|
||||
if got != "aistudio-b" {
|
||||
t.Fatalf("expected most recent successful relay aistudio-b, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAistudioProviderChatFailsOverToNextRelay(t *testing.T) {
|
||||
manager, serverURL, cleanup := startRelayTestServer(t)
|
||||
defer cleanup()
|
||||
SetAIStudioRelayManager(manager)
|
||||
aistudioRelayRegistry.mu.Lock()
|
||||
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
||||
aistudioRelayRegistry.mu.Unlock()
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
||||
CandidateOrder: []providerRuntimeCandidate{
|
||||
{Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100},
|
||||
{Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90},
|
||||
},
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeHTTPResp,
|
||||
Payload: map[string]any{
|
||||
"status": 503,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
"body": `{"error":"no capacity"}`,
|
||||
},
|
||||
}}
|
||||
})
|
||||
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeHTTPResp,
|
||||
Payload: map[string]any{
|
||||
"status": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
"body": `{"candidates":[{"content":{"parts":[{"text":"ok-2"}]},"finishReason":"STOP"}]}`,
|
||||
},
|
||||
}}
|
||||
})
|
||||
defer stopSecond()
|
||||
|
||||
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if resp.Content != "ok-2" {
|
||||
t.Fatalf("expected failover response ok-2, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAistudioProviderChatExplicitChannelDoesNotFailOver(t *testing.T) {
|
||||
manager, serverURL, cleanup := startRelayTestServer(t)
|
||||
defer cleanup()
|
||||
SetAIStudioRelayManager(manager)
|
||||
aistudioRelayRegistry.mu.Lock()
|
||||
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
||||
aistudioRelayRegistry.mu.Unlock()
|
||||
|
||||
stopFirst := connectRelayClient(t, serverURL, "manual", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeHTTPResp,
|
||||
Payload: map[string]any{
|
||||
"status": 503,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
"body": `{"error":"no capacity"}`,
|
||||
},
|
||||
}}
|
||||
})
|
||||
defer stopFirst()
|
||||
|
||||
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeHTTPResp,
|
||||
Payload: map[string]any{
|
||||
"status": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
"body": `{"candidates":[{"content":{"parts":[{"text":"ok-2"}]},"finishReason":"STOP"}]}`,
|
||||
},
|
||||
}}
|
||||
})
|
||||
defer stopSecond()
|
||||
|
||||
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
||||
_, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", map[string]interface{}{"aistudio_channel": "manual"})
|
||||
if err == nil {
|
||||
t.Fatal("expected explicit relay failure without failover")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAistudioProviderStreamFailsOverBeforeFirstChunk(t *testing.T) {
|
||||
manager, serverURL, cleanup := startRelayTestServer(t)
|
||||
defer cleanup()
|
||||
SetAIStudioRelayManager(manager)
|
||||
aistudioRelayRegistry.mu.Lock()
|
||||
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
||||
aistudioRelayRegistry.mu.Unlock()
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
||||
CandidateOrder: []providerRuntimeCandidate{
|
||||
{Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100},
|
||||
{Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90},
|
||||
},
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
stopFirst := connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamStart,
|
||||
Payload: map[string]any{
|
||||
"status": 503,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamEnd,
|
||||
},
|
||||
}
|
||||
})
|
||||
defer stopFirst()
|
||||
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamStart,
|
||||
Payload: map[string]any{
|
||||
"status": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamChunk,
|
||||
Payload: map[string]any{
|
||||
"data": `{"candidates":[{"content":{"parts":[{"text":"ok-stream"}]}}]}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamEnd,
|
||||
},
|
||||
}
|
||||
})
|
||||
defer stopSecond()
|
||||
|
||||
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
||||
var deltas []string
|
||||
resp, err := provider.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil, func(delta string) {
|
||||
deltas = append(deltas, delta)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStream() error = %v", err)
|
||||
}
|
||||
if resp.Content != "ok-stream" {
|
||||
t.Fatalf("expected failover stream content ok-stream, got %q", resp.Content)
|
||||
}
|
||||
if len(deltas) == 0 {
|
||||
t.Fatal("expected stream deltas after failover")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAistudioProviderStreamDoesNotFailOverAfterChunk(t *testing.T) {
|
||||
manager, serverURL, cleanup := startRelayTestServer(t)
|
||||
defer cleanup()
|
||||
SetAIStudioRelayManager(manager)
|
||||
aistudioRelayRegistry.mu.Lock()
|
||||
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
||||
aistudioRelayRegistry.mu.Unlock()
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
||||
CandidateOrder: []providerRuntimeCandidate{
|
||||
{Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100},
|
||||
{Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90},
|
||||
},
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
stopFirst := connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamStart,
|
||||
Payload: map[string]any{
|
||||
"status": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamChunk,
|
||||
Payload: map[string]any{
|
||||
"data": `{"candidates":[{"content":{"parts":[{"text":"partial"}]}}]}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeError,
|
||||
Payload: map[string]any{"error": "stream broke", "status": 502.0},
|
||||
},
|
||||
}
|
||||
})
|
||||
defer stopFirst()
|
||||
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
||||
return []wsrelay.Message{
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamStart,
|
||||
Payload: map[string]any{
|
||||
"status": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamChunk,
|
||||
Payload: map[string]any{
|
||||
"data": `{"candidates":[{"content":{"parts":[{"text":"should-not-use"}]}}]}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: msg.ID,
|
||||
Type: wsrelay.MessageTypeStreamEnd,
|
||||
},
|
||||
}
|
||||
})
|
||||
defer stopSecond()
|
||||
|
||||
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
||||
_, err := provider.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected stream error without mid-stream failover")
|
||||
}
|
||||
if strings.Contains(err.Error(), "should-not-use") {
|
||||
t.Fatalf("unexpected failover after stream started: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startRelayTestServer(t *testing.T) (*wsrelay.Manager, string, func()) {
|
||||
t.Helper()
|
||||
manager := wsrelay.NewManager(wsrelay.Options{
|
||||
Path: "/v1/ws",
|
||||
ProviderFactory: func(r *http.Request) (string, error) {
|
||||
return strings.ToLower(strings.TrimSpace(r.URL.Query().Get("provider"))), nil
|
||||
},
|
||||
})
|
||||
srv := httptest.NewServer(manager.Handler())
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + manager.Path()
|
||||
cleanup := func() {
|
||||
_ = manager.Stop(context.Background())
|
||||
srv.Close()
|
||||
SetAIStudioRelayManager(nil)
|
||||
}
|
||||
return manager, wsURL, cleanup
|
||||
}
|
||||
|
||||
func connectRelayClient(t *testing.T, wsURL, provider string, handle func(wsrelay.Message) []wsrelay.Message) func() {
|
||||
t.Helper()
|
||||
u, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("provider", provider)
|
||||
u.RawQuery = q.Encode()
|
||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial() error = %v", err)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
var msg wsrelay.Message
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return
|
||||
}
|
||||
for _, out := range handle(msg) {
|
||||
if err := conn.WriteJSON(out); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
_ = conn.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user