optimize memory search with modtime block cache and faster stable ranking

This commit is contained in:
DBT
2026-02-24 15:20:00 +00:00
parent 9635d48e67
commit 532f01e4ee

View File

@@ -6,17 +6,34 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
type MemorySearchTool struct {
workspace string
mu sync.RWMutex
cache map[string]cachedMemoryFile
}
type memoryBlock struct {
lineNum int
content string
heading string
lower string
}
type cachedMemoryFile struct {
modTime time.Time
blocks []memoryBlock
}
func NewMemorySearchTool(workspace string) *MemorySearchTool {
return &MemorySearchTool{
workspace: workspace,
cache: make(map[string]cachedMemoryFile),
}
}
@@ -63,6 +80,9 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
if m, ok := args["maxResults"].(float64); ok {
maxResults = int(m)
}
if maxResults <= 0 {
maxResults = 5
}
keywords := strings.Fields(strings.ToLower(query))
if len(keywords) == 0 {
@@ -74,11 +94,16 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
resultsChan := make(chan []searchResult, len(files))
var wg sync.WaitGroup
// 并发搜索所有文件
for _, file := range files {
if ctx.Err() != nil {
break
}
wg.Add(1)
go func(f string) {
defer wg.Done()
if ctx.Err() != nil {
return
}
matches, err := t.searchFile(f, keywords)
if err == nil {
resultsChan <- matches
@@ -86,7 +111,6 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
}(file)
}
// 异步关闭通道
go func() {
wg.Wait()
close(resultsChan)
@@ -97,14 +121,15 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
allResults = append(allResults, matches...)
}
// Simple ranking: sort by score (number of keyword matches) desc
for i := 0; i < len(allResults); i++ {
for j := i + 1; j < len(allResults); j++ {
if allResults[j].score > allResults[i].score {
allResults[i], allResults[j] = allResults[j], allResults[i]
}
sort.SliceStable(allResults, func(i, j int) bool {
if allResults[i].score != allResults[j].score {
return allResults[i].score > allResults[j].score
}
}
if allResults[i].file != allResults[j].file {
return allResults[i].file < allResults[j].file
}
return allResults[i].lineNum < allResults[j].lineNum
})
if len(allResults) > maxResults {
allResults = allResults[:maxResults]
@@ -179,101 +204,113 @@ func dedupeStrings(items []string) []string {
return out
}
// searchFile parses the markdown file into blocks (paragraphs/list items) and searches them
// searchFile searches parsed markdown blocks with cache by file modtime.
func (t *MemorySearchTool) searchFile(path string, keywords []string) ([]searchResult, error) {
blocks, err := t.getOrParseBlocks(path)
if err != nil {
return nil, err
}
results := make([]searchResult, 0, 8)
for _, b := range blocks {
score := 0
for _, kw := range keywords {
if strings.Contains(b.lower, kw) {
score++
}
if b.heading != "" && strings.Contains(strings.ToLower(b.heading), kw) {
score++
}
}
if score == 0 {
continue
}
content := b.content
if b.heading != "" && !strings.HasPrefix(strings.TrimSpace(content), "#") {
content = fmt.Sprintf("[%s]\n%s", b.heading, content)
}
results = append(results, searchResult{file: path, lineNum: b.lineNum, content: content, score: score})
}
return results, nil
}
func (t *MemorySearchTool) getOrParseBlocks(path string) ([]memoryBlock, error) {
st, err := os.Stat(path)
if err != nil {
return nil, err
}
mod := st.ModTime()
t.mu.RLock()
if c, ok := t.cache[path]; ok && c.modTime.Equal(mod) {
blocks := c.blocks
t.mu.RUnlock()
return blocks, nil
}
t.mu.RUnlock()
blocks, err := parseMarkdownBlocks(path)
if err != nil {
return nil, err
}
t.mu.Lock()
t.cache[path] = cachedMemoryFile{modTime: mod, blocks: blocks}
t.mu.Unlock()
return blocks, nil
}
func parseMarkdownBlocks(path string) ([]memoryBlock, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
var results []searchResult
blocks := make([]memoryBlock, 0, 32)
scanner := bufio.NewScanner(file)
var current strings.Builder
blockStartLine := 1
currentLine := 0
lastHeading := ""
var currentBlock strings.Builder
var blockStartLine int = 1
var currentLineNum int = 0
var lastHeading string
processBlock := func() {
content := strings.TrimSpace(currentBlock.String())
if content != "" {
lowerContent := strings.ToLower(content)
score := 0
// Calculate score: how many keywords are present?
for _, kw := range keywords {
if strings.Contains(lowerContent, kw) {
score++
}
}
// Add bonus if heading matches
if lastHeading != "" {
lowerHeading := strings.ToLower(lastHeading)
for _, kw := range keywords {
if strings.Contains(lowerHeading, kw) {
score++
}
}
// Prepend heading context if not already part of block
if !strings.HasPrefix(content, "#") {
content = fmt.Sprintf("[%s]\n%s", lastHeading, content)
}
}
// Only keep if at least one keyword matched
if score > 0 {
results = append(results, searchResult{
file: path,
lineNum: blockStartLine,
content: content,
score: score,
})
}
flush := func() {
content := strings.TrimSpace(current.String())
if content == "" {
current.Reset()
return
}
currentBlock.Reset()
blocks = append(blocks, memoryBlock{lineNum: blockStartLine, content: content, heading: lastHeading, lower: strings.ToLower(content)})
current.Reset()
}
for scanner.Scan() {
currentLineNum++
currentLine++
line := scanner.Text()
trimmed := strings.TrimSpace(line)
// Markdown Block Logic:
// 1. Headers start new blocks
// 2. Empty lines separate blocks
// 3. List items start new blocks (optional, but good for logs)
isHeader := strings.HasPrefix(trimmed, "#")
isEmpty := trimmed == ""
isList := strings.HasPrefix(trimmed, "- ") || strings.HasPrefix(trimmed, "* ") || (len(trimmed) > 3 && trimmed[1] == '.' && trimmed[2] == ' ')
if isHeader {
processBlock() // Flush previous
flush()
lastHeading = strings.TrimLeft(trimmed, "# ")
blockStartLine = currentLineNum
currentBlock.WriteString(line + "\n")
processBlock() // Headers are their own blocks too
blockStartLine = currentLine
current.WriteString(line + "\n")
flush()
continue
}
if isEmpty {
processBlock() // Flush previous
blockStartLine = currentLineNum + 1
flush()
blockStartLine = currentLine + 1
continue
}
if isList {
processBlock() // Flush previous (treat list items as atomic for better granularity)
blockStartLine = currentLineNum
flush()
blockStartLine = currentLine
}
if currentBlock.Len() == 0 {
blockStartLine = currentLineNum
if current.Len() == 0 {
blockStartLine = currentLine
}
currentBlock.WriteString(line + "\n")
current.WriteString(line + "\n")
}
processBlock() // Flush last block
return results, nil
flush()
return blocks, nil
}