mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 06:47:30 +08:00
184 lines
4.1 KiB
Go
184 lines
4.1 KiB
Go
package wsrelay
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
readTimeout = 60 * time.Second
|
|
writeTimeout = 10 * time.Second
|
|
maxInboundMessageLen = 64 << 20
|
|
heartbeatInterval = 30 * time.Second
|
|
)
|
|
|
|
var errClosed = errors.New("websocket session closed")
|
|
|
|
type pendingRequest struct {
|
|
ch chan Message
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
func (pr *pendingRequest) close() {
|
|
if pr == nil {
|
|
return
|
|
}
|
|
pr.closeOnce.Do(func() { close(pr.ch) })
|
|
}
|
|
|
|
type session struct {
|
|
conn *websocket.Conn
|
|
manager *Manager
|
|
provider string
|
|
id string
|
|
closed chan struct{}
|
|
closeOnce sync.Once
|
|
writeMutex sync.Mutex
|
|
pending sync.Map
|
|
}
|
|
|
|
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
|
|
s := &session{
|
|
conn: conn,
|
|
manager: mgr,
|
|
id: id,
|
|
closed: make(chan struct{}),
|
|
}
|
|
conn.SetReadLimit(maxInboundMessageLen)
|
|
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
conn.SetPongHandler(func(string) error {
|
|
return conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
})
|
|
s.startHeartbeat()
|
|
return s
|
|
}
|
|
|
|
func (s *session) startHeartbeat() {
|
|
if s == nil || s.conn == nil {
|
|
return
|
|
}
|
|
ticker := time.NewTicker(heartbeatInterval)
|
|
go func() {
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-s.closed:
|
|
return
|
|
case <-ticker.C:
|
|
s.writeMutex.Lock()
|
|
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
|
|
s.writeMutex.Unlock()
|
|
if err != nil {
|
|
s.cleanup(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *session) run(_ context.Context) {
|
|
defer s.cleanup(errClosed)
|
|
for {
|
|
var msg Message
|
|
if err := s.conn.ReadJSON(&msg); err != nil {
|
|
s.cleanup(err)
|
|
return
|
|
}
|
|
s.dispatch(msg)
|
|
}
|
|
}
|
|
|
|
func (s *session) dispatch(msg Message) {
|
|
if msg.Type == MessageTypePing {
|
|
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
|
|
return
|
|
}
|
|
if value, ok := s.pending.Load(msg.ID); ok {
|
|
req := value.(*pendingRequest)
|
|
select {
|
|
case req.ch <- msg:
|
|
default:
|
|
}
|
|
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
|
actual.(*pendingRequest).close()
|
|
}
|
|
}
|
|
return
|
|
}
|
|
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
|
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
|
|
}
|
|
}
|
|
|
|
func (s *session) send(_ context.Context, msg Message) error {
|
|
select {
|
|
case <-s.closed:
|
|
return errClosed
|
|
default:
|
|
}
|
|
s.writeMutex.Lock()
|
|
defer s.writeMutex.Unlock()
|
|
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
|
return fmt.Errorf("set write deadline: %w", err)
|
|
}
|
|
if err := s.conn.WriteJSON(msg); err != nil {
|
|
return fmt.Errorf("write json: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
|
|
if msg.ID == "" {
|
|
return nil, fmt.Errorf("wsrelay: message id is required")
|
|
}
|
|
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
|
|
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
|
|
}
|
|
value, _ := s.pending.Load(msg.ID)
|
|
req := value.(*pendingRequest)
|
|
if err := s.send(ctx, msg); err != nil {
|
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
|
actual.(*pendingRequest).close()
|
|
}
|
|
return nil, err
|
|
}
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
|
actual.(*pendingRequest).close()
|
|
}
|
|
case <-s.closed:
|
|
}
|
|
}()
|
|
return req.ch, nil
|
|
}
|
|
|
|
func (s *session) cleanup(cause error) {
|
|
s.closeOnce.Do(func() {
|
|
close(s.closed)
|
|
s.pending.Range(func(key, value any) bool {
|
|
req := value.(*pendingRequest)
|
|
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
|
|
select {
|
|
case req.ch <- msg:
|
|
default:
|
|
}
|
|
req.close()
|
|
return true
|
|
})
|
|
s.pending = sync.Map{}
|
|
_ = s.conn.Close()
|
|
if s.manager != nil {
|
|
s.manager.handleSessionClosed(s, cause)
|
|
}
|
|
})
|
|
}
|