mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 05:37:29 +08:00
feat: add guarded webrtc node transport
This commit is contained in:
@@ -79,6 +79,13 @@ func (al *AgentLoop) SetConfigPath(path string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) SetNodeP2PTransport(t nodes.Transport) {
|
||||
if al == nil || al.nodeRouter == nil {
|
||||
return
|
||||
}
|
||||
al.nodeRouter.P2P = t
|
||||
}
|
||||
|
||||
// StartupCompactionReport provides startup memory/session maintenance stats.
|
||||
type StartupCompactionReport struct {
|
||||
TotalSessions int `json:"total_sessions"`
|
||||
|
||||
@@ -40,6 +40,7 @@ type Server struct {
|
||||
nodeConnMu sync.Mutex
|
||||
nodeConnIDs map[string]string
|
||||
nodeSockets map[string]*nodeSocketConn
|
||||
nodeWebRTC *nodes.WebRTCTransport
|
||||
gatewayVersion string
|
||||
webuiVersion string
|
||||
configPath string
|
||||
@@ -116,6 +117,9 @@ func (s *Server) SetToolsCatalogHandler(fn func() interface{}) { s.onToolsCatalo
|
||||
func (s *Server) SetWebUIDir(dir string) { s.webUIDir = strings.TrimSpace(dir) }
|
||||
func (s *Server) SetGatewayVersion(v string) { s.gatewayVersion = strings.TrimSpace(v) }
|
||||
func (s *Server) SetWebUIVersion(v string) { s.webuiVersion = strings.TrimSpace(v) }
|
||||
func (s *Server) SetNodeWebRTCTransport(t *nodes.WebRTCTransport) {
|
||||
s.nodeWebRTC = t
|
||||
}
|
||||
|
||||
func (s *Server) rememberNodeConnection(nodeID, connID string) {
|
||||
nodeID = strings.TrimSpace(nodeID)
|
||||
@@ -142,6 +146,9 @@ func (s *Server) bindNodeSocket(nodeID, connID string, conn *websocket.Conn) {
|
||||
if s.mgr != nil {
|
||||
s.mgr.RegisterWireSender(nodeID, next)
|
||||
}
|
||||
if s.nodeWebRTC != nil {
|
||||
s.nodeWebRTC.BindSignaler(nodeID, next)
|
||||
}
|
||||
if prev != nil && prev.connID != connID {
|
||||
_ = prev.conn.Close()
|
||||
}
|
||||
@@ -165,6 +172,9 @@ func (s *Server) releaseNodeConnection(nodeID, connID string) bool {
|
||||
if s.mgr != nil {
|
||||
s.mgr.RegisterWireSender(nodeID, nil)
|
||||
}
|
||||
if s.nodeWebRTC != nil {
|
||||
s.nodeWebRTC.UnbindSignaler(nodeID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -373,13 +383,23 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
case "signal_offer", "signal_answer", "signal_candidate":
|
||||
targetID := strings.TrimSpace(msg.To)
|
||||
if s.nodeWebRTC != nil && (targetID == "" || strings.EqualFold(targetID, "gateway")) {
|
||||
if err := s.nodeWebRTC.HandleSignal(msg); err != nil {
|
||||
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}); err != nil {
|
||||
return
|
||||
}
|
||||
} else if err := writeAck(nodes.WireAck{OK: true, Type: "signaled", ID: msg.ID}); err != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(connectedID) == "" {
|
||||
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, Error: "node not registered"}); err != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
targetID := strings.TrimSpace(msg.To)
|
||||
if targetID == "" {
|
||||
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "target node required"}); err != nil {
|
||||
return
|
||||
|
||||
@@ -299,8 +299,9 @@ type GatewayNodesConfig struct {
|
||||
}
|
||||
|
||||
type GatewayNodesP2PConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Transport string `json:"transport,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Transport string `json:"transport,omitempty"`
|
||||
STUNServers []string `json:"stun_servers,omitempty"`
|
||||
}
|
||||
|
||||
type CronConfig struct {
|
||||
@@ -546,8 +547,9 @@ func DefaultConfig() *Config {
|
||||
Token: generateGatewayToken(),
|
||||
Nodes: GatewayNodesConfig{
|
||||
P2P: GatewayNodesP2PConfig{
|
||||
Enabled: false,
|
||||
Transport: "websocket_tunnel",
|
||||
Enabled: false,
|
||||
Transport: "websocket_tunnel",
|
||||
STUNServers: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -124,6 +124,7 @@ func Validate(cfg *Config) []error {
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("gateway.nodes.p2p.transport must be one of: websocket_tunnel, webrtc"))
|
||||
}
|
||||
errs = append(errs, validateNonEmptyStringList("gateway.nodes.p2p.stun_servers", cfg.Gateway.Nodes.P2P.STUNServers)...)
|
||||
if cfg.Cron.MinSleepSec <= 0 {
|
||||
errs = append(errs, fmt.Errorf("cron.min_sleep_sec must be > 0"))
|
||||
}
|
||||
|
||||
322
pkg/nodes/webrtc.go
Normal file
322
pkg/nodes/webrtc.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type gatewayRTCSession struct {
|
||||
nodeID string
|
||||
pc *webrtc.PeerConnection
|
||||
dc *webrtc.DataChannel
|
||||
ready chan struct{}
|
||||
readyMu sync.Once
|
||||
writeMu sync.Mutex
|
||||
pending map[string]chan Response
|
||||
mu sync.Mutex
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
func (s *gatewayRTCSession) markReady() {
|
||||
s.readyMu.Do(func() { close(s.ready) })
|
||||
}
|
||||
|
||||
func (s *gatewayRTCSession) send(msg WireMessage) error {
|
||||
if s == nil || s.dc == nil {
|
||||
return fmt.Errorf("webrtc data channel unavailable")
|
||||
}
|
||||
b, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
return s.dc.Send(b)
|
||||
}
|
||||
|
||||
func (s *gatewayRTCSession) nextRequestID() string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.nextID++
|
||||
return fmt.Sprintf("rtc-%s-%d", s.nodeID, s.nextID)
|
||||
}
|
||||
|
||||
type WebRTCTransport struct {
|
||||
stunServers []string
|
||||
|
||||
mu sync.Mutex
|
||||
sessions map[string]*gatewayRTCSession
|
||||
signal map[string]WireSender
|
||||
}
|
||||
|
||||
func NewWebRTCTransport(stunServers []string) *WebRTCTransport {
|
||||
out := make([]string, 0, len(stunServers))
|
||||
for _, server := range stunServers {
|
||||
if v := strings.TrimSpace(server); v != "" {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return &WebRTCTransport{
|
||||
stunServers: out,
|
||||
sessions: map[string]*gatewayRTCSession{},
|
||||
signal: map[string]WireSender{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) Name() string { return "p2p-webrtc" }
|
||||
|
||||
func (t *WebRTCTransport) BindSignaler(nodeID string, sender WireSender) {
|
||||
nodeID = strings.TrimSpace(nodeID)
|
||||
if nodeID == "" {
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if sender == nil {
|
||||
delete(t.signal, nodeID)
|
||||
return
|
||||
}
|
||||
t.signal[nodeID] = sender
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) UnbindSignaler(nodeID string) {
|
||||
t.BindSignaler(nodeID, nil)
|
||||
t.mu.Lock()
|
||||
session := t.sessions[nodeID]
|
||||
delete(t.sessions, nodeID)
|
||||
t.mu.Unlock()
|
||||
if session != nil && session.pc != nil {
|
||||
_ = session.pc.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) currentSignaler(nodeID string) WireSender {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.signal[strings.TrimSpace(nodeID)]
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) HandleSignal(msg WireMessage) error {
|
||||
nodeID := strings.TrimSpace(msg.From)
|
||||
if nodeID == "" {
|
||||
return fmt.Errorf("signal missing from")
|
||||
}
|
||||
session, err := t.ensureSession(nodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(msg.Type)) {
|
||||
case "signal_answer":
|
||||
var desc webrtc.SessionDescription
|
||||
if err := mapInto(msg.Payload, &desc); err != nil {
|
||||
return err
|
||||
}
|
||||
return session.pc.SetRemoteDescription(desc)
|
||||
case "signal_candidate":
|
||||
var candidate webrtc.ICECandidateInit
|
||||
if err := mapInto(msg.Payload, &candidate); err != nil {
|
||||
return err
|
||||
}
|
||||
return session.pc.AddICECandidate(candidate)
|
||||
default:
|
||||
return fmt.Errorf("unsupported signal type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) Send(ctx context.Context, req Request) (Response, error) {
|
||||
session, err := t.ensureSession(req.Node)
|
||||
if err != nil {
|
||||
return Response{OK: false, Code: "p2p_unavailable", Node: req.Node, Action: req.Action, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return Response{}, ctx.Err()
|
||||
case <-session.ready:
|
||||
case <-time.After(8 * time.Second):
|
||||
return Response{OK: false, Code: "p2p_timeout", Node: req.Node, Action: req.Action, Error: "webrtc session not ready"}, nil
|
||||
}
|
||||
|
||||
reqID := session.nextRequestID()
|
||||
respCh := make(chan Response, 1)
|
||||
session.mu.Lock()
|
||||
session.pending[reqID] = respCh
|
||||
session.mu.Unlock()
|
||||
|
||||
if err := session.send(WireMessage{
|
||||
Type: "node_request",
|
||||
ID: reqID,
|
||||
To: req.Node,
|
||||
Request: &req,
|
||||
}); err != nil {
|
||||
session.mu.Lock()
|
||||
delete(session.pending, reqID)
|
||||
session.mu.Unlock()
|
||||
return Response{OK: false, Code: "p2p_send_failed", Node: req.Node, Action: req.Action, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
session.mu.Lock()
|
||||
delete(session.pending, reqID)
|
||||
session.mu.Unlock()
|
||||
return Response{}, ctx.Err()
|
||||
case resp := <-respCh:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) ensureSession(nodeID string) (*gatewayRTCSession, error) {
|
||||
nodeID = strings.TrimSpace(nodeID)
|
||||
if nodeID == "" {
|
||||
return nil, fmt.Errorf("node id required")
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if session := t.sessions[nodeID]; session != nil {
|
||||
t.mu.Unlock()
|
||||
return session, nil
|
||||
}
|
||||
t.mu.Unlock()
|
||||
if t.currentSignaler(nodeID) == nil {
|
||||
return nil, fmt.Errorf("node %s signaling unavailable", nodeID)
|
||||
}
|
||||
|
||||
config := webrtc.Configuration{}
|
||||
if len(t.stunServers) > 0 {
|
||||
config.ICEServers = []webrtc.ICEServer{{URLs: append([]string(nil), t.stunServers...)}}
|
||||
}
|
||||
pc, err := webrtc.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dc, err := pc.CreateDataChannel("clawgo", nil)
|
||||
if err != nil {
|
||||
_ = pc.Close()
|
||||
return nil, err
|
||||
}
|
||||
session := &gatewayRTCSession{
|
||||
nodeID: nodeID,
|
||||
pc: pc,
|
||||
dc: dc,
|
||||
ready: make(chan struct{}),
|
||||
pending: map[string]chan Response{},
|
||||
}
|
||||
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
sender := t.currentSignaler(nodeID)
|
||||
if sender == nil {
|
||||
return
|
||||
}
|
||||
_ = sender.Send(WireMessage{
|
||||
Type: "signal_candidate",
|
||||
From: "gateway",
|
||||
To: nodeID,
|
||||
Session: nodeID,
|
||||
Payload: structToMap(candidate.ToJSON()),
|
||||
})
|
||||
})
|
||||
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateDisconnected:
|
||||
t.mu.Lock()
|
||||
if t.sessions[nodeID] == session {
|
||||
delete(t.sessions, nodeID)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
})
|
||||
dc.OnOpen(func() {
|
||||
session.markReady()
|
||||
})
|
||||
dc.OnMessage(func(message webrtc.DataChannelMessage) {
|
||||
var msg WireMessage
|
||||
if err := json.Unmarshal(message.Data, &msg); err != nil {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(msg.Type)) != "node_response" || msg.Response == nil {
|
||||
return
|
||||
}
|
||||
session.mu.Lock()
|
||||
respCh := session.pending[msg.ID]
|
||||
if respCh != nil {
|
||||
delete(session.pending, msg.ID)
|
||||
}
|
||||
session.mu.Unlock()
|
||||
if respCh != nil {
|
||||
respCh <- *msg.Response
|
||||
}
|
||||
})
|
||||
|
||||
offer, err := pc.CreateOffer(nil)
|
||||
if err != nil {
|
||||
_ = pc.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := pc.SetLocalDescription(offer); err != nil {
|
||||
_ = pc.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if existing := t.sessions[nodeID]; existing != nil {
|
||||
t.mu.Unlock()
|
||||
_ = pc.Close()
|
||||
return existing, nil
|
||||
}
|
||||
t.sessions[nodeID] = session
|
||||
t.mu.Unlock()
|
||||
|
||||
sender := t.currentSignaler(nodeID)
|
||||
if sender == nil {
|
||||
t.mu.Lock()
|
||||
delete(t.sessions, nodeID)
|
||||
t.mu.Unlock()
|
||||
_ = pc.Close()
|
||||
return nil, fmt.Errorf("node %s signaling unavailable", nodeID)
|
||||
}
|
||||
if err := sender.Send(WireMessage{
|
||||
Type: "signal_offer",
|
||||
From: "gateway",
|
||||
To: nodeID,
|
||||
Session: nodeID,
|
||||
Payload: structToMap(*pc.LocalDescription()),
|
||||
}); err != nil {
|
||||
t.mu.Lock()
|
||||
delete(t.sessions, nodeID)
|
||||
t.mu.Unlock()
|
||||
_ = pc.Close()
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func structToMap(v interface{}) map[string]interface{} {
|
||||
b, _ := json.Marshal(v)
|
||||
var out map[string]interface{}
|
||||
_ = json.Unmarshal(b, &out)
|
||||
if out == nil {
|
||||
out = map[string]interface{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapInto(in map[string]interface{}, out interface{}) error {
|
||||
if len(in) == 0 {
|
||||
return fmt.Errorf("empty payload")
|
||||
}
|
||||
b, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, out)
|
||||
}
|
||||
Reference in New Issue
Block a user