mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 05:37:29 +08:00
feat: harden concurrency scheduling and task watchdog
This commit is contained in:
@@ -734,23 +734,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return fmt.Sprintf(tpl, lang), nil
|
||||
}
|
||||
|
||||
// Update tool contexts
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
mt.SetContext(msg.Channel, msg.ChatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(*tools.SpawnTool); ok {
|
||||
st.SetContext(msg.Channel, msg.ChatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("remind"); ok {
|
||||
if rt, ok := tool.(*tools.RemindTool); ok {
|
||||
rt.SetContext(msg.Channel, msg.ChatID)
|
||||
}
|
||||
}
|
||||
|
||||
history := al.sessions.GetHistory(msg.SessionKey)
|
||||
summary := al.sessions.GetSummary(msg.SessionKey)
|
||||
memoryRecallUsed := false
|
||||
@@ -948,7 +931,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
|
||||
execArgs := withToolContextArgs(tc.Name, tc.Arguments, msg.Channel, msg.ChatID)
|
||||
result, err := al.tools.Execute(ctx, tc.Name, execArgs)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("Error: %v", err)
|
||||
}
|
||||
@@ -1168,18 +1152,6 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
// Use the origin session for context
|
||||
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
|
||||
|
||||
// Update tool contexts to original channel/chatID
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
mt.SetContext(originChannel, originChatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(*tools.SpawnTool); ok {
|
||||
st.SetContext(originChannel, originChatID)
|
||||
}
|
||||
}
|
||||
|
||||
// Build messages with the announce content
|
||||
history := al.sessions.GetHistory(sessionKey)
|
||||
summary := al.sessions.GetSummary(sessionKey)
|
||||
@@ -1273,7 +1245,8 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
al.sessions.AddMessageFull(sessionKey, assistantMsg)
|
||||
|
||||
for _, tc := range response.ToolCalls {
|
||||
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
|
||||
execArgs := withToolContextArgs(tc.Name, tc.Arguments, originChannel, originChatID)
|
||||
result, err := al.tools.Execute(ctx, tc.Name, execArgs)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("Error: %v", err)
|
||||
}
|
||||
@@ -1657,6 +1630,42 @@ func truncateString(s string, maxLen int) string {
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
func withToolContextArgs(toolName string, args map[string]interface{}, channel, chatID string) map[string]interface{} {
|
||||
if channel == "" || chatID == "" {
|
||||
return args
|
||||
}
|
||||
switch toolName {
|
||||
case "message", "spawn", "remind":
|
||||
default:
|
||||
return args
|
||||
}
|
||||
|
||||
next := make(map[string]interface{}, len(args)+2)
|
||||
for k, v := range args {
|
||||
next[k] = v
|
||||
}
|
||||
|
||||
if toolName == "message" {
|
||||
if _, ok := next["channel"]; !ok {
|
||||
next["channel"] = channel
|
||||
}
|
||||
if _, hasChat := next["chat_id"]; !hasChat {
|
||||
if _, hasTo := next["to"]; !hasTo {
|
||||
next["chat_id"] = chatID
|
||||
}
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
if _, ok := next["channel"]; !ok {
|
||||
next["channel"] = channel
|
||||
}
|
||||
if _, ok := next["chat_id"]; !ok {
|
||||
next["chat_id"] = chatID
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func shouldRecallMemory(text string, keywords []string) bool {
|
||||
s := strings.ToLower(strings.TrimSpace(text))
|
||||
if s == "" {
|
||||
|
||||
36
pkg/agent/loop_tool_context_test.go
Normal file
36
pkg/agent/loop_tool_context_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestWithToolContextArgsInjectsDefaults(t *testing.T) {
|
||||
args := map[string]interface{}{"message": "hello"}
|
||||
got := withToolContextArgs("message", args, "telegram", "chat-1")
|
||||
if got["channel"] != "telegram" {
|
||||
t.Fatalf("expected channel injected, got %v", got["channel"])
|
||||
}
|
||||
if got["chat_id"] != "chat-1" {
|
||||
t.Fatalf("expected chat_id injected, got %v", got["chat_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithToolContextArgsPreservesExplicitTarget(t *testing.T) {
|
||||
args := map[string]interface{}{"message": "hello", "to": "target-2"}
|
||||
got := withToolContextArgs("message", args, "telegram", "chat-1")
|
||||
if _, ok := got["chat_id"]; ok {
|
||||
t.Fatalf("chat_id should not be injected when 'to' is provided")
|
||||
}
|
||||
if got["to"] != "target-2" {
|
||||
t.Fatalf("expected to preserved, got %v", got["to"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithToolContextArgsSkipsUnrelatedTools(t *testing.T) {
|
||||
args := map[string]interface{}{"query": "x"}
|
||||
got := withToolContextArgs("memory_search", args, "telegram", "chat-1")
|
||||
if len(got) != len(args) {
|
||||
t.Fatalf("expected unchanged args for unrelated tool")
|
||||
}
|
||||
if _, ok := got["channel"]; ok {
|
||||
t.Fatalf("unexpected channel key for unrelated tool")
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -140,11 +139,10 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage
|
||||
res.ErrText = err.Error()
|
||||
}
|
||||
results[index] = res
|
||||
al.publishPlannedTaskProgress(msg, len(tasks), res)
|
||||
}(i, task)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
sort.SliceStable(results, func(i, j int) bool { return results[i].Task.Index < results[j].Task.Index })
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("已自动拆解为 %d 个任务并执行:\n\n", len(results)))
|
||||
for _, r := range results {
|
||||
@@ -162,6 +160,35 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage
|
||||
return strings.TrimSpace(b.String()), nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) publishPlannedTaskProgress(msg bus.InboundMessage, total int, res plannedTaskResult) {
|
||||
if al == nil || al.bus == nil || total <= 1 {
|
||||
return
|
||||
}
|
||||
if msg.Channel == "system" || msg.Channel == "internal" {
|
||||
return
|
||||
}
|
||||
idx := res.Task.Index
|
||||
if idx <= 0 {
|
||||
idx = res.Index + 1
|
||||
}
|
||||
status := "完成"
|
||||
body := strings.TrimSpace(res.Output)
|
||||
if res.ErrText != "" {
|
||||
status = "失败"
|
||||
body = strings.TrimSpace(res.ErrText)
|
||||
}
|
||||
if body == "" {
|
||||
body = "(无输出)"
|
||||
}
|
||||
body = truncate(strings.ReplaceAll(body, "\n", " "), 280)
|
||||
content := fmt.Sprintf("进度 %d/%d:任务%d已%s\n%s", idx, total, idx, status, body)
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) enrichTaskContentWithMemoryAndEKG(ctx context.Context, task plannedTask) string {
|
||||
base := strings.TrimSpace(task.Content)
|
||||
if base == "" {
|
||||
|
||||
@@ -4,9 +4,14 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"clawgo/pkg/bus"
|
||||
"clawgo/pkg/config"
|
||||
"clawgo/pkg/ekg"
|
||||
"clawgo/pkg/providers"
|
||||
)
|
||||
@@ -54,6 +59,115 @@ func TestProcessPlannedMessage_AggregatesResults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type probeProvider struct {
|
||||
mu sync.Mutex
|
||||
inFlight int
|
||||
maxInFlight int
|
||||
delayPerCall time.Duration
|
||||
responseCount int
|
||||
}
|
||||
|
||||
func (p *probeProvider) Chat(_ context.Context, _ []providers.Message, _ []providers.ToolDefinition, _ string, _ map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.inFlight++
|
||||
if p.inFlight > p.maxInFlight {
|
||||
p.maxInFlight = p.inFlight
|
||||
}
|
||||
p.responseCount++
|
||||
p.mu.Unlock()
|
||||
|
||||
time.Sleep(p.delayPerCall)
|
||||
|
||||
p.mu.Lock()
|
||||
n := p.responseCount
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
resp := providers.LLMResponse{Content: "done-" + strconv.Itoa(n), FinishReason: "stop"}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (p *probeProvider) GetDefaultModel() string { return "test-model" }
|
||||
|
||||
func TestRunPlannedTasks_NonConflictingKeysCanRunInParallel(t *testing.T) {
|
||||
p := &probeProvider{delayPerCall: 100 * time.Millisecond}
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.MaxToolIterations = 2
|
||||
cfg.Agents.Defaults.ContextCompaction.Enabled = false
|
||||
loop := NewAgentLoop(cfg, bus.NewMessageBus(), p, nil)
|
||||
|
||||
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "cli",
|
||||
SenderID: "u",
|
||||
ChatID: "direct",
|
||||
SessionKey: "sess-plan-parallel",
|
||||
Content: "[resource_keys: file:pkg/a.go] 修复 a;[resource_keys: file:pkg/b.go] 修复 b",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processPlannedMessage error: %v", err)
|
||||
}
|
||||
if p.maxInFlight < 2 {
|
||||
t.Fatalf("expected parallel execution for non-conflicting keys, got maxInFlight=%d", p.maxInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunPlannedTasks_ConflictingKeysMutuallyExclusive(t *testing.T) {
|
||||
p := &probeProvider{delayPerCall: 100 * time.Millisecond}
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.MaxToolIterations = 2
|
||||
cfg.Agents.Defaults.ContextCompaction.Enabled = false
|
||||
loop := NewAgentLoop(cfg, bus.NewMessageBus(), p, nil)
|
||||
|
||||
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "cli",
|
||||
SenderID: "u",
|
||||
ChatID: "direct",
|
||||
SessionKey: "sess-plan-locked",
|
||||
Content: "[resource_keys: file:pkg/a.go] 修复 a;[resource_keys: file:pkg/a.go] 补测试",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processPlannedMessage error: %v", err)
|
||||
}
|
||||
if p.maxInFlight != 1 {
|
||||
t.Fatalf("expected mutual exclusion for conflicting keys, got maxInFlight=%d", p.maxInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunPlannedTasks_PublishesStepProgress(t *testing.T) {
|
||||
rp := &recordingProvider{responses: []providers.LLMResponse{
|
||||
{Content: "done-a", FinishReason: "stop"},
|
||||
{Content: "done-b", FinishReason: "stop"},
|
||||
}}
|
||||
loop := setupLoop(t, rp)
|
||||
|
||||
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "cli",
|
||||
SenderID: "u",
|
||||
ChatID: "direct",
|
||||
SessionKey: "sess-plan-progress",
|
||||
Content: "修复 pkg/a.go;补充 pkg/b.go 测试",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processPlannedMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
out1, ok := loop.bus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatalf("expected first progress outbound")
|
||||
}
|
||||
out2, ok := loop.bus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatalf("expected second progress outbound")
|
||||
}
|
||||
all := out1.Content + "\n" + out2.Content
|
||||
if !strings.Contains(all, "进度 1/2") || !strings.Contains(all, "进度 2/2") {
|
||||
t.Fatalf("unexpected progress outputs:\n%s", all)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRecentRelatedErrorEvent(t *testing.T) {
|
||||
ws := filepath.Join(t.TempDir(), "workspace")
|
||||
_ = os.MkdirAll(filepath.Join(ws, "memory"), 0o755)
|
||||
|
||||
Reference in New Issue
Block a user