mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-05-06 18:07:28 +08:00
optimize channel orchestration with errgroup and rate limiter
This commit is contained in:
@@ -15,15 +15,18 @@ import (
|
||||
"clawgo/pkg/bus"
|
||||
"clawgo/pkg/config"
|
||||
"clawgo/pkg/logger"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
channels map[string]Channel
|
||||
bus *bus.MessageBus
|
||||
config *config.Config
|
||||
dispatchTask *asyncTask
|
||||
dispatchSem chan struct{}
|
||||
mu sync.RWMutex
|
||||
channels map[string]Channel
|
||||
bus *bus.MessageBus
|
||||
config *config.Config
|
||||
dispatchTask *asyncTask
|
||||
dispatchSem chan struct{}
|
||||
outboundLimit *rate.Limiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type asyncTask struct {
|
||||
@@ -36,7 +39,8 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error
|
||||
bus: messageBus,
|
||||
config: cfg,
|
||||
// Limit concurrent outbound sends to avoid unbounded goroutine growth.
|
||||
dispatchSem: make(chan struct{}, 32),
|
||||
dispatchSem: make(chan struct{}, 32),
|
||||
outboundLimit: rate.NewLimiter(rate.Limit(40), 80),
|
||||
}
|
||||
|
||||
if err := m.initChannels(); err != nil {
|
||||
@@ -161,59 +165,73 @@ func (m *Manager) initChannels() error {
|
||||
|
||||
func (m *Manager) StartAll(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if len(m.channels) == 0 {
|
||||
m.mu.Unlock()
|
||||
logger.WarnC("channels", "No channels enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "Starting all channels")
|
||||
|
||||
channelsSnapshot := make(map[string]Channel, len(m.channels))
|
||||
for k, v := range m.channels {
|
||||
channelsSnapshot[k] = v
|
||||
}
|
||||
dispatchCtx, cancel := context.WithCancel(ctx)
|
||||
m.dispatchTask = &asyncTask{cancel: cancel}
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.InfoC("channels", "Starting all channels")
|
||||
go m.dispatchOutbound(dispatchCtx)
|
||||
|
||||
for name, channel := range m.channels {
|
||||
logger.InfoCF("channels", "Starting channel", map[string]interface{}{
|
||||
logger.FieldChannel: name,
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
for name, channel := range channelsSnapshot {
|
||||
name := name
|
||||
channel := channel
|
||||
g.Go(func() error {
|
||||
logger.InfoCF("channels", "Starting channel", map[string]interface{}{logger.FieldChannel: name})
|
||||
if err := channel.Start(gctx); err != nil {
|
||||
logger.ErrorCF("channels", "Failed to start channel", map[string]interface{}{logger.FieldChannel: name, logger.FieldError: err.Error()})
|
||||
return fmt.Errorf("%s: %w", name, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := channel.Start(ctx); err != nil {
|
||||
logger.ErrorCF("channels", "Failed to start channel", map[string]interface{}{
|
||||
logger.FieldChannel: name,
|
||||
logger.FieldError: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.InfoC("channels", "All channels started")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) StopAll(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
channelsSnapshot := make(map[string]Channel, len(m.channels))
|
||||
for k, v := range m.channels {
|
||||
channelsSnapshot[k] = v
|
||||
}
|
||||
task := m.dispatchTask
|
||||
m.dispatchTask = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.InfoC("channels", "Stopping all channels")
|
||||
|
||||
if m.dispatchTask != nil {
|
||||
m.dispatchTask.cancel()
|
||||
m.dispatchTask = nil
|
||||
if task != nil {
|
||||
task.cancel()
|
||||
}
|
||||
|
||||
for name, channel := range m.channels {
|
||||
logger.InfoCF("channels", "Stopping channel", map[string]interface{}{
|
||||
logger.FieldChannel: name,
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
for name, channel := range channelsSnapshot {
|
||||
name := name
|
||||
channel := channel
|
||||
g.Go(func() error {
|
||||
logger.InfoCF("channels", "Stopping channel", map[string]interface{}{logger.FieldChannel: name})
|
||||
if err := channel.Stop(gctx); err != nil {
|
||||
logger.ErrorCF("channels", "Error stopping channel", map[string]interface{}{logger.FieldChannel: name, logger.FieldError: err.Error()})
|
||||
return fmt.Errorf("%s: %w", name, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := channel.Stop(ctx); err != nil {
|
||||
logger.ErrorCF("channels", "Error stopping channel", map[string]interface{}{
|
||||
logger.FieldChannel: name,
|
||||
logger.FieldError: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.InfoC("channels", "All channels stopped")
|
||||
return nil
|
||||
}
|
||||
@@ -283,6 +301,12 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if m.outboundLimit != nil {
|
||||
if err := m.outboundLimit.Wait(ctx); err != nil {
|
||||
logger.WarnCF("channels", "Outbound rate limiter canceled", map[string]interface{}{logger.FieldError: err.Error()})
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Bound fan-out concurrency to prevent goroutine explosion under burst traffic.
|
||||
m.dispatchSem <- struct{}{}
|
||||
go func(c Channel, outbound bus.OutboundMessage) {
|
||||
|
||||
Reference in New Issue
Block a user