feat: add node websocket signaling relay

This commit is contained in:
lpf
2026-03-08 22:24:45 +08:00
parent 4172a57b39
commit abce524114
3 changed files with 147 additions and 3 deletions

View File

@@ -39,6 +39,7 @@ type Server struct {
server *http.Server
nodeConnMu sync.Mutex
nodeConnIDs map[string]string
nodeSockets map[string]*nodeSocketConn
gatewayVersion string
webuiVersion string
configPath string
@@ -75,9 +76,16 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server
token: strings.TrimSpace(token),
mgr: mgr,
nodeConnIDs: map[string]string{},
nodeSockets: map[string]*nodeSocketConn{},
}
}
type nodeSocketConn struct {
connID string
conn *websocket.Conn
mu sync.Mutex
}
func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) }
func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) }
func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) }
@@ -110,6 +118,21 @@ func (s *Server) rememberNodeConnection(nodeID, connID string) {
s.nodeConnIDs[nodeID] = connID
}
func (s *Server) bindNodeSocket(nodeID, connID string, conn *websocket.Conn) {
nodeID = strings.TrimSpace(nodeID)
connID = strings.TrimSpace(connID)
if nodeID == "" || connID == "" || conn == nil {
return
}
s.nodeConnMu.Lock()
prev := s.nodeSockets[nodeID]
s.nodeSockets[nodeID] = &nodeSocketConn{connID: connID, conn: conn}
s.nodeConnMu.Unlock()
if prev != nil && prev.connID != connID {
_ = prev.conn.Close()
}
}
func (s *Server) releaseNodeConnection(nodeID, connID string) bool {
nodeID = strings.TrimSpace(nodeID)
connID = strings.TrimSpace(connID)
@@ -122,9 +145,33 @@ func (s *Server) releaseNodeConnection(nodeID, connID string) bool {
return false
}
delete(s.nodeConnIDs, nodeID)
if sock := s.nodeSockets[nodeID]; sock != nil && sock.connID == connID {
delete(s.nodeSockets, nodeID)
}
return true
}
func (s *Server) getNodeSocket(nodeID string) *nodeSocketConn {
nodeID = strings.TrimSpace(nodeID)
if nodeID == "" {
return nil
}
s.nodeConnMu.Lock()
defer s.nodeConnMu.Unlock()
return s.nodeSockets[nodeID]
}
func (s *Server) sendNodeSocketMessage(nodeID string, msg nodes.WireMessage) error {
sock := s.getNodeSocket(nodeID)
if sock == nil || sock.conn == nil {
return fmt.Errorf("node %s not connected", strings.TrimSpace(nodeID))
}
sock.mu.Lock()
defer sock.mu.Unlock()
_ = sock.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return sock.conn.WriteJSON(msg)
}
func (s *Server) Start(ctx context.Context) error {
if s.mgr == nil {
return nil
@@ -275,6 +322,7 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
s.mgr.Upsert(*msg.Node)
connectedID = strings.TrimSpace(msg.Node.ID)
s.rememberNodeConnection(connectedID, connID)
s.bindNodeSocket(connectedID, connID, conn)
if err := writeAck(nodes.WireAck{OK: true, Type: "registered", ID: connectedID}); err != nil {
return
}
@@ -291,10 +339,12 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
s.mgr.Upsert(*msg.Node)
connectedID = strings.TrimSpace(msg.Node.ID)
s.rememberNodeConnection(connectedID, connID)
s.bindNodeSocket(connectedID, connID, conn)
} else if n, ok := s.mgr.Get(id); ok {
s.mgr.Upsert(n)
connectedID = id
s.rememberNodeConnection(connectedID, connID)
s.bindNodeSocket(connectedID, connID, conn)
} else {
_ = writeAck(nodes.WireAck{OK: false, Type: "heartbeat", ID: id, Error: "node not found"})
continue
@@ -302,6 +352,30 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
if err := writeAck(nodes.WireAck{OK: true, Type: "heartbeat", ID: connectedID}); err != nil {
return
}
case "signal_offer", "signal_answer", "signal_candidate":
if strings.TrimSpace(connectedID) == "" {
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, Error: "node not registered"}); err != nil {
return
}
continue
}
targetID := strings.TrimSpace(msg.To)
if targetID == "" {
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "target node required"}); err != nil {
return
}
continue
}
msg.From = connectedID
if err := s.sendNodeSocketMessage(targetID, msg); err != nil {
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}); err != nil {
return
}
continue
}
if err := writeAck(nodes.WireAck{OK: true, Type: "relayed", ID: msg.ID}); err != nil {
return
}
default:
if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "unsupported message type"}); err != nil {
return

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
@@ -211,6 +212,71 @@ func TestHandleNodeConnectReconnectKeepsNewestSessionOnline(t *testing.T) {
_ = second.Close()
}
func TestHandleNodeConnectRelaysSignalMessages(t *testing.T) {
t.Parallel()
mgr := nodes.NewManager()
srv := NewServer("127.0.0.1", 0, "", mgr)
mux := http.NewServeMux()
mux.HandleFunc("/nodes/connect", srv.handleNodeConnect)
httpSrv := httptest.NewServer(mux)
defer httpSrv.Close()
wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/nodes/connect"
connect := func(id string) *websocket.Conn {
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
if err := conn.WriteJSON(nodes.WireMessage{Type: "register", Node: &nodes.NodeInfo{ID: id, Name: id}}); err != nil {
t.Fatalf("write register: %v", err)
}
var ack nodes.WireAck
if err := conn.ReadJSON(&ack); err != nil {
t.Fatalf("read register ack: %v", err)
}
if !ack.OK {
t.Fatalf("unexpected register ack: %+v", ack)
}
return conn
}
offerer := connect("edge-a")
defer offerer.Close()
answerer := connect("edge-b")
defer answerer.Close()
signal := nodes.WireMessage{
Type: "signal_offer",
ID: "sig-1",
To: "edge-b",
Session: "sess-1",
Payload: map[string]interface{}{"sdp": "offer-sdp"},
}
if err := offerer.WriteJSON(signal); err != nil {
t.Fatalf("write signal offer: %v", err)
}
var relayAck nodes.WireAck
if err := offerer.ReadJSON(&relayAck); err != nil {
t.Fatalf("read relay ack: %v", err)
}
if !relayAck.OK || relayAck.Type != "relayed" || relayAck.ID != "sig-1" {
t.Fatalf("unexpected relay ack: %+v", relayAck)
}
var forwarded nodes.WireMessage
if err := answerer.ReadJSON(&forwarded); err != nil {
t.Fatalf("read forwarded signal: %v", err)
}
if forwarded.Type != "signal_offer" || forwarded.From != "edge-a" || forwarded.To != "edge-b" || forwarded.Session != "sess-1" {
t.Fatalf("unexpected forwarded signal envelope: %+v", forwarded)
}
if fmt.Sprintf("%v", forwarded.Payload["sdp"]) != "offer-sdp" {
t.Fatalf("unexpected forwarded payload: %+v", forwarded.Payload)
}
}
func TestHandleWebUISubagentsRuntimeLive(t *testing.T) {
t.Parallel()

View File

@@ -51,9 +51,13 @@ type Response struct {
// WireMessage is the websocket envelope for node lifecycle messages.
type WireMessage struct {
Type string `json:"type"`
ID string `json:"id,omitempty"`
Node *NodeInfo `json:"node,omitempty"`
Type string `json:"type"`
ID string `json:"id,omitempty"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Session string `json:"session,omitempty"`
Node *NodeInfo `json:"node,omitempty"`
Payload map[string]interface{} `json:"payload,omitempty"`
}
// WireAck is the websocket response envelope for node lifecycle messages.