mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 21:57:29 +08:00
agent: implement session-sharded concurrent dispatcher
This commit is contained in:
@@ -10,9 +10,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -247,34 +250,26 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running = true
|
||||
|
||||
shards := al.buildSessionShards(ctx)
|
||||
defer func() {
|
||||
for _, ch := range shards {
|
||||
close(ch)
|
||||
}
|
||||
}()
|
||||
|
||||
for al.running {
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
idx := sessionShardIndex(msg.SessionKey, len(shards))
|
||||
select {
|
||||
case shards[idx] <- msg:
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
}
|
||||
|
||||
trigger := al.getTrigger(msg)
|
||||
suppressed := false
|
||||
if response != "" {
|
||||
if outbound, ok := al.prepareOutbound(msg, response); ok {
|
||||
al.bus.PublishOutbound(outbound)
|
||||
} else {
|
||||
suppressed = true
|
||||
}
|
||||
}
|
||||
al.audit.Record(trigger, msg.Channel, msg.SessionKey, suppressed, err)
|
||||
if suppressed {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,6 +280,71 @@ func (al *AgentLoop) Stop() {
|
||||
al.running = false
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundMessage {
|
||||
count := sessionShardCount()
|
||||
shards := make([]chan bus.InboundMessage, count)
|
||||
for i := 0; i < count; i++ {
|
||||
shards[i] = make(chan bus.InboundMessage, 64)
|
||||
go func(ch <-chan bus.InboundMessage) {
|
||||
for msg := range ch {
|
||||
al.processInbound(ctx, msg)
|
||||
}
|
||||
}(shards[i])
|
||||
}
|
||||
logger.InfoCF("agent", "Session-sharded dispatcher enabled", map[string]interface{}{"shards": count})
|
||||
return shards
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processInbound(ctx context.Context, msg bus.InboundMessage) {
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
}
|
||||
|
||||
trigger := al.getTrigger(msg)
|
||||
suppressed := false
|
||||
if response != "" {
|
||||
if outbound, ok := al.prepareOutbound(msg, response); ok {
|
||||
al.bus.PublishOutbound(outbound)
|
||||
} else {
|
||||
suppressed = true
|
||||
}
|
||||
}
|
||||
al.audit.Record(trigger, msg.Channel, msg.SessionKey, suppressed, err)
|
||||
}
|
||||
|
||||
func sessionShardCount() int {
|
||||
if v := strings.TrimSpace(os.Getenv("CLAWGO_SESSION_SHARDS")); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
if n > 64 {
|
||||
return 64
|
||||
}
|
||||
return n
|
||||
}
|
||||
}
|
||||
n := runtime.NumCPU()
|
||||
if n < 2 {
|
||||
n = 2
|
||||
}
|
||||
if n > 16 {
|
||||
n = 16
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func sessionShardIndex(sessionKey string, shardCount int) int {
|
||||
if shardCount <= 1 {
|
||||
return 0
|
||||
}
|
||||
key := strings.TrimSpace(sessionKey)
|
||||
if key == "" {
|
||||
key = "default"
|
||||
}
|
||||
h := fnv.New32a()
|
||||
_, _ = h.Write([]byte(key))
|
||||
return int(h.Sum32() % uint32(shardCount))
|
||||
}
|
||||
|
||||
func (al *AgentLoop) getTrigger(msg bus.InboundMessage) string {
|
||||
if msg.Metadata != nil {
|
||||
if t := strings.TrimSpace(msg.Metadata["trigger"]); t != "" {
|
||||
|
||||
Reference in New Issue
Block a user