mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-05-17 11:47:28 +08:00
feat: align google and relay providers
This commit is contained in:
192
pkg/wsrelay/manager.go
Normal file
192
pkg/wsrelay/manager.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
path string
|
||||
upgrader websocket.Upgrader
|
||||
sessions map[string]*session
|
||||
sessMutex sync.RWMutex
|
||||
|
||||
providerFactory func(*http.Request) (string, error)
|
||||
onConnected func(string)
|
||||
onDisconnected func(string, error)
|
||||
|
||||
logDebugf func(string, ...any)
|
||||
logInfof func(string, ...any)
|
||||
logWarnf func(string, ...any)
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Path string
|
||||
ProviderFactory func(*http.Request) (string, error)
|
||||
OnConnected func(string)
|
||||
OnDisconnected func(string, error)
|
||||
LogDebugf func(string, ...any)
|
||||
LogInfof func(string, ...any)
|
||||
LogWarnf func(string, ...any)
|
||||
}
|
||||
|
||||
func NewManager(opts Options) *Manager {
|
||||
path := strings.TrimSpace(opts.Path)
|
||||
if path == "" {
|
||||
path = "/v1/ws"
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
mgr := &Manager{
|
||||
path: path,
|
||||
sessions: make(map[string]*session),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(*http.Request) bool { return true },
|
||||
},
|
||||
providerFactory: opts.ProviderFactory,
|
||||
onConnected: opts.OnConnected,
|
||||
onDisconnected: opts.OnDisconnected,
|
||||
logDebugf: opts.LogDebugf,
|
||||
logInfof: opts.LogInfof,
|
||||
logWarnf: opts.LogWarnf,
|
||||
}
|
||||
if mgr.logDebugf == nil {
|
||||
mgr.logDebugf = func(string, ...any) {}
|
||||
}
|
||||
if mgr.logInfof == nil {
|
||||
mgr.logInfof = func(string, ...any) {}
|
||||
}
|
||||
if mgr.logWarnf == nil {
|
||||
mgr.logWarnf = func(string, ...any) {}
|
||||
}
|
||||
return mgr
|
||||
}
|
||||
|
||||
func (m *Manager) Path() string {
|
||||
if m == nil {
|
||||
return "/v1/ws"
|
||||
}
|
||||
return m.path
|
||||
}
|
||||
|
||||
func (m *Manager) Handler() http.Handler {
|
||||
return http.HandlerFunc(m.handleWebsocket)
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(_ context.Context) error {
|
||||
m.sessMutex.Lock()
|
||||
sessions := make([]*session, 0, len(m.sessions))
|
||||
for _, sess := range m.sessions {
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
m.sessions = make(map[string]*session)
|
||||
m.sessMutex.Unlock()
|
||||
|
||||
for _, sess := range sessions {
|
||||
if sess != nil {
|
||||
sess.cleanup(errors.New("wsrelay: manager stopped"))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
expectedPath := m.Path()
|
||||
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Method, http.MethodGet) {
|
||||
w.Header().Set("Allow", http.MethodGet)
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
conn, err := m.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
m.logWarnf("wsrelay: upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
s := newSession(conn, m, randomProviderName())
|
||||
if m.providerFactory != nil {
|
||||
name, err := m.providerFactory(r)
|
||||
if err != nil {
|
||||
s.cleanup(err)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(name) != "" {
|
||||
s.provider = strings.ToLower(strings.TrimSpace(name))
|
||||
}
|
||||
}
|
||||
if s.provider == "" {
|
||||
s.provider = strings.ToLower(s.id)
|
||||
}
|
||||
m.sessMutex.Lock()
|
||||
var replaced *session
|
||||
if existing, ok := m.sessions[s.provider]; ok {
|
||||
replaced = existing
|
||||
}
|
||||
m.sessions[s.provider] = s
|
||||
m.sessMutex.Unlock()
|
||||
|
||||
if replaced != nil {
|
||||
replaced.cleanup(errors.New("replaced by new connection"))
|
||||
}
|
||||
if m.onConnected != nil {
|
||||
m.onConnected(s.provider)
|
||||
}
|
||||
go s.run(context.Background())
|
||||
}
|
||||
|
||||
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
|
||||
s := m.session(provider)
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
|
||||
}
|
||||
return s.request(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *Manager) session(provider string) *session {
|
||||
key := strings.ToLower(strings.TrimSpace(provider))
|
||||
m.sessMutex.RLock()
|
||||
s := m.sessions[key]
|
||||
m.sessMutex.RUnlock()
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *Manager) handleSessionClosed(s *session, cause error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(s.provider))
|
||||
m.sessMutex.Lock()
|
||||
if cur, ok := m.sessions[key]; ok && cur == s {
|
||||
delete(m.sessions, key)
|
||||
}
|
||||
m.sessMutex.Unlock()
|
||||
if m.onDisconnected != nil {
|
||||
m.onDisconnected(s.provider, cause)
|
||||
}
|
||||
}
|
||||
|
||||
func randomProviderName() string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
|
||||
}
|
||||
for i := range buf {
|
||||
buf[i] = alphabet[int(buf[i])%len(alphabet)]
|
||||
}
|
||||
return "aistudio-" + string(buf)
|
||||
}
|
||||
Reference in New Issue
Block a user