This commit is contained in:
lpf
2026-02-23 16:38:00 +08:00
parent 95e9be18b8
commit b5430b9021
13 changed files with 1197 additions and 119 deletions

View File

@@ -119,7 +119,7 @@ func (t *WriteFileTool) Name() string {
}
func (t *WriteFileTool) Description() string {
return "Write content to a file"
return "Write content to a file. Supports overwrite (default) and append mode."
}
func (t *WriteFileTool) Parameters() map[string]interface{} {
@@ -134,6 +134,11 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
"type": "string",
"description": "Content to write to the file",
},
"append": map[string]interface{}{
"type": "boolean",
"description": "If true, append content to the file instead of overwriting it",
"default": false,
},
},
"required": []string{"path", "content"},
}
@@ -149,16 +154,31 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}
if !ok {
return "", fmt.Errorf("content is required")
}
appendMode, _ := args["append"].(bool)
resolvedPath, err := resolveToolPath(t.allowedDir, path)
if err != nil {
return "", err
}
if err := os.MkdirAll(filepath.Dir(resolvedPath), 0755); err != nil {
return "", err
}
if appendMode {
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return "", err
}
defer f.Close()
if _, err := f.WriteString(content); err != nil {
return "", err
}
return fmt.Sprintf("File appended successfully: %s", path), nil
}
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
return "", err
}
return fmt.Sprintf("File written successfully: %s", path), nil
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
)
@@ -54,6 +55,12 @@ type searchResult struct {
score int
}
type fileSearchOutcome struct {
matches []searchResult
err error
file string
}
func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
@@ -61,8 +68,14 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
}
maxResults := 5
if m, ok := args["maxResults"].(float64); ok {
maxResults = int(m)
if m, ok := parseIntArg(args["maxResults"]); ok {
maxResults = m
}
if maxResults < 1 {
maxResults = 1
}
if maxResults > 50 {
maxResults = 50
}
keywords := strings.Fields(strings.ToLower(query))
@@ -77,11 +90,14 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
// Fast path: structured memory index.
if idx, err := t.loadOrBuildIndex(files); err == nil && idx != nil {
results := t.searchInIndex(idx, keywords)
return t.renderSearchResults(query, results, maxResults), nil
// If index has entries, use it. Otherwise fallback to file scan so parser/read warnings are visible.
if len(idx.Entries) > 0 {
results := t.searchInIndex(idx, keywords)
return t.renderSearchResults(query, results, maxResults), nil
}
}
resultsChan := make(chan []searchResult, len(files))
resultsChan := make(chan fileSearchOutcome, len(files))
var wg sync.WaitGroup
// Search all files concurrently
@@ -90,9 +106,7 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
go func(f string) {
defer wg.Done()
matches, err := t.searchFile(f, keywords)
if err == nil {
resultsChan <- matches
}
resultsChan <- fileSearchOutcome{matches: matches, err: err, file: f}
}(file)
}
@@ -103,11 +117,28 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
}()
var allResults []searchResult
for matches := range resultsChan {
allResults = append(allResults, matches...)
var failedFiles []string
for outcome := range resultsChan {
if outcome.err != nil {
relPath, _ := filepath.Rel(t.workspace, outcome.file)
if relPath == "" {
relPath = outcome.file
}
failedFiles = append(failedFiles, relPath)
continue
}
allResults = append(allResults, outcome.matches...)
}
return t.renderSearchResults(query, allResults, maxResults), nil
output := t.renderSearchResults(query, allResults, maxResults)
if len(failedFiles) > 0 {
suffix := formatSearchWarningSuffix(failedFiles)
if strings.HasPrefix(output, "No memory found for query:") {
return output + suffix, nil
}
return output + "\n" + suffix, nil
}
return output, nil
}
func (t *MemorySearchTool) searchInIndex(idx *memoryIndex, keywords []string) []searchResult {
@@ -190,9 +221,13 @@ func (t *MemorySearchTool) getMemoryFiles() []string {
}
}
// Check long-term memory in both legacy and current locations.
addIfExists(filepath.Join(t.workspace, "MEMORY.md"))
addIfExists(filepath.Join(t.workspace, "memory", "MEMORY.md"))
// Prefer canonical long-term memory path.
canonical := filepath.Join(t.workspace, "memory", "MEMORY.md")
addIfExists(canonical)
// Legacy path fallback only when canonical file is absent.
if _, err := os.Stat(canonical); err != nil {
addIfExists(filepath.Join(t.workspace, "MEMORY.md"))
}
// Check memory/ directory recursively (e.g., memory/YYYYMM/YYYYMMDD.md).
memDir := filepath.Join(t.workspace, "memory")
@@ -221,6 +256,7 @@ func (t *MemorySearchTool) searchFile(path string, keywords []string) ([]searchR
var results []searchResult
scanner := bufio.NewScanner(file)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
var currentBlock strings.Builder
var blockStartLine int = 1
@@ -312,5 +348,41 @@ func (t *MemorySearchTool) searchFile(path string, keywords []string) ([]searchR
}
processBlock() // Flush last block
if err := scanner.Err(); err != nil {
return nil, err
}
return results, nil
}
func parseIntArg(value interface{}) (int, bool) {
switch v := value.(type) {
case float64:
return int(v), true
case int:
return v, true
case int64:
return int(v), true
case string:
n, err := strconv.Atoi(strings.TrimSpace(v))
if err == nil {
return n, true
}
}
return 0, false
}
func formatSearchWarningSuffix(failedFiles []string) string {
if len(failedFiles) == 0 {
return ""
}
maxShown := 3
shown := failedFiles
if len(shown) > maxShown {
shown = shown[:maxShown]
}
msg := fmt.Sprintf("Warning: memory_search skipped %d file(s) due to read/parse errors: %s", len(failedFiles), strings.Join(shown, ", "))
if len(failedFiles) > maxShown {
msg += ", ..."
}
return msg
}

