From abce524114f118c1de96f421adfe788a9e247e45 Mon Sep 17 00:00:00 2001 From: lpf Date: Sun, 8 Mar 2026 22:24:45 +0800 Subject: [PATCH] feat: add node websocket signaling relay --- pkg/api/server.go | 74 ++++++++++++++++++++++++++++++++++++++++++ pkg/api/server_test.go | 66 +++++++++++++++++++++++++++++++++++++ pkg/nodes/types.go | 10 ++++-- 3 files changed, 147 insertions(+), 3 deletions(-) diff --git a/pkg/api/server.go b/pkg/api/server.go index c09c999..273debe 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -39,6 +39,7 @@ type Server struct { server *http.Server nodeConnMu sync.Mutex nodeConnIDs map[string]string + nodeSockets map[string]*nodeSocketConn gatewayVersion string webuiVersion string configPath string @@ -75,9 +76,16 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server token: strings.TrimSpace(token), mgr: mgr, nodeConnIDs: map[string]string{}, + nodeSockets: map[string]*nodeSocketConn{}, } } +type nodeSocketConn struct { + connID string + conn *websocket.Conn + mu sync.Mutex +} + func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) } func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) } func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) } @@ -110,6 +118,21 @@ func (s *Server) rememberNodeConnection(nodeID, connID string) { s.nodeConnIDs[nodeID] = connID } +func (s *Server) bindNodeSocket(nodeID, connID string, conn *websocket.Conn) { + nodeID = strings.TrimSpace(nodeID) + connID = strings.TrimSpace(connID) + if nodeID == "" || connID == "" || conn == nil { + return + } + s.nodeConnMu.Lock() + prev := s.nodeSockets[nodeID] + s.nodeSockets[nodeID] = &nodeSocketConn{connID: connID, conn: conn} + s.nodeConnMu.Unlock() + if prev != nil && prev.connID != connID { + _ = prev.conn.Close() + } +} + func (s *Server) releaseNodeConnection(nodeID, connID string) bool { nodeID = strings.TrimSpace(nodeID) connID = strings.TrimSpace(connID) @@ -122,9 +145,33 @@ func (s *Server) releaseNodeConnection(nodeID, connID string) bool { return false } delete(s.nodeConnIDs, nodeID) + if sock := s.nodeSockets[nodeID]; sock != nil && sock.connID == connID { + delete(s.nodeSockets, nodeID) + } return true } +func (s *Server) getNodeSocket(nodeID string) *nodeSocketConn { + nodeID = strings.TrimSpace(nodeID) + if nodeID == "" { + return nil + } + s.nodeConnMu.Lock() + defer s.nodeConnMu.Unlock() + return s.nodeSockets[nodeID] +} + +func (s *Server) sendNodeSocketMessage(nodeID string, msg nodes.WireMessage) error { + sock := s.getNodeSocket(nodeID) + if sock == nil || sock.conn == nil { + return fmt.Errorf("node %s not connected", strings.TrimSpace(nodeID)) + } + sock.mu.Lock() + defer sock.mu.Unlock() + _ = sock.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return sock.conn.WriteJSON(msg) +} + func (s *Server) Start(ctx context.Context) error { if s.mgr == nil { return nil @@ -275,6 +322,7 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { s.mgr.Upsert(*msg.Node) connectedID = strings.TrimSpace(msg.Node.ID) s.rememberNodeConnection(connectedID, connID) + s.bindNodeSocket(connectedID, connID, conn) if err := writeAck(nodes.WireAck{OK: true, Type: "registered", ID: connectedID}); err != nil { return } @@ -291,10 +339,12 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { s.mgr.Upsert(*msg.Node) connectedID = strings.TrimSpace(msg.Node.ID) s.rememberNodeConnection(connectedID, connID) + s.bindNodeSocket(connectedID, connID, conn) } else if n, ok := s.mgr.Get(id); ok { s.mgr.Upsert(n) connectedID = id s.rememberNodeConnection(connectedID, connID) + s.bindNodeSocket(connectedID, connID, conn) } else { _ = writeAck(nodes.WireAck{OK: false, Type: "heartbeat", ID: id, Error: "node not found"}) continue @@ -302,6 +352,30 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { if err := writeAck(nodes.WireAck{OK: true, Type: "heartbeat", ID: connectedID}); err != nil { return } + case "signal_offer", "signal_answer", "signal_candidate": + if strings.TrimSpace(connectedID) == "" { + if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, Error: "node not registered"}); err != nil { + return + } + continue + } + targetID := strings.TrimSpace(msg.To) + if targetID == "" { + if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "target node required"}); err != nil { + return + } + continue + } + msg.From = connectedID + if err := s.sendNodeSocketMessage(targetID, msg); err != nil { + if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}); err != nil { + return + } + continue + } + if err := writeAck(nodes.WireAck{OK: true, Type: "relayed", ID: msg.ID}); err != nil { + return + } default: if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "unsupported message type"}); err != nil { return diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index 10066da..d106878 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "os" @@ -211,6 +212,71 @@ func TestHandleNodeConnectReconnectKeepsNewestSessionOnline(t *testing.T) { _ = second.Close() } +func TestHandleNodeConnectRelaysSignalMessages(t *testing.T) { + t.Parallel() + + mgr := nodes.NewManager() + srv := NewServer("127.0.0.1", 0, "", mgr) + mux := http.NewServeMux() + mux.HandleFunc("/nodes/connect", srv.handleNodeConnect) + httpSrv := httptest.NewServer(mux) + defer httpSrv.Close() + + wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/nodes/connect" + connect := func(id string) *websocket.Conn { + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + if err := conn.WriteJSON(nodes.WireMessage{Type: "register", Node: &nodes.NodeInfo{ID: id, Name: id}}); err != nil { + t.Fatalf("write register: %v", err) + } + var ack nodes.WireAck + if err := conn.ReadJSON(&ack); err != nil { + t.Fatalf("read register ack: %v", err) + } + if !ack.OK { + t.Fatalf("unexpected register ack: %+v", ack) + } + return conn + } + + offerer := connect("edge-a") + defer offerer.Close() + answerer := connect("edge-b") + defer answerer.Close() + + signal := nodes.WireMessage{ + Type: "signal_offer", + ID: "sig-1", + To: "edge-b", + Session: "sess-1", + Payload: map[string]interface{}{"sdp": "offer-sdp"}, + } + if err := offerer.WriteJSON(signal); err != nil { + t.Fatalf("write signal offer: %v", err) + } + + var relayAck nodes.WireAck + if err := offerer.ReadJSON(&relayAck); err != nil { + t.Fatalf("read relay ack: %v", err) + } + if !relayAck.OK || relayAck.Type != "relayed" || relayAck.ID != "sig-1" { + t.Fatalf("unexpected relay ack: %+v", relayAck) + } + + var forwarded nodes.WireMessage + if err := answerer.ReadJSON(&forwarded); err != nil { + t.Fatalf("read forwarded signal: %v", err) + } + if forwarded.Type != "signal_offer" || forwarded.From != "edge-a" || forwarded.To != "edge-b" || forwarded.Session != "sess-1" { + t.Fatalf("unexpected forwarded signal envelope: %+v", forwarded) + } + if fmt.Sprintf("%v", forwarded.Payload["sdp"]) != "offer-sdp" { + t.Fatalf("unexpected forwarded payload: %+v", forwarded.Payload) + } +} + func TestHandleWebUISubagentsRuntimeLive(t *testing.T) { t.Parallel() diff --git a/pkg/nodes/types.go b/pkg/nodes/types.go index 039d771..feb7566 100644 --- a/pkg/nodes/types.go +++ b/pkg/nodes/types.go @@ -51,9 +51,13 @@ type Response struct { // WireMessage is the websocket envelope for node lifecycle messages. type WireMessage struct { - Type string `json:"type"` - ID string `json:"id,omitempty"` - Node *NodeInfo `json:"node,omitempty"` + Type string `json:"type"` + ID string `json:"id,omitempty"` + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + Session string `json:"session,omitempty"` + Node *NodeInfo `json:"node,omitempty"` + Payload map[string]interface{} `json:"payload,omitempty"` } // WireAck is the websocket response envelope for node lifecycle messages.