feat: unify websocket runtime and harden node control

This commit is contained in:
lpf
2026-03-08 22:22:49 +08:00
parent 7e67619826
commit 4172a57b39
15 changed files with 2082 additions and 124 deletions

View File

@@ -96,6 +96,7 @@ func printHelp() {
fmt.Println(" config Get/set config values")
fmt.Println(" cron Manage scheduled tasks")
fmt.Println(" channel Test and manage messaging channels")
fmt.Println(" node Register remote node metadata and heartbeat")
fmt.Println(" skills Manage skills (install, list, remove)")
fmt.Println(" uninstall Uninstall clawgo components")
fmt.Println(" version Show version information")

551
cmd/clawgo/cmd_node.go Normal file
View File

@@ -0,0 +1,551 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"time"
"clawgo/pkg/config"
"clawgo/pkg/nodes"
"github.com/gorilla/websocket"
)
type nodeRegisterOptions struct {
GatewayBase string
Token string
ID string
Name string
Endpoint string
OS string
Arch string
Version string
Actions []string
Models []string
Capabilities nodes.Capabilities
Watch bool
HeartbeatSec int
}
type nodeHeartbeatOptions struct {
GatewayBase string
Token string
ID string
}
func nodeCmd() {
args := os.Args[2:]
if len(args) == 0 {
printNodeHelp()
return
}
switch strings.ToLower(strings.TrimSpace(args[0])) {
case "register":
nodeRegisterCmd(args[1:])
case "heartbeat":
nodeHeartbeatCmd(args[1:])
case "help", "--help", "-h":
printNodeHelp()
default:
fmt.Printf("Unknown node command: %s\n", args[0])
printNodeHelp()
}
}
func printNodeHelp() {
fmt.Println("Node commands:")
fmt.Println(" clawgo node register [options]")
fmt.Println(" clawgo node heartbeat [options]")
fmt.Println()
fmt.Println("Register options:")
fmt.Println(" --gateway <url> Gateway base URL, e.g. http://host:18790")
fmt.Println(" --token <value> Gateway token (optional when gateway.token is empty)")
fmt.Println(" --id <value> Node ID (default: hostname)")
fmt.Println(" --name <value> Node name (default: hostname)")
fmt.Println(" --endpoint <url> Public endpoint of this node")
fmt.Println(" --os <value> Reported OS (default: current runtime)")
fmt.Println(" --arch <value> Reported arch (default: current runtime)")
fmt.Println(" --version <value> Reported node version (default: current clawgo version)")
fmt.Println(" --actions <csv> Supported actions, e.g. run,agent_task")
fmt.Println(" --models <csv> Supported models, e.g. gpt-4o-mini")
fmt.Println(" --capabilities <csv> Capability flags: run,invoke,model,camera,screen,location,canvas")
fmt.Println(" --watch Keep a websocket connection open and send heartbeats")
fmt.Println(" --heartbeat-sec <n> Heartbeat interval in seconds when --watch is set (default: 30)")
fmt.Println()
fmt.Println("Heartbeat options:")
fmt.Println(" --gateway <url> Gateway base URL")
fmt.Println(" --token <value> Gateway token")
fmt.Println(" --id <value> Node ID")
}
func nodeRegisterCmd(args []string) {
cfg, _ := loadConfig()
opts, err := parseNodeRegisterArgs(args, cfg)
if err != nil {
fmt.Printf("Error: %v\n", err)
printNodeHelp()
os.Exit(1)
}
client := &http.Client{Timeout: 20 * time.Second}
info := buildNodeInfo(opts)
ctx := context.Background()
if err := postNodeRegister(ctx, client, opts.GatewayBase, opts.Token, info); err != nil {
fmt.Printf("Error registering node: %v\n", err)
os.Exit(1)
}
fmt.Printf("✓ Node registered: %s -> %s\n", info.ID, opts.GatewayBase)
if !opts.Watch {
return
}
fmt.Printf("✓ Heartbeat loop started: every %ds\n", opts.HeartbeatSec)
if err := runNodeHeartbeatLoop(client, opts, info); err != nil {
fmt.Printf("Heartbeat loop stopped: %v\n", err)
os.Exit(1)
}
}
func nodeHeartbeatCmd(args []string) {
cfg, _ := loadConfig()
opts, err := parseNodeHeartbeatArgs(args, cfg)
if err != nil {
fmt.Printf("Error: %v\n", err)
printNodeHelp()
os.Exit(1)
}
client := &http.Client{Timeout: 20 * time.Second}
if err := postNodeHeartbeat(context.Background(), client, opts.GatewayBase, opts.Token, opts.ID); err != nil {
fmt.Printf("Error sending heartbeat: %v\n", err)
os.Exit(1)
}
fmt.Printf("✓ Heartbeat sent: %s -> %s\n", opts.ID, opts.GatewayBase)
}
func parseNodeRegisterArgs(args []string, cfg *config.Config) (nodeRegisterOptions, error) {
host, _ := os.Hostname()
host = strings.TrimSpace(host)
if host == "" {
host = "node"
}
opts := nodeRegisterOptions{
GatewayBase: defaultGatewayBase(cfg),
Token: defaultGatewayToken(cfg),
ID: host,
Name: host,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
Version: version,
HeartbeatSec: 30,
Capabilities: capabilitiesFromCSV("run,invoke,model"),
}
for i := 0; i < len(args); i++ {
arg := strings.TrimSpace(args[i])
next := func() (string, error) {
if i+1 >= len(args) {
return "", fmt.Errorf("missing value for %s", arg)
}
i++
return strings.TrimSpace(args[i]), nil
}
switch arg {
case "--gateway":
v, err := next()
if err != nil {
return opts, err
}
opts.GatewayBase = v
case "--token":
v, err := next()
if err != nil {
return opts, err
}
opts.Token = v
case "--id":
v, err := next()
if err != nil {
return opts, err
}
opts.ID = v
case "--name":
v, err := next()
if err != nil {
return opts, err
}
opts.Name = v
case "--endpoint":
v, err := next()
if err != nil {
return opts, err
}
opts.Endpoint = v
case "--os":
v, err := next()
if err != nil {
return opts, err
}
opts.OS = v
case "--arch":
v, err := next()
if err != nil {
return opts, err
}
opts.Arch = v
case "--version":
v, err := next()
if err != nil {
return opts, err
}
opts.Version = v
case "--actions":
v, err := next()
if err != nil {
return opts, err
}
opts.Actions = splitCSV(v)
case "--models":
v, err := next()
if err != nil {
return opts, err
}
opts.Models = splitCSV(v)
case "--capabilities":
v, err := next()
if err != nil {
return opts, err
}
opts.Capabilities = capabilitiesFromCSV(v)
case "--watch":
opts.Watch = true
case "--heartbeat-sec":
v, err := next()
if err != nil {
return opts, err
}
n, convErr := strconv.Atoi(v)
if convErr != nil || n <= 0 {
return opts, fmt.Errorf("invalid --heartbeat-sec: %s", v)
}
opts.HeartbeatSec = n
default:
return opts, fmt.Errorf("unknown option: %s", arg)
}
}
if strings.TrimSpace(opts.GatewayBase) == "" {
return opts, fmt.Errorf("--gateway is required")
}
if strings.TrimSpace(opts.ID) == "" {
return opts, fmt.Errorf("--id is required")
}
opts.GatewayBase = normalizeGatewayBase(opts.GatewayBase)
return opts, nil
}
func parseNodeHeartbeatArgs(args []string, cfg *config.Config) (nodeHeartbeatOptions, error) {
host, _ := os.Hostname()
host = strings.TrimSpace(host)
if host == "" {
host = "node"
}
opts := nodeHeartbeatOptions{
GatewayBase: defaultGatewayBase(cfg),
Token: defaultGatewayToken(cfg),
ID: host,
}
for i := 0; i < len(args); i++ {
arg := strings.TrimSpace(args[i])
next := func() (string, error) {
if i+1 >= len(args) {
return "", fmt.Errorf("missing value for %s", arg)
}
i++
return strings.TrimSpace(args[i]), nil
}
switch arg {
case "--gateway":
v, err := next()
if err != nil {
return opts, err
}
opts.GatewayBase = v
case "--token":
v, err := next()
if err != nil {
return opts, err
}
opts.Token = v
case "--id":
v, err := next()
if err != nil {
return opts, err
}
opts.ID = v
default:
return opts, fmt.Errorf("unknown option: %s", arg)
}
}
if strings.TrimSpace(opts.GatewayBase) == "" {
return opts, fmt.Errorf("--gateway is required")
}
if strings.TrimSpace(opts.ID) == "" {
return opts, fmt.Errorf("--id is required")
}
opts.GatewayBase = normalizeGatewayBase(opts.GatewayBase)
return opts, nil
}
func buildNodeInfo(opts nodeRegisterOptions) nodes.NodeInfo {
return nodes.NodeInfo{
ID: strings.TrimSpace(opts.ID),
Name: strings.TrimSpace(opts.Name),
OS: strings.TrimSpace(opts.OS),
Arch: strings.TrimSpace(opts.Arch),
Version: strings.TrimSpace(opts.Version),
Endpoint: strings.TrimSpace(opts.Endpoint),
Capabilities: opts.Capabilities,
Actions: append([]string(nil), opts.Actions...),
Models: append([]string(nil), opts.Models...),
}
}
func runNodeHeartbeatLoop(client *http.Client, opts nodeRegisterOptions, info nodes.NodeInfo) error {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()
for {
if err := runNodeHeartbeatSocket(ctx, opts, info); err != nil {
if ctx.Err() != nil {
fmt.Println("✓ Node heartbeat stopped")
return nil
}
fmt.Printf("Warning: node socket closed for %s: %v\n", info.ID, err)
}
if ctx.Err() != nil {
fmt.Println("✓ Node heartbeat stopped")
return nil
}
if regErr := postNodeRegister(ctx, client, opts.GatewayBase, opts.Token, info); regErr != nil {
fmt.Printf("Warning: re-register failed for %s: %v\n", info.ID, regErr)
} else {
fmt.Printf("✓ Node re-registered: %s\n", info.ID)
}
select {
case <-ctx.Done():
fmt.Println("✓ Node heartbeat stopped")
return nil
case <-time.After(2 * time.Second):
}
}
}
func runNodeHeartbeatSocket(ctx context.Context, opts nodeRegisterOptions, info nodes.NodeInfo) error {
wsURL := nodeWebsocketURL(opts.GatewayBase)
headers := http.Header{}
if strings.TrimSpace(opts.Token) != "" {
headers.Set("Authorization", "Bearer "+strings.TrimSpace(opts.Token))
}
conn, _, err := websocket.DefaultDialer.DialContext(ctx, wsURL, headers)
if err != nil {
return err
}
defer conn.Close()
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteJSON(nodes.WireMessage{Type: "register", Node: &info}); err != nil {
return err
}
if err := readNodeAck(conn, "registered", info.ID); err != nil {
return err
}
fmt.Printf("✓ Node socket connected: %s\n", info.ID)
ticker := time.NewTicker(time.Duration(opts.HeartbeatSec) * time.Second)
pingTicker := time.NewTicker(nodeSocketPingInterval(opts.HeartbeatSec))
defer ticker.Stop()
defer pingTicker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-pingTicker.C:
if err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(10*time.Second)); err != nil {
return err
}
case <-ticker.C:
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteJSON(nodes.WireMessage{Type: "heartbeat", ID: info.ID}); err != nil {
return err
}
if err := readNodeAck(conn, "heartbeat", info.ID); err != nil {
return err
}
fmt.Printf("✓ Heartbeat ok: %s\n", info.ID)
}
}
}
func nodeSocketPingInterval(heartbeatSec int) time.Duration {
if heartbeatSec <= 0 {
return 25 * time.Second
}
interval := time.Duration(heartbeatSec) * time.Second / 2
if interval < 10*time.Second {
interval = 10 * time.Second
}
if interval > 25*time.Second {
interval = 25 * time.Second
}
return interval
}
func readNodeAck(conn *websocket.Conn, expectedType, id string) error {
_ = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
var ack nodes.WireAck
if err := conn.ReadJSON(&ack); err != nil {
return err
}
if !ack.OK {
if strings.TrimSpace(ack.Error) == "" {
ack.Error = "unknown websocket error"
}
return fmt.Errorf("%s", ack.Error)
}
ackType := strings.ToLower(strings.TrimSpace(ack.Type))
if expectedType != "" && ackType != strings.ToLower(strings.TrimSpace(expectedType)) {
return fmt.Errorf("unexpected websocket ack type: %s", ack.Type)
}
if strings.TrimSpace(id) != "" && strings.TrimSpace(ack.ID) != "" && strings.TrimSpace(ack.ID) != strings.TrimSpace(id) {
return fmt.Errorf("unexpected websocket ack id: %s", ack.ID)
}
return nil
}
func postNodeRegister(ctx context.Context, client *http.Client, gatewayBase, token string, info nodes.NodeInfo) error {
return postNodeJSON(ctx, client, gatewayBase, token, "/nodes/register", info)
}
func postNodeHeartbeat(ctx context.Context, client *http.Client, gatewayBase, token, id string) error {
return postNodeJSON(ctx, client, gatewayBase, token, "/nodes/heartbeat", map[string]string{"id": strings.TrimSpace(id)})
}
func postNodeJSON(ctx context.Context, client *http.Client, gatewayBase, token, path string, payload interface{}) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(gatewayBase, "/")+path, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
var out bytes.Buffer
_, _ = out.ReadFrom(resp.Body)
msg := strings.TrimSpace(out.String())
if msg == "" {
msg = resp.Status
}
return fmt.Errorf("http %d: %s", resp.StatusCode, msg)
}
return nil
}
func defaultGatewayBase(cfg *config.Config) string {
if raw := strings.TrimSpace(os.Getenv("CLAWGO_GATEWAY_URL")); raw != "" {
return normalizeGatewayBase(raw)
}
host := "127.0.0.1"
port := 18790
if cfg != nil {
if v := strings.TrimSpace(cfg.Gateway.Host); v != "" && v != "0.0.0.0" && v != "::" {
host = v
}
if cfg.Gateway.Port > 0 {
port = cfg.Gateway.Port
}
}
return fmt.Sprintf("http://%s:%d", host, port)
}
func defaultGatewayToken(cfg *config.Config) string {
if v := strings.TrimSpace(os.Getenv("CLAWGO_GATEWAY_TOKEN")); v != "" {
return v
}
if cfg == nil {
return ""
}
return strings.TrimSpace(cfg.Gateway.Token)
}
func normalizeGatewayBase(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
if !strings.HasPrefix(raw, "http://") && !strings.HasPrefix(raw, "https://") {
raw = "http://" + raw
}
return strings.TrimRight(raw, "/")
}
func nodeWebsocketURL(gatewayBase string) string {
base := normalizeGatewayBase(gatewayBase)
base = strings.TrimPrefix(base, "http://")
base = strings.TrimPrefix(base, "https://")
scheme := "ws://"
if strings.HasPrefix(strings.TrimSpace(gatewayBase), "https://") {
scheme = "wss://"
}
return scheme + base + "/nodes/connect"
}
func splitCSV(raw string) []string {
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
seen := map[string]bool{}
for _, part := range parts {
item := strings.TrimSpace(part)
if item == "" || seen[item] {
continue
}
seen[item] = true
out = append(out, item)
}
return out
}
func capabilitiesFromCSV(raw string) nodes.Capabilities {
caps := nodes.Capabilities{}
for _, item := range splitCSV(raw) {
switch strings.ToLower(item) {
case "run":
caps.Run = true
case "invoke":
caps.Invoke = true
case "model", "agent_task":
caps.Model = true
case "camera":
caps.Camera = true
case "screen":
caps.Screen = true
case "location":
caps.Location = true
case "canvas":
caps.Canvas = true
}
}
return caps
}

