perf: implement concurrent file searching in MemorySearchTool

This commit is contained in:
DBT
2026-02-12 05:32:39 +00:00
parent 9dc73616d6
commit 440cfee620

View File

@@ -7,6 +7,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
) )
type MemorySearchTool struct { type MemorySearchTool struct {
@@ -69,37 +70,53 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
} }
files := t.getMemoryFiles() files := t.getMemoryFiles()
var results []searchResult
resultsChan := make(chan []searchResult, len(files))
var wg sync.WaitGroup
// 并发搜索所有文件
for _, file := range files { for _, file := range files {
matches, err := t.searchFile(file, keywords) wg.Add(1)
if err != nil { go func(f string) {
continue // skip unreadable files defer wg.Done()
} matches, err := t.searchFile(f, keywords)
results = append(results, matches...) if err == nil {
resultsChan <- matches
}
}(file)
}
// 异步关闭通道
go func() {
wg.Wait()
close(resultsChan)
}()
var allResults []searchResult
for matches := range resultsChan {
allResults = append(allResults, matches...)
} }
// Simple ranking: sort by score (number of keyword matches) desc // Simple ranking: sort by score (number of keyword matches) desc
// Ideally use a stable sort or more sophisticated scoring for i := 0; i < len(allResults); i++ {
for i := 0; i < len(results); i++ { for j := i + 1; j < len(allResults); j++ {
for j := i + 1; j < len(results); j++ { if allResults[j].score > allResults[i].score {
if results[j].score > results[i].score { allResults[i], allResults[j] = allResults[j], allResults[i]
results[i], results[j] = results[j], results[i]
} }
} }
} }
if len(results) > maxResults { if len(allResults) > maxResults {
results = results[:maxResults] allResults = allResults[:maxResults]
} }
if len(results) == 0 { if len(allResults) == 0 {
return fmt.Sprintf("No memory found for query: %s", query), nil return fmt.Sprintf("No memory found for query: %s", query), nil
} }
var sb strings.Builder var sb strings.Builder
sb.WriteString(fmt.Sprintf("Found %d memories for '%s':\n\n", len(results), query)) sb.WriteString(fmt.Sprintf("Found %d memories for '%s':\n\n", len(allResults), query))
for _, res := range results { for _, res := range allResults {
relPath, _ := filepath.Rel(t.workspace, res.file) relPath, _ := filepath.Rel(t.workspace, res.file)
sb.WriteString(fmt.Sprintf("--- Source: %s:%d ---\n%s\n\n", relPath, res.lineNum, res.content)) sb.WriteString(fmt.Sprintf("--- Source: %s:%d ---\n%s\n\n", relPath, res.lineNum, res.content))
} }