diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 06bc61b..7b4d956 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,9 +10,12 @@ import ( "context" "encoding/json" "fmt" + "hash/fnv" "os" "path/filepath" "regexp" + "runtime" + "strconv" "strings" "sync" @@ -247,34 +250,26 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers func (al *AgentLoop) Run(ctx context.Context) error { al.running = true + shards := al.buildSessionShards(ctx) + defer func() { + for _, ch := range shards { + close(ch) + } + }() + for al.running { + msg, ok := al.bus.ConsumeInbound(ctx) + if !ok { + if ctx.Err() != nil { + return nil + } + continue + } + idx := sessionShardIndex(msg.SessionKey, len(shards)) select { + case shards[idx] <- msg: case <-ctx.Done(): return nil - default: - msg, ok := al.bus.ConsumeInbound(ctx) - if !ok { - continue - } - - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } - - trigger := al.getTrigger(msg) - suppressed := false - if response != "" { - if outbound, ok := al.prepareOutbound(msg, response); ok { - al.bus.PublishOutbound(outbound) - } else { - suppressed = true - } - } - al.audit.Record(trigger, msg.Channel, msg.SessionKey, suppressed, err) - if suppressed { - continue - } } } @@ -285,6 +280,71 @@ func (al *AgentLoop) Stop() { al.running = false } +func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundMessage { + count := sessionShardCount() + shards := make([]chan bus.InboundMessage, count) + for i := 0; i < count; i++ { + shards[i] = make(chan bus.InboundMessage, 64) + go func(ch <-chan bus.InboundMessage) { + for msg := range ch { + al.processInbound(ctx, msg) + } + }(shards[i]) + } + logger.InfoCF("agent", "Session-sharded dispatcher enabled", map[string]interface{}{"shards": count}) + return shards +} + +func (al *AgentLoop) processInbound(ctx context.Context, msg bus.InboundMessage) { + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) + } + + trigger := al.getTrigger(msg) + suppressed := false + if response != "" { + if outbound, ok := al.prepareOutbound(msg, response); ok { + al.bus.PublishOutbound(outbound) + } else { + suppressed = true + } + } + al.audit.Record(trigger, msg.Channel, msg.SessionKey, suppressed, err) +} + +func sessionShardCount() int { + if v := strings.TrimSpace(os.Getenv("CLAWGO_SESSION_SHARDS")); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + if n > 64 { + return 64 + } + return n + } + } + n := runtime.NumCPU() + if n < 2 { + n = 2 + } + if n > 16 { + n = 16 + } + return n +} + +func sessionShardIndex(sessionKey string, shardCount int) int { + if shardCount <= 1 { + return 0 + } + key := strings.TrimSpace(sessionKey) + if key == "" { + key = "default" + } + h := fnv.New32a() + _, _ = h.Write([]byte(key)) + return int(h.Sum32() % uint32(shardCount)) +} + func (al *AgentLoop) getTrigger(msg bus.InboundMessage) string { if msg.Metadata != nil { if t := strings.TrimSpace(msg.Metadata["trigger"]); t != "" {