From ba95aeed351da101dda64323aa5f98a837615929 Mon Sep 17 00:00:00 2001 From: lpf Date: Sun, 15 Mar 2026 15:31:00 +0800 Subject: [PATCH] Harden gateway auth and file boundaries --- pkg/api/server.go | 94 +++++++++++--- pkg/api/server_chat_whatsapp.go | 2 +- pkg/api/server_live.go | 2 +- pkg/api/server_node_artifacts.go | 42 +++--- pkg/api/server_node_artifacts_test.go | 32 ++++- pkg/api/server_nodes_gateway.go | 2 +- pkg/api/server_providers.go | 14 +- pkg/api/server_runtime_nodes.go | 2 +- pkg/api/server_security_test.go | 179 ++++++++++++++++++++++++++ pkg/api/server_skills.go | 106 ++++++++++++--- pkg/api/server_skills_test.go | 93 +++++++++++++ pkg/api/server_webui.go | 15 ++- pkg/config/config.go | 2 +- pkg/config/config_save_test.go | 28 ++++ pkg/tools/filesystem.go | 21 ++- pkg/tools/filesystem_security_test.go | 44 +++++++ 16 files changed, 587 insertions(+), 91 deletions(-) create mode 100644 pkg/api/server_security_test.go create mode 100644 pkg/api/server_skills_test.go create mode 100644 pkg/config/config_save_test.go create mode 100644 pkg/tools/filesystem_security_test.go diff --git a/pkg/api/server.go b/pkg/api/server.go index fbe7413..4be7156 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net" "net/http" "net/url" "strconv" @@ -73,10 +74,6 @@ type Server struct { skillsRPCReg *rpcpkg.Registry } -var nodesWebsocketUpgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, -} - func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server { addr := strings.TrimSpace(host) if addr == "" { @@ -290,15 +287,82 @@ func (s *Server) Start(ctx context.Context) error { return nil } +func requestUsesTLS(r *http.Request) bool { + if r == nil { + return false + } + if r.TLS != nil { + return true + } + return strings.EqualFold(strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")), "https") +} + +func canonicalOriginHost(host string, https bool) string { + host = strings.TrimSpace(host) + if host == "" { + return "" + } + if parsedHost, parsedPort, err := net.SplitHostPort(host); err == nil { + return strings.ToLower(net.JoinHostPort(parsedHost, parsedPort)) + } + port := "80" + if https { + port = "443" + } + return strings.ToLower(net.JoinHostPort(host, port)) +} + +func (s *Server) isTrustedOrigin(r *http.Request) bool { + if r == nil { + return false + } + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin == "" { + return true + } + u, err := url.Parse(origin) + if err != nil { + return false + } + switch strings.ToLower(strings.TrimSpace(u.Scheme)) { + case "http", "https": + default: + return false + } + return canonicalOriginHost(u.Host, strings.EqualFold(u.Scheme, "https")) == canonicalOriginHost(r.Host, requestUsesTLS(r)) +} + +func (s *Server) websocketUpgrader() *websocket.Upgrader { + return &websocket.Upgrader{ + CheckOrigin: s.isTrustedOrigin, + } +} + +func (s *Server) isBearerAuthorized(r *http.Request) bool { + if s == nil || r == nil || strings.TrimSpace(s.token) == "" { + return false + } + return strings.TrimSpace(r.Header.Get("Authorization")) == "Bearer "+s.token +} + func (s *Server) withCORS(next http.Handler) http.Handler { if next == nil { next = http.NotFoundHandler() } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With") - w.Header().Set("Access-Control-Expose-Headers", "*") + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin != "" { + if !s.isTrustedOrigin(r) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With") + w.Header().Set("Access-Control-Expose-Headers", "*") + w.Header().Add("Vary", "Origin") + } if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return @@ -357,23 +421,11 @@ func (s *Server) checkAuth(r *http.Request) bool { if s.token == "" { return true } - auth := strings.TrimSpace(r.Header.Get("Authorization")) - if auth == "Bearer "+s.token { - return true - } - if strings.TrimSpace(r.URL.Query().Get("token")) == s.token { + if s.isBearerAuthorized(r) { return true } if c, err := r.Cookie("clawgo_webui_token"); err == nil && strings.TrimSpace(c.Value) == s.token { return true } - // Browser asset fallback: allow token propagated via Referer query. - if ref := strings.TrimSpace(r.Referer()); ref != "" { - if u, err := url.Parse(ref); err == nil { - if strings.TrimSpace(u.Query().Get("token")) == s.token { - return true - } - } - } return false } diff --git a/pkg/api/server_chat_whatsapp.go b/pkg/api/server_chat_whatsapp.go index 9d1a216..1dad3a4 100644 --- a/pkg/api/server_chat_whatsapp.go +++ b/pkg/api/server_chat_whatsapp.go @@ -93,7 +93,7 @@ func (s *Server) handleWebUIChatLive(w http.ResponseWriter, r *http.Request) { http.Error(w, "chat handler not configured", http.StatusInternalServerError) return } - conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil) + conn, err := s.websocketUpgrader().Upgrade(w, r, nil) if err != nil { return } diff --git a/pkg/api/server_live.go b/pkg/api/server_live.go index ace918e..716f967 100644 --- a/pkg/api/server_live.go +++ b/pkg/api/server_live.go @@ -129,7 +129,7 @@ func (s *Server) handleWebUILogsLive(w http.ResponseWriter, r *http.Request) { http.Error(w, "log path not configured", http.StatusInternalServerError) return } - conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil) + conn, err := s.websocketUpgrader().Upgrade(w, r, nil) if err != nil { return } diff --git a/pkg/api/server_node_artifacts.go b/pkg/api/server_node_artifacts.go index 8ac6f12..b08788e 100644 --- a/pkg/api/server_node_artifacts.go +++ b/pkg/api/server_node_artifacts.go @@ -317,10 +317,6 @@ func resolveArtifactPath(workspace, raw string) string { return "" } if filepath.IsAbs(raw) { - clean := filepath.Clean(raw) - if info, err := os.Stat(clean); err == nil && !info.IsDir() { - return clean - } return "" } root := strings.TrimSpace(workspace) @@ -338,12 +334,12 @@ func resolveArtifactPath(workspace, raw string) string { } func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, string, error) { - if content := strings.TrimSpace(fmt.Sprint(item["content_base64"])); content != "" { + if content := strings.TrimSpace(stringFromMap(item, "content_base64")); content != "" { raw, err := base64.StdEncoding.DecodeString(content) if err != nil { return nil, "", err } - return raw, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil + return raw, strings.TrimSpace(stringFromMap(item, "mime_type")), nil } for _, rawPath := range []string{fmt.Sprint(item["source_path"]), fmt.Sprint(item["path"])} { if path := resolveArtifactPath(workspace, rawPath); path != "" { @@ -351,10 +347,10 @@ func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, s if err != nil { return nil, "", err } - return b, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil + return b, strings.TrimSpace(stringFromMap(item, "mime_type")), nil } } - if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" { + if contentText := strings.TrimSpace(stringFromMap(item, "content_text")); contentText != "" { return []byte(contentText), "text/plain; charset=utf-8", nil } return nil, "", fmt.Errorf("artifact content unavailable") @@ -411,9 +407,14 @@ func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.R nodeList, _ := payload["nodes"].([]nodes.NodeInfo) p2p, _ := payload["p2p"].(map[string]interface{}) alerts := filteredNodeAlerts(s.webUINodeAlertsPayload(nodeList, p2p, dispatches), nodeFilter) - - var archive bytes.Buffer - zw := zip.NewWriter(&archive) + filename := "node-artifacts-export.zip" + if nodeFilter != "" { + filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter)) + } + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename)) + w.WriteHeader(http.StatusOK) + zw := zip.NewWriter(w) writeZipJSON := func(name string, value interface{}) error { entry, err := zw.Create(name) if err != nil { @@ -470,26 +471,13 @@ func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.R } entry, err := zw.Create(entryName) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) return } if _, err := entry.Write(raw); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) return } } - if err := zw.Close(); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - filename := "node-artifacts-export.zip" - if nodeFilter != "" { - filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter)) - } - w.Header().Set("Content-Type", "application/zip") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename)) - w.WriteHeader(http.StatusOK) - _, _ = w.Write(archive.Bytes()) + _ = zw.Close() } func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.Request) { @@ -519,7 +507,7 @@ func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http. if mimeType == "" { mimeType = "application/octet-stream" } - if contentB64 := strings.TrimSpace(fmt.Sprint(item["content_base64"])); contentB64 != "" { + if contentB64 := strings.TrimSpace(stringFromMap(item, "content_base64")); contentB64 != "" { payload, err := base64.StdEncoding.DecodeString(contentB64) if err != nil { http.Error(w, "invalid inline artifact payload", http.StatusBadRequest) @@ -536,7 +524,7 @@ func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http. return } } - if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" { + if contentText := strings.TrimSpace(stringFromMap(item, "content_text")); contentText != "" { w.Header().Set("Content-Type", mimeType) w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) _, _ = w.Write([]byte(contentText)) diff --git a/pkg/api/server_node_artifacts_test.go b/pkg/api/server_node_artifacts_test.go index 099781c..7ea1691 100644 --- a/pkg/api/server_node_artifacts_test.go +++ b/pkg/api/server_node_artifacts_test.go @@ -31,7 +31,7 @@ func TestHandleWebUINodeArtifactsListAndDelete(t *testing.T) { if err := os.WriteFile(artifactPath, []byte("artifact-body"), 0o644); err != nil { t.Fatalf("write artifact: %v", err) } - auditLine := fmt.Sprintf("{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"%s\",\"size_bytes\":13}]}\n", artifactPath) + auditLine := "{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"artifact.txt\",\"size_bytes\":13}]}\n" if err := os.WriteFile(filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl"), []byte(auditLine), 0o644); err != nil { t.Fatalf("write audit: %v", err) } @@ -68,6 +68,36 @@ func TestHandleWebUINodeArtifactsListAndDelete(t *testing.T) { } } +func TestHandleWebUINodeArtifactDownloadRejectsAbsoluteSourcePath(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + workspace := t.TempDir() + srv.SetWorkspacePath(workspace) + if err := os.MkdirAll(filepath.Join(workspace, "memory"), 0o755); err != nil { + t.Fatalf("mkdir memory: %v", err) + } + artifactPath := filepath.Join(workspace, "artifact.txt") + if err := os.WriteFile(artifactPath, []byte("artifact-body"), 0o644); err != nil { + t.Fatalf("write artifact: %v", err) + } + auditLine := fmt.Sprintf("{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"%s\",\"size_bytes\":13}]}\n", artifactPath) + if err := os.WriteFile(filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl"), []byte(auditLine), 0o644); err != nil { + t.Fatalf("write audit: %v", err) + } + + items := srv.webUINodeArtifactsPayload(10) + if len(items) != 1 { + t.Fatalf("expected 1 artifact, got %d", len(items)) + } + req := httptest.NewRequest(http.MethodGet, "/api/node_artifacts/download?id="+fmt.Sprint(items[0]["id"]), nil) + rec := httptest.NewRecorder() + srv.handleWebUINodeArtifactDownload(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d: %s", rec.Code, rec.Body.String()) + } +} + func TestHandleWebUINodeArtifactsExport(t *testing.T) { t.Parallel() diff --git a/pkg/api/server_nodes_gateway.go b/pkg/api/server_nodes_gateway.go index f38716a..91b86af 100644 --- a/pkg/api/server_nodes_gateway.go +++ b/pkg/api/server_nodes_gateway.go @@ -103,7 +103,7 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) { http.Error(w, "nodes manager unavailable", http.StatusInternalServerError) return } - conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil) + conn, err := s.websocketUpgrader().Upgrade(w, r, nil) if err != nil { return } diff --git a/pkg/api/server_providers.go b/pkg/api/server_providers.go index 10a201d..15b5ca1 100644 --- a/pkg/api/server_providers.go +++ b/pkg/api/server_providers.go @@ -19,7 +19,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re http.Error(w, "unauthorized", http.StatusUnauthorized) return } - if r.Method != http.MethodPost && r.Method != http.MethodGet { + if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } @@ -29,15 +29,9 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re NetworkProxy string `json:"network_proxy"` ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"` } - if r.Method == http.MethodPost { - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) - return - } - } else { - body.Provider = strings.TrimSpace(r.URL.Query().Get("provider")) - body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label")) - body.NetworkProxy = strings.TrimSpace(r.URL.Query().Get("network_proxy")) + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return } cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig) if err != nil { diff --git a/pkg/api/server_runtime_nodes.go b/pkg/api/server_runtime_nodes.go index 0b9ba54..793bf14 100644 --- a/pkg/api/server_runtime_nodes.go +++ b/pkg/api/server_runtime_nodes.go @@ -26,7 +26,7 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) { http.Error(w, "unauthorized", http.StatusUnauthorized) return } - conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil) + conn, err := s.websocketUpgrader().Upgrade(w, r, nil) if err != nil { return } diff --git a/pkg/api/server_security_test.go b/pkg/api/server_security_test.go new file mode 100644 index 0000000..003f67d --- /dev/null +++ b/pkg/api/server_security_test.go @@ -0,0 +1,179 @@ +package api + +import ( + "bytes" + "encoding/json" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/YspCoder/clawgo/pkg/nodes" + "github.com/gorilla/websocket" +) + +func TestCheckAuthAllowsBearerAndCookieOnly(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "secret-token", nil) + + bearerReq := httptest.NewRequest(http.MethodGet, "/", nil) + bearerReq.Header.Set("Authorization", "Bearer secret-token") + if !srv.checkAuth(bearerReq) { + t.Fatalf("expected bearer auth to succeed") + } + + cookieReq := httptest.NewRequest(http.MethodGet, "/", nil) + cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"}) + if !srv.checkAuth(cookieReq) { + t.Fatalf("expected cookie auth to succeed") + } + + queryReq := httptest.NewRequest(http.MethodGet, "/?token=secret-token", nil) + if srv.checkAuth(queryReq) { + t.Fatalf("expected query token auth to fail") + } + + refererReq := httptest.NewRequest(http.MethodGet, "/", nil) + refererReq.Header.Set("Referer", "https://example.com/?token=secret-token") + if srv.checkAuth(refererReq) { + t.Fatalf("expected referer token auth to fail") + } +} + +func TestWithCORSRejectsForeignOrigin(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nil) + req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil) + req.Host = "example.com" + req.Header.Set("Origin", "https://evil.example") + rec := httptest.NewRecorder() + + srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rec.Code) + } +} + +func TestWithCORSAcceptsSameOrigin(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nil) + req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil) + req.Host = "example.com" + req.Header.Set("Origin", "http://example.com") + rec := httptest.NewRecorder() + + srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" { + t.Fatalf("unexpected allow-origin header %q", got) + } +} + +func TestHandleNodeConnectRejectsForeignOrigin(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nodes.NewManager()) + mux := http.NewServeMux() + mux.HandleFunc("/nodes/connect", srv.handleNodeConnect) + httpSrv := httptest.NewServer(mux) + defer httpSrv.Close() + + wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/nodes/connect" + dialer := websocket.Dialer{} + headers := http.Header{"Origin": []string{"https://evil.example"}} + conn, resp, err := dialer.Dial(wsURL, headers) + if err == nil { + conn.Close() + t.Fatalf("expected websocket handshake to fail") + } + if resp == nil || resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected 403 response, got %#v", resp) + } +} + +func TestHandleWebUISetsCookieForBearerOnly(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "secret-token", nil) + + bearerReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + bearerReq.Header.Set("Authorization", "Bearer secret-token") + bearerRec := httptest.NewRecorder() + srv.handleWebUI(bearerRec, bearerReq) + if len(bearerRec.Result().Cookies()) == 0 { + t.Fatalf("expected bearer-authenticated UI request to set cookie") + } + + cookieReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"}) + cookieRec := httptest.NewRecorder() + srv.handleWebUI(cookieRec, cookieReq) + if len(cookieRec.Result().Cookies()) != 0 { + t.Fatalf("expected cookie-authenticated UI request not to reset cookie") + } +} + +func TestHandleWebUIUploadDoesNotExposeAbsolutePath(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nil) + var form bytes.Buffer + mw := multipartWriter(t, &form, "file", "demo.txt", []byte("upload-body")) + req := httptest.NewRequest(http.MethodPost, "/api/upload", &form) + req.Header.Set("Content-Type", mw.FormDataContentType()) + rec := httptest.NewRecorder() + srv.handleWebUIUpload(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + var payload map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if _, ok := payload["path"]; ok { + t.Fatalf("expected upload response to omit absolute path: %+v", payload) + } + if strings.TrimSpace(payload["media"].(string)) == "" { + t.Fatalf("expected media handle in response: %+v", payload) + } +} + +func multipartWriter(t *testing.T, dst *bytes.Buffer, fieldName, filename string, body []byte) *multipart.Writer { + t.Helper() + mw := multipart.NewWriter(dst) + part, err := mw.CreateFormFile(fieldName, filename) + if err != nil { + t.Fatalf("create form file: %v", err) + } + if _, err := part.Write(body); err != nil { + t.Fatalf("write form file: %v", err) + } + if err := mw.Close(); err != nil { + t.Fatalf("close multipart writer: %v", err) + } + return mw +} + +func TestHandleWebUIProviderOAuthStartRejectsGet(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "", nil) + req := httptest.NewRequest(http.MethodGet, "/api/provider/oauth/start?provider=openai", nil) + rec := httptest.NewRecorder() + srv.handleWebUIProviderOAuthStart(rec, req) + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected 405, got %d", rec.Code) + } +} diff --git a/pkg/api/server_skills.go b/pkg/api/server_skills.go index 9dd3e77..68a4443 100644 --- a/pkg/api/server_skills.go +++ b/pkg/api/server_skills.go @@ -22,6 +22,39 @@ import ( "github.com/YspCoder/clawgo/pkg/tools" ) +const ( + skillArchiveUploadLimit int64 = 32 << 20 + skillArchiveMaxFiles = 256 + skillArchiveMaxSingleFile int64 = 8 << 20 + skillArchiveMaxExpanded int64 = 32 << 20 +) + +type archiveExtractLimits struct { + maxFiles int + maxSingleFile int64 + maxExpanded int64 + fileCount int + totalExpanded int64 +} + +func (l *archiveExtractLimits) addFile(size int64) error { + if size < 0 { + return fmt.Errorf("invalid archive entry size") + } + l.fileCount++ + if l.maxFiles > 0 && l.fileCount > l.maxFiles { + return fmt.Errorf("archive contains too many files") + } + if l.maxSingleFile > 0 && size > l.maxSingleFile { + return fmt.Errorf("archive entry exceeds size limit") + } + l.totalExpanded += size + if l.maxExpanded > 0 && l.totalExpanded > l.maxExpanded { + return fmt.Errorf("archive exceeds expanded size limit") + } + return nil +} + func mustMap(v interface{}) map[string]interface{} { if v == nil { return map[string]interface{}{} @@ -262,7 +295,13 @@ func ensureClawHubReady(ctx context.Context) (string, error) { } func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]string, error) { - if err := r.ParseMultipartForm(128 << 20); err != nil { + if r.ContentLength > skillArchiveUploadLimit { + return nil, fmt.Errorf("archive upload exceeds size limit") + } + if r.Body != nil { + r.Body = http.MaxBytesReader(nil, r.Body, skillArchiveUploadLimit) + } + if err := r.ParseMultipartForm(8 << 20); err != nil { return nil, err } f, h, err := r.FormFile("file") @@ -292,7 +331,12 @@ func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]strin } defer os.RemoveAll(extractDir) - if err := extractArchive(archivePath, extractDir); err != nil { + limits := archiveExtractLimits{ + maxFiles: skillArchiveMaxFiles, + maxSingleFile: skillArchiveMaxSingleFile, + maxExpanded: skillArchiveMaxExpanded, + } + if err := extractArchive(archivePath, extractDir, &limits); err != nil { return nil, err } @@ -403,21 +447,21 @@ func sanitizeSkillName(name string) string { return out } -func extractArchive(archivePath, targetDir string) error { +func extractArchive(archivePath, targetDir string, limits *archiveExtractLimits) error { lower := strings.ToLower(archivePath) switch { case strings.HasSuffix(lower, ".zip"): - return extractZip(archivePath, targetDir) + return extractZip(archivePath, targetDir, limits) case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"): - return extractTarGz(archivePath, targetDir) + return extractTarGz(archivePath, targetDir, limits) case strings.HasSuffix(lower, ".tar"): - return extractTar(archivePath, targetDir) + return extractTar(archivePath, targetDir, limits) default: return fmt.Errorf("unsupported archive format: %s", filepath.Base(archivePath)) } } -func extractZip(archivePath, targetDir string) error { +func extractZip(archivePath, targetDir string, limits *archiveExtractLimits) error { zr, err := zip.OpenReader(archivePath) if err != nil { return err @@ -425,16 +469,21 @@ func extractZip(archivePath, targetDir string) error { defer zr.Close() for _, f := range zr.File { + if !f.FileInfo().IsDir() && limits != nil { + if err := limits.addFile(int64(f.UncompressedSize64)); err != nil { + return err + } + } if err := writeArchivedEntry(targetDir, f.Name, f.FileInfo().IsDir(), func() (io.ReadCloser, error) { return f.Open() - }); err != nil { + }, limits.maxSingleFile); err != nil { return err } } return nil } -func extractTarGz(archivePath, targetDir string) error { +func extractTarGz(archivePath, targetDir string, limits *archiveExtractLimits) error { f, err := os.Open(archivePath) if err != nil { return err @@ -445,19 +494,19 @@ func extractTarGz(archivePath, targetDir string) error { return err } defer gz.Close() - return extractTarReader(tar.NewReader(gz), targetDir) + return extractTarReader(tar.NewReader(gz), targetDir, limits) } -func extractTar(archivePath, targetDir string) error { +func extractTar(archivePath, targetDir string, limits *archiveExtractLimits) error { f, err := os.Open(archivePath) if err != nil { return err } defer f.Close() - return extractTarReader(tar.NewReader(f), targetDir) + return extractTarReader(tar.NewReader(f), targetDir, limits) } -func extractTarReader(tr *tar.Reader, targetDir string) error { +func extractTarReader(tr *tar.Reader, targetDir string, limits *archiveExtractLimits) error { for { hdr, err := tr.Next() if errors.Is(err, io.EOF) { @@ -468,21 +517,34 @@ func extractTarReader(tr *tar.Reader, targetDir string) error { } switch hdr.Typeflag { case tar.TypeDir: - if err := writeArchivedEntry(targetDir, hdr.Name, true, nil); err != nil { + maxEntryBytes := int64(0) + if limits != nil { + maxEntryBytes = limits.maxSingleFile + } + if err := writeArchivedEntry(targetDir, hdr.Name, true, nil, maxEntryBytes); err != nil { return err } case tar.TypeReg, tar.TypeRegA: + if limits != nil { + if err := limits.addFile(hdr.Size); err != nil { + return err + } + } name := hdr.Name + maxEntryBytes := int64(0) + if limits != nil { + maxEntryBytes = limits.maxSingleFile + } if err := writeArchivedEntry(targetDir, name, false, func() (io.ReadCloser, error) { return io.NopCloser(tr), nil - }); err != nil { + }, maxEntryBytes); err != nil { return err } } } } -func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error)) error { +func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error), maxBytes int64) error { clean := filepath.Clean(strings.TrimSpace(name)) clean = strings.TrimPrefix(clean, string(filepath.Separator)) clean = strings.TrimPrefix(clean, "/") @@ -514,7 +576,17 @@ func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.Re return err } defer out.Close() - _, err = io.Copy(out, rc) + reader := io.Reader(rc) + if maxBytes > 0 { + reader = io.LimitReader(rc, maxBytes+1) + } + written, err := io.Copy(out, reader) + if err != nil { + return err + } + if maxBytes > 0 && written > maxBytes { + return fmt.Errorf("archive entry exceeds size limit") + } return err } diff --git a/pkg/api/server_skills_test.go b/pkg/api/server_skills_test.go new file mode 100644 index 0000000..84305a9 --- /dev/null +++ b/pkg/api/server_skills_test.go @@ -0,0 +1,93 @@ +package api + +import ( + "archive/zip" + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestImportSkillArchiveFromMultipartSucceedsForSmallArchive(t *testing.T) { + t.Parallel() + + var archive bytes.Buffer + zw := zip.NewWriter(&archive) + entry, err := zw.Create("demo/SKILL.md") + if err != nil { + t.Fatalf("create entry: %v", err) + } + if _, err := entry.Write([]byte("# Demo\n")); err != nil { + t.Fatalf("write entry: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("close zip: %v", err) + } + + req := multipartRequest(t, "demo.zip", archive.Bytes()) + skillsDir := t.TempDir() + imported, err := importSkillArchiveFromMultipart(req, skillsDir) + if err != nil { + t.Fatalf("import archive: %v", err) + } + if len(imported) != 1 || imported[0] != "demo" { + t.Fatalf("unexpected imported skills: %+v", imported) + } + if _, err := os.Stat(filepath.Join(skillsDir, "demo", "SKILL.md")); err != nil { + t.Fatalf("expected imported SKILL.md: %v", err) + } +} + +func TestExtractArchiveRejectsOversizedExpandedEntry(t *testing.T) { + t.Parallel() + + var archive bytes.Buffer + zw := zip.NewWriter(&archive) + entry, err := zw.Create("demo/SKILL.md") + if err != nil { + t.Fatalf("create entry: %v", err) + } + if _, err := entry.Write(bytes.Repeat([]byte("a"), int(skillArchiveMaxSingleFile+1))); err != nil { + t.Fatalf("write entry: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("close zip: %v", err) + } + + archivePath := filepath.Join(t.TempDir(), "demo.zip") + if err := os.WriteFile(archivePath, archive.Bytes(), 0o644); err != nil { + t.Fatalf("write archive: %v", err) + } + err = extractArchive(archivePath, t.TempDir(), &archiveExtractLimits{ + maxFiles: skillArchiveMaxFiles, + maxSingleFile: skillArchiveMaxSingleFile, + maxExpanded: skillArchiveMaxExpanded, + }) + if err == nil || !strings.Contains(err.Error(), "size limit") { + t.Fatalf("expected size limit error, got %v", err) + } +} + +func multipartRequest(t *testing.T, filename string, body []byte) *http.Request { + t.Helper() + var form bytes.Buffer + mw := multipart.NewWriter(&form) + part, err := mw.CreateFormFile("file", filename) + if err != nil { + t.Fatalf("create form file: %v", err) + } + if _, err := part.Write(body); err != nil { + t.Fatalf("write form file: %v", err) + } + if err := mw.Close(); err != nil { + t.Fatalf("close multipart writer: %v", err) + } + req := httptest.NewRequest("POST", "/api/skills", &form) + req.Header.Set("Content-Type", mw.FormDataContentType()) + req.ContentLength = int64(form.Len()) + return req +} diff --git a/pkg/api/server_webui.go b/pkg/api/server_webui.go index d81b99e..341eb0a 100644 --- a/pkg/api/server_webui.go +++ b/pkg/api/server_webui.go @@ -20,12 +20,13 @@ func (s *Server) handleWebUI(w http.ResponseWriter, r *http.Request) { http.Error(w, "unauthorized", http.StatusUnauthorized) return } - if s.token != "" { + if s.token != "" && s.isBearerAuthorized(r) { http.SetCookie(w, &http.Cookie{ Name: "clawgo_webui_token", Value: s.token, Path: "/", HttpOnly: true, + Secure: requestUsesTLS(r), SameSite: http.SameSiteLaxMode, MaxAge: 86400, }) @@ -118,7 +119,7 @@ func (s *Server) handleWebUIUpload(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - writeJSON(w, map[string]interface{}{"ok": true, "path": path, "name": h.Filename}) + writeJSON(w, map[string]interface{}{"ok": true, "media": name, "name": h.Filename}) } func gatewayBuildVersion() string { @@ -165,14 +166,14 @@ const webUIHTML = `
Session: