mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 06:47:30 +08:00
Harden gateway auth and file boundaries
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@@ -73,10 +74,6 @@ type Server struct {
|
||||
skillsRPCReg *rpcpkg.Registry
|
||||
}
|
||||
|
||||
var nodesWebsocketUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server {
|
||||
addr := strings.TrimSpace(host)
|
||||
if addr == "" {
|
||||
@@ -290,15 +287,82 @@ func (s *Server) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func requestUsesTLS(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")), "https")
|
||||
}
|
||||
|
||||
func canonicalOriginHost(host string, https bool) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
if parsedHost, parsedPort, err := net.SplitHostPort(host); err == nil {
|
||||
return strings.ToLower(net.JoinHostPort(parsedHost, parsedPort))
|
||||
}
|
||||
port := "80"
|
||||
if https {
|
||||
port = "443"
|
||||
}
|
||||
return strings.ToLower(net.JoinHostPort(host, port))
|
||||
}
|
||||
|
||||
func (s *Server) isTrustedOrigin(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(u.Scheme)) {
|
||||
case "http", "https":
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return canonicalOriginHost(u.Host, strings.EqualFold(u.Scheme, "https")) == canonicalOriginHost(r.Host, requestUsesTLS(r))
|
||||
}
|
||||
|
||||
func (s *Server) websocketUpgrader() *websocket.Upgrader {
|
||||
return &websocket.Upgrader{
|
||||
CheckOrigin: s.isTrustedOrigin,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) isBearerAuthorized(r *http.Request) bool {
|
||||
if s == nil || r == nil || strings.TrimSpace(s.token) == "" {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.Header.Get("Authorization")) == "Bearer "+s.token
|
||||
}
|
||||
|
||||
func (s *Server) withCORS(next http.Handler) http.Handler {
|
||||
if next == nil {
|
||||
next = http.NotFoundHandler()
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin != "" {
|
||||
if !s.isTrustedOrigin(r) {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Add("Vary", "Origin")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
@@ -357,23 +421,11 @@ func (s *Server) checkAuth(r *http.Request) bool {
|
||||
if s.token == "" {
|
||||
return true
|
||||
}
|
||||
auth := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if auth == "Bearer "+s.token {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(r.URL.Query().Get("token")) == s.token {
|
||||
if s.isBearerAuthorized(r) {
|
||||
return true
|
||||
}
|
||||
if c, err := r.Cookie("clawgo_webui_token"); err == nil && strings.TrimSpace(c.Value) == s.token {
|
||||
return true
|
||||
}
|
||||
// Browser asset fallback: allow token propagated via Referer query.
|
||||
if ref := strings.TrimSpace(r.Referer()); ref != "" {
|
||||
if u, err := url.Parse(ref); err == nil {
|
||||
if strings.TrimSpace(u.Query().Get("token")) == s.token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func (s *Server) handleWebUIChatLive(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "chat handler not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ func (s *Server) handleWebUILogsLive(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "log path not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -317,10 +317,6 @@ func resolveArtifactPath(workspace, raw string) string {
|
||||
return ""
|
||||
}
|
||||
if filepath.IsAbs(raw) {
|
||||
clean := filepath.Clean(raw)
|
||||
if info, err := os.Stat(clean); err == nil && !info.IsDir() {
|
||||
return clean
|
||||
}
|
||||
return ""
|
||||
}
|
||||
root := strings.TrimSpace(workspace)
|
||||
@@ -338,12 +334,12 @@ func resolveArtifactPath(workspace, raw string) string {
|
||||
}
|
||||
|
||||
func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, string, error) {
|
||||
if content := strings.TrimSpace(fmt.Sprint(item["content_base64"])); content != "" {
|
||||
if content := strings.TrimSpace(stringFromMap(item, "content_base64")); content != "" {
|
||||
raw, err := base64.StdEncoding.DecodeString(content)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return raw, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil
|
||||
return raw, strings.TrimSpace(stringFromMap(item, "mime_type")), nil
|
||||
}
|
||||
for _, rawPath := range []string{fmt.Sprint(item["source_path"]), fmt.Sprint(item["path"])} {
|
||||
if path := resolveArtifactPath(workspace, rawPath); path != "" {
|
||||
@@ -351,10 +347,10 @@ func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, s
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return b, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil
|
||||
return b, strings.TrimSpace(stringFromMap(item, "mime_type")), nil
|
||||
}
|
||||
}
|
||||
if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" {
|
||||
if contentText := strings.TrimSpace(stringFromMap(item, "content_text")); contentText != "" {
|
||||
return []byte(contentText), "text/plain; charset=utf-8", nil
|
||||
}
|
||||
return nil, "", fmt.Errorf("artifact content unavailable")
|
||||
@@ -411,9 +407,14 @@ func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.R
|
||||
nodeList, _ := payload["nodes"].([]nodes.NodeInfo)
|
||||
p2p, _ := payload["p2p"].(map[string]interface{})
|
||||
alerts := filteredNodeAlerts(s.webUINodeAlertsPayload(nodeList, p2p, dispatches), nodeFilter)
|
||||
|
||||
var archive bytes.Buffer
|
||||
zw := zip.NewWriter(&archive)
|
||||
filename := "node-artifacts-export.zip"
|
||||
if nodeFilter != "" {
|
||||
filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
zw := zip.NewWriter(w)
|
||||
writeZipJSON := func(name string, value interface{}) error {
|
||||
entry, err := zw.Create(name)
|
||||
if err != nil {
|
||||
@@ -470,26 +471,13 @@ func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.R
|
||||
}
|
||||
entry, err := zw.Create(entryName)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := entry.Write(raw); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
filename := "node-artifacts-export.zip"
|
||||
if nodeFilter != "" {
|
||||
filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(archive.Bytes())
|
||||
_ = zw.Close()
|
||||
}
|
||||
|
||||
func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -519,7 +507,7 @@ func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
if contentB64 := strings.TrimSpace(fmt.Sprint(item["content_base64"])); contentB64 != "" {
|
||||
if contentB64 := strings.TrimSpace(stringFromMap(item, "content_base64")); contentB64 != "" {
|
||||
payload, err := base64.StdEncoding.DecodeString(contentB64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid inline artifact payload", http.StatusBadRequest)
|
||||
@@ -536,7 +524,7 @@ func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
}
|
||||
if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" {
|
||||
if contentText := strings.TrimSpace(stringFromMap(item, "content_text")); contentText != "" {
|
||||
w.Header().Set("Content-Type", mimeType)
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name))
|
||||
_, _ = w.Write([]byte(contentText))
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestHandleWebUINodeArtifactsListAndDelete(t *testing.T) {
|
||||
if err := os.WriteFile(artifactPath, []byte("artifact-body"), 0o644); err != nil {
|
||||
t.Fatalf("write artifact: %v", err)
|
||||
}
|
||||
auditLine := fmt.Sprintf("{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"%s\",\"size_bytes\":13}]}\n", artifactPath)
|
||||
auditLine := "{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"artifact.txt\",\"size_bytes\":13}]}\n"
|
||||
if err := os.WriteFile(filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl"), []byte(auditLine), 0o644); err != nil {
|
||||
t.Fatalf("write audit: %v", err)
|
||||
}
|
||||
@@ -68,6 +68,36 @@ func TestHandleWebUINodeArtifactsListAndDelete(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebUINodeArtifactDownloadRejectsAbsoluteSourcePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "", nodes.NewManager())
|
||||
workspace := t.TempDir()
|
||||
srv.SetWorkspacePath(workspace)
|
||||
if err := os.MkdirAll(filepath.Join(workspace, "memory"), 0o755); err != nil {
|
||||
t.Fatalf("mkdir memory: %v", err)
|
||||
}
|
||||
artifactPath := filepath.Join(workspace, "artifact.txt")
|
||||
if err := os.WriteFile(artifactPath, []byte("artifact-body"), 0o644); err != nil {
|
||||
t.Fatalf("write artifact: %v", err)
|
||||
}
|
||||
auditLine := fmt.Sprintf("{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"%s\",\"size_bytes\":13}]}\n", artifactPath)
|
||||
if err := os.WriteFile(filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl"), []byte(auditLine), 0o644); err != nil {
|
||||
t.Fatalf("write audit: %v", err)
|
||||
}
|
||||
|
||||
items := srv.webUINodeArtifactsPayload(10)
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 artifact, got %d", len(items))
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/node_artifacts/download?id="+fmt.Sprint(items[0]["id"]), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.handleWebUINodeArtifactDownload(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebUINodeArtifactsExport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "nodes manager unavailable", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
@@ -29,15 +29,9 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
|
||||
NetworkProxy string `json:"network_proxy"`
|
||||
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
|
||||
}
|
||||
if r.Method == http.MethodPost {
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
body.Provider = strings.TrimSpace(r.URL.Query().Get("provider"))
|
||||
body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label"))
|
||||
body.NetworkProxy = strings.TrimSpace(r.URL.Query().Get("network_proxy"))
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
|
||||
if err != nil {
|
||||
|
||||
@@ -26,7 +26,7 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
179
pkg/api/server_security_test.go
Normal file
179
pkg/api/server_security_test.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/YspCoder/clawgo/pkg/nodes"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestCheckAuthAllowsBearerAndCookieOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "secret-token", nil)
|
||||
|
||||
bearerReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
bearerReq.Header.Set("Authorization", "Bearer secret-token")
|
||||
if !srv.checkAuth(bearerReq) {
|
||||
t.Fatalf("expected bearer auth to succeed")
|
||||
}
|
||||
|
||||
cookieReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"})
|
||||
if !srv.checkAuth(cookieReq) {
|
||||
t.Fatalf("expected cookie auth to succeed")
|
||||
}
|
||||
|
||||
queryReq := httptest.NewRequest(http.MethodGet, "/?token=secret-token", nil)
|
||||
if srv.checkAuth(queryReq) {
|
||||
t.Fatalf("expected query token auth to fail")
|
||||
}
|
||||
|
||||
refererReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
refererReq.Header.Set("Referer", "https://example.com/?token=secret-token")
|
||||
if srv.checkAuth(refererReq) {
|
||||
t.Fatalf("expected referer token auth to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCORSRejectsForeignOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil)
|
||||
req.Host = "example.com"
|
||||
req.Header.Set("Origin", "https://evil.example")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCORSAcceptsSameOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil)
|
||||
req.Host = "example.com"
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rec.Code)
|
||||
}
|
||||
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("unexpected allow-origin header %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNodeConnectRejectsForeignOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "", nodes.NewManager())
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/nodes/connect", srv.handleNodeConnect)
|
||||
httpSrv := httptest.NewServer(mux)
|
||||
defer httpSrv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/nodes/connect"
|
||||
dialer := websocket.Dialer{}
|
||||
headers := http.Header{"Origin": []string{"https://evil.example"}}
|
||||
conn, resp, err := dialer.Dial(wsURL, headers)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
t.Fatalf("expected websocket handshake to fail")
|
||||
}
|
||||
if resp == nil || resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 response, got %#v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebUISetsCookieForBearerOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "secret-token", nil)
|
||||
|
||||
bearerReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
bearerReq.Header.Set("Authorization", "Bearer secret-token")
|
||||
bearerRec := httptest.NewRecorder()
|
||||
srv.handleWebUI(bearerRec, bearerReq)
|
||||
if len(bearerRec.Result().Cookies()) == 0 {
|
||||
t.Fatalf("expected bearer-authenticated UI request to set cookie")
|
||||
}
|
||||
|
||||
cookieReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"})
|
||||
cookieRec := httptest.NewRecorder()
|
||||
srv.handleWebUI(cookieRec, cookieReq)
|
||||
if len(cookieRec.Result().Cookies()) != 0 {
|
||||
t.Fatalf("expected cookie-authenticated UI request not to reset cookie")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebUIUploadDoesNotExposeAbsolutePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "", nil)
|
||||
var form bytes.Buffer
|
||||
mw := multipartWriter(t, &form, "file", "demo.txt", []byte("upload-body"))
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/upload", &form)
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
srv.handleWebUIUpload(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if _, ok := payload["path"]; ok {
|
||||
t.Fatalf("expected upload response to omit absolute path: %+v", payload)
|
||||
}
|
||||
if strings.TrimSpace(payload["media"].(string)) == "" {
|
||||
t.Fatalf("expected media handle in response: %+v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func multipartWriter(t *testing.T, dst *bytes.Buffer, fieldName, filename string, body []byte) *multipart.Writer {
|
||||
t.Helper()
|
||||
mw := multipart.NewWriter(dst)
|
||||
part, err := mw.CreateFormFile(fieldName, filename)
|
||||
if err != nil {
|
||||
t.Fatalf("create form file: %v", err)
|
||||
}
|
||||
if _, err := part.Write(body); err != nil {
|
||||
t.Fatalf("write form file: %v", err)
|
||||
}
|
||||
if err := mw.Close(); err != nil {
|
||||
t.Fatalf("close multipart writer: %v", err)
|
||||
}
|
||||
return mw
|
||||
}
|
||||
|
||||
func TestHandleWebUIProviderOAuthStartRejectsGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := NewServer("127.0.0.1", 0, "", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/provider/oauth/start?provider=openai", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.handleWebUIProviderOAuthStart(rec, req)
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("expected 405, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,39 @@ import (
|
||||
"github.com/YspCoder/clawgo/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
skillArchiveUploadLimit int64 = 32 << 20
|
||||
skillArchiveMaxFiles = 256
|
||||
skillArchiveMaxSingleFile int64 = 8 << 20
|
||||
skillArchiveMaxExpanded int64 = 32 << 20
|
||||
)
|
||||
|
||||
type archiveExtractLimits struct {
|
||||
maxFiles int
|
||||
maxSingleFile int64
|
||||
maxExpanded int64
|
||||
fileCount int
|
||||
totalExpanded int64
|
||||
}
|
||||
|
||||
func (l *archiveExtractLimits) addFile(size int64) error {
|
||||
if size < 0 {
|
||||
return fmt.Errorf("invalid archive entry size")
|
||||
}
|
||||
l.fileCount++
|
||||
if l.maxFiles > 0 && l.fileCount > l.maxFiles {
|
||||
return fmt.Errorf("archive contains too many files")
|
||||
}
|
||||
if l.maxSingleFile > 0 && size > l.maxSingleFile {
|
||||
return fmt.Errorf("archive entry exceeds size limit")
|
||||
}
|
||||
l.totalExpanded += size
|
||||
if l.maxExpanded > 0 && l.totalExpanded > l.maxExpanded {
|
||||
return fmt.Errorf("archive exceeds expanded size limit")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustMap(v interface{}) map[string]interface{} {
|
||||
if v == nil {
|
||||
return map[string]interface{}{}
|
||||
@@ -262,7 +295,13 @@ func ensureClawHubReady(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]string, error) {
|
||||
if err := r.ParseMultipartForm(128 << 20); err != nil {
|
||||
if r.ContentLength > skillArchiveUploadLimit {
|
||||
return nil, fmt.Errorf("archive upload exceeds size limit")
|
||||
}
|
||||
if r.Body != nil {
|
||||
r.Body = http.MaxBytesReader(nil, r.Body, skillArchiveUploadLimit)
|
||||
}
|
||||
if err := r.ParseMultipartForm(8 << 20); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, h, err := r.FormFile("file")
|
||||
@@ -292,7 +331,12 @@ func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]strin
|
||||
}
|
||||
defer os.RemoveAll(extractDir)
|
||||
|
||||
if err := extractArchive(archivePath, extractDir); err != nil {
|
||||
limits := archiveExtractLimits{
|
||||
maxFiles: skillArchiveMaxFiles,
|
||||
maxSingleFile: skillArchiveMaxSingleFile,
|
||||
maxExpanded: skillArchiveMaxExpanded,
|
||||
}
|
||||
if err := extractArchive(archivePath, extractDir, &limits); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -403,21 +447,21 @@ func sanitizeSkillName(name string) string {
|
||||
return out
|
||||
}
|
||||
|
||||
func extractArchive(archivePath, targetDir string) error {
|
||||
func extractArchive(archivePath, targetDir string, limits *archiveExtractLimits) error {
|
||||
lower := strings.ToLower(archivePath)
|
||||
switch {
|
||||
case strings.HasSuffix(lower, ".zip"):
|
||||
return extractZip(archivePath, targetDir)
|
||||
return extractZip(archivePath, targetDir, limits)
|
||||
case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"):
|
||||
return extractTarGz(archivePath, targetDir)
|
||||
return extractTarGz(archivePath, targetDir, limits)
|
||||
case strings.HasSuffix(lower, ".tar"):
|
||||
return extractTar(archivePath, targetDir)
|
||||
return extractTar(archivePath, targetDir, limits)
|
||||
default:
|
||||
return fmt.Errorf("unsupported archive format: %s", filepath.Base(archivePath))
|
||||
}
|
||||
}
|
||||
|
||||
func extractZip(archivePath, targetDir string) error {
|
||||
func extractZip(archivePath, targetDir string, limits *archiveExtractLimits) error {
|
||||
zr, err := zip.OpenReader(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -425,16 +469,21 @@ func extractZip(archivePath, targetDir string) error {
|
||||
defer zr.Close()
|
||||
|
||||
for _, f := range zr.File {
|
||||
if !f.FileInfo().IsDir() && limits != nil {
|
||||
if err := limits.addFile(int64(f.UncompressedSize64)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := writeArchivedEntry(targetDir, f.Name, f.FileInfo().IsDir(), func() (io.ReadCloser, error) {
|
||||
return f.Open()
|
||||
}); err != nil {
|
||||
}, limits.maxSingleFile); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractTarGz(archivePath, targetDir string) error {
|
||||
func extractTarGz(archivePath, targetDir string, limits *archiveExtractLimits) error {
|
||||
f, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -445,19 +494,19 @@ func extractTarGz(archivePath, targetDir string) error {
|
||||
return err
|
||||
}
|
||||
defer gz.Close()
|
||||
return extractTarReader(tar.NewReader(gz), targetDir)
|
||||
return extractTarReader(tar.NewReader(gz), targetDir, limits)
|
||||
}
|
||||
|
||||
func extractTar(archivePath, targetDir string) error {
|
||||
func extractTar(archivePath, targetDir string, limits *archiveExtractLimits) error {
|
||||
f, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return extractTarReader(tar.NewReader(f), targetDir)
|
||||
return extractTarReader(tar.NewReader(f), targetDir, limits)
|
||||
}
|
||||
|
||||
func extractTarReader(tr *tar.Reader, targetDir string) error {
|
||||
func extractTarReader(tr *tar.Reader, targetDir string, limits *archiveExtractLimits) error {
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
@@ -468,21 +517,34 @@ func extractTarReader(tr *tar.Reader, targetDir string) error {
|
||||
}
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := writeArchivedEntry(targetDir, hdr.Name, true, nil); err != nil {
|
||||
maxEntryBytes := int64(0)
|
||||
if limits != nil {
|
||||
maxEntryBytes = limits.maxSingleFile
|
||||
}
|
||||
if err := writeArchivedEntry(targetDir, hdr.Name, true, nil, maxEntryBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeReg, tar.TypeRegA:
|
||||
if limits != nil {
|
||||
if err := limits.addFile(hdr.Size); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
name := hdr.Name
|
||||
maxEntryBytes := int64(0)
|
||||
if limits != nil {
|
||||
maxEntryBytes = limits.maxSingleFile
|
||||
}
|
||||
if err := writeArchivedEntry(targetDir, name, false, func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(tr), nil
|
||||
}); err != nil {
|
||||
}, maxEntryBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error)) error {
|
||||
func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error), maxBytes int64) error {
|
||||
clean := filepath.Clean(strings.TrimSpace(name))
|
||||
clean = strings.TrimPrefix(clean, string(filepath.Separator))
|
||||
clean = strings.TrimPrefix(clean, "/")
|
||||
@@ -514,7 +576,17 @@ func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.Re
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
_, err = io.Copy(out, rc)
|
||||
reader := io.Reader(rc)
|
||||
if maxBytes > 0 {
|
||||
reader = io.LimitReader(rc, maxBytes+1)
|
||||
}
|
||||
written, err := io.Copy(out, reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if maxBytes > 0 && written > maxBytes {
|
||||
return fmt.Errorf("archive entry exceeds size limit")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
93
pkg/api/server_skills_test.go
Normal file
93
pkg/api/server_skills_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestImportSkillArchiveFromMultipartSucceedsForSmallArchive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var archive bytes.Buffer
|
||||
zw := zip.NewWriter(&archive)
|
||||
entry, err := zw.Create("demo/SKILL.md")
|
||||
if err != nil {
|
||||
t.Fatalf("create entry: %v", err)
|
||||
}
|
||||
if _, err := entry.Write([]byte("# Demo\n")); err != nil {
|
||||
t.Fatalf("write entry: %v", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("close zip: %v", err)
|
||||
}
|
||||
|
||||
req := multipartRequest(t, "demo.zip", archive.Bytes())
|
||||
skillsDir := t.TempDir()
|
||||
imported, err := importSkillArchiveFromMultipart(req, skillsDir)
|
||||
if err != nil {
|
||||
t.Fatalf("import archive: %v", err)
|
||||
}
|
||||
if len(imported) != 1 || imported[0] != "demo" {
|
||||
t.Fatalf("unexpected imported skills: %+v", imported)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(skillsDir, "demo", "SKILL.md")); err != nil {
|
||||
t.Fatalf("expected imported SKILL.md: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractArchiveRejectsOversizedExpandedEntry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var archive bytes.Buffer
|
||||
zw := zip.NewWriter(&archive)
|
||||
entry, err := zw.Create("demo/SKILL.md")
|
||||
if err != nil {
|
||||
t.Fatalf("create entry: %v", err)
|
||||
}
|
||||
if _, err := entry.Write(bytes.Repeat([]byte("a"), int(skillArchiveMaxSingleFile+1))); err != nil {
|
||||
t.Fatalf("write entry: %v", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("close zip: %v", err)
|
||||
}
|
||||
|
||||
archivePath := filepath.Join(t.TempDir(), "demo.zip")
|
||||
if err := os.WriteFile(archivePath, archive.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("write archive: %v", err)
|
||||
}
|
||||
err = extractArchive(archivePath, t.TempDir(), &archiveExtractLimits{
|
||||
maxFiles: skillArchiveMaxFiles,
|
||||
maxSingleFile: skillArchiveMaxSingleFile,
|
||||
maxExpanded: skillArchiveMaxExpanded,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "size limit") {
|
||||
t.Fatalf("expected size limit error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func multipartRequest(t *testing.T, filename string, body []byte) *http.Request {
|
||||
t.Helper()
|
||||
var form bytes.Buffer
|
||||
mw := multipart.NewWriter(&form)
|
||||
part, err := mw.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
t.Fatalf("create form file: %v", err)
|
||||
}
|
||||
if _, err := part.Write(body); err != nil {
|
||||
t.Fatalf("write form file: %v", err)
|
||||
}
|
||||
if err := mw.Close(); err != nil {
|
||||
t.Fatalf("close multipart writer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest("POST", "/api/skills", &form)
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
req.ContentLength = int64(form.Len())
|
||||
return req
|
||||
}
|
||||
@@ -20,12 +20,13 @@ func (s *Server) handleWebUI(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if s.token != "" {
|
||||
if s.token != "" && s.isBearerAuthorized(r) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "clawgo_webui_token",
|
||||
Value: s.token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: requestUsesTLS(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 86400,
|
||||
})
|
||||
@@ -118,7 +119,7 @@ func (s *Server) handleWebUIUpload(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
writeJSON(w, map[string]interface{}{"ok": true, "path": path, "name": h.Filename})
|
||||
writeJSON(w, map[string]interface{}{"ok": true, "media": name, "name": h.Filename})
|
||||
}
|
||||
|
||||
func gatewayBuildVersion() string {
|
||||
@@ -165,14 +166,14 @@ const webUIHTML = `<!doctype html>
|
||||
<div>Session: <input id="session" value="webui:default"/> <input id="msg" placeholder="message" style="width:420px"/> <input id="file" type="file"/> <button onclick="sendChat()">Send</button></div>
|
||||
<div id="chatlog"></div>
|
||||
<script>
|
||||
function auth(){const t=document.getElementById('token').value.trim();return t?('?token='+encodeURIComponent(t)):''}
|
||||
async function loadCfg(){let r=await fetch('/api/config'+auth());document.getElementById('cfg').value=await r.text()}
|
||||
async function saveCfg(){let j=JSON.parse(document.getElementById('cfg').value);let r=await fetch('/api/config'+auth(),{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(j)});alert(await r.text())}
|
||||
function authHeaders(extra){const h=Object.assign({},extra||{});const t=document.getElementById('token').value.trim();if(t)h['Authorization']='Bearer '+t;return h}
|
||||
async function loadCfg(){let r=await fetch('/api/config',{headers:authHeaders()});document.getElementById('cfg').value=await r.text()}
|
||||
async function saveCfg(){let j=JSON.parse(document.getElementById('cfg').value);let r=await fetch('/api/config',{method:'POST',headers:authHeaders({'Content-Type':'application/json'}),body:JSON.stringify(j)});alert(await r.text())}
|
||||
async function sendChat(){
|
||||
let media='';const f=document.getElementById('file').files[0];
|
||||
if(f){let fd=new FormData();fd.append('file',f);let ur=await fetch('/api/upload'+auth(),{method:'POST',body:fd});let uj=await ur.json();media=uj.path||''}
|
||||
if(f){let fd=new FormData();fd.append('file',f);let ur=await fetch('/api/upload',{method:'POST',headers:authHeaders(),body:fd});let uj=await ur.json();media=uj.media||uj.name||''}
|
||||
const payload={session:document.getElementById('session').value,message:document.getElementById('msg').value,media};
|
||||
let r=await fetch('/api/chat'+auth(),{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(payload)});let t=await r.text();
|
||||
let r=await fetch('/api/chat',{method:'POST',headers:authHeaders({'Content-Type':'application/json'}),body:JSON.stringify(payload)});let t=await r.text();
|
||||
document.getElementById('chatlog').textContent += '\nUSER> '+payload.message+(media?(' [file:'+media+']'):'')+'\nBOT> '+t+'\n';
|
||||
}
|
||||
loadCfg();
|
||||
|
||||
@@ -804,7 +804,7 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(path, data, 0644)
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func (c *Config) WorkspacePath() string {
|
||||
|
||||
28
pkg/config/config_save_test.go
Normal file
28
pkg/config/config_save_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSaveConfigUsesOwnerOnlyPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission bits are not reliable on windows")
|
||||
}
|
||||
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
if err := SaveConfig(path, DefaultConfig()); err != nil {
|
||||
t.Fatalf("save config: %v", err)
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("stat config: %v", err)
|
||||
}
|
||||
if got := info.Mode().Perm(); got != 0o600 {
|
||||
t.Fatalf("expected 0600 permissions, got %o", got)
|
||||
}
|
||||
}
|
||||
@@ -9,16 +9,31 @@ import (
|
||||
)
|
||||
|
||||
func resolveToolPath(baseDir, path string) (string, error) {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return "", fmt.Errorf("path is required")
|
||||
}
|
||||
if filepath.IsAbs(path) {
|
||||
return filepath.Clean(path), nil
|
||||
return "", fmt.Errorf("absolute path is not allowed")
|
||||
}
|
||||
joined := path
|
||||
if baseDir != "" {
|
||||
return filepath.Clean(filepath.Join(baseDir, path)), nil
|
||||
joined = filepath.Join(baseDir, path)
|
||||
}
|
||||
abs, err := filepath.Abs(path)
|
||||
abs, err := filepath.Abs(joined)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
if baseDir != "" {
|
||||
absBase, err := filepath.Abs(baseDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve base path: %w", err)
|
||||
}
|
||||
rel, err := filepath.Rel(absBase, abs)
|
||||
if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
|
||||
return "", fmt.Errorf("path escapes allowed directory")
|
||||
}
|
||||
}
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
|
||||
44
pkg/tools/filesystem_security_test.go
Normal file
44
pkg/tools/filesystem_security_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveToolPathRejectsAbsolutePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := t.TempDir()
|
||||
if _, err := resolveToolPath(base, "/tmp/outside.txt"); err == nil {
|
||||
t.Fatalf("expected absolute path to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolPathRejectsTraversal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := t.TempDir()
|
||||
if _, err := resolveToolPath(base, "../outside.txt"); err == nil {
|
||||
t.Fatalf("expected traversal path to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileToolAllowsWorkspaceRelativePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := t.TempDir()
|
||||
path := filepath.Join(base, "notes.txt")
|
||||
if err := os.WriteFile(path, []byte("hello"), 0o644); err != nil {
|
||||
t.Fatalf("write fixture: %v", err)
|
||||
}
|
||||
tool := NewReadFileTool(base)
|
||||
got, err := tool.Execute(context.Background(), map[string]interface{}{"path": "notes.txt"})
|
||||
if err != nil {
|
||||
t.Fatalf("read file: %v", err)
|
||||
}
|
||||
if got != "hello" {
|
||||
t.Fatalf("unexpected content %q", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user