feat: tunnel node requests over websocket p2p channel

This commit is contained in:
lpf
2026-03-08 22:29:40 +08:00
parent abce524114
commit 2aa7db9619
7 changed files with 370 additions and 46 deletions

View File

@@ -11,6 +11,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
"clawgo/pkg/config"
@@ -21,6 +22,7 @@ import (
type nodeRegisterOptions struct {
GatewayBase string
Token string
NodeToken string
ID string
Name string
Endpoint string
@@ -67,6 +69,7 @@ func printNodeHelp() {
fmt.Println("Register options:")
fmt.Println(" --gateway <url> Gateway base URL, e.g. http://host:18790")
fmt.Println(" --token <value> Gateway token (optional when gateway.token is empty)")
fmt.Println(" --node-token <value> Bearer token for this node endpoint (optional)")
fmt.Println(" --id <value> Node ID (default: hostname)")
fmt.Println(" --name <value> Node name (default: hostname)")
fmt.Println(" --endpoint <url> Public endpoint of this node")
@@ -166,6 +169,12 @@ func parseNodeRegisterArgs(args []string, cfg *config.Config) (nodeRegisterOptio
return opts, err
}
opts.Token = v
case "--node-token":
v, err := next()
if err != nil {
return opts, err
}
opts.NodeToken = v
case "--id":
v, err := next()
if err != nil {
@@ -307,6 +316,7 @@ func buildNodeInfo(opts nodeRegisterOptions) nodes.NodeInfo {
Arch: strings.TrimSpace(opts.Arch),
Version: strings.TrimSpace(opts.Version),
Endpoint: strings.TrimSpace(opts.Endpoint),
Token: strings.TrimSpace(opts.NodeToken),
Capabilities: opts.Capabilities,
Actions: append([]string(nil), opts.Actions...),
Models: append([]string(nil), opts.Models...),
@@ -353,12 +363,27 @@ func runNodeHeartbeatSocket(ctx context.Context, opts nodeRegisterOptions, info
return err
}
defer conn.Close()
var writeMu sync.Mutex
writeJSON := func(v interface{}) error {
writeMu.Lock()
defer writeMu.Unlock()
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return conn.WriteJSON(v)
}
writePing := func() error {
writeMu.Lock()
defer writeMu.Unlock()
return conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(10*time.Second))
}
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteJSON(nodes.WireMessage{Type: "register", Node: &info}); err != nil {
if err := writeJSON(nodes.WireMessage{Type: "register", Node: &info}); err != nil {
return err
}
if err := readNodeAck(conn, "registered", info.ID); err != nil {
acks := make(chan nodes.WireAck, 8)
errs := make(chan error, 1)
client := &http.Client{Timeout: 20 * time.Second}
go readNodeSocketLoop(ctx, conn, writeJSON, client, info, opts, acks, errs)
if err := waitNodeAck(ctx, acks, errs, "registered", info.ID); err != nil {
return err
}
fmt.Printf("✓ Node socket connected: %s\n", info.ID)
@@ -372,16 +397,20 @@ func runNodeHeartbeatSocket(ctx context.Context, opts nodeRegisterOptions, info
select {
case <-ctx.Done():
return nil
case err := <-errs:
if err != nil {
return err
}
return nil
case <-pingTicker.C:
if err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(10*time.Second)); err != nil {
if err := writePing(); err != nil {
return err
}
case <-ticker.C:
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteJSON(nodes.WireMessage{Type: "heartbeat", ID: info.ID}); err != nil {
if err := writeJSON(nodes.WireMessage{Type: "heartbeat", ID: info.ID}); err != nil {
return err
}
if err := readNodeAck(conn, "heartbeat", info.ID); err != nil {
if err := waitNodeAck(ctx, acks, errs, "heartbeat", info.ID); err != nil {
return err
}
fmt.Printf("✓ Heartbeat ok: %s\n", info.ID)
@@ -403,26 +432,118 @@ func nodeSocketPingInterval(heartbeatSec int) time.Duration {
return interval
}
func readNodeAck(conn *websocket.Conn, expectedType, id string) error {
_ = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
var ack nodes.WireAck
if err := conn.ReadJSON(&ack); err != nil {
return err
}
if !ack.OK {
if strings.TrimSpace(ack.Error) == "" {
ack.Error = "unknown websocket error"
func waitNodeAck(ctx context.Context, acks <-chan nodes.WireAck, errs <-chan error, expectedType, id string) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errs:
if err == nil {
return context.Canceled
}
return err
case ack := <-acks:
if !ack.OK {
if strings.TrimSpace(ack.Error) == "" {
ack.Error = "unknown websocket error"
}
return fmt.Errorf("%s", ack.Error)
}
ackType := strings.ToLower(strings.TrimSpace(ack.Type))
if expectedType != "" && ackType != strings.ToLower(strings.TrimSpace(expectedType)) {
continue
}
if strings.TrimSpace(id) != "" && strings.TrimSpace(ack.ID) != "" && strings.TrimSpace(ack.ID) != strings.TrimSpace(id) {
continue
}
return nil
}
return fmt.Errorf("%s", ack.Error)
}
ackType := strings.ToLower(strings.TrimSpace(ack.Type))
if expectedType != "" && ackType != strings.ToLower(strings.TrimSpace(expectedType)) {
return fmt.Errorf("unexpected websocket ack type: %s", ack.Type)
}
func readNodeSocketLoop(ctx context.Context, conn *websocket.Conn, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, acks chan<- nodes.WireAck, errs chan<- error) {
defer close(acks)
defer close(errs)
for {
select {
case <-ctx.Done():
errs <- nil
return
default:
}
_ = conn.SetReadDeadline(time.Now().Add(90 * time.Second))
_, data, err := conn.ReadMessage()
if err != nil {
errs <- err
return
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
continue
}
if _, hasOK := raw["ok"]; hasOK {
var ack nodes.WireAck
if err := json.Unmarshal(data, &ack); err == nil {
acks <- ack
}
continue
}
var msg nodes.WireMessage
if err := json.Unmarshal(data, &msg); err != nil {
continue
}
switch strings.ToLower(strings.TrimSpace(msg.Type)) {
case "node_request":
go handleNodeWireRequest(ctx, writeJSON, client, info, opts, msg)
case "signal_offer", "signal_answer", "signal_candidate":
fmt.Printf(" Signal received: type=%s from=%s session=%s\n", msg.Type, strings.TrimSpace(msg.From), strings.TrimSpace(msg.Session))
}
}
if strings.TrimSpace(id) != "" && strings.TrimSpace(ack.ID) != "" && strings.TrimSpace(ack.ID) != strings.TrimSpace(id) {
return fmt.Errorf("unexpected websocket ack id: %s", ack.ID)
}
func handleNodeWireRequest(ctx context.Context, writeJSON func(interface{}) error, client *http.Client, info nodes.NodeInfo, opts nodeRegisterOptions, msg nodes.WireMessage) {
resp := nodes.Response{
OK: false,
Code: "invalid_request",
Node: info.ID,
Action: "",
Error: "request missing",
}
return nil
if msg.Request != nil {
req := *msg.Request
resp.Action = req.Action
if strings.TrimSpace(opts.Endpoint) == "" {
resp.Error = "node endpoint not configured"
resp.Code = "endpoint_missing"
} else {
if req.Node == "" {
req.Node = info.ID
}
execResp, err := nodes.DoEndpointRequest(ctx, client, opts.Endpoint, opts.NodeToken, req)
if err != nil {
resp = nodes.Response{
OK: false,
Code: "transport_error",
Node: info.ID,
Action: req.Action,
Error: err.Error(),
}
} else {
resp = execResp
if strings.TrimSpace(resp.Node) == "" {
resp.Node = info.ID
}
}
}
}
_ = writeJSON(nodes.WireMessage{
Type: "node_response",
ID: msg.ID,
From: info.ID,
To: strings.TrimSpace(msg.From),
Session: strings.TrimSpace(msg.Session),
Response: &resp,
})
}
func postNodeRegister(ctx context.Context, client *http.Client, gatewayBase, token string, info nodes.NodeInfo) error {