mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-14 22:09:37 +08:00
fix
This commit is contained in:
@@ -119,7 +119,7 @@ func (t *WriteFileTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Description() string {
|
||||
return "Write content to a file"
|
||||
return "Write content to a file. Supports overwrite (default) and append mode."
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
@@ -134,6 +134,11 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file",
|
||||
},
|
||||
"append": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "If true, append content to the file instead of overwriting it",
|
||||
"default": false,
|
||||
},
|
||||
},
|
||||
"required": []string{"path", "content"},
|
||||
}
|
||||
@@ -149,16 +154,31 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
}
|
||||
appendMode, _ := args["append"].(bool)
|
||||
|
||||
resolvedPath, err := resolveToolPath(t.allowedDir, path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(resolvedPath), 0755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if appendMode {
|
||||
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("File appended successfully: %s", path), nil
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("File written successfully: %s", path), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@@ -54,6 +55,12 @@ type searchResult struct {
|
||||
score int
|
||||
}
|
||||
|
||||
type fileSearchOutcome struct {
|
||||
matches []searchResult
|
||||
err error
|
||||
file string
|
||||
}
|
||||
|
||||
func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || query == "" {
|
||||
@@ -61,8 +68,14 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
|
||||
}
|
||||
|
||||
maxResults := 5
|
||||
if m, ok := args["maxResults"].(float64); ok {
|
||||
maxResults = int(m)
|
||||
if m, ok := parseIntArg(args["maxResults"]); ok {
|
||||
maxResults = m
|
||||
}
|
||||
if maxResults < 1 {
|
||||
maxResults = 1
|
||||
}
|
||||
if maxResults > 50 {
|
||||
maxResults = 50
|
||||
}
|
||||
|
||||
keywords := strings.Fields(strings.ToLower(query))
|
||||
@@ -77,11 +90,14 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
|
||||
|
||||
// Fast path: structured memory index.
|
||||
if idx, err := t.loadOrBuildIndex(files); err == nil && idx != nil {
|
||||
results := t.searchInIndex(idx, keywords)
|
||||
return t.renderSearchResults(query, results, maxResults), nil
|
||||
// If index has entries, use it. Otherwise fallback to file scan so parser/read warnings are visible.
|
||||
if len(idx.Entries) > 0 {
|
||||
results := t.searchInIndex(idx, keywords)
|
||||
return t.renderSearchResults(query, results, maxResults), nil
|
||||
}
|
||||
}
|
||||
|
||||
resultsChan := make(chan []searchResult, len(files))
|
||||
resultsChan := make(chan fileSearchOutcome, len(files))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Search all files concurrently
|
||||
@@ -90,9 +106,7 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
|
||||
go func(f string) {
|
||||
defer wg.Done()
|
||||
matches, err := t.searchFile(f, keywords)
|
||||
if err == nil {
|
||||
resultsChan <- matches
|
||||
}
|
||||
resultsChan <- fileSearchOutcome{matches: matches, err: err, file: f}
|
||||
}(file)
|
||||
}
|
||||
|
||||
@@ -103,11 +117,28 @@ func (t *MemorySearchTool) Execute(ctx context.Context, args map[string]interfac
|
||||
}()
|
||||
|
||||
var allResults []searchResult
|
||||
for matches := range resultsChan {
|
||||
allResults = append(allResults, matches...)
|
||||
var failedFiles []string
|
||||
for outcome := range resultsChan {
|
||||
if outcome.err != nil {
|
||||
relPath, _ := filepath.Rel(t.workspace, outcome.file)
|
||||
if relPath == "" {
|
||||
relPath = outcome.file
|
||||
}
|
||||
failedFiles = append(failedFiles, relPath)
|
||||
continue
|
||||
}
|
||||
allResults = append(allResults, outcome.matches...)
|
||||
}
|
||||
|
||||
return t.renderSearchResults(query, allResults, maxResults), nil
|
||||
output := t.renderSearchResults(query, allResults, maxResults)
|
||||
if len(failedFiles) > 0 {
|
||||
suffix := formatSearchWarningSuffix(failedFiles)
|
||||
if strings.HasPrefix(output, "No memory found for query:") {
|
||||
return output + suffix, nil
|
||||
}
|
||||
return output + "\n" + suffix, nil
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (t *MemorySearchTool) searchInIndex(idx *memoryIndex, keywords []string) []searchResult {
|
||||
@@ -190,9 +221,13 @@ func (t *MemorySearchTool) getMemoryFiles() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Check long-term memory in both legacy and current locations.
|
||||
addIfExists(filepath.Join(t.workspace, "MEMORY.md"))
|
||||
addIfExists(filepath.Join(t.workspace, "memory", "MEMORY.md"))
|
||||
// Prefer canonical long-term memory path.
|
||||
canonical := filepath.Join(t.workspace, "memory", "MEMORY.md")
|
||||
addIfExists(canonical)
|
||||
// Legacy path fallback only when canonical file is absent.
|
||||
if _, err := os.Stat(canonical); err != nil {
|
||||
addIfExists(filepath.Join(t.workspace, "MEMORY.md"))
|
||||
}
|
||||
|
||||
// Check memory/ directory recursively (e.g., memory/YYYYMM/YYYYMMDD.md).
|
||||
memDir := filepath.Join(t.workspace, "memory")
|
||||
@@ -221,6 +256,7 @@ func (t *MemorySearchTool) searchFile(path string, keywords []string) ([]searchR
|
||||
|
||||
var results []searchResult
|
||||
scanner := bufio.NewScanner(file)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
|
||||
var currentBlock strings.Builder
|
||||
var blockStartLine int = 1
|
||||
@@ -312,5 +348,41 @@ func (t *MemorySearchTool) searchFile(path string, keywords []string) ([]searchR
|
||||
}
|
||||
|
||||
processBlock() // Flush last block
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func parseIntArg(value interface{}) (int, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int(v), true
|
||||
case int:
|
||||
return v, true
|
||||
case int64:
|
||||
return int(v), true
|
||||
case string:
|
||||
n, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
if err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func formatSearchWarningSuffix(failedFiles []string) string {
|
||||
if len(failedFiles) == 0 {
|
||||
return ""
|
||||
}
|
||||
maxShown := 3
|
||||
shown := failedFiles
|
||||
if len(shown) > maxShown {
|
||||
shown = shown[:maxShown]
|
||||
}
|
||||
msg := fmt.Sprintf("Warning: memory_search skipped %d file(s) due to read/parse errors: %s", len(failedFiles), strings.Join(shown, ", "))
|
||||
if len(failedFiles) > maxShown {
|
||||
msg += ", ..."
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
102
pkg/tools/memory_test.go
Normal file
102
pkg/tools/memory_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMemorySearchToolClampsMaxResults(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
memDir := filepath.Join(workspace, "memory")
|
||||
if err := os.MkdirAll(memDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir failed: %v", err)
|
||||
}
|
||||
content := "# Long-term Memory\n\nalpha one\n\nalpha two\n"
|
||||
if err := os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte(content), 0644); err != nil {
|
||||
t.Fatalf("write failed: %v", err)
|
||||
}
|
||||
|
||||
tool := NewMemorySearchTool(workspace)
|
||||
out, err := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"query": "alpha",
|
||||
"maxResults": -5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, "Found 1 memories") {
|
||||
t.Fatalf("expected clamped result count, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemorySearchToolScannerHandlesLargeLine(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
memDir := filepath.Join(workspace, "memory")
|
||||
if err := os.MkdirAll(memDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir failed: %v", err)
|
||||
}
|
||||
large := strings.Repeat("x", 80*1024) + " needle"
|
||||
if err := os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte(large), 0644); err != nil {
|
||||
t.Fatalf("write failed: %v", err)
|
||||
}
|
||||
|
||||
tool := NewMemorySearchTool(workspace)
|
||||
out, err := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"query": "needle",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, "needle") {
|
||||
t.Fatalf("expected search hit in large line, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemorySearchToolPrefersCanonicalMemoryPath(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
memDir := filepath.Join(workspace, "memory")
|
||||
if err := os.MkdirAll(memDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("canonical"), 0644); err != nil {
|
||||
t.Fatalf("write canonical failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(workspace, "MEMORY.md"), []byte("legacy"), 0644); err != nil {
|
||||
t.Fatalf("write legacy failed: %v", err)
|
||||
}
|
||||
|
||||
tool := NewMemorySearchTool(workspace)
|
||||
files := tool.getMemoryFiles()
|
||||
for _, file := range files {
|
||||
if file == filepath.Join(workspace, "MEMORY.md") {
|
||||
t.Fatalf("legacy path should be ignored when canonical exists: %v", files)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemorySearchToolReportsFileScanWarnings(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
memDir := filepath.Join(workspace, "memory")
|
||||
if err := os.MkdirAll(memDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir failed: %v", err)
|
||||
}
|
||||
|
||||
tooLargeLine := strings.Repeat("x", 2*1024*1024) + "\n"
|
||||
if err := os.WriteFile(filepath.Join(memDir, "bad.md"), []byte(tooLargeLine), 0644); err != nil {
|
||||
t.Fatalf("write bad file failed: %v", err)
|
||||
}
|
||||
|
||||
tool := NewMemorySearchTool(workspace)
|
||||
out, err := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"query": "needle",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, "Warning: memory_search skipped") {
|
||||
t.Fatalf("expected warning suffix when scan errors happen, got: %s", out)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user