Harden gateway auth and file boundaries

This commit is contained in:
lpf
2026-03-15 15:31:00 +08:00
parent 617f7cc0f1
commit ba95aeed35
16 changed files with 587 additions and 91 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
@@ -73,10 +74,6 @@ type Server struct {
skillsRPCReg *rpcpkg.Registry
}
var nodesWebsocketUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server {
addr := strings.TrimSpace(host)
if addr == "" {
@@ -290,15 +287,82 @@ func (s *Server) Start(ctx context.Context) error {
return nil
}
func requestUsesTLS(r *http.Request) bool {
if r == nil {
return false
}
if r.TLS != nil {
return true
}
return strings.EqualFold(strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")), "https")
}
func canonicalOriginHost(host string, https bool) string {
host = strings.TrimSpace(host)
if host == "" {
return ""
}
if parsedHost, parsedPort, err := net.SplitHostPort(host); err == nil {
return strings.ToLower(net.JoinHostPort(parsedHost, parsedPort))
}
port := "80"
if https {
port = "443"
}
return strings.ToLower(net.JoinHostPort(host, port))
}
func (s *Server) isTrustedOrigin(r *http.Request) bool {
if r == nil {
return false
}
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
return true
}
u, err := url.Parse(origin)
if err != nil {
return false
}
switch strings.ToLower(strings.TrimSpace(u.Scheme)) {
case "http", "https":
default:
return false
}
return canonicalOriginHost(u.Host, strings.EqualFold(u.Scheme, "https")) == canonicalOriginHost(r.Host, requestUsesTLS(r))
}
func (s *Server) websocketUpgrader() *websocket.Upgrader {
return &websocket.Upgrader{
CheckOrigin: s.isTrustedOrigin,
}
}
func (s *Server) isBearerAuthorized(r *http.Request) bool {
if s == nil || r == nil || strings.TrimSpace(s.token) == "" {
return false
}
return strings.TrimSpace(r.Header.Get("Authorization")) == "Bearer "+s.token
}
func (s *Server) withCORS(next http.Handler) http.Handler {
if next == nil {
next = http.NotFoundHandler()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With")
w.Header().Set("Access-Control-Expose-Headers", "*")
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin != "" {
if !s.isTrustedOrigin(r) {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With")
w.Header().Set("Access-Control-Expose-Headers", "*")
w.Header().Add("Vary", "Origin")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
@@ -357,23 +421,11 @@ func (s *Server) checkAuth(r *http.Request) bool {
if s.token == "" {
return true
}
auth := strings.TrimSpace(r.Header.Get("Authorization"))
if auth == "Bearer "+s.token {
return true
}
if strings.TrimSpace(r.URL.Query().Get("token")) == s.token {
if s.isBearerAuthorized(r) {
return true
}
if c, err := r.Cookie("clawgo_webui_token"); err == nil && strings.TrimSpace(c.Value) == s.token {
return true
}
// Browser asset fallback: allow token propagated via Referer query.
if ref := strings.TrimSpace(r.Referer()); ref != "" {
if u, err := url.Parse(ref); err == nil {
if strings.TrimSpace(u.Query().Get("token")) == s.token {
return true
}
}
}
return false
}

View File

@@ -93,7 +93,7 @@ func (s *Server) handleWebUIChatLive(w http.ResponseWriter, r *http.Request) {
http.Error(w, "chat handler not configured", http.StatusInternalServerError)
return
}
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
if err != nil {
return
}

View File

@@ -129,7 +129,7 @@ func (s *Server) handleWebUILogsLive(w http.ResponseWriter, r *http.Request) {
http.Error(w, "log path not configured", http.StatusInternalServerError)
return
}
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
if err != nil {
return
}

View File

@@ -317,10 +317,6 @@ func resolveArtifactPath(workspace, raw string) string {
return ""
}
if filepath.IsAbs(raw) {
clean := filepath.Clean(raw)
if info, err := os.Stat(clean); err == nil && !info.IsDir() {
return clean
}
return ""
}
root := strings.TrimSpace(workspace)
@@ -338,12 +334,12 @@ func resolveArtifactPath(workspace, raw string) string {
}
func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, string, error) {
if content := strings.TrimSpace(fmt.Sprint(item["content_base64"])); content != "" {
if content := strings.TrimSpace(stringFromMap(item, "content_base64")); content != "" {
raw, err := base64.StdEncoding.DecodeString(content)
if err != nil {
return nil, "", err
}
return raw, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil
return raw, strings.TrimSpace(stringFromMap(item, "mime_type")), nil
}
for _, rawPath := range []string{fmt.Sprint(item["source_path"]), fmt.Sprint(item["path"])} {
if path := resolveArtifactPath(workspace, rawPath); path != "" {
@@ -351,10 +347,10 @@ func readArtifactBytes(workspace string, item map[string]interface{}) ([]byte, s
if err != nil {
return nil, "", err
}
return b, strings.TrimSpace(fmt.Sprint(item["mime_type"])), nil
return b, strings.TrimSpace(stringFromMap(item, "mime_type")), nil
}
}
if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" {
if contentText := strings.TrimSpace(stringFromMap(item, "content_text")); contentText != "" {
return []byte(contentText), "text/plain; charset=utf-8", nil
}
return nil, "", fmt.Errorf("artifact content unavailable")
@@ -411,9 +407,14 @@ func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.R
nodeList, _ := payload["nodes"].([]nodes.NodeInfo)
p2p, _ := payload["p2p"].(map[string]interface{})
alerts := filteredNodeAlerts(s.webUINodeAlertsPayload(nodeList, p2p, dispatches), nodeFilter)
var archive bytes.Buffer
zw := zip.NewWriter(&archive)
filename := "node-artifacts-export.zip"
if nodeFilter != "" {
filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter))
}
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
w.WriteHeader(http.StatusOK)
zw := zip.NewWriter(w)
writeZipJSON := func(name string, value interface{}) error {
entry, err := zw.Create(name)
if err != nil {
@@ -470,26 +471,13 @@ func (s *Server) handleWebUINodeArtifactsExport(w http.ResponseWriter, r *http.R
}
entry, err := zw.Create(entryName)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if _, err := entry.Write(raw); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
if err := zw.Close(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
filename := "node-artifacts-export.zip"
if nodeFilter != "" {
filename = fmt.Sprintf("node-artifacts-%s.zip", sanitizeZipEntryName(nodeFilter))
}
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(archive.Bytes())
_ = zw.Close()
}
func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.Request) {
@@ -519,7 +507,7 @@ func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.
if mimeType == "" {
mimeType = "application/octet-stream"
}
if contentB64 := strings.TrimSpace(fmt.Sprint(item["content_base64"])); contentB64 != "" {
if contentB64 := strings.TrimSpace(stringFromMap(item, "content_base64")); contentB64 != "" {
payload, err := base64.StdEncoding.DecodeString(contentB64)
if err != nil {
http.Error(w, "invalid inline artifact payload", http.StatusBadRequest)
@@ -536,7 +524,7 @@ func (s *Server) handleWebUINodeArtifactDownload(w http.ResponseWriter, r *http.
return
}
}
if contentText := fmt.Sprint(item["content_text"]); strings.TrimSpace(contentText) != "" {
if contentText := strings.TrimSpace(stringFromMap(item, "content_text")); contentText != "" {
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name))
_, _ = w.Write([]byte(contentText))

View File

@@ -31,7 +31,7 @@ func TestHandleWebUINodeArtifactsListAndDelete(t *testing.T) {
if err := os.WriteFile(artifactPath, []byte("artifact-body"), 0o644); err != nil {
t.Fatalf("write artifact: %v", err)
}
auditLine := fmt.Sprintf("{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"%s\",\"size_bytes\":13}]}\n", artifactPath)
auditLine := "{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"artifact.txt\",\"size_bytes\":13}]}\n"
if err := os.WriteFile(filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl"), []byte(auditLine), 0o644); err != nil {
t.Fatalf("write audit: %v", err)
}
@@ -68,6 +68,36 @@ func TestHandleWebUINodeArtifactsListAndDelete(t *testing.T) {
}
}
func TestHandleWebUINodeArtifactDownloadRejectsAbsoluteSourcePath(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nodes.NewManager())
workspace := t.TempDir()
srv.SetWorkspacePath(workspace)
if err := os.MkdirAll(filepath.Join(workspace, "memory"), 0o755); err != nil {
t.Fatalf("mkdir memory: %v", err)
}
artifactPath := filepath.Join(workspace, "artifact.txt")
if err := os.WriteFile(artifactPath, []byte("artifact-body"), 0o644); err != nil {
t.Fatalf("write artifact: %v", err)
}
auditLine := fmt.Sprintf("{\"time\":\"2026-03-09T00:00:00Z\",\"node\":\"edge-a\",\"action\":\"run\",\"artifacts\":[{\"name\":\"artifact.txt\",\"kind\":\"text\",\"mime_type\":\"text/plain\",\"source_path\":\"%s\",\"size_bytes\":13}]}\n", artifactPath)
if err := os.WriteFile(filepath.Join(workspace, "memory", "nodes-dispatch-audit.jsonl"), []byte(auditLine), 0o644); err != nil {
t.Fatalf("write audit: %v", err)
}
items := srv.webUINodeArtifactsPayload(10)
if len(items) != 1 {
t.Fatalf("expected 1 artifact, got %d", len(items))
}
req := httptest.NewRequest(http.MethodGet, "/api/node_artifacts/download?id="+fmt.Sprint(items[0]["id"]), nil)
rec := httptest.NewRecorder()
srv.handleWebUINodeArtifactDownload(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", rec.Code, rec.Body.String())
}
}
func TestHandleWebUINodeArtifactsExport(t *testing.T) {
t.Parallel()

View File

@@ -103,7 +103,7 @@ func (s *Server) handleNodeConnect(w http.ResponseWriter, r *http.Request) {
http.Error(w, "nodes manager unavailable", http.StatusInternalServerError)
return
}
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
if err != nil {
return
}

View File

@@ -19,7 +19,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost && r.Method != http.MethodGet {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
@@ -29,15 +29,9 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
NetworkProxy string `json:"network_proxy"`
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
}
if r.Method == http.MethodPost {
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
} else {
body.Provider = strings.TrimSpace(r.URL.Query().Get("provider"))
body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label"))
body.NetworkProxy = strings.TrimSpace(r.URL.Query().Get("network_proxy"))
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
if err != nil {

View File

@@ -26,7 +26,7 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
conn, err := nodesWebsocketUpgrader.Upgrade(w, r, nil)
conn, err := s.websocketUpgrader().Upgrade(w, r, nil)
if err != nil {
return
}

View File

@@ -0,0 +1,179 @@
package api
import (
"bytes"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/YspCoder/clawgo/pkg/nodes"
"github.com/gorilla/websocket"
)
func TestCheckAuthAllowsBearerAndCookieOnly(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "secret-token", nil)
bearerReq := httptest.NewRequest(http.MethodGet, "/", nil)
bearerReq.Header.Set("Authorization", "Bearer secret-token")
if !srv.checkAuth(bearerReq) {
t.Fatalf("expected bearer auth to succeed")
}
cookieReq := httptest.NewRequest(http.MethodGet, "/", nil)
cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"})
if !srv.checkAuth(cookieReq) {
t.Fatalf("expected cookie auth to succeed")
}
queryReq := httptest.NewRequest(http.MethodGet, "/?token=secret-token", nil)
if srv.checkAuth(queryReq) {
t.Fatalf("expected query token auth to fail")
}
refererReq := httptest.NewRequest(http.MethodGet, "/", nil)
refererReq.Header.Set("Referer", "https://example.com/?token=secret-token")
if srv.checkAuth(refererReq) {
t.Fatalf("expected referer token auth to fail")
}
}
func TestWithCORSRejectsForeignOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil)
req.Host = "example.com"
req.Header.Set("Origin", "https://evil.example")
rec := httptest.NewRecorder()
srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", rec.Code)
}
}
func TestWithCORSAcceptsSameOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil)
req.Host = "example.com"
req.Header.Set("Origin", "http://example.com")
rec := httptest.NewRecorder()
srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rec.Code)
}
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Fatalf("unexpected allow-origin header %q", got)
}
}
func TestHandleNodeConnectRejectsForeignOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nodes.NewManager())
mux := http.NewServeMux()
mux.HandleFunc("/nodes/connect", srv.handleNodeConnect)
httpSrv := httptest.NewServer(mux)
defer httpSrv.Close()
wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/nodes/connect"
dialer := websocket.Dialer{}
headers := http.Header{"Origin": []string{"https://evil.example"}}
conn, resp, err := dialer.Dial(wsURL, headers)
if err == nil {
conn.Close()
t.Fatalf("expected websocket handshake to fail")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected 403 response, got %#v", resp)
}
}
func TestHandleWebUISetsCookieForBearerOnly(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "secret-token", nil)
bearerReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
bearerReq.Header.Set("Authorization", "Bearer secret-token")
bearerRec := httptest.NewRecorder()
srv.handleWebUI(bearerRec, bearerReq)
if len(bearerRec.Result().Cookies()) == 0 {
t.Fatalf("expected bearer-authenticated UI request to set cookie")
}
cookieReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"})
cookieRec := httptest.NewRecorder()
srv.handleWebUI(cookieRec, cookieReq)
if len(cookieRec.Result().Cookies()) != 0 {
t.Fatalf("expected cookie-authenticated UI request not to reset cookie")
}
}
func TestHandleWebUIUploadDoesNotExposeAbsolutePath(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
var form bytes.Buffer
mw := multipartWriter(t, &form, "file", "demo.txt", []byte("upload-body"))
req := httptest.NewRequest(http.MethodPost, "/api/upload", &form)
req.Header.Set("Content-Type", mw.FormDataContentType())
rec := httptest.NewRecorder()
srv.handleWebUIUpload(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
}
var payload map[string]interface{}
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if _, ok := payload["path"]; ok {
t.Fatalf("expected upload response to omit absolute path: %+v", payload)
}
if strings.TrimSpace(payload["media"].(string)) == "" {
t.Fatalf("expected media handle in response: %+v", payload)
}
}
func multipartWriter(t *testing.T, dst *bytes.Buffer, fieldName, filename string, body []byte) *multipart.Writer {
t.Helper()
mw := multipart.NewWriter(dst)
part, err := mw.CreateFormFile(fieldName, filename)
if err != nil {
t.Fatalf("create form file: %v", err)
}
if _, err := part.Write(body); err != nil {
t.Fatalf("write form file: %v", err)
}
if err := mw.Close(); err != nil {
t.Fatalf("close multipart writer: %v", err)
}
return mw
}
func TestHandleWebUIProviderOAuthStartRejectsGet(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
req := httptest.NewRequest(http.MethodGet, "/api/provider/oauth/start?provider=openai", nil)
rec := httptest.NewRecorder()
srv.handleWebUIProviderOAuthStart(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected 405, got %d", rec.Code)
}
}

View File

@@ -22,6 +22,39 @@ import (
"github.com/YspCoder/clawgo/pkg/tools"
)
const (
skillArchiveUploadLimit int64 = 32 << 20
skillArchiveMaxFiles = 256
skillArchiveMaxSingleFile int64 = 8 << 20
skillArchiveMaxExpanded int64 = 32 << 20
)
type archiveExtractLimits struct {
maxFiles int
maxSingleFile int64
maxExpanded int64
fileCount int
totalExpanded int64
}
func (l *archiveExtractLimits) addFile(size int64) error {
if size < 0 {
return fmt.Errorf("invalid archive entry size")
}
l.fileCount++
if l.maxFiles > 0 && l.fileCount > l.maxFiles {
return fmt.Errorf("archive contains too many files")
}
if l.maxSingleFile > 0 && size > l.maxSingleFile {
return fmt.Errorf("archive entry exceeds size limit")
}
l.totalExpanded += size
if l.maxExpanded > 0 && l.totalExpanded > l.maxExpanded {
return fmt.Errorf("archive exceeds expanded size limit")
}
return nil
}
func mustMap(v interface{}) map[string]interface{} {
if v == nil {
return map[string]interface{}{}
@@ -262,7 +295,13 @@ func ensureClawHubReady(ctx context.Context) (string, error) {
}
func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]string, error) {
if err := r.ParseMultipartForm(128 << 20); err != nil {
if r.ContentLength > skillArchiveUploadLimit {
return nil, fmt.Errorf("archive upload exceeds size limit")
}
if r.Body != nil {
r.Body = http.MaxBytesReader(nil, r.Body, skillArchiveUploadLimit)
}
if err := r.ParseMultipartForm(8 << 20); err != nil {
return nil, err
}
f, h, err := r.FormFile("file")
@@ -292,7 +331,12 @@ func importSkillArchiveFromMultipart(r *http.Request, skillsDir string) ([]strin
}
defer os.RemoveAll(extractDir)
if err := extractArchive(archivePath, extractDir); err != nil {
limits := archiveExtractLimits{
maxFiles: skillArchiveMaxFiles,
maxSingleFile: skillArchiveMaxSingleFile,
maxExpanded: skillArchiveMaxExpanded,
}
if err := extractArchive(archivePath, extractDir, &limits); err != nil {
return nil, err
}
@@ -403,21 +447,21 @@ func sanitizeSkillName(name string) string {
return out
}
func extractArchive(archivePath, targetDir string) error {
func extractArchive(archivePath, targetDir string, limits *archiveExtractLimits) error {
lower := strings.ToLower(archivePath)
switch {
case strings.HasSuffix(lower, ".zip"):
return extractZip(archivePath, targetDir)
return extractZip(archivePath, targetDir, limits)
case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"):
return extractTarGz(archivePath, targetDir)
return extractTarGz(archivePath, targetDir, limits)
case strings.HasSuffix(lower, ".tar"):
return extractTar(archivePath, targetDir)
return extractTar(archivePath, targetDir, limits)
default:
return fmt.Errorf("unsupported archive format: %s", filepath.Base(archivePath))
}
}
func extractZip(archivePath, targetDir string) error {
func extractZip(archivePath, targetDir string, limits *archiveExtractLimits) error {
zr, err := zip.OpenReader(archivePath)
if err != nil {
return err
@@ -425,16 +469,21 @@ func extractZip(archivePath, targetDir string) error {
defer zr.Close()
for _, f := range zr.File {
if !f.FileInfo().IsDir() && limits != nil {
if err := limits.addFile(int64(f.UncompressedSize64)); err != nil {
return err
}
}
if err := writeArchivedEntry(targetDir, f.Name, f.FileInfo().IsDir(), func() (io.ReadCloser, error) {
return f.Open()
}); err != nil {
}, limits.maxSingleFile); err != nil {
return err
}
}
return nil
}
func extractTarGz(archivePath, targetDir string) error {
func extractTarGz(archivePath, targetDir string, limits *archiveExtractLimits) error {
f, err := os.Open(archivePath)
if err != nil {
return err
@@ -445,19 +494,19 @@ func extractTarGz(archivePath, targetDir string) error {
return err
}
defer gz.Close()
return extractTarReader(tar.NewReader(gz), targetDir)
return extractTarReader(tar.NewReader(gz), targetDir, limits)
}
func extractTar(archivePath, targetDir string) error {
func extractTar(archivePath, targetDir string, limits *archiveExtractLimits) error {
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer f.Close()
return extractTarReader(tar.NewReader(f), targetDir)
return extractTarReader(tar.NewReader(f), targetDir, limits)
}
func extractTarReader(tr *tar.Reader, targetDir string) error {
func extractTarReader(tr *tar.Reader, targetDir string, limits *archiveExtractLimits) error {
for {
hdr, err := tr.Next()
if errors.Is(err, io.EOF) {
@@ -468,21 +517,34 @@ func extractTarReader(tr *tar.Reader, targetDir string) error {
}
switch hdr.Typeflag {
case tar.TypeDir:
if err := writeArchivedEntry(targetDir, hdr.Name, true, nil); err != nil {
maxEntryBytes := int64(0)
if limits != nil {
maxEntryBytes = limits.maxSingleFile
}
if err := writeArchivedEntry(targetDir, hdr.Name, true, nil, maxEntryBytes); err != nil {
return err
}
case tar.TypeReg, tar.TypeRegA:
if limits != nil {
if err := limits.addFile(hdr.Size); err != nil {
return err
}
}
name := hdr.Name
maxEntryBytes := int64(0)
if limits != nil {
maxEntryBytes = limits.maxSingleFile
}
if err := writeArchivedEntry(targetDir, name, false, func() (io.ReadCloser, error) {
return io.NopCloser(tr), nil
}); err != nil {
}, maxEntryBytes); err != nil {
return err
}
}
}
}
func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error)) error {
func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.ReadCloser, error), maxBytes int64) error {
clean := filepath.Clean(strings.TrimSpace(name))
clean = strings.TrimPrefix(clean, string(filepath.Separator))
clean = strings.TrimPrefix(clean, "/")
@@ -514,7 +576,17 @@ func writeArchivedEntry(targetDir, name string, isDir bool, opener func() (io.Re
return err
}
defer out.Close()
_, err = io.Copy(out, rc)
reader := io.Reader(rc)
if maxBytes > 0 {
reader = io.LimitReader(rc, maxBytes+1)
}
written, err := io.Copy(out, reader)
if err != nil {
return err
}
if maxBytes > 0 && written > maxBytes {
return fmt.Errorf("archive entry exceeds size limit")
}
return err
}

View File

@@ -0,0 +1,93 @@
package api
import (
"archive/zip"
"bytes"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func TestImportSkillArchiveFromMultipartSucceedsForSmallArchive(t *testing.T) {
t.Parallel()
var archive bytes.Buffer
zw := zip.NewWriter(&archive)
entry, err := zw.Create("demo/SKILL.md")
if err != nil {
t.Fatalf("create entry: %v", err)
}
if _, err := entry.Write([]byte("# Demo\n")); err != nil {
t.Fatalf("write entry: %v", err)
}
if err := zw.Close(); err != nil {
t.Fatalf("close zip: %v", err)
}
req := multipartRequest(t, "demo.zip", archive.Bytes())
skillsDir := t.TempDir()
imported, err := importSkillArchiveFromMultipart(req, skillsDir)
if err != nil {
t.Fatalf("import archive: %v", err)
}
if len(imported) != 1 || imported[0] != "demo" {
t.Fatalf("unexpected imported skills: %+v", imported)
}
if _, err := os.Stat(filepath.Join(skillsDir, "demo", "SKILL.md")); err != nil {
t.Fatalf("expected imported SKILL.md: %v", err)
}
}
func TestExtractArchiveRejectsOversizedExpandedEntry(t *testing.T) {
t.Parallel()
var archive bytes.Buffer
zw := zip.NewWriter(&archive)
entry, err := zw.Create("demo/SKILL.md")
if err != nil {
t.Fatalf("create entry: %v", err)
}
if _, err := entry.Write(bytes.Repeat([]byte("a"), int(skillArchiveMaxSingleFile+1))); err != nil {
t.Fatalf("write entry: %v", err)
}
if err := zw.Close(); err != nil {
t.Fatalf("close zip: %v", err)
}
archivePath := filepath.Join(t.TempDir(), "demo.zip")
if err := os.WriteFile(archivePath, archive.Bytes(), 0o644); err != nil {
t.Fatalf("write archive: %v", err)
}
err = extractArchive(archivePath, t.TempDir(), &archiveExtractLimits{
maxFiles: skillArchiveMaxFiles,
maxSingleFile: skillArchiveMaxSingleFile,
maxExpanded: skillArchiveMaxExpanded,
})
if err == nil || !strings.Contains(err.Error(), "size limit") {
t.Fatalf("expected size limit error, got %v", err)
}
}
func multipartRequest(t *testing.T, filename string, body []byte) *http.Request {
t.Helper()
var form bytes.Buffer
mw := multipart.NewWriter(&form)
part, err := mw.CreateFormFile("file", filename)
if err != nil {
t.Fatalf("create form file: %v", err)
}
if _, err := part.Write(body); err != nil {
t.Fatalf("write form file: %v", err)
}
if err := mw.Close(); err != nil {
t.Fatalf("close multipart writer: %v", err)
}
req := httptest.NewRequest("POST", "/api/skills", &form)
req.Header.Set("Content-Type", mw.FormDataContentType())
req.ContentLength = int64(form.Len())
return req
}

View File

@@ -20,12 +20,13 @@ func (s *Server) handleWebUI(w http.ResponseWriter, r *http.Request) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if s.token != "" {
if s.token != "" && s.isBearerAuthorized(r) {
http.SetCookie(w, &http.Cookie{
Name: "clawgo_webui_token",
Value: s.token,
Path: "/",
HttpOnly: true,
Secure: requestUsesTLS(r),
SameSite: http.SameSiteLaxMode,
MaxAge: 86400,
})
@@ -118,7 +119,7 @@ func (s *Server) handleWebUIUpload(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{"ok": true, "path": path, "name": h.Filename})
writeJSON(w, map[string]interface{}{"ok": true, "media": name, "name": h.Filename})
}
func gatewayBuildVersion() string {
@@ -165,14 +166,14 @@ const webUIHTML = `<!doctype html>
<div>Session: <input id="session" value="webui:default"/> <input id="msg" placeholder="message" style="width:420px"/> <input id="file" type="file"/> <button onclick="sendChat()">Send</button></div>
<div id="chatlog"></div>
<script>
function auth(){const t=document.getElementById('token').value.trim();return t?('?token='+encodeURIComponent(t)):''}
async function loadCfg(){let r=await fetch('/api/config'+auth());document.getElementById('cfg').value=await r.text()}
async function saveCfg(){let j=JSON.parse(document.getElementById('cfg').value);let r=await fetch('/api/config'+auth(),{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(j)});alert(await r.text())}
function authHeaders(extra){const h=Object.assign({},extra||{});const t=document.getElementById('token').value.trim();if(t)h['Authorization']='Bearer '+t;return h}
async function loadCfg(){let r=await fetch('/api/config',{headers:authHeaders()});document.getElementById('cfg').value=await r.text()}
async function saveCfg(){let j=JSON.parse(document.getElementById('cfg').value);let r=await fetch('/api/config',{method:'POST',headers:authHeaders({'Content-Type':'application/json'}),body:JSON.stringify(j)});alert(await r.text())}
async function sendChat(){
let media='';const f=document.getElementById('file').files[0];
if(f){let fd=new FormData();fd.append('file',f);let ur=await fetch('/api/upload'+auth(),{method:'POST',body:fd});let uj=await ur.json();media=uj.path||''}
if(f){let fd=new FormData();fd.append('file',f);let ur=await fetch('/api/upload',{method:'POST',headers:authHeaders(),body:fd});let uj=await ur.json();media=uj.media||uj.name||''}
const payload={session:document.getElementById('session').value,message:document.getElementById('msg').value,media};
let r=await fetch('/api/chat'+auth(),{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(payload)});let t=await r.text();
let r=await fetch('/api/chat',{method:'POST',headers:authHeaders({'Content-Type':'application/json'}),body:JSON.stringify(payload)});let t=await r.text();
document.getElementById('chatlog').textContent += '\nUSER> '+payload.message+(media?(' [file:'+media+']'):'')+'\nBOT> '+t+'\n';
}
loadCfg();

View File

@@ -804,7 +804,7 @@ func SaveConfig(path string, cfg *Config) error {
return err
}
return os.WriteFile(path, data, 0644)
return os.WriteFile(path, data, 0600)
}
func (c *Config) WorkspacePath() string {

View File

@@ -0,0 +1,28 @@
package config
import (
"os"
"path/filepath"
"runtime"
"testing"
)
func TestSaveConfigUsesOwnerOnlyPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("permission bits are not reliable on windows")
}
path := filepath.Join(t.TempDir(), "config.json")
if err := SaveConfig(path, DefaultConfig()); err != nil {
t.Fatalf("save config: %v", err)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("stat config: %v", err)
}
if got := info.Mode().Perm(); got != 0o600 {
t.Fatalf("expected 0600 permissions, got %o", got)
}
}

View File

@@ -9,16 +9,31 @@ import (
)
func resolveToolPath(baseDir, path string) (string, error) {
path = strings.TrimSpace(path)
if path == "" {
return "", fmt.Errorf("path is required")
}
if filepath.IsAbs(path) {
return filepath.Clean(path), nil
return "", fmt.Errorf("absolute path is not allowed")
}
joined := path
if baseDir != "" {
return filepath.Clean(filepath.Join(baseDir, path)), nil
joined = filepath.Join(baseDir, path)
}
abs, err := filepath.Abs(path)
abs, err := filepath.Abs(joined)
if err != nil {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
if baseDir != "" {
absBase, err := filepath.Abs(baseDir)
if err != nil {
return "", fmt.Errorf("failed to resolve base path: %w", err)
}
rel, err := filepath.Rel(absBase, abs)
if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return "", fmt.Errorf("path escapes allowed directory")
}
}
return abs, nil
}

View File

@@ -0,0 +1,44 @@
package tools
import (
"context"
"os"
"path/filepath"
"testing"
)
func TestResolveToolPathRejectsAbsolutePath(t *testing.T) {
t.Parallel()
base := t.TempDir()
if _, err := resolveToolPath(base, "/tmp/outside.txt"); err == nil {
t.Fatalf("expected absolute path to be rejected")
}
}
func TestResolveToolPathRejectsTraversal(t *testing.T) {
t.Parallel()
base := t.TempDir()
if _, err := resolveToolPath(base, "../outside.txt"); err == nil {
t.Fatalf("expected traversal path to be rejected")
}
}
func TestReadFileToolAllowsWorkspaceRelativePath(t *testing.T) {
t.Parallel()
base := t.TempDir()
path := filepath.Join(base, "notes.txt")
if err := os.WriteFile(path, []byte("hello"), 0o644); err != nil {
t.Fatalf("write fixture: %v", err)
}
tool := NewReadFileTool(base)
got, err := tool.Execute(context.Background(), map[string]interface{}{"path": "notes.txt"})
if err != nil {
t.Fatalf("read file: %v", err)
}
if got != "hello" {
t.Fatalf("unexpected content %q", got)
}
}