From 2aa7db96192c13b2d85a7f6294e8c15913bd8b3d Mon Sep 17 00:00:00 2001 From: lpf Date: Sun, 8 Mar 2026 22:29:40 +0800 Subject: [PATCH] feat: tunnel node requests over websocket p2p channel --- cmd/clawgo/cmd_node.go | 167 +++++++++++++++++++++++++++++++----- pkg/agent/loop.go | 2 +- pkg/api/server.go | 22 ++++- pkg/nodes/manager.go | 99 ++++++++++++++++++++- pkg/nodes/transport.go | 53 +++++++----- pkg/nodes/transport_test.go | 71 +++++++++++++++ pkg/nodes/types.go | 2 + 7 files changed, 370 insertions(+), 46 deletions(-) create mode 100644 pkg/nodes/transport_test.go diff --git a/cmd/clawgo/cmd_node.go b/cmd/clawgo/cmd_node.go index 2833a67..a9cf8d3 100644 --- a/cmd/clawgo/cmd_node.go +++ b/cmd/clawgo/cmd_node.go @@ -11,6 +11,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "clawgo/pkg/config" @@ -21,6 +22,7 @@ import ( type nodeRegisterOptions struct { GatewayBase string Token string + NodeToken string ID string Name string Endpoint string @@ -67,6 +69,7 @@ func printNodeHelp() { fmt.Println("Register options:") fmt.Println(" --gateway Gateway base URL, e.g. http://host:18790") fmt.Println(" --token Gateway token (optional when gateway.token is empty)") + fmt.Println(" --node-token Bearer token for this node endpoint (optional)") fmt.Println(" --id Node ID (default: hostname)") fmt.Println(" --name Node name (default: hostname)") fmt.Println(" --endpoint Public endpoint of this node") @@ -166,6 +169,12 @@ func parseNodeRegisterArgs(args []string, cfg *config.Config) (nodeRegisterOptio return opts, err } opts.Token = v + case "--node-token": + v, err := next() + if err != nil { + return opts, err + } + opts.NodeToken = v case "--id": v, err := next() if err != nil { @@ -307,6 +316,7 @@ func buildNodeInfo(opts nodeRegisterOptions) nodes.NodeInfo { Arch: strings.TrimSpace(opts.Arch), Version: strings.TrimSpace(opts.Version), Endpoint: strings.TrimSpace(opts.Endpoint), + Token: strings.TrimSpace(opts.NodeToken), Capabilities: opts.Capabilities, Actions: append([]string(nil), opts.Actions...), Models: append([]string(nil), opts.Models...), @@ -353,12 +363,27 @@ func runNodeHeartbeatSocket(ctx context.Context, opts nodeRegisterOptions, info return err } defer conn.Close() + var writeMu sync.Mutex + writeJSON := func(v interface{}) error { + writeMu.Lock() + defer writeMu.Unlock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return conn.WriteJSON(v) + } + writePing := func() error { + writeMu.Lock() + defer writeMu.Unlock() + return conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(10*time.Second)) + } - _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteJSON(nodes.WireMessage{Type: "register", Node: &info}); err != nil { + if err := writeJSON(nodes.WireMessage{Type: "register", Node: &info}); err != nil { return err } - if err := readNodeAck(conn, "registered", info.ID); err != nil { + acks := make(chan nodes.WireAck, 8) + errs := make(chan error, 1) + client := &http.Client{Timeout: 20 * time.Second} + go readNodeSocketLoop(ctx, conn, writeJSON, client, info, opts, acks, errs) + if err := waitNodeAck(ctx, acks, errs, "registered", info.ID); err != nil { return err } fmt.Printf("✓ Node socket connected: %s\n", info.ID) @@ -372,16 +397,20 @@ func runNodeHeartbeatSocket(ctx context.Context, opts nodeRegisterOptions, info select { case <-ctx.Done(): return nil + case err := <-errs: + if err != nil { + return err + } + return nil case <-pingTicker.C: - if err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(10*time.Second)); err != nil { + if err := writePing(); err != nil { return err } case <-ticker.C: - _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteJSON(nodes.WireMessage{Type: "heartbeat", ID: info.ID}); err != nil { + if err := writeJSON(nodes.WireMessage{Type: "heartbeat", ID: info.ID}); err != nil { return err } - if err := readNodeAck(conn, "heartbeat", info.ID); err != nil { + if err := waitNodeAck(ctx, acks, errs, "heartbeat", info.ID); err != nil { return err } fmt.Printf("✓ Heartbeat ok: %s\n", info.ID) @@ -403,26 +432,118 @@ func nodeSocketPingInterval(heartbeatSec int) time.Duration { return interval } -func readNodeAck(conn *websocket.Conn, expectedType, id string) error { - _ = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - var ack nodes.WireAck - if err := conn.ReadJSON(&ack); err != nil { - return err - } - if !ack.OK { - if strings.TrimSpace(ack.Error) == "" { - ack.Error = "unknown websocket error" +func waitNodeAck(ctx context.Context, acks <-chan nodes.WireAck, errs <-chan error, expectedType, id string) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errs: + if err == nil { + return context.Canceled + } + return err + case ack := <-acks: + if !ack.OK { + if strings.TrimSpace(ack.Error) == "" { + ack.Error = "unknown websocket error" + } + return fmt.Errorf("%s", ack.Error) + } + ackType := strings.ToLower(strings.TrimSpace(ack.Type)) + if expectedType != "" && ackType != strings.ToLower(strings.TrimSpace(expectedType)) { + continue + } + if strings.TrimSpace(id) != "" && strings.TrimSpace(ack.ID) != "" && strings.TrimSpace(ack.ID) != strings.TrimSpace(id) { + continue + } + return nil } - return fmt.Errorf("%s", ack.Error) } - ackType := strings.ToLower(strings.TrimSpace(ack.Type)) - if expectedType != "" && ackType != strings.ToLower(strings.TrimSpace(expectedType)) { - return fmt.Errorf("unexpected websocket ack type: %s", ack.Type) +} + +func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, acks chan<- nodes.WireAck, errs chan<- error) { + defer close(acks) + defer close(errs) + for { + select { + case <-ctx.Done(): + errs <- nil + return + default: + } + _ = conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + _, data, err := conn.ReadMessage() + if err != nil { + errs <- err + return + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + continue + } + if _, hasOK := raw["ok"]; hasOK { + var ack nodes.WireAck + if err := json.Unmarshal(data, &ack); err == nil { + acks <- ack + } + continue + } + var msg nodes.WireMessage + if err := json.Unmarshal(data, &msg); err != nil { + continue + } + switch strings.ToLower(strings.TrimSpace(msg.Type)) { + case "node_request": + go handleNodeWireRequest(ctx, writeJSON, client, info, opts, msg) + case "signal_offer", "signal_answer", "signal_candidate": + fmt.Printf("ℹ Signal received: type=%s from=%s session=%s\n", msg.Type, strings.TrimSpace(msg.From), strings.TrimSpace(msg.Session)) + } } - if strings.TrimSpace(id) != "" && strings.TrimSpace(ack.ID) != "" && strings.TrimSpace(ack.ID) != strings.TrimSpace(id) { - return fmt.Errorf("unexpected websocket ack id: %s", ack.ID) +} + +func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) { + resp := nodes.Response{ + OK: false, + Code: "invalid_request", + Node: info.ID, + Action: "", + Error: "request missing", } - return nil + if msg.Request != nil { + req := *msg.Request + resp.Action = req.Action + if strings.TrimSpace(opts.Endpoint) == "" { + resp.Error = "node endpoint not configured" + resp.Code = "endpoint_missing" + } else { + if req.Node == "" { + req.Node = info.ID + } + execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, req) + if err != nil { + resp = nodes.Response{ + OK: false, + Code: "transport_error", + Node: info.ID, + Action: req.Action, + Error: err.Error(), + } + } else { + resp = execResp + if strings.TrimSpace(resp.Node) == "" { + resp.Node = info.ID + } + } + } + } + _ = writeJSON(nodes.WireMessage{ + Type: "node_response", + ID: msg.ID, + From: info.ID, + To: strings.TrimSpace(msg.From), + Session: strings.TrimSpace(msg.Session), + Response: &resp, + }) } func postNodeRegister(ctx context.Context, client *http.Client, gatewayBase, token string, info nodes.NodeInfo) error { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f5d1278..31e7eae 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -142,7 +142,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers return nodes.Response{OK: false, Code: "unsupported_action", Node: "local", Action: req.Action, Error: "unsupported local simulated action"} } }) - nodesRouter := &nodes.Router{P2P: &nodes.StubP2PTransport{}, Relay: &nodes.HTTPRelayTransport{Manager: nodesManager}} + nodesRouter := &nodes.Router{P2P: &nodes.WebsocketP2PTransport{Manager: nodesManager}, Relay: &nodes.HTTPRelayTransport{Manager: nodesManager}} toolsRegistry.Register(tools.NewNodesTool(nodesManager, nodesRouter, filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl"))) if cs != nil { diff --git a/pkg/api/server.go b/pkg/api/server.go index 273debe..cb77ad4 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -86,6 +86,16 @@ type nodeSocketConn struct { mu sync.Mutex } +func (c *nodeSocketConn) Send(msg nodes.WireMessage) error { + if c == nil || c.conn == nil { + return fmt.Errorf("node websocket unavailable") + } + c.mu.Lock() + defer c.mu.Unlock() + _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return c.conn.WriteJSON(msg) +} + func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) } func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) } func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) } @@ -124,10 +134,14 @@ func (s *Server) bindNodeSocket(nodeID, connID string, conn *websocket.Conn) { if nodeID == "" || connID == "" || conn == nil { return } + next := &nodeSocketConn{connID: connID, conn: conn} s.nodeConnMu.Lock() prev := s.nodeSockets[nodeID] - s.nodeSockets[nodeID] = &nodeSocketConn{connID: connID, conn: conn} + s.nodeSockets[nodeID] = next s.nodeConnMu.Unlock() + if s.mgr != nil { + s.mgr.RegisterWireSender(nodeID, next) + } if prev != nil && prev.connID != connID { _ = prev.conn.Close() } @@ -148,6 +162,9 @@ func (s *Server) releaseNodeConnection(nodeID, connID string) bool { if sock := s.nodeSockets[nodeID]; sock != nil && sock.connID == connID { delete(s.nodeSockets, nodeID) } + if s.mgr != nil { + s.mgr.RegisterWireSender(nodeID, nil) + } return true } @@ -313,6 +330,9 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { return } _ = conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + if s.mgr != nil && s.mgr.HandleWireMessage(msg) { + continue + } switch strings.ToLower(strings.TrimSpace(msg.Type)) { case "register": if msg.Node == nil || strings.TrimSpace(msg.Node.ID) == "" { diff --git a/pkg/nodes/manager.go b/pkg/nodes/manager.go index b781c84..17607ec 100644 --- a/pkg/nodes/manager.go +++ b/pkg/nodes/manager.go @@ -1,7 +1,9 @@ package nodes import ( + "context" "encoding/json" + "fmt" "os" "path/filepath" "sort" @@ -14,11 +16,17 @@ const defaultNodeTTL = 60 * time.Second // Manager keeps paired node metadata and basic routing helpers. type Handler func(req Request) Response +type WireSender interface { + Send(msg WireMessage) error +} type Manager struct { mu sync.RWMutex nodes map[string]NodeInfo handlers map[string]Handler + senders map[string]WireSender + pending map[string]chan WireMessage + nextWire uint64 ttl time.Duration auditPath string statePath string @@ -29,7 +37,13 @@ var defaultManager = NewManager() func DefaultManager() *Manager { return defaultManager } func NewManager() *Manager { - m := &Manager{nodes: map[string]NodeInfo{}, handlers: map[string]Handler{}, ttl: defaultNodeTTL} + m := &Manager{ + nodes: map[string]NodeInfo{}, + handlers: map[string]Handler{}, + senders: map[string]WireSender{}, + pending: map[string]chan WireMessage{}, + ttl: defaultNodeTTL, + } go m.reaperLoop() return m } @@ -132,6 +146,89 @@ func (m *Manager) RegisterHandler(nodeID string, h Handler) { m.handlers[nodeID] = h } +func (m *Manager) RegisterWireSender(nodeID string, sender WireSender) { + nodeID = strings.TrimSpace(nodeID) + if nodeID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if sender == nil { + delete(m.senders, nodeID) + return + } + m.senders[nodeID] = sender +} + +func (m *Manager) HandleWireMessage(msg WireMessage) bool { + switch strings.ToLower(strings.TrimSpace(msg.Type)) { + case "node_response": + if strings.TrimSpace(msg.ID) == "" { + return false + } + m.mu.Lock() + ch := m.pending[msg.ID] + if ch != nil { + delete(m.pending, msg.ID) + } + m.mu.Unlock() + if ch == nil { + return false + } + select { + case ch <- msg: + default: + } + return true + default: + return false + } +} + +func (m *Manager) SendWireRequest(ctx context.Context, nodeID string, req Request) (Response, error) { + nodeID = strings.TrimSpace(nodeID) + if nodeID == "" { + return Response{}, fmt.Errorf("node id required") + } + m.mu.Lock() + sender := m.senders[nodeID] + if sender == nil { + m.mu.Unlock() + return Response{}, fmt.Errorf("node %s websocket sender unavailable", nodeID) + } + m.nextWire++ + wireID := fmt.Sprintf("wire-%d", m.nextWire) + ch := make(chan WireMessage, 1) + m.pending[wireID] = ch + m.mu.Unlock() + + msg := WireMessage{ + Type: "node_request", + ID: wireID, + To: nodeID, + Request: &req, + } + if err := sender.Send(msg); err != nil { + m.mu.Lock() + delete(m.pending, wireID) + m.mu.Unlock() + return Response{}, err + } + + select { + case <-ctx.Done(): + m.mu.Lock() + delete(m.pending, wireID) + m.mu.Unlock() + return Response{}, ctx.Err() + case incoming := <-ch: + if incoming.Response == nil { + return Response{}, fmt.Errorf("node %s returned empty response", nodeID) + } + return *incoming.Response, nil + } +} + func (m *Manager) Invoke(req Request) (Response, bool) { m.mu.RLock() h, ok := m.handlers[req.Node] diff --git a/pkg/nodes/transport.go b/pkg/nodes/transport.go index 4d82637..6e0fd3d 100644 --- a/pkg/nodes/transport.go +++ b/pkg/nodes/transport.go @@ -52,13 +52,23 @@ func (r *Router) Dispatch(ctx context.Context, req Request, mode string) (Respon } } -// StubP2PTransport provides phase-2 negotiation scaffold. -type StubP2PTransport struct{} +// WebsocketP2PTransport uses the persistent node websocket as a request/response tunnel +// while the project evolves toward a true peer data channel. +type WebsocketP2PTransport struct { + Manager *Manager +} -func (s *StubP2PTransport) Name() string { return "p2p" } -func (s *StubP2PTransport) Send(ctx context.Context, req Request) (Response, error) { - _ = ctx - return Response{OK: false, Node: req.Node, Action: req.Action, Error: "p2p session not established yet"}, nil +func (s *WebsocketP2PTransport) Name() string { return "p2p" } +func (s *WebsocketP2PTransport) Send(ctx context.Context, req Request) (Response, error) { + if s == nil || s.Manager == nil { + return Response{OK: false, Node: req.Node, Action: req.Action, Error: "p2p manager unavailable"}, nil + } + resp, err := s.Manager.SendWireRequest(ctx, req.Node, req) + if err != nil { + return Response{OK: false, Code: "p2p_unavailable", Node: req.Node, Action: req.Action, Error: err.Error()}, nil + } + resp.Payload = normalizeDevicePayload(resp.Action, resp.Payload) + return resp, nil } // HTTPRelayTransport dispatches requests to node-agent endpoints over HTTP. @@ -96,22 +106,11 @@ func actionHTTPPath(action string) string { } } -func (s *HTTPRelayTransport) Send(ctx context.Context, req Request) (Response, error) { - if s.Manager == nil { - return Response{OK: false, Code: "relay_unavailable", Node: req.Node, Action: req.Action, Error: "relay manager not configured"}, nil - } - if resp, ok := s.Manager.Invoke(req); ok { - return resp, nil - } - n, ok := s.Manager.Get(req.Node) - if !ok { - return Response{OK: false, Code: "node_not_found", Node: req.Node, Action: req.Action, Error: "node not found"}, nil - } - endpoint := strings.TrimRight(strings.TrimSpace(n.Endpoint), "/") +func DoEndpointRequest(ctx context.Context, client *http.Client, endpoint, token string, req Request) (Response, error) { + endpoint = strings.TrimRight(strings.TrimSpace(endpoint), "/") if endpoint == "" { return Response{OK: false, Code: "endpoint_missing", Node: req.Node, Action: req.Action, Error: "node endpoint not configured"}, nil } - client := s.Client if client == nil { client = &http.Client{Timeout: 20 * time.Second} } @@ -122,7 +121,7 @@ func (s *HTTPRelayTransport) Send(ctx context.Context, req Request) (Response, e return Response{}, err } hreq.Header.Set("Content-Type", "application/json") - if tok := strings.TrimSpace(n.Token); tok != "" { + if tok := strings.TrimSpace(token); tok != "" { hreq.Header.Set("Authorization", "Bearer "+tok) } hresp, err := client.Do(hreq) @@ -152,6 +151,20 @@ func (s *HTTPRelayTransport) Send(ctx context.Context, req Request) (Response, e return resp, nil } +func (s *HTTPRelayTransport) Send(ctx context.Context, req Request) (Response, error) { + if s.Manager == nil { + return Response{OK: false, Code: "relay_unavailable", Node: req.Node, Action: req.Action, Error: "relay manager not configured"}, nil + } + if resp, ok := s.Manager.Invoke(req); ok { + return resp, nil + } + n, ok := s.Manager.Get(req.Node) + if !ok { + return Response{OK: false, Code: "node_not_found", Node: req.Node, Action: req.Action, Error: "node not found"}, nil + } + return DoEndpointRequest(ctx, s.Client, n.Endpoint, n.Token, req) +} + func normalizeDevicePayload(action string, payload map[string]interface{}) map[string]interface{} { if payload == nil { payload = map[string]interface{}{} diff --git a/pkg/nodes/transport_test.go b/pkg/nodes/transport_test.go new file mode 100644 index 0000000..33f2842 --- /dev/null +++ b/pkg/nodes/transport_test.go @@ -0,0 +1,71 @@ +package nodes + +import ( + "context" + "testing" + "time" +) + +type captureWireSender struct { + send func(msg WireMessage) error +} + +func (c *captureWireSender) Send(msg WireMessage) error { + if c.send != nil { + return c.send(msg) + } + return nil +} + +func TestWebsocketP2PTransportSend(t *testing.T) { + t.Parallel() + + manager := NewManager() + manager.Upsert(NodeInfo{ + ID: "edge-dev", + Online: true, + Capabilities: Capabilities{ + Run: true, + }, + }) + manager.RegisterWireSender("edge-dev", &captureWireSender{ + send: func(msg WireMessage) error { + if msg.Type != "node_request" || msg.Request == nil || msg.Request.Action != "run" { + t.Fatalf("unexpected wire request: %+v", msg) + } + go func() { + time.Sleep(20 * time.Millisecond) + manager.HandleWireMessage(WireMessage{ + Type: "node_response", + ID: msg.ID, + Response: &Response{ + OK: true, + Code: "ok", + Node: "edge-dev", + Action: "run", + Payload: map[string]interface{}{ + "status": "done", + }, + }, + }) + }() + return nil + }, + }) + + transport := &WebsocketP2PTransport{Manager: manager} + resp, err := transport.Send(context.Background(), Request{ + Action: "run", + Node: "edge-dev", + Args: map[string]interface{}{"command": []string{"echo", "ok"}}, + }) + if err != nil { + t.Fatalf("transport send failed: %v", err) + } + if !resp.OK || resp.Node != "edge-dev" || resp.Action != "run" { + t.Fatalf("unexpected response: %+v", resp) + } + if resp.Payload["status"] != "done" { + t.Fatalf("unexpected payload: %+v", resp.Payload) + } +} diff --git a/pkg/nodes/types.go b/pkg/nodes/types.go index feb7566..bf116f2 100644 --- a/pkg/nodes/types.go +++ b/pkg/nodes/types.go @@ -57,6 +57,8 @@ type WireMessage struct { To string `json:"to,omitempty"` Session string `json:"session,omitempty"` Node *NodeInfo `json:"node,omitempty"` + Request *Request `json:"request,omitempty"` + Response *Response `json:"response,omitempty"` Payload map[string]interface{} `json:"payload,omitempty"` }