feat: add node websocket signaling relay

This commit is contained in:
lpf
2026-03-08 22:24:45 +08:00
parent 4172a57b39
commit abce524114
3 changed files with 147 additions and 3 deletions

View File

@@ -39,6 +39,7 @@ type Server struct {
server *http.Server
nodeConnMu sync.Mutex
nodeConnIDs map[string]string
nodeSockets map[string]*nodeSocketConn
gatewayVersion string
webuiVersion string
configPath string
@@ -75,9 +76,16 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server
token: strings.TrimSpace(token),
mgr: mgr,
nodeConnIDs: map[string]string{},
nodeSockets: map[string]*nodeSocketConn{},
}
}
type nodeSocketConn struct {
connID string
conn *websocket.Conn
mu sync.Mutex
}
func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) }
func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) }
func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) }
@@ -110,6 +118,21 @@ func (s *Server) rememberNodeConnection(nodeID, connID string) {
s.nodeConnIDs[nodeID] = connID
}
func (s *Server) bindNodeSocket(nodeID, connID string, conn *websocket.Conn) {
nodeID = strings.TrimSpace(nodeID)
connID = strings.TrimSpace(connID)
if nodeID == "" || connID == "" || conn == nil {
return
}
s.nodeConnMu.Lock()
prev := s.nodeSockets[nodeID]
s.nodeSockets[nodeID] = &nodeSocketConn{connID: connID, conn: conn}
s.nodeConnMu.Unlock()
if prev != nil && prev.connID != connID {
_ = prev.conn.Close()
}
}
func (s *Server) releaseNodeConnection(nodeID, connID string) bool {
nodeID = strings.TrimSpace(nodeID)
connID = strings.TrimSpace(connID)
@@ -122,9 +145,33 @@ func (s *Server) releaseNodeConnection(nodeID, connID string) bool {
return false
}
delete(s.nodeConnIDs, nodeID)
if sock := s.nodeSockets[nodeID]; sock != nil && sock.connID == connID {
delete(s.nodeSockets, nodeID)
}
return true
}
func (s *Server) getNodeSocket(nodeID string) *nodeSocketConn {
nodeID = strings.TrimSpace(nodeID)
if nodeID == "" {
return nil
}
s.nodeConnMu.Lock()
defer s.nodeConnMu.Unlock()
return s.nodeSockets[nodeID]
}
func (s *Server) sendNodeSocketMessage(nodeID string, msg nodes.WireMessage) error {
sock := s.getNodeSocket(nodeID)
if sock == nil || sock.conn == nil {
return fmt.Errorf("node %s not connected", strings.TrimSpace(nodeID))
}
sock.mu.Lock()
defer sock.mu.Unlock()
_ = sock.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return sock.conn.WriteJSON(msg)
}
func (s *Server) Start(ctx context.Context) error {
if s.mgr == nil {
return nil
@@ -275,6 +322,7 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
s.mgr.Upsert(*msg.Node)
connectedID = strings.TrimSpace(msg.Node.ID)
s.rememberNodeConnection(connectedID, connID)
s.bindNodeSocket(connectedID, connID, conn)
if err := writeAck(nodes.WireAck{OK: true, Type: "registered", ID: connectedID}); err != nil {
return
}
@@ -291,10 +339,12 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
s.mgr.Upsert(*msg.Node)
connectedID = strings.TrimSpace(msg.Node.ID)
s.rememberNodeConnection(connectedID, connID)
s.bindNodeSocket(connectedID, connID, conn)
} else if n, ok := s.mgr.Get(id); ok {
s.mgr.Upsert(n)
connectedID = id
s.rememberNodeConnection(connectedID, connID)
s.bindNodeSocket(connectedID, connID, conn)
} else {
_ = writeAck(nodes.WireAck{OK: false, Type: "heartbeat", ID: id, Error: "node not found"})
continue
@@ -302,6 +352,30 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
if err := writeAck(nodes.WireAck{OK: true, Type: "heartbeat", ID: connectedID}); err != nil {
return
}
case "signal_offer", "signal_answer", "signal_candidate":
if strings.TrimSpace(connectedID) == "" {
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, Error: "node not registered"}); err != nil {
return
}
continue
}
targetID := strings.TrimSpace(msg.To)
if targetID == "" {
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "target node required"}); err != nil {
return
}
continue
}
msg.From = connectedID
if err := s.sendNodeSocketMessage(targetID, msg); err != nil {
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}); err != nil {
return
}
continue
}
if err := writeAck(nodes.WireAck{OK: true, Type: "relayed", ID: msg.ID}); err != nil {
return
}
default:
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "unsupported message type"}); err != nil {
return