diff --git a/pkg/tools/memory.go b/pkg/tools/memory.go index bc9b2b9..c0a5946 100644 --- a/pkg/tools/memory.go +++ b/pkg/tools/memory.go @@ -6,17 +6,34 @@ import ( "fmt" "os" "path/filepath" + "sort" "strings" "sync" + "time" ) type MemorySearchTool struct { workspace string + mu sync.RWMutex + cache map[string]cachedMemoryFile +} + +type memoryBlock struct { + lineNum int + content string + heading string + lower string +} + +type cachedMemoryFile struct { + modTime time.Time + blocks []memoryBlock } func NewMemorySearchTool(workspace string) *MemorySearchTool { return &MemorySearchTool{ workspace: workspace, + cache: make(map[string]cachedMemoryFile), } } @@ -63,6 +80,9 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac if m, ok := args["maxResults"].(float64); ok { maxResults = int(m) } + if maxResults <= 0 { + maxResults = 5 + } keywords := strings.Fields(strings.ToLower(query)) if len(keywords) == 0 { @@ -74,11 +94,16 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac resultsChan := make(chan []searchResult, len(files)) var wg sync.WaitGroup - // 并发搜索所有文件 for _, file := range files { + if ctx.Err() != nil { + break + } wg.Add(1) go func(f string) { defer wg.Done() + if ctx.Err() != nil { + return + } matches, err := t.searchFile(f, keywords) if err == nil { resultsChan <- matches @@ -86,7 +111,6 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac }(file) } - // 异步关闭通道 go func() { wg.Wait() close(resultsChan) @@ -97,14 +121,15 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac allResults = append(allResults, matches...) } - // Simple ranking: sort by score (number of keyword matches) desc - for i := 0; i < len(allResults); i++ { - for j := i + 1; j < len(allResults); j++ { - if allResults[j].score > allResults[i].score { - allResults[i], allResults[j] = allResults[j], allResults[i] - } + sort.SliceStable(allResults, func(i, j int) bool { + if allResults[i].score != allResults[j].score { + return allResults[i].score > allResults[j].score } - } + if allResults[i].file != allResults[j].file { + return allResults[i].file < allResults[j].file + } + return allResults[i].lineNum < allResults[j].lineNum + }) if len(allResults) > maxResults { allResults = allResults[:maxResults] @@ -179,101 +204,113 @@ func dedupeStrings(items []string) []string { return out } -// searchFile parses the markdown file into blocks (paragraphs/list items) and searches them +// searchFile searches parsed markdown blocks with cache by file modtime. func (t *MemorySearchTool) searchFile(path string, keywords []string) ([]searchResult, error) { + blocks, err := t.getOrParseBlocks(path) + if err != nil { + return nil, err + } + results := make([]searchResult, 0, 8) + for _, b := range blocks { + score := 0 + for _, kw := range keywords { + if strings.Contains(b.lower, kw) { + score++ + } + if b.heading != "" && strings.Contains(strings.ToLower(b.heading), kw) { + score++ + } + } + if score == 0 { + continue + } + content := b.content + if b.heading != "" && !strings.HasPrefix(strings.TrimSpace(content), "#") { + content = fmt.Sprintf("[%s]\n%s", b.heading, content) + } + results = append(results, searchResult{file: path, lineNum: b.lineNum, content: content, score: score}) + } + return results, nil +} + +func (t *MemorySearchTool) getOrParseBlocks(path string) ([]memoryBlock, error) { + st, err := os.Stat(path) + if err != nil { + return nil, err + } + mod := st.ModTime() + t.mu.RLock() + if c, ok := t.cache[path]; ok && c.modTime.Equal(mod) { + blocks := c.blocks + t.mu.RUnlock() + return blocks, nil + } + t.mu.RUnlock() + + blocks, err := parseMarkdownBlocks(path) + if err != nil { + return nil, err + } + t.mu.Lock() + t.cache[path] = cachedMemoryFile{modTime: mod, blocks: blocks} + t.mu.Unlock() + return blocks, nil +} + +func parseMarkdownBlocks(path string) ([]memoryBlock, error) { file, err := os.Open(path) if err != nil { return nil, err } defer file.Close() - var results []searchResult + blocks := make([]memoryBlock, 0, 32) scanner := bufio.NewScanner(file) + var current strings.Builder + blockStartLine := 1 + currentLine := 0 + lastHeading := "" - var currentBlock strings.Builder - var blockStartLine int = 1 - var currentLineNum int = 0 - var lastHeading string - - processBlock := func() { - content := strings.TrimSpace(currentBlock.String()) - if content != "" { - lowerContent := strings.ToLower(content) - score := 0 - // Calculate score: how many keywords are present? - for _, kw := range keywords { - if strings.Contains(lowerContent, kw) { - score++ - } - } - - // Add bonus if heading matches - if lastHeading != "" { - lowerHeading := strings.ToLower(lastHeading) - for _, kw := range keywords { - if strings.Contains(lowerHeading, kw) { - score++ - } - } - // Prepend heading context if not already part of block - if !strings.HasPrefix(content, "#") { - content = fmt.Sprintf("[%s]\n%s", lastHeading, content) - } - } - - // Only keep if at least one keyword matched - if score > 0 { - results = append(results, searchResult{ - file: path, - lineNum: blockStartLine, - content: content, - score: score, - }) - } + flush := func() { + content := strings.TrimSpace(current.String()) + if content == "" { + current.Reset() + return } - currentBlock.Reset() + blocks = append(blocks, memoryBlock{lineNum: blockStartLine, content: content, heading: lastHeading, lower: strings.ToLower(content)}) + current.Reset() } for scanner.Scan() { - currentLineNum++ + currentLine++ line := scanner.Text() trimmed := strings.TrimSpace(line) - - // Markdown Block Logic: - // 1. Headers start new blocks - // 2. Empty lines separate blocks - // 3. List items start new blocks (optional, but good for logs) - isHeader := strings.HasPrefix(trimmed, "#") isEmpty := trimmed == "" isList := strings.HasPrefix(trimmed, "- ") || strings.HasPrefix(trimmed, "* ") || (len(trimmed) > 3 && trimmed[1] == '.' && trimmed[2] == ' ') if isHeader { - processBlock() // Flush previous + flush() lastHeading = strings.TrimLeft(trimmed, "# ") - blockStartLine = currentLineNum - currentBlock.WriteString(line + "\n") - processBlock() // Headers are their own blocks too + blockStartLine = currentLine + current.WriteString(line + "\n") + flush() continue } - if isEmpty { - processBlock() // Flush previous - blockStartLine = currentLineNum + 1 + flush() + blockStartLine = currentLine + 1 continue } - if isList { - processBlock() // Flush previous (treat list items as atomic for better granularity) - blockStartLine = currentLineNum + flush() + blockStartLine = currentLine } - - if currentBlock.Len() == 0 { - blockStartLine = currentLineNum + if current.Len() == 0 { + blockStartLine = currentLine } - currentBlock.WriteString(line + "\n") + current.WriteString(line + "\n") } - - processBlock() // Flush last block - return results, nil + flush() + return blocks, nil }