diff --git a/cmd/clawgo/cmd_gateway.go b/cmd/clawgo/cmd_gateway.go index 2000da7..9a9f47c 100644 --- a/cmd/clawgo/cmd_gateway.go +++ b/cmd/clawgo/cmd_gateway.go @@ -129,6 +129,20 @@ func gatewayCmd() { } registryServer := api.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, cfg.Gateway.Token, nodes.DefaultManager()) + configureGatewayNodeP2P := func(loop *agent.AgentLoop, server *api.Server, runtimeCfg *config.Config) { + if loop == nil || server == nil || runtimeCfg == nil { + return + } + switch { + case runtimeCfg.Gateway.Nodes.P2P.Enabled && strings.EqualFold(strings.TrimSpace(runtimeCfg.Gateway.Nodes.P2P.Transport), "webrtc"): + webrtcTransport := nodes.NewWebRTCTransport(runtimeCfg.Gateway.Nodes.P2P.STUNServers) + loop.SetNodeP2PTransport(webrtcTransport) + server.SetNodeWebRTCTransport(webrtcTransport) + default: + server.SetNodeWebRTCTransport(nil) + } + } + configureGatewayNodeP2P(agentLoop, registryServer, cfg) registryServer.SetGatewayVersion(version) registryServer.SetWebUIVersion(version) registryServer.SetConfigPath(getConfigPath()) @@ -343,7 +357,8 @@ func gatewayCmd() { runtimeSame := reflect.DeepEqual(cfg.Agents, newCfg.Agents) && reflect.DeepEqual(cfg.Providers, newCfg.Providers) && reflect.DeepEqual(cfg.Tools, newCfg.Tools) && - reflect.DeepEqual(cfg.Channels, newCfg.Channels) + reflect.DeepEqual(cfg.Channels, newCfg.Channels) && + reflect.DeepEqual(cfg.Gateway.Nodes, newCfg.Gateway.Nodes) if runtimeSame { configureLogging(newCfg) @@ -369,6 +384,7 @@ func gatewayCmd() { } cfg = newCfg runtimecfg.Set(cfg) + configureGatewayNodeP2P(agentLoop, registryServer, cfg) fmt.Println("✓ Config hot-reload applied (logging/metadata only)") return } @@ -386,6 +402,7 @@ func gatewayCmd() { agentLoop = newAgentLoop cfg = newCfg runtimecfg.Set(cfg) + configureGatewayNodeP2P(agentLoop, registryServer, cfg) sentinelService.Stop() sentinelService = sentinel.NewService( getConfigPath(), diff --git a/cmd/clawgo/cmd_node.go b/cmd/clawgo/cmd_node.go index a9cf8d3..9a56e6d 100644 --- a/cmd/clawgo/cmd_node.go +++ b/cmd/clawgo/cmd_node.go @@ -17,6 +17,7 @@ import ( "clawgo/pkg/config" "clawgo/pkg/nodes" "github.com/gorilla/websocket" + "github.com/pion/webrtc/v4" ) type nodeRegisterOptions struct { @@ -42,6 +43,17 @@ type nodeHeartbeatOptions struct { ID string } +type nodeWebRTCSession struct { + info nodes.NodeInfo + opts nodeRegisterOptions + client *http.Client + writeJSON func(interface{}) error + + mu sync.Mutex + pc *webrtc.PeerConnection + dc *webrtc.DataChannel +} + func nodeCmd() { args := os.Args[2:] if len(args) == 0 { @@ -464,6 +476,7 @@ func waitNodeAck(ctx context.Context, acks <-chan nodes.WireAck, errs <-chan err func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, acks chan<- nodes.WireAck, errs chan<- error) { defer close(acks) defer close(errs) + rtc := &nodeWebRTCSession{info: info, opts: opts, client: client, writeJSON: writeJSON} for { select { case <-ctx.Done(): @@ -496,46 +509,128 @@ func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON fun case "node_request": go handleNodeWireRequest(ctx, writeJSON, client, info, opts, msg) case "signal_offer", "signal_answer", "signal_candidate": - fmt.Printf("ℹ Signal received: type=%s from=%s session=%s\n", msg.Type, strings.TrimSpace(msg.From), strings.TrimSpace(msg.Session)) + if err := rtc.handleSignal(ctx, msg); err != nil { + fmt.Printf("Warning: node webrtc signal failed: %v\n", err) + } } } } -func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) { - resp := nodes.Response{ - OK: false, - Code: "invalid_request", - Node: info.ID, - Action: "", - Error: "request missing", - } - if msg.Request != nil { - req := *msg.Request - resp.Action = req.Action - if strings.TrimSpace(opts.Endpoint) == "" { - resp.Error = "node endpoint not configured" - resp.Code = "endpoint_missing" - } else { - if req.Node == "" { - req.Node = info.ID - } - execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, req) - if err != nil { - resp = nodes.Response{ - OK: false, - Code: "transport_error", - Node: info.ID, - Action: req.Action, - Error: err.Error(), - } - } else { - resp = execResp - if strings.TrimSpace(resp.Node) == "" { - resp.Node = info.ID - } - } +func (s *nodeWebRTCSession) handleSignal(ctx context.Context, msg nodes.WireMessage) error { + s.mu.Lock() + defer s.mu.Unlock() + + switch strings.ToLower(strings.TrimSpace(msg.Type)) { + case "signal_offer": + pc, err := s.ensurePeerConnectionLocked() + if err != nil { + return err } + var desc webrtc.SessionDescription + if err := mapWirePayload(msg.Payload, &desc); err != nil { + return err + } + if err := pc.SetRemoteDescription(desc); err != nil { + return err + } + answer, err := pc.CreateAnswer(nil) + if err != nil { + return err + } + if err := pc.SetLocalDescription(answer); err != nil { + return err + } + return s.writeJSON(nodes.WireMessage{ + Type: "signal_answer", + From: s.info.ID, + To: "gateway", + Session: strings.TrimSpace(msg.Session), + Payload: structToWirePayload(*pc.LocalDescription()), + }) + case "signal_candidate": + if s.pc == nil { + return nil + } + var candidate webrtc.ICECandidateInit + if err := mapWirePayload(msg.Payload, &candidate); err != nil { + return err + } + return s.pc.AddICECandidate(candidate) + default: + return nil } +} + +func (s *nodeWebRTCSession) ensurePeerConnectionLocked() (*webrtc.PeerConnection, error) { + if s.pc != nil { + return s.pc, nil + } + config := webrtc.Configuration{} + pc, err := webrtc.NewPeerConnection(config) + if err != nil { + return nil, err + } + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } + _ = s.writeJSON(nodes.WireMessage{ + Type: "signal_candidate", + From: s.info.ID, + To: "gateway", + Session: s.info.ID, + Payload: structToWirePayload(candidate.ToJSON()), + }) + }) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateDisconnected: + s.mu.Lock() + defer s.mu.Unlock() + if s.pc != nil { + _ = s.pc.Close() + } + s.pc = nil + s.dc = nil + } + }) + pc.OnDataChannel(func(dc *webrtc.DataChannel) { + s.mu.Lock() + s.dc = dc + s.mu.Unlock() + dc.OnMessage(func(message webrtc.DataChannelMessage) { + var msg nodes.WireMessage + if err := json.Unmarshal(message.Data, &msg); err != nil { + return + } + if strings.ToLower(strings.TrimSpace(msg.Type)) != "node_request" || msg.Request == nil { + return + } + go s.handleDataChannelRequest(context.Background(), dc, msg) + }) + }) + s.pc = pc + return pc, nil +} + +func (s *nodeWebRTCSession) handleDataChannelRequest(ctx context.Context, dc *webrtc.DataChannel, msg nodes.WireMessage) { + resp := executeNodeRequest(ctx, s.client, s.info, s.opts, msg.Request) + b, err := json.Marshal(nodes.WireMessage{ + Type: "node_response", + ID: msg.ID, + From: s.info.ID, + To: "gateway", + Session: strings.TrimSpace(msg.Session), + Response: &resp, + }) + if err != nil { + return + } + _ = dc.Send(b) +} + +func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) { + resp := executeNodeRequest(ctx, client, info, opts, msg.Request) _ = writeJSON(nodes.WireMessage{ Type: "node_response", ID: msg.ID, @@ -546,6 +641,64 @@ func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) erro }) } +func executeNodeRequest(ctx context.Context, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, req *nodes.Request) nodes.Response { + resp := nodes.Response{ + OK: false, + Code: "invalid_request", + Node: info.ID, + Action: "", + Error: "request missing", + } + if req == nil { + return resp + } + next := *req + resp.Action = next.Action + if strings.TrimSpace(opts.Endpoint) == "" { + resp.Error = "node endpoint not configured" + resp.Code = "endpoint_missing" + return resp + } + if next.Node == "" { + next.Node = info.ID + } + execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, next) + if err != nil { + return nodes.Response{ + OK: false, + Code: "transport_error", + Node: info.ID, + Action: next.Action, + Error: err.Error(), + } + } + if strings.TrimSpace(execResp.Node) == "" { + execResp.Node = info.ID + } + return execResp +} + +func structToWirePayload(v interface{}) map[string]interface{} { + b, _ := json.Marshal(v) + var out map[string]interface{} + _ = json.Unmarshal(b, &out) + if out == nil { + out = map[string]interface{}{} + } + return out +} + +func mapWirePayload(in map[string]interface{}, out interface{}) error { + if len(in) == 0 { + return fmt.Errorf("empty payload") + } + b, err := json.Marshal(in) + if err != nil { + return err + } + return json.Unmarshal(b, out) +} + func postNodeRegister(ctx context.Context, client *http.Client, gatewayBase, token string, info nodes.NodeInfo) error { return postNodeJSON(ctx, client, gatewayBase, token, "/nodes/register", info) } diff --git a/config.example.json b/config.example.json index bdceee0..379d29a 100644 --- a/config.example.json +++ b/config.example.json @@ -279,7 +279,8 @@ "nodes": { "p2p": { "enabled": false, - "transport": "websocket_tunnel" + "transport": "websocket_tunnel", + "stun_servers": [] } } }, diff --git a/go.mod b/go.mod index 5e67b8c..2658967 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,22 @@ require ( github.com/grbit/go-json v0.11.0 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/pion/datachannel v1.5.10 // indirect + github.com/pion/dtls/v3 v3.0.7 // indirect + github.com/pion/ice/v4 v4.0.10 // indirect + github.com/pion/interceptor v0.1.41 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns/v2 v2.0.7 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtcp v1.2.15 // indirect + github.com/pion/rtp v1.8.23 // indirect + github.com/pion/sctp v1.8.40 // indirect + github.com/pion/sdp/v3 v3.0.16 // indirect + github.com/pion/srtp/v3 v3.0.8 // indirect + github.com/pion/stun/v3 v3.0.0 // indirect + github.com/pion/transport/v3 v3.0.8 // indirect + github.com/pion/turn/v4 v4.1.1 // indirect + github.com/pion/webrtc/v4 v4.1.6 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -36,6 +52,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fastjson v1.6.7 // indirect + github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect diff --git a/go.sum b/go.sum index 0da6fa1..e5308e1 100644 --- a/go.sum +++ b/go.sum @@ -86,6 +86,38 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= +github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= +github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= +github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= +github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= +github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4= +github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= +github.com/pion/interceptor v0.1.41 h1:NpvX3HgWIukTf2yTBVjVGFXtpSpWgXjqz7IIpu7NsOw= +github.com/pion/interceptor v0.1.41/go.mod h1:nEt4187unvRXJFyjiw00GKo+kIuXMWQI9K89fsosDLY= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= +github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= +github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= +github.com/pion/rtp v1.8.23 h1:kxX3bN4nM97DPrVBGq5I/Xcl332HnTHeP1Swx3/MCnU= +github.com/pion/rtp v1.8.23/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM= +github.com/pion/sctp v1.8.40 h1:bqbgWYOrUhsYItEnRObUYZuzvOMsVplS3oNgzedBlG8= +github.com/pion/sctp v1.8.40/go.mod h1:SPBBUENXE6ThkEksN5ZavfAhFYll+h+66ZiG6IZQuzo= +github.com/pion/sdp/v3 v3.0.16 h1:0dKzYO6gTAvuLaAKQkC02eCPjMIi4NuAr/ibAwrGDCo= +github.com/pion/sdp/v3 v3.0.16/go.mod h1:9tyKzznud3qiweZcD86kS0ff1pGYB3VX+Bcsmkx6IXo= +github.com/pion/srtp/v3 v3.0.8 h1:RjRrjcIeQsilPzxvdaElN0CpuQZdMvcl9VZ5UY9suUM= +github.com/pion/srtp/v3 v3.0.8/go.mod h1:2Sq6YnDH7/UDCvkSoHSDNDeyBcFgWL0sAVycVbAsXFg= +github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= +github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= +github.com/pion/transport/v3 v3.0.8 h1:oI3myyYnTKUSTthu/NZZ8eu2I5sHbxbUNNFW62olaYc= +github.com/pion/transport/v3 v3.0.8/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= +github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= +github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= +github.com/pion/webrtc/v4 v4.1.6 h1:srHH2HwvCGwPba25EYJgUzgLqCQoXl1VCUnrGQMSzUw= +github.com/pion/webrtc/v4 v4.1.6/go.mod h1:wKecGRlkl3ox/As/MYghJL+b/cVXMEhoPMJWPuGQFhU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -124,6 +156,8 @@ github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZy github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw= github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM= github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 71cd171..6e5f180 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -79,6 +79,13 @@ func (al *AgentLoop) SetConfigPath(path string) { } } +func (al *AgentLoop) SetNodeP2PTransport(t nodes.Transport) { + if al == nil || al.nodeRouter == nil { + return + } + al.nodeRouter.P2P = t +} + // StartupCompactionReport provides startup memory/session maintenance stats. type StartupCompactionReport struct { TotalSessions int `json:"total_sessions"` diff --git a/pkg/api/server.go b/pkg/api/server.go index cb77ad4..28a2451 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -40,6 +40,7 @@ type Server struct { nodeConnMu sync.Mutex nodeConnIDs map[string]string nodeSockets map[string]*nodeSocketConn + nodeWebRTC *nodes.WebRTCTransport gatewayVersion string webuiVersion string configPath string @@ -116,6 +117,9 @@ func (s *Server) SetToolsCatalogHandler(fn func() interface{}) { s.onToolsCatalo func (s *Server) SetWebUIDir(dir string) { s.webUIDir = strings.TrimSpace(dir) } func (s *Server) SetGatewayVersion(v string) { s.gatewayVersion = strings.TrimSpace(v) } func (s *Server) SetWebUIVersion(v string) { s.webuiVersion = strings.TrimSpace(v) } +func (s *Server) SetNodeWebRTCTransport(t *nodes.WebRTCTransport) { + s.nodeWebRTC = t +} func (s *Server) rememberNodeConnection(nodeID, connID string) { nodeID = strings.TrimSpace(nodeID) @@ -142,6 +146,9 @@ func (s *Server) bindNodeSocket(nodeID, connID string, conn *websocket.Conn) { if s.mgr != nil { s.mgr.RegisterWireSender(nodeID, next) } + if s.nodeWebRTC != nil { + s.nodeWebRTC.BindSignaler(nodeID, next) + } if prev != nil && prev.connID != connID { _ = prev.conn.Close() } @@ -165,6 +172,9 @@ func (s *Server) releaseNodeConnection(nodeID, connID string) bool { if s.mgr != nil { s.mgr.RegisterWireSender(nodeID, nil) } + if s.nodeWebRTC != nil { + s.nodeWebRTC.UnbindSignaler(nodeID) + } return true } @@ -373,13 +383,23 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { return } case "signal_offer", "signal_answer", "signal_candidate": + targetID := strings.TrimSpace(msg.To) + if s.nodeWebRTC != nil && (targetID == "" || strings.EqualFold(targetID, "gateway")) { + if err := s.nodeWebRTC.HandleSignal(msg); err != nil { + if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: err.Error()}); err != nil { + return + } + } else if err := writeAck(nodes.WireAck{OK: true, Type: "signaled", ID: msg.ID}); err != nil { + return + } + continue + } if strings.TrimSpace(connectedID) == "" { if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, Error: "node not registered"}); err != nil { return } continue } - targetID := strings.TrimSpace(msg.To) if targetID == "" { if err := writeAck(nodes.WireAck{OK: false, Type: msg.Type, ID: msg.ID, Error: "target node required"}); err != nil { return diff --git a/pkg/config/config.go b/pkg/config/config.go index 800d1fa..4e0edc9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -299,8 +299,9 @@ type GatewayNodesConfig struct { } type GatewayNodesP2PConfig struct { - Enabled bool `json:"enabled"` - Transport string `json:"transport,omitempty"` + Enabled bool `json:"enabled"` + Transport string `json:"transport,omitempty"` + STUNServers []string `json:"stun_servers,omitempty"` } type CronConfig struct { @@ -546,8 +547,9 @@ func DefaultConfig() *Config { Token: generateGatewayToken(), Nodes: GatewayNodesConfig{ P2P: GatewayNodesP2PConfig{ - Enabled: false, - Transport: "websocket_tunnel", + Enabled: false, + Transport: "websocket_tunnel", + STUNServers: []string{}, }, }, }, diff --git a/pkg/config/validate.go b/pkg/config/validate.go index f002d89..d625d00 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -124,6 +124,7 @@ func Validate(cfg *Config) []error { default: errs = append(errs, fmt.Errorf("gateway.nodes.p2p.transport must be one of: websocket_tunnel, webrtc")) } + errs = append(errs, validateNonEmptyStringList("gateway.nodes.p2p.stun_servers", cfg.Gateway.Nodes.P2P.STUNServers)...) if cfg.Cron.MinSleepSec <= 0 { errs = append(errs, fmt.Errorf("cron.min_sleep_sec must be > 0")) } diff --git a/pkg/nodes/webrtc.go b/pkg/nodes/webrtc.go new file mode 100644 index 0000000..548c4df --- /dev/null +++ b/pkg/nodes/webrtc.go @@ -0,0 +1,322 @@ +package nodes + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/pion/webrtc/v4" +) + +type gatewayRTCSession struct { + nodeID string + pc *webrtc.PeerConnection + dc *webrtc.DataChannel + ready chan struct{} + readyMu sync.Once + writeMu sync.Mutex + pending map[string]chan Response + mu sync.Mutex + nextID uint64 +} + +func (s *gatewayRTCSession) markReady() { + s.readyMu.Do(func() { close(s.ready) }) +} + +func (s *gatewayRTCSession) send(msg WireMessage) error { + if s == nil || s.dc == nil { + return fmt.Errorf("webrtc data channel unavailable") + } + b, err := json.Marshal(msg) + if err != nil { + return err + } + s.writeMu.Lock() + defer s.writeMu.Unlock() + return s.dc.Send(b) +} + +func (s *gatewayRTCSession) nextRequestID() string { + s.mu.Lock() + defer s.mu.Unlock() + s.nextID++ + return fmt.Sprintf("rtc-%s-%d", s.nodeID, s.nextID) +} + +type WebRTCTransport struct { + stunServers []string + + mu sync.Mutex + sessions map[string]*gatewayRTCSession + signal map[string]WireSender +} + +func NewWebRTCTransport(stunServers []string) *WebRTCTransport { + out := make([]string, 0, len(stunServers)) + for _, server := range stunServers { + if v := strings.TrimSpace(server); v != "" { + out = append(out, v) + } + } + return &WebRTCTransport{ + stunServers: out, + sessions: map[string]*gatewayRTCSession{}, + signal: map[string]WireSender{}, + } +} + +func (t *WebRTCTransport) Name() string { return "p2p-webrtc" } + +func (t *WebRTCTransport) BindSignaler(nodeID string, sender WireSender) { + nodeID = strings.TrimSpace(nodeID) + if nodeID == "" { + return + } + t.mu.Lock() + defer t.mu.Unlock() + if sender == nil { + delete(t.signal, nodeID) + return + } + t.signal[nodeID] = sender +} + +func (t *WebRTCTransport) UnbindSignaler(nodeID string) { + t.BindSignaler(nodeID, nil) + t.mu.Lock() + session := t.sessions[nodeID] + delete(t.sessions, nodeID) + t.mu.Unlock() + if session != nil && session.pc != nil { + _ = session.pc.Close() + } +} + +func (t *WebRTCTransport) currentSignaler(nodeID string) WireSender { + t.mu.Lock() + defer t.mu.Unlock() + return t.signal[strings.TrimSpace(nodeID)] +} + +func (t *WebRTCTransport) HandleSignal(msg WireMessage) error { + nodeID := strings.TrimSpace(msg.From) + if nodeID == "" { + return fmt.Errorf("signal missing from") + } + session, err := t.ensureSession(nodeID) + if err != nil { + return err + } + switch strings.ToLower(strings.TrimSpace(msg.Type)) { + case "signal_answer": + var desc webrtc.SessionDescription + if err := mapInto(msg.Payload, &desc); err != nil { + return err + } + return session.pc.SetRemoteDescription(desc) + case "signal_candidate": + var candidate webrtc.ICECandidateInit + if err := mapInto(msg.Payload, &candidate); err != nil { + return err + } + return session.pc.AddICECandidate(candidate) + default: + return fmt.Errorf("unsupported signal type: %s", msg.Type) + } +} + +func (t *WebRTCTransport) Send(ctx context.Context, req Request) (Response, error) { + session, err := t.ensureSession(req.Node) + if err != nil { + return Response{OK: false, Code: "p2p_unavailable", Node: req.Node, Action: req.Action, Error: err.Error()}, nil + } + + select { + case <-ctx.Done(): + return Response{}, ctx.Err() + case <-session.ready: + case <-time.After(8 * time.Second): + return Response{OK: false, Code: "p2p_timeout", Node: req.Node, Action: req.Action, Error: "webrtc session not ready"}, nil + } + + reqID := session.nextRequestID() + respCh := make(chan Response, 1) + session.mu.Lock() + session.pending[reqID] = respCh + session.mu.Unlock() + + if err := session.send(WireMessage{ + Type: "node_request", + ID: reqID, + To: req.Node, + Request: &req, + }); err != nil { + session.mu.Lock() + delete(session.pending, reqID) + session.mu.Unlock() + return Response{OK: false, Code: "p2p_send_failed", Node: req.Node, Action: req.Action, Error: err.Error()}, nil + } + + select { + case <-ctx.Done(): + session.mu.Lock() + delete(session.pending, reqID) + session.mu.Unlock() + return Response{}, ctx.Err() + case resp := <-respCh: + return resp, nil + } +} + +func (t *WebRTCTransport) ensureSession(nodeID string) (*gatewayRTCSession, error) { + nodeID = strings.TrimSpace(nodeID) + if nodeID == "" { + return nil, fmt.Errorf("node id required") + } + + t.mu.Lock() + if session := t.sessions[nodeID]; session != nil { + t.mu.Unlock() + return session, nil + } + t.mu.Unlock() + if t.currentSignaler(nodeID) == nil { + return nil, fmt.Errorf("node %s signaling unavailable", nodeID) + } + + config := webrtc.Configuration{} + if len(t.stunServers) > 0 { + config.ICEServers = []webrtc.ICEServer{{URLs: append([]string(nil), t.stunServers...)}} + } + pc, err := webrtc.NewPeerConnection(config) + if err != nil { + return nil, err + } + dc, err := pc.CreateDataChannel("clawgo", nil) + if err != nil { + _ = pc.Close() + return nil, err + } + session := &gatewayRTCSession{ + nodeID: nodeID, + pc: pc, + dc: dc, + ready: make(chan struct{}), + pending: map[string]chan Response{}, + } + + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } + sender := t.currentSignaler(nodeID) + if sender == nil { + return + } + _ = sender.Send(WireMessage{ + Type: "signal_candidate", + From: "gateway", + To: nodeID, + Session: nodeID, + Payload: structToMap(candidate.ToJSON()), + }) + }) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateDisconnected: + t.mu.Lock() + if t.sessions[nodeID] == session { + delete(t.sessions, nodeID) + } + t.mu.Unlock() + } + }) + dc.OnOpen(func() { + session.markReady() + }) + dc.OnMessage(func(message webrtc.DataChannelMessage) { + var msg WireMessage + if err := json.Unmarshal(message.Data, &msg); err != nil { + return + } + if strings.ToLower(strings.TrimSpace(msg.Type)) != "node_response" || msg.Response == nil { + return + } + session.mu.Lock() + respCh := session.pending[msg.ID] + if respCh != nil { + delete(session.pending, msg.ID) + } + session.mu.Unlock() + if respCh != nil { + respCh <- *msg.Response + } + }) + + offer, err := pc.CreateOffer(nil) + if err != nil { + _ = pc.Close() + return nil, err + } + if err := pc.SetLocalDescription(offer); err != nil { + _ = pc.Close() + return nil, err + } + + t.mu.Lock() + if existing := t.sessions[nodeID]; existing != nil { + t.mu.Unlock() + _ = pc.Close() + return existing, nil + } + t.sessions[nodeID] = session + t.mu.Unlock() + + sender := t.currentSignaler(nodeID) + if sender == nil { + t.mu.Lock() + delete(t.sessions, nodeID) + t.mu.Unlock() + _ = pc.Close() + return nil, fmt.Errorf("node %s signaling unavailable", nodeID) + } + if err := sender.Send(WireMessage{ + Type: "signal_offer", + From: "gateway", + To: nodeID, + Session: nodeID, + Payload: structToMap(*pc.LocalDescription()), + }); err != nil { + t.mu.Lock() + delete(t.sessions, nodeID) + t.mu.Unlock() + _ = pc.Close() + return nil, err + } + return session, nil +} + +func structToMap(v interface{}) map[string]interface{} { + b, _ := json.Marshal(v) + var out map[string]interface{} + _ = json.Unmarshal(b, &out) + if out == nil { + out = map[string]interface{}{} + } + return out +} + +func mapInto(in map[string]interface{}, out interface{}) error { + if len(in) == 0 { + return fmt.Errorf("empty payload") + } + b, err := json.Marshal(in) + if err != nil { + return err + } + return json.Unmarshal(b, out) +}