Harden gateway auth and file boundaries

This commit is contained in:
lpf
2026-03-15 15:31:00 +08:00
parent 617f7cc0f1
commit ba95aeed35
16 changed files with 587 additions and 91 deletions

View File

@@ -9,16 +9,31 @@ import (
)
func resolveToolPath(baseDir, path string) (string, error) {
path = strings.TrimSpace(path)
if path == "" {
return "", fmt.Errorf("path is required")
}
if filepath.IsAbs(path) {
return filepath.Clean(path), nil
return "", fmt.Errorf("absolute path is not allowed")
}
joined := path
if baseDir != "" {
return filepath.Clean(filepath.Join(baseDir, path)), nil
joined = filepath.Join(baseDir, path)
}
abs, err := filepath.Abs(path)
abs, err := filepath.Abs(joined)
if err != nil {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
if baseDir != "" {
absBase, err := filepath.Abs(baseDir)
if err != nil {
return "", fmt.Errorf("failed to resolve base path: %w", err)
}
rel, err := filepath.Rel(absBase, abs)
if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return "", fmt.Errorf("path escapes allowed directory")
}
}
return abs, nil
}

View File

@@ -0,0 +1,44 @@
package tools
import (
"context"
"os"
"path/filepath"
"testing"
)
func TestResolveToolPathRejectsAbsolutePath(t *testing.T) {
t.Parallel()
base := t.TempDir()
if _, err := resolveToolPath(base, "/tmp/outside.txt"); err == nil {
t.Fatalf("expected absolute path to be rejected")
}
}
func TestResolveToolPathRejectsTraversal(t *testing.T) {
t.Parallel()
base := t.TempDir()
if _, err := resolveToolPath(base, "../outside.txt"); err == nil {
t.Fatalf("expected traversal path to be rejected")
}
}
func TestReadFileToolAllowsWorkspaceRelativePath(t *testing.T) {
t.Parallel()
base := t.TempDir()
path := filepath.Join(base, "notes.txt")
if err := os.WriteFile(path, []byte("hello"), 0o644); err != nil {
t.Fatalf("write fixture: %v", err)
}
tool := NewReadFileTool(base)
got, err := tool.Execute(context.Background(), map[string]interface{}{"path": "notes.txt"})
if err != nil {
t.Fatalf("read file: %v", err)
}
if got != "hello" {
t.Fatalf("unexpected content %q", got)
}
}