mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-14 19:37:31 +08:00
593 lines
18 KiB
Go
593 lines
18 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/YspCoder/clawgo/pkg/config"
|
|
"github.com/YspCoder/clawgo/pkg/wsrelay"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
func TestCreateProviderByNameRoutesAIStudioProvider(t *testing.T) {
|
|
cfg := config.DefaultConfig()
|
|
cfg.Models.Providers["aistudio"] = config.ProviderConfig{
|
|
Auth: "none",
|
|
TimeoutSec: 90,
|
|
Models: []string{"gemini-2.5-pro"},
|
|
}
|
|
provider, err := CreateProviderByName(cfg, "aistudio")
|
|
if err != nil {
|
|
t.Fatalf("CreateProviderByName() error = %v", err)
|
|
}
|
|
if _, ok := provider.(*AistudioProvider); !ok {
|
|
t.Fatalf("expected *AistudioProvider, got %T", provider)
|
|
}
|
|
}
|
|
|
|
func TestAistudioProviderChatUsesRelay(t *testing.T) {
|
|
manager, serverURL, cleanup := startRelayTestServer(t)
|
|
defer cleanup()
|
|
SetAIStudioRelayManager(manager)
|
|
|
|
reqCh := make(chan string, 1)
|
|
stopClient := connectRelayClient(t, serverURL, "aistudio", func(msg wsrelay.Message) []wsrelay.Message {
|
|
reqCh <- fmt.Sprintf("%v", msg.Payload["url"])
|
|
return []wsrelay.Message{{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeHTTPResp,
|
|
Payload: map[string]any{
|
|
"status": 200,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"application/json"},
|
|
},
|
|
"body": `{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}`,
|
|
},
|
|
}}
|
|
})
|
|
defer stopClient()
|
|
|
|
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
|
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
if resp.Content != "ok" {
|
|
t.Fatalf("expected content ok, got %q", resp.Content)
|
|
}
|
|
select {
|
|
case raw := <-reqCh:
|
|
if !strings.Contains(raw, ":generateContent") {
|
|
t.Fatalf("expected generateContent endpoint, got %q", raw)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for relay request")
|
|
}
|
|
}
|
|
|
|
func TestAistudioProviderCountTokensUsesRelay(t *testing.T) {
|
|
manager, serverURL, cleanup := startRelayTestServer(t)
|
|
defer cleanup()
|
|
SetAIStudioRelayManager(manager)
|
|
|
|
stopClient := connectRelayClient(t, serverURL, "aistudio-count", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeHTTPResp,
|
|
Payload: map[string]any{
|
|
"status": 200,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"application/json"},
|
|
},
|
|
"body": `{"totalTokens":42}`,
|
|
},
|
|
}}
|
|
})
|
|
defer stopClient()
|
|
|
|
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
|
usage, err := provider.CountTokens(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", map[string]interface{}{"aistudio_channel": "aistudio-count"})
|
|
if err != nil {
|
|
t.Fatalf("CountTokens() error = %v", err)
|
|
}
|
|
if usage.TotalTokens != 42 {
|
|
t.Fatalf("expected total tokens 42, got %d", usage.TotalTokens)
|
|
}
|
|
}
|
|
|
|
func TestAIStudioRelayRuntimeTracksConnectedChannels(t *testing.T) {
|
|
providerRuntimeRegistry.mu.Lock()
|
|
delete(providerRuntimeRegistry.api, "aistudio")
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
NotifyAIStudioRelayConnected("aistudio")
|
|
NotifyAIStudioRelayConnected("aistudio-alt")
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
state := providerRuntimeRegistry.api["aistudio"]
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
if len(state.CandidateOrder) != 2 {
|
|
t.Fatalf("expected 2 relay candidates, got %d", len(state.CandidateOrder))
|
|
}
|
|
if state.CandidateOrder[0].Kind != "relay" {
|
|
t.Fatalf("expected relay candidate kind, got %q", state.CandidateOrder[0].Kind)
|
|
}
|
|
|
|
NotifyAIStudioRelayDisconnected("aistudio-alt", nil)
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
state = providerRuntimeRegistry.api["aistudio"]
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
if len(state.CandidateOrder) != 1 || state.CandidateOrder[0].Target != "aistudio" {
|
|
t.Fatalf("unexpected candidate order after disconnect: %+v", state.CandidateOrder)
|
|
}
|
|
|
|
NotifyAIStudioRelayDisconnected("aistudio", nil)
|
|
}
|
|
|
|
func TestGetProviderRuntimeSnapshotIncludesAIStudioRelayAccounts(t *testing.T) {
|
|
providerRuntimeRegistry.mu.Lock()
|
|
delete(providerRuntimeRegistry.api, "aistudio")
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
NotifyAIStudioRelayConnected("aistudio")
|
|
defer NotifyAIStudioRelayDisconnected("aistudio", nil)
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Models.Providers["aistudio"] = config.ProviderConfig{
|
|
TimeoutSec: 30,
|
|
Models: []string{"gemini-2.5-pro"},
|
|
}
|
|
|
|
snapshot := GetProviderRuntimeSnapshot(cfg)
|
|
items, _ := snapshot["items"].([]map[string]interface{})
|
|
if len(items) == 0 {
|
|
t.Fatal("expected snapshot items")
|
|
}
|
|
|
|
var found map[string]interface{}
|
|
for _, item := range items {
|
|
if strings.TrimSpace(fmt.Sprintf("%v", item["name"])) == "aistudio" {
|
|
found = item
|
|
break
|
|
}
|
|
}
|
|
if found == nil {
|
|
t.Fatal("expected aistudio snapshot item")
|
|
}
|
|
|
|
accounts, _ := found["oauth_accounts"].([]OAuthAccountInfo)
|
|
if len(accounts) != 1 || accounts[0].AccountLabel != "aistudio" {
|
|
t.Fatalf("unexpected relay accounts: %+v", accounts)
|
|
}
|
|
}
|
|
|
|
func TestAIStudioRelayAccountsUseLastSuccessAsRefresh(t *testing.T) {
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
aistudioRelayRegistry.mu.Lock()
|
|
aistudioRelayRegistry.connected = map[string]time.Time{"aistudio": now.Add(-10 * time.Minute)}
|
|
aistudioRelayRegistry.succeeded = map[string]time.Time{"aistudio": now}
|
|
aistudioRelayRegistry.mu.Unlock()
|
|
|
|
accounts := listAIStudioRelayAccounts()
|
|
if len(accounts) != 1 {
|
|
t.Fatalf("expected one account, got %d", len(accounts))
|
|
}
|
|
if accounts[0].LastRefresh != now.Format(time.RFC3339) {
|
|
t.Fatalf("expected last refresh %s, got %s", now.Format(time.RFC3339), accounts[0].LastRefresh)
|
|
}
|
|
}
|
|
|
|
func TestAIStudioRelayRuntimeRecordsFailureAndRecovery(t *testing.T) {
|
|
providerRuntimeRegistry.mu.Lock()
|
|
delete(providerRuntimeRegistry.api, "aistudio")
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
NotifyAIStudioRelayConnected("aistudio")
|
|
defer NotifyAIStudioRelayDisconnected("aistudio", nil)
|
|
|
|
recordAIStudioRelayFailure("aistudio", fmt.Errorf("boom"))
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
state := providerRuntimeRegistry.api["aistudio"]
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
if state.API.FailureCount == 0 {
|
|
t.Fatal("expected api failure count to increase")
|
|
}
|
|
if len(state.RecentErrors) == 0 {
|
|
t.Fatal("expected recent errors to be recorded")
|
|
}
|
|
if len(state.CandidateOrder) == 0 || state.CandidateOrder[0].Status != "cooldown" {
|
|
t.Fatalf("expected relay candidate cooldown, got %+v", state.CandidateOrder)
|
|
}
|
|
|
|
recordAIStudioRelaySuccess("aistudio")
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
state = providerRuntimeRegistry.api["aistudio"]
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
if state.LastSuccess == nil {
|
|
t.Fatal("expected last success event")
|
|
}
|
|
if len(state.CandidateOrder) == 0 || state.CandidateOrder[0].Status != "ready" {
|
|
t.Fatalf("expected relay candidate recovery, got %+v", state.CandidateOrder)
|
|
}
|
|
}
|
|
|
|
func TestAIStudioChannelIDPrefersHealthyAvailableRelay(t *testing.T) {
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
|
CandidateOrder: []providerRuntimeCandidate{
|
|
{Kind: "relay", Target: "aistudio-bad", Available: false, Status: "cooldown", CooldownUntil: time.Now().Add(5 * time.Minute).Format(time.RFC3339), HealthScore: 90},
|
|
{Kind: "relay", Target: "aistudio-good", Available: true, Status: "ready", HealthScore: 100},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
got := aistudioChannelCandidates("aistudio", nil)[0]
|
|
if got != "aistudio-good" {
|
|
t.Fatalf("expected aistudio-good, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestAIStudioChannelIDExplicitOptionWins(t *testing.T) {
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
|
CandidateOrder: []providerRuntimeCandidate{
|
|
{Kind: "relay", Target: "aistudio-good", Available: true, Status: "ready", HealthScore: 100},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
got := aistudioChannelCandidates("aistudio", map[string]interface{}{"aistudio_channel": "manual"})[0]
|
|
if got != "manual" {
|
|
t.Fatalf("expected explicit channel manual, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestAIStudioChannelIDPrefersMostRecentSuccessfulRelay(t *testing.T) {
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
|
CandidateOrder: []providerRuntimeCandidate{
|
|
{Kind: "relay", Target: "aistudio-a", Available: true, Status: "ready", HealthScore: 100},
|
|
{Kind: "relay", Target: "aistudio-b", Available: true, Status: "ready", HealthScore: 100},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
aistudioRelayRegistry.mu.Lock()
|
|
if aistudioRelayRegistry.succeeded == nil {
|
|
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
|
}
|
|
aistudioRelayRegistry.succeeded["aistudio-a"] = time.Now().Add(-1 * time.Minute)
|
|
aistudioRelayRegistry.succeeded["aistudio-b"] = time.Now()
|
|
aistudioRelayRegistry.mu.Unlock()
|
|
|
|
got := aistudioChannelCandidates("aistudio", nil)[0]
|
|
if got != "aistudio-b" {
|
|
t.Fatalf("expected most recent successful relay aistudio-b, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestAistudioProviderChatFailsOverToNextRelay(t *testing.T) {
|
|
manager, serverURL, cleanup := startRelayTestServer(t)
|
|
defer cleanup()
|
|
SetAIStudioRelayManager(manager)
|
|
aistudioRelayRegistry.mu.Lock()
|
|
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
|
aistudioRelayRegistry.mu.Unlock()
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
|
CandidateOrder: []providerRuntimeCandidate{
|
|
{Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100},
|
|
{Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeHTTPResp,
|
|
Payload: map[string]any{
|
|
"status": 503,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"application/json"},
|
|
},
|
|
"body": `{"error":"no capacity"}`,
|
|
},
|
|
}}
|
|
})
|
|
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeHTTPResp,
|
|
Payload: map[string]any{
|
|
"status": 200,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"application/json"},
|
|
},
|
|
"body": `{"candidates":[{"content":{"parts":[{"text":"ok-2"}]},"finishReason":"STOP"}]}`,
|
|
},
|
|
}}
|
|
})
|
|
defer stopSecond()
|
|
|
|
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
|
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
if resp.Content != "ok-2" {
|
|
t.Fatalf("expected failover response ok-2, got %q", resp.Content)
|
|
}
|
|
}
|
|
|
|
func TestAistudioProviderChatExplicitChannelDoesNotFailOver(t *testing.T) {
|
|
manager, serverURL, cleanup := startRelayTestServer(t)
|
|
defer cleanup()
|
|
SetAIStudioRelayManager(manager)
|
|
aistudioRelayRegistry.mu.Lock()
|
|
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
|
aistudioRelayRegistry.mu.Unlock()
|
|
|
|
stopFirst := connectRelayClient(t, serverURL, "manual", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeHTTPResp,
|
|
Payload: map[string]any{
|
|
"status": 503,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"application/json"},
|
|
},
|
|
"body": `{"error":"no capacity"}`,
|
|
},
|
|
}}
|
|
})
|
|
defer stopFirst()
|
|
|
|
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeHTTPResp,
|
|
Payload: map[string]any{
|
|
"status": 200,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"application/json"},
|
|
},
|
|
"body": `{"candidates":[{"content":{"parts":[{"text":"ok-2"}]},"finishReason":"STOP"}]}`,
|
|
},
|
|
}}
|
|
})
|
|
defer stopSecond()
|
|
|
|
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
|
_, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", map[string]interface{}{"aistudio_channel": "manual"})
|
|
if err == nil {
|
|
t.Fatal("expected explicit relay failure without failover")
|
|
}
|
|
}
|
|
|
|
func TestAistudioProviderStreamFailsOverBeforeFirstChunk(t *testing.T) {
|
|
manager, serverURL, cleanup := startRelayTestServer(t)
|
|
defer cleanup()
|
|
SetAIStudioRelayManager(manager)
|
|
aistudioRelayRegistry.mu.Lock()
|
|
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
|
aistudioRelayRegistry.mu.Unlock()
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
|
CandidateOrder: []providerRuntimeCandidate{
|
|
{Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100},
|
|
{Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
stopFirst := connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamStart,
|
|
Payload: map[string]any{
|
|
"status": 503,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"text/event-stream"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamEnd,
|
|
},
|
|
}
|
|
})
|
|
defer stopFirst()
|
|
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamStart,
|
|
Payload: map[string]any{
|
|
"status": 200,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"text/event-stream"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamChunk,
|
|
Payload: map[string]any{
|
|
"data": `{"candidates":[{"content":{"parts":[{"text":"ok-stream"}]}}]}`,
|
|
},
|
|
},
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamEnd,
|
|
},
|
|
}
|
|
})
|
|
defer stopSecond()
|
|
|
|
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
|
var deltas []string
|
|
resp, err := provider.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil, func(delta string) {
|
|
deltas = append(deltas, delta)
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("ChatStream() error = %v", err)
|
|
}
|
|
if resp.Content != "ok-stream" {
|
|
t.Fatalf("expected failover stream content ok-stream, got %q", resp.Content)
|
|
}
|
|
if len(deltas) == 0 {
|
|
t.Fatal("expected stream deltas after failover")
|
|
}
|
|
}
|
|
|
|
func TestAistudioProviderStreamDoesNotFailOverAfterChunk(t *testing.T) {
|
|
manager, serverURL, cleanup := startRelayTestServer(t)
|
|
defer cleanup()
|
|
SetAIStudioRelayManager(manager)
|
|
aistudioRelayRegistry.mu.Lock()
|
|
aistudioRelayRegistry.succeeded = map[string]time.Time{}
|
|
aistudioRelayRegistry.mu.Unlock()
|
|
|
|
providerRuntimeRegistry.mu.Lock()
|
|
providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{
|
|
CandidateOrder: []providerRuntimeCandidate{
|
|
{Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100},
|
|
{Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90},
|
|
},
|
|
}
|
|
providerRuntimeRegistry.mu.Unlock()
|
|
|
|
stopFirst := connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamStart,
|
|
Payload: map[string]any{
|
|
"status": 200,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"text/event-stream"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamChunk,
|
|
Payload: map[string]any{
|
|
"data": `{"candidates":[{"content":{"parts":[{"text":"partial"}]}}]}`,
|
|
},
|
|
},
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeError,
|
|
Payload: map[string]any{"error": "stream broke", "status": 502.0},
|
|
},
|
|
}
|
|
})
|
|
defer stopFirst()
|
|
stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message {
|
|
return []wsrelay.Message{
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamStart,
|
|
Payload: map[string]any{
|
|
"status": 200,
|
|
"headers": map[string]any{
|
|
"Content-Type": []string{"text/event-stream"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamChunk,
|
|
Payload: map[string]any{
|
|
"data": `{"candidates":[{"content":{"parts":[{"text":"should-not-use"}]}}]}`,
|
|
},
|
|
},
|
|
{
|
|
ID: msg.ID,
|
|
Type: wsrelay.MessageTypeStreamEnd,
|
|
},
|
|
}
|
|
})
|
|
defer stopSecond()
|
|
|
|
provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil)
|
|
_, err := provider.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil, nil)
|
|
if err == nil {
|
|
t.Fatal("expected stream error without mid-stream failover")
|
|
}
|
|
if strings.Contains(err.Error(), "should-not-use") {
|
|
t.Fatalf("unexpected failover after stream started: %v", err)
|
|
}
|
|
}
|
|
|
|
func startRelayTestServer(t *testing.T) (*wsrelay.Manager, string, func()) {
|
|
t.Helper()
|
|
manager := wsrelay.NewManager(wsrelay.Options{
|
|
Path: "/v1/ws",
|
|
ProviderFactory: func(r *http.Request) (string, error) {
|
|
return strings.ToLower(strings.TrimSpace(r.URL.Query().Get("provider"))), nil
|
|
},
|
|
})
|
|
srv := httptest.NewServer(manager.Handler())
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + manager.Path()
|
|
cleanup := func() {
|
|
_ = manager.Stop(context.Background())
|
|
srv.Close()
|
|
SetAIStudioRelayManager(nil)
|
|
}
|
|
return manager, wsURL, cleanup
|
|
}
|
|
|
|
func connectRelayClient(t *testing.T, wsURL, provider string, handle func(wsrelay.Message) []wsrelay.Message) func() {
|
|
t.Helper()
|
|
u, err := url.Parse(wsURL)
|
|
if err != nil {
|
|
t.Fatalf("url.Parse() error = %v", err)
|
|
}
|
|
q := u.Query()
|
|
q.Set("provider", provider)
|
|
u.RawQuery = q.Encode()
|
|
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if err != nil {
|
|
t.Fatalf("Dial() error = %v", err)
|
|
}
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
for {
|
|
var msg wsrelay.Message
|
|
if err := conn.ReadJSON(&msg); err != nil {
|
|
return
|
|
}
|
|
for _, out := range handle(msg) {
|
|
if err := conn.WriteJSON(out); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
return func() {
|
|
_ = conn.Close()
|
|
<-done
|
|
}
|
|
}
|