Files
clawgo/pkg/agent/session_planner.go

610 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package agent
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"clawgo/pkg/bus"
"clawgo/pkg/ekg"
"clawgo/pkg/providers"
"clawgo/pkg/scheduling"
)
type plannedTask struct {
Index int
Content string
ResourceKeys []string
}
type plannedTaskResult struct {
Index int
Task plannedTask
Output string
Err error
ErrText string
}
var reLeadingNumber = regexp.MustCompile(`^\d+[\.)、]\s*`)
func (al *AgentLoop) processPlannedMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
tasks := al.planSessionTasks(ctx, msg)
if len(tasks) <= 1 {
return al.processMessage(ctx, msg)
}
return al.runPlannedTasks(ctx, msg, tasks)
}
func (al *AgentLoop) planSessionTasks(ctx context.Context, msg bus.InboundMessage) []plannedTask {
base := strings.TrimSpace(msg.Content)
if base == "" {
return nil
}
if msg.Metadata != nil {
if planningDisabled(msg.Metadata["disable_planning"]) {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
}
if msg.Channel == "system" || msg.Channel == "internal" {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
if msg.Metadata != nil {
if strings.TrimSpace(msg.Metadata["trigger"]) != "" {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
}
if strings.HasPrefix(base, "/") {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
segments := splitPlannedSegments(base)
if len(segments) <= 1 {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
if refined, ok := al.inferPlannedSegments(ctx, base, segments); ok {
segments = refined
}
out := make([]plannedTask, 0, len(segments))
for i, seg := range segments {
content := strings.TrimSpace(seg)
if content == "" {
continue
}
out = append(out, plannedTask{
Index: i + 1,
Content: content,
ResourceKeys: scheduling.DeriveResourceKeys(content),
})
}
if len(out) == 0 {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
if len(out) == 1 {
out[0].Content = base
out[0].ResourceKeys = scheduling.DeriveResourceKeys(base)
}
return out
}
type plannerDecision struct {
ShouldSplit bool `json:"should_split"`
Tasks []string `json:"tasks"`
}
func (al *AgentLoop) inferPlannedSegments(ctx context.Context, content string, candidates []string) ([]string, bool) {
if al == nil || al.provider == nil {
return nil, false
}
content = strings.TrimSpace(content)
if content == "" || len(candidates) <= 1 {
return nil, false
}
plannerCtx := ctx
if plannerCtx == nil {
plannerCtx = context.Background()
}
var cancel context.CancelFunc
plannerCtx, cancel = context.WithTimeout(plannerCtx, 8*time.Second)
defer cancel()
previewCandidates := candidates
if len(previewCandidates) > 12 {
previewCandidates = previewCandidates[:12]
}
var candidateList strings.Builder
for i, item := range previewCandidates {
candidateList.WriteString(fmt.Sprintf("%d. %s\n", i+1, strings.TrimSpace(item)))
}
resp, err := al.provider.Chat(plannerCtx, []providers.Message{
{
Role: "system",
Content: "Decide whether the request should stay as one task or be split into a small number of high-level task groups. " +
"Default to no split. Only split when the request contains multiple independent deliverables, roles, or workstreams. " +
"Never split into fine-grained execution steps. Merge related steps. Return strict JSON only: " +
`{"should_split":true|false,"tasks":["..."]}` +
" If should_split is false, tasks must be empty. If true, tasks must contain 2 to 8 concise high-level task groups.",
},
{
Role: "user",
Content: fmt.Sprintf("Original request:\n%s\n\nRule-based candidate segments (%d shown of %d):\n%s",
content,
len(previewCandidates),
len(candidates),
strings.TrimSpace(candidateList.String()),
),
},
}, nil, al.provider.GetDefaultModel(), map[string]interface{}{
"max_tokens": 256,
})
if err != nil || resp == nil {
return nil, false
}
decision, ok := parsePlannerDecision(resp.Content)
if !ok {
return nil, false
}
if !decision.ShouldSplit {
return []string{content}, true
}
tasks := sanitizePlannerTasks(decision.Tasks)
if len(tasks) < 2 || len(tasks) > 8 {
return []string{content}, true
}
return tasks, true
}
func parsePlannerDecision(raw string) (plannerDecision, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return plannerDecision{}, false
}
if fenced := extractJSONObject(raw); fenced != "" {
raw = fenced
}
var out plannerDecision
if err := json.Unmarshal([]byte(raw), &out); err != nil {
return plannerDecision{}, false
}
return out, true
}
func extractJSONObject(raw string) string {
start := strings.Index(raw, "{")
end := strings.LastIndex(raw, "}")
if start < 0 || end <= start {
return ""
}
return strings.TrimSpace(raw[start : end+1])
}
func sanitizePlannerTasks(items []string) []string {
out := make([]string, 0, len(items))
seen := map[string]struct{}{}
for _, item := range items {
t := strings.TrimSpace(reLeadingNumber.ReplaceAllString(strings.TrimSpace(item), ""))
if t == "" {
continue
}
key := strings.ToLower(t)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, t)
}
return out
}
func planningDisabled(v string) bool {
switch strings.ToLower(strings.TrimSpace(v)) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func splitPlannedSegments(content string) []string {
lines := strings.Split(content, "\n")
bullet := make([]string, 0, len(lines))
for _, line := range lines {
t := strings.TrimSpace(line)
if t == "" {
continue
}
if strings.HasPrefix(t, "- ") || strings.HasPrefix(t, "* ") {
bullet = append(bullet, strings.TrimSpace(t[2:]))
continue
}
if reLeadingNumber.MatchString(t) {
bullet = append(bullet, strings.TrimSpace(reLeadingNumber.ReplaceAllString(t, "")))
}
}
if len(bullet) >= 2 {
return bullet
}
// Only split implicit plans on strong separators. Plain newlines are often
// just formatting inside a single request, and "然后/and then" frequently
// describes execution order inside one task rather than separate tasks.
replaced := strings.NewReplacer("", ";")
norm := replaced.Replace(content)
parts := strings.Split(norm, ";")
out := make([]string, 0, len(parts))
for _, p := range parts {
t := strings.TrimSpace(p)
if t == "" {
continue
}
out = append(out, t)
}
return out
}
func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage, tasks []plannedTask) (string, error) {
results := make([]plannedTaskResult, len(tasks))
var wg sync.WaitGroup
var progressMu sync.Mutex
completed := 0
failed := 0
milestones := plannedProgressMilestones(len(tasks))
notified := make(map[int]struct{}, len(milestones))
for i, task := range tasks {
wg.Add(1)
go func(index int, t plannedTask) {
defer wg.Done()
subMsg := msg
subMsg.Content = al.enrichTaskContentWithMemoryAndEKG(ctx, t)
subMsg.Metadata = cloneMetadata(msg.Metadata)
if subMsg.Metadata == nil {
subMsg.Metadata = map[string]string{}
}
subMsg.Metadata["resource_keys"] = strings.Join(t.ResourceKeys, ",")
subMsg.Metadata["planned_task_index"] = fmt.Sprintf("%d", t.Index)
subMsg.Metadata["planned_task_total"] = fmt.Sprintf("%d", len(tasks))
out, err := al.processMessage(ctx, subMsg)
res := plannedTaskResult{Index: index, Task: t, Output: strings.TrimSpace(out), Err: err}
if err != nil {
res.ErrText = err.Error()
}
results[index] = res
progressMu.Lock()
completed++
if res.ErrText != "" && !isPlannedTaskCancellation(ctx, res) {
failed++
}
snapshotCompleted := completed
snapshotFailed := failed
shouldNotify := shouldPublishPlannedTaskProgress(ctx, len(tasks), snapshotCompleted, res, milestones, notified)
if shouldNotify && res.ErrText == "" {
notified[snapshotCompleted] = struct{}{}
}
progressMu.Unlock()
if shouldNotify {
al.publishPlannedTaskProgress(msg, len(tasks), snapshotCompleted, snapshotFailed, res)
}
}(i, task)
}
wg.Wait()
if err := ctx.Err(); err != nil {
return "", err
}
var b strings.Builder
b.WriteString(fmt.Sprintf("已自动拆解为 %d 个任务并执行:\n\n", len(results)))
for _, r := range results {
b.WriteString(fmt.Sprintf("[%d] %s\n", r.Task.Index, r.Task.Content))
if r.ErrText != "" {
b.WriteString("执行失败:" + r.ErrText + "\n\n")
continue
}
if r.Output == "" {
b.WriteString("(无输出)\n\n")
continue
}
b.WriteString(r.Output + "\n\n")
}
return strings.TrimSpace(b.String()), nil
}
func plannedProgressMilestones(total int) []int {
if total <= 3 {
return nil
}
points := []float64{0.33, 0.66}
out := make([]int, 0, len(points))
seen := map[int]struct{}{}
for _, p := range points {
step := int(math.Round(float64(total) * p))
if step <= 0 || step >= total {
continue
}
if _, ok := seen[step]; ok {
continue
}
seen[step] = struct{}{}
out = append(out, step)
}
return out
}
func shouldPublishPlannedTaskProgress(ctx context.Context, total, completed int, res plannedTaskResult, milestones []int, notified map[int]struct{}) bool {
if total <= 1 {
return false
}
if isPlannedTaskCancellation(ctx, res) {
return false
}
if strings.TrimSpace(res.ErrText) != "" {
return true
}
if completed >= total {
return false
}
for _, step := range milestones {
if completed != step {
continue
}
if _, ok := notified[step]; ok {
return false
}
return true
}
return false
}
func isPlannedTaskCancellation(ctx context.Context, res plannedTaskResult) bool {
if res.Err != nil && errors.Is(res.Err, context.Canceled) {
return true
}
if strings.EqualFold(strings.TrimSpace(res.ErrText), context.Canceled.Error()) {
return true
}
return ctx != nil && errors.Is(ctx.Err(), context.Canceled)
}
func (al *AgentLoop) publishPlannedTaskProgress(msg bus.InboundMessage, total, completed, failed int, res plannedTaskResult) {
if al == nil || al.bus == nil || total <= 1 {
return
}
if msg.Channel == "system" || msg.Channel == "internal" {
return
}
idx := res.Task.Index
if idx <= 0 {
idx = res.Index + 1
}
status := "完成"
body := strings.TrimSpace(res.Output)
if res.ErrText != "" {
status = "失败"
body = strings.TrimSpace(res.ErrText)
}
if body == "" {
body = "(无输出)"
}
body = summarizePlannedTaskProgressBody(body, 6, 320)
content := fmt.Sprintf("阶段进度 %d/%d失败 %d\n最近任务%d 已%s\n%s", completed, total, failed, idx, status, body)
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: content,
})
}
func summarizePlannedTaskProgressBody(body string, maxLines, maxChars int) string {
body = strings.ReplaceAll(body, "\r\n", "\n")
body = strings.TrimSpace(body)
if body == "" {
return "(无输出)"
}
lines := strings.Split(body, "\n")
out := make([]string, 0, len(lines))
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
out = append(out, line)
if maxLines > 0 && len(out) >= maxLines {
break
}
}
if len(out) == 0 {
return "(无输出)"
}
joined := strings.Join(out, "\n")
if maxChars > 0 && len(joined) > maxChars {
joined = truncate(joined, maxChars)
}
if len(lines) > len(out) && !strings.HasSuffix(joined, "...") {
joined += "\n..."
}
return joined
}
func (al *AgentLoop) enrichTaskContentWithMemoryAndEKG(ctx context.Context, task plannedTask) string {
base := strings.TrimSpace(task.Content)
if base == "" {
return base
}
hints := make([]string, 0, 2)
if mem := al.memoryHintForTask(ctx, task); mem != "" {
hints = append(hints, "Memory:\n"+mem)
}
if risk := al.ekgHintForTask(task); risk != "" {
hints = append(hints, "EKG:\n"+risk)
}
if len(hints) == 0 {
return base
}
return strings.TrimSpace(
"Task Context (use it as constraints, avoid repeating known failures):\n" +
strings.Join(hints, "\n\n") +
"\n\nTask:\n" + base,
)
}
func (al *AgentLoop) memoryHintForTask(ctx context.Context, task plannedTask) string {
if al == nil || al.tools == nil {
return ""
}
args := map[string]interface{}{
"query": task.Content,
"maxResults": 2,
}
if ns := memoryNamespaceFromContext(ctx); ns != "main" {
args["namespace"] = ns
}
res, err := al.tools.Execute(ctx, "memory_search", args)
if err != nil {
return ""
}
txt := strings.TrimSpace(res)
if txt == "" || strings.HasPrefix(strings.ToLower(txt), "no memory found") {
return ""
}
return truncate(txt, 1200)
}
func (al *AgentLoop) ekgHintForTask(task plannedTask) string {
if al == nil || al.ekg == nil || strings.TrimSpace(al.workspace) == "" {
return ""
}
evt, ok := al.findRecentRelatedErrorEvent(task.Content)
if !ok {
return ""
}
errSig := ekg.NormalizeErrorSignature(evt.Log)
if errSig == "" {
return ""
}
advice := al.ekg.GetAdvice(ekg.SignalContext{
TaskID: evt.TaskID,
ErrSig: errSig,
Source: evt.Source,
Channel: evt.Channel,
})
if !advice.ShouldEscalate {
return ""
}
reasons := strings.Join(advice.Reason, ", ")
if strings.TrimSpace(reasons) == "" {
reasons = "repeated error signature"
}
return fmt.Sprintf("Related repeated error signature detected (%s). Suggested retry backoff: %ds. Last error: %s",
errSig, advice.RetryBackoffSec, truncate(strings.TrimSpace(evt.Log), 240))
}
type taskAuditErrorEvent struct {
TaskID string
Source string
Channel string
Log string
Preview string
}
func (al *AgentLoop) findRecentRelatedErrorEvent(taskContent string) (taskAuditErrorEvent, bool) {
path := filepath.Join(strings.TrimSpace(al.workspace), "memory", "task-audit.jsonl")
f, err := os.Open(path)
if err != nil {
return taskAuditErrorEvent{}, false
}
defer f.Close()
kw := tokenizeTaskText(taskContent)
if len(kw) == 0 {
return taskAuditErrorEvent{}, false
}
var best taskAuditErrorEvent
bestScore := 0
s := bufio.NewScanner(f)
for s.Scan() {
line := strings.TrimSpace(s.Text())
if line == "" {
continue
}
var row map[string]interface{}
if json.Unmarshal([]byte(line), &row) != nil {
continue
}
if strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", row["status"]))) != "error" {
continue
}
logText := strings.TrimSpace(fmt.Sprintf("%v", row["log"]))
if logText == "" {
continue
}
preview := strings.TrimSpace(fmt.Sprintf("%v", row["input_preview"]))
score := overlapScore(kw, tokenizeTaskText(preview))
if score < 1 || score < bestScore {
continue
}
bestScore = score
best = taskAuditErrorEvent{
TaskID: strings.TrimSpace(fmt.Sprintf("%v", row["task_id"])),
Source: strings.TrimSpace(fmt.Sprintf("%v", row["source"])),
Channel: strings.TrimSpace(fmt.Sprintf("%v", row["channel"])),
Log: logText,
Preview: preview,
}
}
if bestScore == 0 || strings.TrimSpace(best.TaskID) == "" {
return taskAuditErrorEvent{}, false
}
return best, true
}
func tokenizeTaskText(s string) []string {
normalized := strings.NewReplacer("\n", " ", "\t", " ", ",", " ", "", " ", ".", " ", "。", " ", ":", " ", "", " ", ";", " ", "", " ").Replace(strings.ToLower(strings.TrimSpace(s)))
parts := strings.Fields(normalized)
out := make([]string, 0, len(parts))
for _, p := range parts {
if len(p) < 3 {
continue
}
out = append(out, p)
}
return out
}
func overlapScore(a, b []string) int {
if len(a) == 0 || len(b) == 0 {
return 0
}
set := make(map[string]struct{}, len(a))
for _, k := range a {
set[k] = struct{}{}
}
score := 0
for _, k := range b {
if _, ok := set[k]; ok {
score++
}
}
return score
}
func cloneMetadata(m map[string]string) map[string]string {
if len(m) == 0 {
return nil
}
out := make(map[string]string, len(m))
for k, v := range m {
out[k] = v
}
return out
}