mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 05:37:29 +08:00
feat: add guarded webrtc node transport
This commit is contained in:
@@ -129,6 +129,20 @@ func gatewayCmd() {
|
||||
}
|
||||
|
||||
registryServer := api.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, cfg.Gateway.Token, nodes.DefaultManager())
|
||||
configureGatewayNodeP2P := func(loop *agent.AgentLoop, server *api.Server, runtimeCfg *config.Config) {
|
||||
if loop == nil || server == nil || runtimeCfg == nil {
|
||||
return
|
||||
}
|
||||
switch {
|
||||
case runtimeCfg.Gateway.Nodes.P2P.Enabled && strings.EqualFold(strings.TrimSpace(runtimeCfg.Gateway.Nodes.P2P.Transport), "webrtc"):
|
||||
webrtcTransport := nodes.NewWebRTCTransport(runtimeCfg.Gateway.Nodes.P2P.STUNServers)
|
||||
loop.SetNodeP2PTransport(webrtcTransport)
|
||||
server.SetNodeWebRTCTransport(webrtcTransport)
|
||||
default:
|
||||
server.SetNodeWebRTCTransport(nil)
|
||||
}
|
||||
}
|
||||
configureGatewayNodeP2P(agentLoop, registryServer, cfg)
|
||||
registryServer.SetGatewayVersion(version)
|
||||
registryServer.SetWebUIVersion(version)
|
||||
registryServer.SetConfigPath(getConfigPath())
|
||||
@@ -343,7 +357,8 @@ func gatewayCmd() {
|
||||
runtimeSame := reflect.DeepEqual(cfg.Agents, newCfg.Agents) &&
|
||||
reflect.DeepEqual(cfg.Providers, newCfg.Providers) &&
|
||||
reflect.DeepEqual(cfg.Tools, newCfg.Tools) &&
|
||||
reflect.DeepEqual(cfg.Channels, newCfg.Channels)
|
||||
reflect.DeepEqual(cfg.Channels, newCfg.Channels) &&
|
||||
reflect.DeepEqual(cfg.Gateway.Nodes, newCfg.Gateway.Nodes)
|
||||
|
||||
if runtimeSame {
|
||||
configureLogging(newCfg)
|
||||
@@ -369,6 +384,7 @@ func gatewayCmd() {
|
||||
}
|
||||
cfg = newCfg
|
||||
runtimecfg.Set(cfg)
|
||||
configureGatewayNodeP2P(agentLoop, registryServer, cfg)
|
||||
fmt.Println("✓ Config hot-reload applied (logging/metadata only)")
|
||||
return
|
||||
}
|
||||
@@ -386,6 +402,7 @@ func gatewayCmd() {
|
||||
agentLoop = newAgentLoop
|
||||
cfg = newCfg
|
||||
runtimecfg.Set(cfg)
|
||||
configureGatewayNodeP2P(agentLoop, registryServer, cfg)
|
||||
sentinelService.Stop()
|
||||
sentinelService = sentinel.NewService(
|
||||
getConfigPath(),
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"clawgo/pkg/config"
|
||||
"clawgo/pkg/nodes"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type nodeRegisterOptions struct {
|
||||
@@ -42,6 +43,17 @@ type nodeHeartbeatOptions struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
type nodeWebRTCSession struct {
|
||||
info nodes.NodeInfo
|
||||
opts nodeRegisterOptions
|
||||
client *http.Client
|
||||
writeJSON func(interface{}) error
|
||||
|
||||
mu sync.Mutex
|
||||
pc *webrtc.PeerConnection
|
||||
dc *webrtc.DataChannel
|
||||
}
|
||||
|
||||
func nodeCmd() {
|
||||
args := os.Args[2:]
|
||||
if len(args) == 0 {
|
||||
@@ -464,6 +476,7 @@ func waitNodeAck(ctx context.Context, acks <-chan nodes.WireAck, errs <-chan err
|
||||
func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, acks chan<- nodes.WireAck, errs chan<- error) {
|
||||
defer close(acks)
|
||||
defer close(errs)
|
||||
rtc := &nodeWebRTCSession{info: info, opts: opts, client: client, writeJSON: writeJSON}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -496,46 +509,128 @@ func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON fun
|
||||
case "node_request":
|
||||
go handleNodeWireRequest(ctx, writeJSON, client, info, opts, msg)
|
||||
case "signal_offer", "signal_answer", "signal_candidate":
|
||||
fmt.Printf("ℹ Signal received: type=%s from=%s session=%s\n", msg.Type, strings.TrimSpace(msg.From), strings.TrimSpace(msg.Session))
|
||||
if err := rtc.handleSignal(ctx, msg); err != nil {
|
||||
fmt.Printf("Warning: node webrtc signal failed: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) {
|
||||
resp := nodes.Response{
|
||||
OK: false,
|
||||
Code: "invalid_request",
|
||||
Node: info.ID,
|
||||
Action: "",
|
||||
Error: "request missing",
|
||||
}
|
||||
if msg.Request != nil {
|
||||
req := *msg.Request
|
||||
resp.Action = req.Action
|
||||
if strings.TrimSpace(opts.Endpoint) == "" {
|
||||
resp.Error = "node endpoint not configured"
|
||||
resp.Code = "endpoint_missing"
|
||||
} else {
|
||||
if req.Node == "" {
|
||||
req.Node = info.ID
|
||||
}
|
||||
execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, req)
|
||||
if err != nil {
|
||||
resp = nodes.Response{
|
||||
OK: false,
|
||||
Code: "transport_error",
|
||||
Node: info.ID,
|
||||
Action: req.Action,
|
||||
Error: err.Error(),
|
||||
}
|
||||
} else {
|
||||
resp = execResp
|
||||
if strings.TrimSpace(resp.Node) == "" {
|
||||
resp.Node = info.ID
|
||||
}
|
||||
}
|
||||
func (s *nodeWebRTCSession) handleSignal(ctx context.Context, msg nodes.WireMessage) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(msg.Type)) {
|
||||
case "signal_offer":
|
||||
pc, err := s.ensurePeerConnectionLocked()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var desc webrtc.SessionDescription
|
||||
if err := mapWirePayload(msg.Payload, &desc); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pc.SetRemoteDescription(desc); err != nil {
|
||||
return err
|
||||
}
|
||||
answer, err := pc.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pc.SetLocalDescription(answer); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(nodes.WireMessage{
|
||||
Type: "signal_answer",
|
||||
From: s.info.ID,
|
||||
To: "gateway",
|
||||
Session: strings.TrimSpace(msg.Session),
|
||||
Payload: structToWirePayload(*pc.LocalDescription()),
|
||||
})
|
||||
case "signal_candidate":
|
||||
if s.pc == nil {
|
||||
return nil
|
||||
}
|
||||
var candidate webrtc.ICECandidateInit
|
||||
if err := mapWirePayload(msg.Payload, &candidate); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.pc.AddICECandidate(candidate)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *nodeWebRTCSession) ensurePeerConnectionLocked() (*webrtc.PeerConnection, error) {
|
||||
if s.pc != nil {
|
||||
return s.pc, nil
|
||||
}
|
||||
config := webrtc.Configuration{}
|
||||
pc, err := webrtc.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
_ = s.writeJSON(nodes.WireMessage{
|
||||
Type: "signal_candidate",
|
||||
From: s.info.ID,
|
||||
To: "gateway",
|
||||
Session: s.info.ID,
|
||||
Payload: structToWirePayload(candidate.ToJSON()),
|
||||
})
|
||||
})
|
||||
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateDisconnected:
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.pc != nil {
|
||||
_ = s.pc.Close()
|
||||
}
|
||||
s.pc = nil
|
||||
s.dc = nil
|
||||
}
|
||||
})
|
||||
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
s.mu.Lock()
|
||||
s.dc = dc
|
||||
s.mu.Unlock()
|
||||
dc.OnMessage(func(message webrtc.DataChannelMessage) {
|
||||
var msg nodes.WireMessage
|
||||
if err := json.Unmarshal(message.Data, &msg); err != nil {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(msg.Type)) != "node_request" || msg.Request == nil {
|
||||
return
|
||||
}
|
||||
go s.handleDataChannelRequest(context.Background(), dc, msg)
|
||||
})
|
||||
})
|
||||
s.pc = pc
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
func (s *nodeWebRTCSession) handleDataChannelRequest(ctx context.Context, dc *webrtc.DataChannel, msg nodes.WireMessage) {
|
||||
resp := executeNodeRequest(ctx, s.client, s.info, s.opts, msg.Request)
|
||||
b, err := json.Marshal(nodes.WireMessage{
|
||||
Type: "node_response",
|
||||
ID: msg.ID,
|
||||
From: s.info.ID,
|
||||
To: "gateway",
|
||||
Session: strings.TrimSpace(msg.Session),
|
||||
Response: &resp,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = dc.Send(b)
|
||||
}
|
||||
|
||||
func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) {
|
||||
resp := executeNodeRequest(ctx, client, info, opts, msg.Request)
|
||||
_ = writeJSON(nodes.WireMessage{
|
||||
Type: "node_response",
|
||||
ID: msg.ID,
|
||||
@@ -546,6 +641,64 @@ func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) erro
|
||||
})
|
||||
}
|
||||
|
||||
func executeNodeRequest(ctx context.Context, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, req *nodes.Request) nodes.Response {
|
||||
resp := nodes.Response{
|
||||
OK: false,
|
||||
Code: "invalid_request",
|
||||
Node: info.ID,
|
||||
Action: "",
|
||||
Error: "request missing",
|
||||
}
|
||||
if req == nil {
|
||||
return resp
|
||||
}
|
||||
next := *req
|
||||
resp.Action = next.Action
|
||||
if strings.TrimSpace(opts.Endpoint) == "" {
|
||||
resp.Error = "node endpoint not configured"
|
||||
resp.Code = "endpoint_missing"
|
||||
return resp
|
||||
}
|
||||
if next.Node == "" {
|
||||
next.Node = info.ID
|
||||
}
|
||||
execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, next)
|
||||
if err != nil {
|
||||
return nodes.Response{
|
||||
OK: false,
|
||||
Code: "transport_error",
|
||||
Node: info.ID,
|
||||
Action: next.Action,
|
||||
Error: err.Error(),
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(execResp.Node) == "" {
|
||||
execResp.Node = info.ID
|
||||
}
|
||||
return execResp
|
||||
}
|
||||
|
||||
func structToWirePayload(v interface{}) map[string]interface{} {
|
||||
b, _ := json.Marshal(v)
|
||||
var out map[string]interface{}
|
||||
_ = json.Unmarshal(b, &out)
|
||||
if out == nil {
|
||||
out = map[string]interface{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapWirePayload(in map[string]interface{}, out interface{}) error {
|
||||
if len(in) == 0 {
|
||||
return fmt.Errorf("empty payload")
|
||||
}
|
||||
b, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, out)
|
||||
}
|
||||
|
||||
func postNodeRegister(ctx context.Context, client *http.Client, gatewayBase, token string, info nodes.NodeInfo) error {
|
||||
return postNodeJSON(ctx, client, gatewayBase, token, "/nodes/register", info)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user