diff --git a/pkg/nodes/transport_test.go b/pkg/nodes/transport_test.go index 33f2842..92fbf7f 100644 --- a/pkg/nodes/transport_test.go +++ b/pkg/nodes/transport_test.go @@ -2,8 +2,12 @@ package nodes import ( "context" + "encoding/json" + "sync" "testing" "time" + + "github.com/pion/webrtc/v4" ) type captureWireSender struct { @@ -69,3 +73,140 @@ func TestWebsocketP2PTransportSend(t *testing.T) { t.Fatalf("unexpected payload: %+v", resp.Payload) } } + +func TestWebRTCTransportSendEndToEnd(t *testing.T) { + t.Parallel() + + transport := NewWebRTCTransport(nil) + nodeID := "edge-webrtc" + + var remotePC *webrtc.PeerConnection + var remoteMu sync.Mutex + handleRemoteSignal := func(msg WireMessage) error { + remoteMu.Lock() + defer remoteMu.Unlock() + + ensureRemote := func() (*webrtc.PeerConnection, error) { + if remotePC != nil { + return remotePC, nil + } + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + return nil, err + } + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } + _ = transport.HandleSignal(WireMessage{ + Type: "signal_candidate", + From: nodeID, + To: "gateway", + Session: nodeID, + Payload: structToMap(candidate.ToJSON()), + }) + }) + pc.OnDataChannel(func(dc *webrtc.DataChannel) { + dc.OnMessage(func(message webrtc.DataChannelMessage) { + var wire WireMessage + if err := json.Unmarshal(message.Data, &wire); err != nil { + return + } + if wire.Type != "node_request" || wire.Request == nil { + return + } + resp := Response{ + OK: true, + Code: "ok", + Node: nodeID, + Action: wire.Request.Action, + Payload: map[string]interface{}{ + "status": "done-over-webrtc", + }, + } + b, err := json.Marshal(WireMessage{ + Type: "node_response", + ID: wire.ID, + From: nodeID, + To: "gateway", + Response: &resp, + }) + if err != nil { + return + } + _ = dc.Send(b) + }) + }) + remotePC = pc + return remotePC, nil + } + + pc, err := ensureRemote() + if err != nil { + return err + } + + switch msg.Type { + case "signal_offer": + var desc webrtc.SessionDescription + if err := mapInto(msg.Payload, &desc); err != nil { + return err + } + if err := pc.SetRemoteDescription(desc); err != nil { + return err + } + answer, err := pc.CreateAnswer(nil) + if err != nil { + return err + } + if err := pc.SetLocalDescription(answer); err != nil { + return err + } + return transport.HandleSignal(WireMessage{ + Type: "signal_answer", + From: nodeID, + To: "gateway", + Session: nodeID, + Payload: structToMap(*pc.LocalDescription()), + }) + case "signal_candidate": + var candidate webrtc.ICECandidateInit + if err := mapInto(msg.Payload, &candidate); err != nil { + return err + } + return pc.AddICECandidate(candidate) + default: + return nil + } + } + + transport.BindSignaler(nodeID, &captureWireSender{ + send: handleRemoteSignal, + }) + defer func() { + transport.UnbindSignaler(nodeID) + remoteMu.Lock() + defer remoteMu.Unlock() + if remotePC != nil { + _ = remotePC.Close() + } + }() + + resp, err := transport.Send(context.Background(), Request{ + Action: "run", + Node: nodeID, + Args: map[string]interface{}{"command": []string{"echo", "ok"}}, + }) + if err != nil { + t.Fatalf("webrtc transport send failed: %v", err) + } + if !resp.OK { + t.Fatalf("expected ok response, got %+v", resp) + } + if resp.Payload["status"] != "done-over-webrtc" { + t.Fatalf("unexpected payload: %+v", resp.Payload) + } + if resp.Payload["used_transport"] != nil { + t.Fatalf("transport annotations should not be added at transport layer: %+v", resp.Payload) + } +}