102
pkg/tools/memory_test.go Normal file
View File

@@ -0,0 +1,102 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
func TestMemorySearchToolClampsMaxResults(t *testing.T) {
workspace := t.TempDir()
memDir := filepath.Join(workspace, "memory")
if err := os.MkdirAll(memDir, 0755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
content := "# Long-term Memory\n\nalpha one\n\nalpha two\n"
if err := os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte(content), 0644); err != nil {
t.Fatalf("write failed: %v", err)
}
tool := NewMemorySearchTool(workspace)
out, err := tool.Execute(context.Background(), map[string]interface{}{
"query": "alpha",
"maxResults": -5,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(out, "Found 1 memories") {
t.Fatalf("expected clamped result count, got: %s", out)
}
}
func TestMemorySearchToolScannerHandlesLargeLine(t *testing.T) {
workspace := t.TempDir()
memDir := filepath.Join(workspace, "memory")
if err := os.MkdirAll(memDir, 0755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
large := strings.Repeat("x", 80*1024) + " needle"
if err := os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte(large), 0644); err != nil {
t.Fatalf("write failed: %v", err)
}
tool := NewMemorySearchTool(workspace)
out, err := tool.Execute(context.Background(), map[string]interface{}{
"query": "needle",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(out, "needle") {
t.Fatalf("expected search hit in large line, got: %s", out)
}
}
func TestMemorySearchToolPrefersCanonicalMemoryPath(t *testing.T) {
workspace := t.TempDir()
memDir := filepath.Join(workspace, "memory")
if err := os.MkdirAll(memDir, 0755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
if err := os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("canonical"), 0644); err != nil {
t.Fatalf("write canonical failed: %v", err)
}
if err := os.WriteFile(filepath.Join(workspace, "MEMORY.md"), []byte("legacy"), 0644); err != nil {
t.Fatalf("write legacy failed: %v", err)
}
tool := NewMemorySearchTool(workspace)
files := tool.getMemoryFiles()
for _, file := range files {
if file == filepath.Join(workspace, "MEMORY.md") {
t.Fatalf("legacy path should be ignored when canonical exists: %v", files)
}
}
}
func TestMemorySearchToolReportsFileScanWarnings(t *testing.T) {
workspace := t.TempDir()
memDir := filepath.Join(workspace, "memory")
if err := os.MkdirAll(memDir, 0755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
tooLargeLine := strings.Repeat("x", 2*1024*1024) + "\n"
if err := os.WriteFile(filepath.Join(memDir, "bad.md"), []byte(tooLargeLine), 0644); err != nil {
t.Fatalf("write bad file failed: %v", err)
}
tool := NewMemorySearchTool(workspace)
out, err := tool.Execute(context.Background(), map[string]interface{}{
"query": "needle",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(out, "Warning: memory_search skipped") {
t.Fatalf("expected warning suffix when scan errors happen, got: %s", out)
}
}