Files
clawgo/pkg/tools/memory.go
2026-02-23 16:38:00 +08:00

389 lines
9.5 KiB
Go

package tools
import (
"bufio"
"context"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
)
type MemorySearchTool struct {
workspace string
}
func NewMemorySearchTool(workspace string) *MemorySearchTool {
return &MemorySearchTool{
workspace: workspace,
}
}
func (t *MemorySearchTool) Name() string {
return "memory_search"
}
func (t *MemorySearchTool) Description() string {
return "Semantically search MEMORY.md and memory/*.md files for information. Returns relevant snippets (paragraphs) containing the query terms."
}
func (t *MemorySearchTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "Search query keywords (e.g., 'docker deploy project')",
},
"maxResults": map[string]interface{}{
"type": "integer",
"description": "Maximum number of results to return",
"default": 5,
},
},
"required": []string{"query"},
}
}
type searchResult struct {
file string
lineNum int
content string
score int
}
type fileSearchOutcome struct {
matches []searchResult
err error
file string
}
func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return "", fmt.Errorf("query is required")
}
maxResults := 5
if m, ok := parseIntArg(args["maxResults"]); ok {
maxResults = m
}
if maxResults < 1 {
maxResults = 1
}
if maxResults > 50 {
maxResults = 50
}
keywords := strings.Fields(strings.ToLower(query))
if len(keywords) == 0 {
return "Please provide search keywords.", nil
}
files := t.getMemoryFiles()
if len(files) == 0 {
return fmt.Sprintf("No memory files found for query: %s", query), nil
}
// Fast path: structured memory index.
if idx, err := t.loadOrBuildIndex(files); err == nil && idx != nil {
// If index has entries, use it. Otherwise fallback to file scan so parser/read warnings are visible.
if len(idx.Entries) > 0 {
results := t.searchInIndex(idx, keywords)
return t.renderSearchResults(query, results, maxResults), nil
}
}
resultsChan := make(chan fileSearchOutcome, len(files))
var wg sync.WaitGroup
// Search all files concurrently
for _, file := range files {
wg.Add(1)
go func(f string) {
defer wg.Done()
matches, err := t.searchFile(f, keywords)
resultsChan <- fileSearchOutcome{matches: matches, err: err, file: f}
}(file)
}
// Close channel asynchronously
go func() {
wg.Wait()
close(resultsChan)
}()
var allResults []searchResult
var failedFiles []string
for outcome := range resultsChan {
if outcome.err != nil {
relPath, _ := filepath.Rel(t.workspace, outcome.file)
if relPath == "" {
relPath = outcome.file
}
failedFiles = append(failedFiles, relPath)
continue
}
allResults = append(allResults, outcome.matches...)
}
output := t.renderSearchResults(query, allResults, maxResults)
if len(failedFiles) > 0 {
suffix := formatSearchWarningSuffix(failedFiles)
if strings.HasPrefix(output, "No memory found for query:") {
return output + suffix, nil
}
return output + "\n" + suffix, nil
}
return output, nil
}
func (t *MemorySearchTool) searchInIndex(idx *memoryIndex, keywords []string) []searchResult {
type scoreItem struct {
entry memoryIndexEntry
score int
}
acc := make(map[int]int)
for _, kw := range keywords {
token := strings.ToLower(strings.TrimSpace(kw))
for _, entryID := range idx.Inverted[token] {
acc[entryID]++
}
}
out := make([]scoreItem, 0, len(acc))
for entryID, score := range acc {
if entryID < 0 || entryID >= len(idx.Entries) || score <= 0 {
continue
}
out = append(out, scoreItem{
entry: idx.Entries[entryID],
score: score,
})
}
sort.Slice(out, func(i, j int) bool {
if out[i].score == out[j].score {
return out[i].entry.LineNum < out[j].entry.LineNum
}
return out[i].score > out[j].score
})
results := make([]searchResult, 0, len(out))
for _, item := range out {
results = append(results, searchResult{
file: item.entry.File,
lineNum: item.entry.LineNum,
content: item.entry.Content,
score: item.score,
})
}
return results
}
func (t *MemorySearchTool) renderSearchResults(query string, allResults []searchResult, maxResults int) string {
sort.Slice(allResults, func(i, j int) bool {
if allResults[i].score == allResults[j].score {
return allResults[i].lineNum < allResults[j].lineNum
}
return allResults[i].score > allResults[j].score
})
if len(allResults) > maxResults {
allResults = allResults[:maxResults]
}
if len(allResults) == 0 {
return fmt.Sprintf("No memory found for query: %s", query)
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Found %d memories for '%s':\n\n", len(allResults), query))
for _, res := range allResults {
relPath, _ := filepath.Rel(t.workspace, res.file)
sb.WriteString(fmt.Sprintf("--- Source: %s:%d ---\n%s\n\n", relPath, res.lineNum, res.content))
}
return sb.String()
}
func (t *MemorySearchTool) getMemoryFiles() []string {
var files []string
seen := map[string]struct{}{}
addIfExists := func(path string) {
if _, ok := seen[path]; ok {
return
}
if _, err := os.Stat(path); err == nil {
files = append(files, path)
seen[path] = struct{}{}
}
}
// Prefer canonical long-term memory path.
canonical := filepath.Join(t.workspace, "memory", "MEMORY.md")
addIfExists(canonical)
// Legacy path fallback only when canonical file is absent.
if _, err := os.Stat(canonical); err != nil {
addIfExists(filepath.Join(t.workspace, "MEMORY.md"))
}
// Check memory/ directory recursively (e.g., memory/YYYYMM/YYYYMMDD.md).
memDir := filepath.Join(t.workspace, "memory")
_ = filepath.Walk(memDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info == nil || info.IsDir() {
return nil
}
if strings.HasSuffix(strings.ToLower(info.Name()), ".md") {
if _, ok := seen[path]; !ok {
files = append(files, path)
seen[path] = struct{}{}
}
}
return nil
})
return files
}
// searchFile parses the markdown file into blocks (paragraphs/list items) and searches them
func (t *MemorySearchTool) searchFile(path string, keywords []string) ([]searchResult, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
var results []searchResult
scanner := bufio.NewScanner(file)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
var currentBlock strings.Builder
var blockStartLine int = 1
var currentLineNum int = 0
var lastHeading string
processBlock := func() {
content := strings.TrimSpace(currentBlock.String())
if content != "" {
lowerContent := strings.ToLower(content)
score := 0
// Calculate score: how many keywords are present?
for _, kw := range keywords {
if strings.Contains(lowerContent, kw) {
score++
}
}
// Add bonus if heading matches
if lastHeading != "" {
lowerHeading := strings.ToLower(lastHeading)
for _, kw := range keywords {
if strings.Contains(lowerHeading, kw) {
score++
}
}
// Prepend heading context if not already part of block
if !strings.HasPrefix(content, "#") {
content = fmt.Sprintf("[%s]\n%s", lastHeading, content)
}
}
// Keep all blocks when keywords are empty (index build).
if len(keywords) == 0 {
score = 1
}
// Only keep if at least one keyword matched.
if score > 0 {
results = append(results, searchResult{
file: path,
lineNum: blockStartLine,
content: content,
score: score,
})
}
}
currentBlock.Reset()
}
for scanner.Scan() {
currentLineNum++
line := scanner.Text()
trimmed := strings.TrimSpace(line)
// Markdown Block Logic:
// 1. Headers start new blocks
// 2. Empty lines separate blocks
// 3. List items start new blocks (optional, but good for logs)
isHeader := strings.HasPrefix(trimmed, "#")
isEmpty := trimmed == ""
isList := strings.HasPrefix(trimmed, "- ") || strings.HasPrefix(trimmed, "* ") || (len(trimmed) > 3 && trimmed[1] == '.' && trimmed[2] == ' ')
if isHeader {
processBlock() // Flush previous
lastHeading = strings.TrimLeft(trimmed, "# ")
blockStartLine = currentLineNum
currentBlock.WriteString(line + "\n")
processBlock() // Headers are their own blocks too
continue
}
if isEmpty {
processBlock() // Flush previous
blockStartLine = currentLineNum + 1
continue
}
if isList {
processBlock() // Flush previous (treat list items as atomic for better granularity)
blockStartLine = currentLineNum
}
if currentBlock.Len() == 0 {
blockStartLine = currentLineNum
}
currentBlock.WriteString(line + "\n")
}
processBlock() // Flush last block
if err := scanner.Err(); err != nil {
return nil, err
}
return results, nil
}
func parseIntArg(value interface{}) (int, bool) {
switch v := value.(type) {
case float64:
return int(v), true
case int:
return v, true
case int64:
return int(v), true
case string:
n, err := strconv.Atoi(strings.TrimSpace(v))
if err == nil {
return n, true
}
}
return 0, false
}
func formatSearchWarningSuffix(failedFiles []string) string {
if len(failedFiles) == 0 {
return ""
}
maxShown := 3
shown := failedFiles
if len(shown) > maxShown {
shown = shown[:maxShown]
}
msg := fmt.Sprintf("Warning: memory_search skipped %d file(s) due to read/parse errors: %s", len(failedFiles), strings.Join(shown, ", "))
if len(failedFiles) > maxShown {
msg += ", ..."
}
return msg
}