From dd705e5e93f6a33120e8e44579efc62782b13da6 Mon Sep 17 00:00:00 2001 From: DBT Date: Tue, 24 Feb 2026 09:43:40 +0000 Subject: [PATCH] apply go-level optimizations for lock-free snapshots, pooling, worker flow and typed errors --- pkg/channels/manager.go | 40 +++++++++++++++++++------------- pkg/tools/errors.go | 8 +++++++ pkg/tools/message.go | 23 +++++++++++++------ pkg/tools/registry.go | 51 ++++++++++++++++++++--------------------- 4 files changed, 73 insertions(+), 49 deletions(-) create mode 100644 pkg/tools/errors.go diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index ee29c0e..c4de86d 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -11,6 +11,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "clawgo/pkg/bus" "clawgo/pkg/config" @@ -27,6 +28,7 @@ type Manager struct { dispatchSem chan struct{} outboundLimit *rate.Limiter mu sync.RWMutex + snapshot atomic.Value // map[string]Channel } type asyncTask struct { @@ -42,6 +44,7 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error dispatchSem: make(chan struct{}, 32), outboundLimit: rate.NewLimiter(rate.Limit(40), 80), } + m.snapshot.Store(map[string]Channel{}) if err := m.initChannels(); err != nil { return nil, err @@ -159,10 +162,19 @@ func (m *Manager) initChannels() error { logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{ "enabled_channels": len(m.channels), }) + m.refreshSnapshot() return nil } +func (m *Manager) refreshSnapshot() { + next := make(map[string]Channel, len(m.channels)) + for k, v := range m.channels { + next[k] = v + } + m.snapshot.Store(next) +} + func (m *Manager) StartAll(ctx context.Context) error { m.mu.Lock() if len(m.channels) == 0 { @@ -276,9 +288,8 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { return } - m.mu.RLock() - channel, exists := m.channels[msg.Channel] - m.mu.RUnlock() + cur, _ := m.snapshot.Load().(map[string]Channel) + channel, exists := cur[msg.Channel] if !exists { logger.WarnCF("channels", "Unknown channel for outbound message", map[string]interface{}{ @@ -323,29 +334,24 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { } func (m *Manager) GetChannel(name string) (Channel, bool) { - m.mu.RLock() - defer m.mu.RUnlock() - channel, ok := m.channels[name] + cur, _ := m.snapshot.Load().(map[string]Channel) + channel, ok := cur[name] return channel, ok } func (m *Manager) GetStatus() map[string]interface{} { - m.mu.RLock() - defer m.mu.RUnlock() - - status := make(map[string]interface{}) - for name := range m.channels { + cur, _ := m.snapshot.Load().(map[string]Channel) + status := make(map[string]interface{}, len(cur)) + for name := range cur { status[name] = map[string]interface{}{} } return status } func (m *Manager) GetEnabledChannels() []string { - m.mu.RLock() - defer m.mu.RUnlock() - - names := make([]string, 0, len(m.channels)) - for name := range m.channels { + cur, _ := m.snapshot.Load().(map[string]Channel) + names := make([]string, 0, len(cur)) + for name := range cur { names = append(names, name) } return names @@ -355,12 +361,14 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { m.mu.Lock() defer m.mu.Unlock() m.channels[name] = channel + m.refreshSnapshot() } func (m *Manager) UnregisterChannel(name string) { m.mu.Lock() defer m.mu.Unlock() delete(m.channels, name) + m.refreshSnapshot() } func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { diff --git a/pkg/tools/errors.go b/pkg/tools/errors.go new file mode 100644 index 0000000..b44be91 --- /dev/null +++ b/pkg/tools/errors.go @@ -0,0 +1,8 @@ +package tools + +import "errors" + +var ( + ErrUnsupportedAction = errors.New("unsupported action") + ErrMissingField = errors.New("missing required field") +) diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 0f5c003..9b23328 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -3,6 +3,9 @@ package tools import ( "context" "fmt" + "strings" + "sync" + "clawgo/pkg/bus" ) @@ -14,6 +17,8 @@ type MessageTool struct { defaultChatID string } +var buttonRowPool = sync.Pool{New: func() interface{} { return make([]bus.Button, 0, 8) }} + func NewMessageTool() *MessageTool { return &MessageTool{} } @@ -93,6 +98,7 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) { func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { action, _ := args["action"].(string) + action = strings.ToLower(strings.TrimSpace(action)) if action == "" { action = "send" } @@ -106,22 +112,22 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) switch action { case "send": if content == "" { - return "", fmt.Errorf("message/content is required for action=send") + return "", fmt.Errorf("%w: message/content for action=send", ErrMissingField) } case "edit": if messageID == "" || content == "" { - return "", fmt.Errorf("message_id and message/content are required for action=edit") + return "", fmt.Errorf("%w: message_id and message/content for action=edit", ErrMissingField) } case "delete": if messageID == "" { - return "", fmt.Errorf("message_id is required for action=delete") + return "", fmt.Errorf("%w: message_id for action=delete", ErrMissingField) } case "react": if messageID == "" || emoji == "" { - return "", fmt.Errorf("message_id and emoji are required for action=react") + return "", fmt.Errorf("%w: message_id and emoji for action=react", ErrMissingField) } default: - return fmt.Sprintf("Unsupported action: %s", action), nil + return "", fmt.Errorf("%w: %s", ErrUnsupportedAction, action) } channel, _ := args["channel"].(string) @@ -149,7 +155,8 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) if btns, ok := args["buttons"].([]interface{}); ok { for _, row := range btns { if rowArr, ok := row.([]interface{}); ok { - var buttonRow []bus.Button + pooled := buttonRowPool.Get().([]bus.Button) + buttonRow := pooled[:0] for _, b := range rowArr { if bMap, ok := b.(map[string]interface{}); ok { text, _ := bMap["text"].(string) @@ -160,8 +167,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) } } if len(buttonRow) > 0 { - buttons = append(buttons, buttonRow) + copied := append([]bus.Button(nil), buttonRow...) + buttons = append(buttons, copied) } + buttonRowPool.Put(buttonRow[:0]) } } } diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 53bcb51..59e23fe 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -4,32 +4,38 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" "clawgo/pkg/logger" ) type ToolRegistry struct { - tools map[string]Tool - mu sync.RWMutex + tools map[string]Tool + mu sync.RWMutex + snapshot atomic.Value // map[string]Tool (copy-on-write) } func NewToolRegistry() *ToolRegistry { - return &ToolRegistry{ - tools: make(map[string]Tool), - } + r := &ToolRegistry{tools: make(map[string]Tool)} + r.snapshot.Store(map[string]Tool{}) + return r } func (r *ToolRegistry) Register(tool Tool) { r.mu.Lock() defer r.mu.Unlock() r.tools[tool.Name()] = tool + next := make(map[string]Tool, len(r.tools)) + for k, v := range r.tools { + next[k] = v + } + r.snapshot.Store(next) } func (r *ToolRegistry) Get(name string) (Tool, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - tool, ok := r.tools[name] + cur, _ := r.snapshot.Load().(map[string]Tool) + tool, ok := cur[name] return tool, ok } @@ -73,11 +79,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string } func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { - r.mu.RLock() - defer r.mu.RUnlock() - - definitions := make([]map[string]interface{}, 0, len(r.tools)) - for _, tool := range r.tools { + cur, _ := r.snapshot.Load().(map[string]Tool) + definitions := make([]map[string]interface{}, 0, len(cur)) + for _, tool := range cur { definitions = append(definitions, ToolToSchema(tool)) } return definitions @@ -85,11 +89,9 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { // List returns a list of all registered tool names. func (r *ToolRegistry) List() []string { - r.mu.RLock() - defer r.mu.RUnlock() - - names := make([]string, 0, len(r.tools)) - for name := range r.tools { + cur, _ := r.snapshot.Load().(map[string]Tool) + names := make([]string, 0, len(cur)) + for name := range cur { names = append(names, name) } return names @@ -97,19 +99,16 @@ func (r *ToolRegistry) List() []string { // Count returns the number of registered tools. func (r *ToolRegistry) Count() int { - r.mu.RLock() - defer r.mu.RUnlock() - return len(r.tools) + cur, _ := r.snapshot.Load().(map[string]Tool) + return len(cur) } // GetSummaries returns human-readable summaries of all registered tools. // Returns a slice of "name - description" strings. func (r *ToolRegistry) GetSummaries() []string { - r.mu.RLock() - defer r.mu.RUnlock() - - summaries := make([]string, 0, len(r.tools)) - for _, tool := range r.tools { + cur, _ := r.snapshot.Load().(map[string]Tool) + summaries := make([]string, 0, len(cur)) + for _, tool := range cur { summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description())) } return summaries