feat: harden concurrency scheduling and task watchdog

This commit is contained in:
lpf
2026-03-05 11:32:06 +08:00
parent 0f3196f305
commit 2fbb98bccd
20 changed files with 1526 additions and 159 deletions

View File

@@ -734,23 +734,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return fmt.Sprintf(tpl, lang), nil
}
// Update tool contexts
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
mt.SetContext(msg.Channel, msg.ChatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
st.SetContext(msg.Channel, msg.ChatID)
}
}
if tool, ok := al.tools.Get("remind"); ok {
if rt, ok := tool.(*tools.RemindTool); ok {
rt.SetContext(msg.Channel, msg.ChatID)
}
}
history := al.sessions.GetHistory(msg.SessionKey)
summary := al.sessions.GetSummary(msg.SessionKey)
memoryRecallUsed := false
@@ -948,7 +931,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
"iteration": iteration,
})
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
execArgs := withToolContextArgs(tc.Name, tc.Arguments, msg.Channel, msg.ChatID)
result, err := al.tools.Execute(ctx, tc.Name, execArgs)
if err != nil {
result = fmt.Sprintf("Error: %v", err)
}
@@ -1168,18 +1152,6 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
// Use the origin session for context
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
// Update tool contexts to original channel/chatID
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
mt.SetContext(originChannel, originChatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
st.SetContext(originChannel, originChatID)
}
}
// Build messages with the announce content
history := al.sessions.GetHistory(sessionKey)
summary := al.sessions.GetSummary(sessionKey)
@@ -1273,7 +1245,8 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
al.sessions.AddMessageFull(sessionKey, assistantMsg)
for _, tc := range response.ToolCalls {
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
execArgs := withToolContextArgs(tc.Name, tc.Arguments, originChannel, originChatID)
result, err := al.tools.Execute(ctx, tc.Name, execArgs)
if err != nil {
result = fmt.Sprintf("Error: %v", err)
}
@@ -1657,6 +1630,42 @@ func truncateString(s string, maxLen int) string {
return s[:maxLen-3] + "..."
}
func withToolContextArgs(toolName string, args map[string]interface{}, channel, chatID string) map[string]interface{} {
if channel == "" || chatID == "" {
return args
}
switch toolName {
case "message", "spawn", "remind":
default:
return args
}
next := make(map[string]interface{}, len(args)+2)
for k, v := range args {
next[k] = v
}
if toolName == "message" {
if _, ok := next["channel"]; !ok {
next["channel"] = channel
}
if _, hasChat := next["chat_id"]; !hasChat {
if _, hasTo := next["to"]; !hasTo {
next["chat_id"] = chatID
}
}
return next
}
if _, ok := next["channel"]; !ok {
next["channel"] = channel
}
if _, ok := next["chat_id"]; !ok {
next["chat_id"] = chatID
}
return next
}
func shouldRecallMemory(text string, keywords []string) bool {
s := strings.ToLower(strings.TrimSpace(text))
if s == "" {

View File

@@ -0,0 +1,36 @@
package agent
import "testing"
func TestWithToolContextArgsInjectsDefaults(t *testing.T) {
args := map[string]interface{}{"message": "hello"}
got := withToolContextArgs("message", args, "telegram", "chat-1")
if got["channel"] != "telegram" {
t.Fatalf("expected channel injected, got %v", got["channel"])
}
if got["chat_id"] != "chat-1" {
t.Fatalf("expected chat_id injected, got %v", got["chat_id"])
}
}
func TestWithToolContextArgsPreservesExplicitTarget(t *testing.T) {
args := map[string]interface{}{"message": "hello", "to": "target-2"}
got := withToolContextArgs("message", args, "telegram", "chat-1")
if _, ok := got["chat_id"]; ok {
t.Fatalf("chat_id should not be injected when 'to' is provided")
}
if got["to"] != "target-2" {
t.Fatalf("expected to preserved, got %v", got["to"])
}
}
func TestWithToolContextArgsSkipsUnrelatedTools(t *testing.T) {
args := map[string]interface{}{"query": "x"}
got := withToolContextArgs("memory_search", args, "telegram", "chat-1")
if len(got) != len(args) {
t.Fatalf("expected unchanged args for unrelated tool")
}
if _, ok := got["channel"]; ok {
t.Fatalf("unexpected channel key for unrelated tool")
}
}

View File

@@ -8,7 +8,6 @@ import (
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
@@ -140,11 +139,10 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage
res.ErrText = err.Error()
}
results[index] = res
al.publishPlannedTaskProgress(msg, len(tasks), res)
}(i, task)
}
wg.Wait()
sort.SliceStable(results, func(i, j int) bool { return results[i].Task.Index < results[j].Task.Index })
var b strings.Builder
b.WriteString(fmt.Sprintf("已自动拆解为 %d 个任务并执行:\n\n", len(results)))
for _, r := range results {
@@ -162,6 +160,35 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage
return strings.TrimSpace(b.String()), nil
}
func (al *AgentLoop) publishPlannedTaskProgress(msg bus.InboundMessage, total int, res plannedTaskResult) {
if al == nil || al.bus == nil || total <= 1 {
return
}
if msg.Channel == "system" || msg.Channel == "internal" {
return
}
idx := res.Task.Index
if idx <= 0 {
idx = res.Index + 1
}
status := "完成"
body := strings.TrimSpace(res.Output)
if res.ErrText != "" {
status = "失败"
body = strings.TrimSpace(res.ErrText)
}
if body == "" {
body = "(无输出)"
}
body = truncate(strings.ReplaceAll(body, "\n", " "), 280)
content := fmt.Sprintf("进度 %d/%d任务%d已%s\n%s", idx, total, idx, status, body)
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: content,
})
}
func (al *AgentLoop) enrichTaskContentWithMemoryAndEKG(ctx context.Context, task plannedTask) string {
base := strings.TrimSpace(task.Content)
if base == "" {

View File

@@ -4,9 +4,14 @@ import (
"context"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
"clawgo/pkg/bus"
"clawgo/pkg/config"
"clawgo/pkg/ekg"
"clawgo/pkg/providers"
)
@@ -54,6 +59,115 @@ func TestProcessPlannedMessage_AggregatesResults(t *testing.T) {
}
}
type probeProvider struct {
mu sync.Mutex
inFlight int
maxInFlight int
delayPerCall time.Duration
responseCount int
}
func (p *probeProvider) Chat(_ context.Context, _ []providers.Message, _ []providers.ToolDefinition, _ string, _ map[string]interface{}) (*providers.LLMResponse, error) {
p.mu.Lock()
p.inFlight++
if p.inFlight > p.maxInFlight {
p.maxInFlight = p.inFlight
}
p.responseCount++
p.mu.Unlock()
time.Sleep(p.delayPerCall)
p.mu.Lock()
n := p.responseCount
p.inFlight--
p.mu.Unlock()
resp := providers.LLMResponse{Content: "done-" + strconv.Itoa(n), FinishReason: "stop"}
return &resp, nil
}
func (p *probeProvider) GetDefaultModel() string { return "test-model" }
func TestRunPlannedTasks_NonConflictingKeysCanRunInParallel(t *testing.T) {
p := &probeProvider{delayPerCall: 100 * time.Millisecond}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace")
cfg.Agents.Defaults.MaxToolIterations = 2
cfg.Agents.Defaults.ContextCompaction.Enabled = false
loop := NewAgentLoop(cfg, bus.NewMessageBus(), p, nil)
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
Channel: "cli",
SenderID: "u",
ChatID: "direct",
SessionKey: "sess-plan-parallel",
Content: "[resource_keys: file:pkg/a.go] 修复 a[resource_keys: file:pkg/b.go] 修复 b",
})
if err != nil {
t.Fatalf("processPlannedMessage error: %v", err)
}
if p.maxInFlight < 2 {
t.Fatalf("expected parallel execution for non-conflicting keys, got maxInFlight=%d", p.maxInFlight)
}
}
func TestRunPlannedTasks_ConflictingKeysMutuallyExclusive(t *testing.T) {
p := &probeProvider{delayPerCall: 100 * time.Millisecond}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace")
cfg.Agents.Defaults.MaxToolIterations = 2
cfg.Agents.Defaults.ContextCompaction.Enabled = false
loop := NewAgentLoop(cfg, bus.NewMessageBus(), p, nil)
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
Channel: "cli",
SenderID: "u",
ChatID: "direct",
SessionKey: "sess-plan-locked",
Content: "[resource_keys: file:pkg/a.go] 修复 a[resource_keys: file:pkg/a.go] 补测试",
})
if err != nil {
t.Fatalf("processPlannedMessage error: %v", err)
}
if p.maxInFlight != 1 {
t.Fatalf("expected mutual exclusion for conflicting keys, got maxInFlight=%d", p.maxInFlight)
}
}
func TestRunPlannedTasks_PublishesStepProgress(t *testing.T) {
rp := &recordingProvider{responses: []providers.LLMResponse{
{Content: "done-a", FinishReason: "stop"},
{Content: "done-b", FinishReason: "stop"},
}}
loop := setupLoop(t, rp)
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
Channel: "cli",
SenderID: "u",
ChatID: "direct",
SessionKey: "sess-plan-progress",
Content: "修复 pkg/a.go补充 pkg/b.go 测试",
})
if err != nil {
t.Fatalf("processPlannedMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
out1, ok := loop.bus.SubscribeOutbound(ctx)
if !ok {
t.Fatalf("expected first progress outbound")
}
out2, ok := loop.bus.SubscribeOutbound(ctx)
if !ok {
t.Fatalf("expected second progress outbound")
}
all := out1.Content + "\n" + out2.Content
if !strings.Contains(all, "进度 1/2") || !strings.Contains(all, "进度 2/2") {
t.Fatalf("unexpected progress outputs:\n%s", all)
}
}
func TestFindRecentRelatedErrorEvent(t *testing.T) {
ws := filepath.Join(t.TempDir(), "workspace")
_ = os.MkdirAll(filepath.Join(ws, "memory"), 0o755)