feat: add guarded webrtc node transport

This commit is contained in:
lpf
2026-03-08 22:47:41 +08:00
parent ad2e732f56
commit daaac53f5a
10 changed files with 615 additions and 41 deletions

View File

@@ -129,6 +129,20 @@ func gatewayCmd() {
}
registryServer := api.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, cfg.Gateway.Token, nodes.DefaultManager())
configureGatewayNodeP2P := func(loop *agent.AgentLoop, server *api.Server, runtimeCfg *config.Config) {
if loop == nil || server == nil || runtimeCfg == nil {
return
}
switch {
case runtimeCfg.Gateway.Nodes.P2P.Enabled && strings.EqualFold(strings.TrimSpace(runtimeCfg.Gateway.Nodes.P2P.Transport), "webrtc"):
webrtcTransport := nodes.NewWebRTCTransport(runtimeCfg.Gateway.Nodes.P2P.STUNServers)
loop.SetNodeP2PTransport(webrtcTransport)
server.SetNodeWebRTCTransport(webrtcTransport)
default:
server.SetNodeWebRTCTransport(nil)
}
}
configureGatewayNodeP2P(agentLoop, registryServer, cfg)
registryServer.SetGatewayVersion(version)
registryServer.SetWebUIVersion(version)
registryServer.SetConfigPath(getConfigPath())
@@ -343,7 +357,8 @@ func gatewayCmd() {
runtimeSame := reflect.DeepEqual(cfg.Agents, newCfg.Agents) &&
reflect.DeepEqual(cfg.Providers, newCfg.Providers) &&
reflect.DeepEqual(cfg.Tools, newCfg.Tools) &&
reflect.DeepEqual(cfg.Channels, newCfg.Channels)
reflect.DeepEqual(cfg.Channels, newCfg.Channels) &&
reflect.DeepEqual(cfg.Gateway.Nodes, newCfg.Gateway.Nodes)
if runtimeSame {
configureLogging(newCfg)
@@ -369,6 +384,7 @@ func gatewayCmd() {
}
cfg = newCfg
runtimecfg.Set(cfg)
configureGatewayNodeP2P(agentLoop, registryServer, cfg)
fmt.Println("✓ Config hot-reload applied (logging/metadata only)")
return
}
@@ -386,6 +402,7 @@ func gatewayCmd() {
agentLoop = newAgentLoop
cfg = newCfg
runtimecfg.Set(cfg)
configureGatewayNodeP2P(agentLoop, registryServer, cfg)
sentinelService.Stop()
sentinelService = sentinel.NewService(
getConfigPath(),

View File

@@ -17,6 +17,7 @@ import (
"clawgo/pkg/config"
"clawgo/pkg/nodes"
"github.com/gorilla/websocket"
"github.com/pion/webrtc/v4"
)
type nodeRegisterOptions struct {
@@ -42,6 +43,17 @@ type nodeHeartbeatOptions struct {
ID string
}
type nodeWebRTCSession struct {
info nodes.NodeInfo
opts nodeRegisterOptions
client *http.Client
writeJSON func(interface{}) error
mu sync.Mutex
pc *webrtc.PeerConnection
dc *webrtc.DataChannel
}
func nodeCmd() {
args := os.Args[2:]
if len(args) == 0 {
@@ -464,6 +476,7 @@ func waitNodeAck(ctx context.Context, acks <-chan nodes.WireAck, errs <-chan err
func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, acks chan<- nodes.WireAck, errs chan<- error) {
defer close(acks)
defer close(errs)
rtc := &nodeWebRTCSession{info: info, opts: opts, client: client, writeJSON: writeJSON}
for {
select {
case <-ctx.Done():
@@ -496,46 +509,128 @@ func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON fun
case "node_request":
go handleNodeWireRequest(ctx, writeJSON, client, info, opts, msg)
case "signal_offer", "signal_answer", "signal_candidate":
fmt.Printf(" Signal received: type=%s from=%s session=%s\n", msg.Type, strings.TrimSpace(msg.From), strings.TrimSpace(msg.Session))
if err := rtc.handleSignal(ctx, msg); err != nil {
fmt.Printf("Warning: node webrtc signal failed: %v\n", err)
}
}
}
}
func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) {
resp := nodes.Response{
OK: false,
Code: "invalid_request",
Node: info.ID,
Action: "",
Error: "request missing",
}
if msg.Request != nil {
req := *msg.Request
resp.Action = req.Action
if strings.TrimSpace(opts.Endpoint) == "" {
resp.Error = "node endpoint not configured"
resp.Code = "endpoint_missing"
} else {
if req.Node == "" {
req.Node = info.ID
}
execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, req)
if err != nil {
resp = nodes.Response{
OK: false,
Code: "transport_error",
Node: info.ID,
Action: req.Action,
Error: err.Error(),
}
} else {
resp = execResp
if strings.TrimSpace(resp.Node) == "" {
resp.Node = info.ID
}
}
func (s *nodeWebRTCSession) handleSignal(ctx context.Context, msg nodes.WireMessage) error {
s.mu.Lock()
defer s.mu.Unlock()
switch strings.ToLower(strings.TrimSpace(msg.Type)) {
case "signal_offer":
pc, err := s.ensurePeerConnectionLocked()
if err != nil {
return err
}
var desc webrtc.SessionDescription
if err := mapWirePayload(msg.Payload, &desc); err != nil {
return err
}
if err := pc.SetRemoteDescription(desc); err != nil {
return err
}
answer, err := pc.CreateAnswer(nil)
if err != nil {
return err
}
if err := pc.SetLocalDescription(answer); err != nil {
return err
}
return s.writeJSON(nodes.WireMessage{
Type: "signal_answer",
From: s.info.ID,
To: "gateway",
Session: strings.TrimSpace(msg.Session),
Payload: structToWirePayload(*pc.LocalDescription()),
})
case "signal_candidate":
if s.pc == nil {
return nil
}
var candidate webrtc.ICECandidateInit
if err := mapWirePayload(msg.Payload, &candidate); err != nil {
return err
}
return s.pc.AddICECandidate(candidate)
default:
return nil
}
}
func (s *nodeWebRTCSession) ensurePeerConnectionLocked() (*webrtc.PeerConnection, error) {
if s.pc != nil {
return s.pc, nil
}
config := webrtc.Configuration{}
pc, err := webrtc.NewPeerConnection(config)
if err != nil {
return nil, err
}
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
_ = s.writeJSON(nodes.WireMessage{
Type: "signal_candidate",
From: s.info.ID,
To: "gateway",
Session: s.info.ID,
Payload: structToWirePayload(candidate.ToJSON()),
})
})
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateDisconnected:
s.mu.Lock()
defer s.mu.Unlock()
if s.pc != nil {
_ = s.pc.Close()
}
s.pc = nil
s.dc = nil
}
})
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
s.mu.Lock()
s.dc = dc
s.mu.Unlock()
dc.OnMessage(func(message webrtc.DataChannelMessage) {
var msg nodes.WireMessage
if err := json.Unmarshal(message.Data, &msg); err != nil {
return
}
if strings.ToLower(strings.TrimSpace(msg.Type)) != "node_request" || msg.Request == nil {
return
}
go s.handleDataChannelRequest(context.Background(), dc, msg)
})
})
s.pc = pc
return pc, nil
}
func (s *nodeWebRTCSession) handleDataChannelRequest(ctx context.Context, dc *webrtc.DataChannel, msg nodes.WireMessage) {
resp := executeNodeRequest(ctx, s.client, s.info, s.opts, msg.Request)
b, err := json.Marshal(nodes.WireMessage{
Type: "node_response",
ID: msg.ID,
From: s.info.ID,
To: "gateway",
Session: strings.TrimSpace(msg.Session),
Response: &resp,
})
if err != nil {
return
}
_ = dc.Send(b)
}
func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) {
resp := executeNodeRequest(ctx, client, info, opts, msg.Request)
_ = writeJSON(nodes.WireMessage{
Type: "node_response",
ID: msg.ID,
@@ -546,6 +641,64 @@ func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) erro
})
}
func executeNodeRequest(ctx context.Context, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, req *nodes.Request) nodes.Response {
resp := nodes.Response{
OK: false,
Code: "invalid_request",
Node: info.ID,
Action: "",
Error: "request missing",
}
if req == nil {
return resp
}
next := *req
resp.Action = next.Action
if strings.TrimSpace(opts.Endpoint) == "" {
resp.Error = "node endpoint not configured"
resp.Code = "endpoint_missing"
return resp
}
if next.Node == "" {
next.Node = info.ID
}
execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, next)
if err != nil {
return nodes.Response{
OK: false,
Code: "transport_error",
Node: info.ID,
Action: next.Action,
Error: err.Error(),
}
}
if strings.TrimSpace(execResp.Node) == "" {
execResp.Node = info.ID
}
return execResp
}
func structToWirePayload(v interface{}) map[string]interface{} {
b, _ := json.Marshal(v)
var out map[string]interface{}
_ = json.Unmarshal(b, &out)
if out == nil {
out = map[string]interface{}{}
}
return out
}
func mapWirePayload(in map[string]interface{}, out interface{}) error {
if len(in) == 0 {
return fmt.Errorf("empty payload")
}
b, err := json.Marshal(in)
if err != nil {
return err
}
return json.Unmarshal(b, out)
}
func postNodeRegister(ctx context.Context, client *http.Client, gatewayBase, token string, info nodes.NodeInfo) error {
return postNodeJSON(ctx, client, gatewayBase, token, "/nodes/register", info)
}