diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 3bab6ca..3d2e9fb 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -157,9 +157,9 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers }) toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool("")) - toolsRegistry.Register(tools.NewWriteFileTool("")) - toolsRegistry.Register(tools.NewListDirTool("")) + toolsRegistry.Register(tools.NewReadFileTool(workspace)) + toolsRegistry.Register(tools.NewWriteFileTool(workspace)) + toolsRegistry.Register(tools.NewListDirTool(workspace)) toolsRegistry.Register(tools.NewExecTool(cfg.Tools.Shell, workspace)) if cs != nil { @@ -196,7 +196,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(tools.NewPipelineDispatchTool(orchestrator, subagentManager)) // Register edit file tool - editFileTool := tools.NewEditFileTool("") + editFileTool := tools.NewEditFileTool(workspace) toolsRegistry.Register(editFileTool) // Register memory search tool diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 39560be..674f566 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -8,6 +8,20 @@ import ( "strings" ) +func resolveToolPath(baseDir, path string) (string, error) { + if filepath.IsAbs(path) { + return filepath.Clean(path), nil + } + if baseDir != "" { + return filepath.Clean(filepath.Join(baseDir, path)), nil + } + abs, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("failed to resolve path: %w", err) + } + return abs, nil +} + // ReadFileTool reads the contents of a file. type ReadFileTool struct { allowedDir string @@ -52,22 +66,9 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) return "", fmt.Errorf("path is required") } - resolvedPath := path - if filepath.IsAbs(path) { - resolvedPath = filepath.Clean(path) - } else { - abs, err := filepath.Abs(path) - if err != nil { - return "", fmt.Errorf("failed to resolve path: %w", err) - } - resolvedPath = abs - } - - if t.allowedDir != "" { - allowedAbs, _ := filepath.Abs(t.allowedDir) - if !strings.HasPrefix(resolvedPath, allowedAbs) { - return "", fmt.Errorf("path %s is outside allowed directory", path) - } + resolvedPath, err := resolveToolPath(t.allowedDir, path) + if err != nil { + return "", err } f, err := os.Open(resolvedPath) @@ -149,22 +150,9 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{} return "", fmt.Errorf("content is required") } - resolvedPath := path - if filepath.IsAbs(path) { - resolvedPath = filepath.Clean(path) - } else { - abs, err := filepath.Abs(path) - if err != nil { - return "", fmt.Errorf("failed to resolve path: %w", err) - } - resolvedPath = abs - } - - if t.allowedDir != "" { - allowedAbs, _ := filepath.Abs(t.allowedDir) - if !strings.HasPrefix(resolvedPath, allowedAbs) { - return "", fmt.Errorf("path %s is outside allowed directory", path) - } + resolvedPath, err := resolveToolPath(t.allowedDir, path) + if err != nil { + return "", err } if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { @@ -216,22 +204,9 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) recursive, _ := args["recursive"].(bool) - resolvedPath := path - if filepath.IsAbs(path) { - resolvedPath = filepath.Clean(path) - } else { - abs, err := filepath.Abs(path) - if err != nil { - return "", fmt.Errorf("failed to resolve path: %w", err) - } - resolvedPath = abs - } - - if t.allowedDir != "" { - allowedAbs, _ := filepath.Abs(t.allowedDir) - if !strings.HasPrefix(resolvedPath, allowedAbs) { - return "", fmt.Errorf("path %s is outside allowed directory", path) - } + resolvedPath, err := resolveToolPath(t.allowedDir, path) + if err != nil { + return "", err } var results []string @@ -330,22 +305,9 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) return "", fmt.Errorf("new_text is required") } - resolvedPath := path - if filepath.IsAbs(path) { - resolvedPath = filepath.Clean(path) - } else { - abs, err := filepath.Abs(path) - if err != nil { - return "", fmt.Errorf("failed to resolve path: %w", err) - } - resolvedPath = abs - } - - if t.allowedDir != "" { - allowedAbs, _ := filepath.Abs(t.allowedDir) - if !strings.HasPrefix(resolvedPath, allowedAbs) { - return "", fmt.Errorf("path %s is outside allowed directory", path) - } + resolvedPath, err := resolveToolPath(t.allowedDir, path) + if err != nil { + return "", err } content, err := os.ReadFile(resolvedPath) diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go new file mode 100644 index 0000000..1198c20 --- /dev/null +++ b/pkg/tools/filesystem_test.go @@ -0,0 +1,54 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestReadFileToolResolvesRelativePathFromAllowedDir(t *testing.T) { + workspace := t.TempDir() + targetPath := filepath.Join(workspace, "cmd", "clawgo", "main.go") + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + if err := os.WriteFile(targetPath, []byte("package main"), 0644); err != nil { + t.Fatalf("write failed: %v", err) + } + + tool := NewReadFileTool(workspace) + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "path": "cmd/clawgo/main.go", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "package main" { + t.Fatalf("unexpected output: %q", out) + } +} + +func TestReadFileToolAllowsParentTraversalWhenPermitted(t *testing.T) { + workspace := t.TempDir() + parentFile := filepath.Join(filepath.Dir(workspace), "outside.txt") + if err := os.WriteFile(parentFile, []byte("outside"), 0644); err != nil { + t.Fatalf("write failed: %v", err) + } + + tool := NewReadFileTool(workspace) + relPath, err := filepath.Rel(workspace, parentFile) + if err != nil { + t.Fatalf("rel failed: %v", err) + } + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "path": relPath, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "outside" { + t.Fatalf("unexpected output: %q", out) + } +}