132
cmd/clawgo/cmd_node_test.go Normal file
View File

@@ -0,0 +1,132 @@
package main
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"clawgo/pkg/config"
"clawgo/pkg/nodes"
)
func TestParseNodeRegisterArgsDefaults(t *testing.T) {
t.Parallel()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "gateway.example"
cfg.Gateway.Port = 7788
cfg.Gateway.Token = "cfg-token"
opts, err := parseNodeRegisterArgs([]string{"--id", "edge-dev"}, cfg)
if err != nil {
t.Fatalf("parseNodeRegisterArgs failed: %v", err)
}
if opts.GatewayBase != "http://gateway.example:7788" {
t.Fatalf("unexpected gateway base: %s", opts.GatewayBase)
}
if opts.Token != "cfg-token" {
t.Fatalf("unexpected token: %s", opts.Token)
}
if opts.ID != "edge-dev" {
t.Fatalf("unexpected id: %s", opts.ID)
}
if !opts.Capabilities.Run || !opts.Capabilities.Invoke || !opts.Capabilities.Model {
t.Fatalf("expected default run/invoke/model capabilities, got %+v", opts.Capabilities)
}
}
func TestPostNodeRegisterSendsNodeInfo(t *testing.T) {
t.Parallel()
var gotAuth string
var got nodes.NodeInfo
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/nodes/register" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
gotAuth = r.Header.Get("Authorization")
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
t.Fatalf("decode body: %v", err)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
info := nodes.NodeInfo{
ID: "edge-dev",
Name: "Edge Dev",
Endpoint: "http://edge.example:18790",
Capabilities: nodes.Capabilities{
Run: true, Invoke: true, Model: true,
},
Actions: []string{"run", "agent_task"},
Models: []string{"gpt-4o-mini"},
}
client := &http.Client{Timeout: 2 * time.Second}
if err := postNodeRegister(context.Background(), client, srv.URL, "secret", info); err != nil {
t.Fatalf("postNodeRegister failed: %v", err)
}
if gotAuth != "Bearer secret" {
t.Fatalf("unexpected auth header: %s", gotAuth)
}
if got.ID != "edge-dev" || got.Endpoint != "http://edge.example:18790" {
t.Fatalf("unexpected node payload: %+v", got)
}
if len(got.Actions) != 2 || got.Actions[1] != "agent_task" {
t.Fatalf("unexpected actions: %+v", got.Actions)
}
}
func TestPostNodeHeartbeatSendsNodeID(t *testing.T) {
t.Parallel()
var body map[string]string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/nodes/heartbeat" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode body: %v", err)
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
client := &http.Client{Timeout: 2 * time.Second}
if err := postNodeHeartbeat(context.Background(), client, srv.URL, "", "edge-dev"); err != nil {
t.Fatalf("postNodeHeartbeat failed: %v", err)
}
if strings.TrimSpace(body["id"]) != "edge-dev" {
t.Fatalf("unexpected heartbeat body: %+v", body)
}
}
func TestNodeWebsocketURL(t *testing.T) {
t.Parallel()
if got := nodeWebsocketURL("http://gateway.example:18790"); got != "ws://gateway.example:18790/nodes/connect" {
t.Fatalf("unexpected ws url: %s", got)
}
if got := nodeWebsocketURL("https://gateway.example"); got != "wss://gateway.example/nodes/connect" {
t.Fatalf("unexpected wss url: %s", got)
}
}
func TestNodeSocketPingInterval(t *testing.T) {
t.Parallel()
if got := nodeSocketPingInterval(120); got != 25*time.Second {
t.Fatalf("expected 25s cap, got %s", got)
}
if got := nodeSocketPingInterval(20); got != 10*time.Second {
t.Fatalf("expected 10s floor, got %s", got)
}
if got := nodeSocketPingInterval(30); got != 15*time.Second {
t.Fatalf("expected half heartbeat, got %s", got)
}
}

View File

@@ -67,6 +67,8 @@ func main() {
cronCmd()
case "channel":
channelCmd()
case "node":
nodeCmd()
case "skills":
skillsCmd()
case "version", "--version", "-v":