Files
clawgo/pkg/api/server_security_test.go
2026-03-15 23:46:06 +08:00

252 lines
7.9 KiB
Go

package api
import (
"bytes"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/YspCoder/clawgo/pkg/nodes"
"github.com/gorilla/websocket"
)
func TestCheckAuthAllowsBearerAndCookieOnly(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "secret-token", nil)
bearerReq := httptest.NewRequest(http.MethodGet, "/", nil)
bearerReq.Header.Set("Authorization", "Bearer secret-token")
if !srv.checkAuth(bearerReq) {
t.Fatalf("expected bearer auth to succeed")
}
cookieReq := httptest.NewRequest(http.MethodGet, "/", nil)
cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"})
if !srv.checkAuth(cookieReq) {
t.Fatalf("expected cookie auth to succeed")
}
queryReq := httptest.NewRequest(http.MethodGet, "/?token=secret-token", nil)
if srv.checkAuth(queryReq) {
t.Fatalf("expected query token auth to fail")
}
refererReq := httptest.NewRequest(http.MethodGet, "/", nil)
refererReq.Header.Set("Referer", "https://example.com/?token=secret-token")
if srv.checkAuth(refererReq) {
t.Fatalf("expected referer token auth to fail")
}
}
func TestWithCORSRejectsInvalidOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil)
req.Host = "example.com"
req.Header.Set("Origin", "javascript:alert(1)")
rec := httptest.NewRecorder()
srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", rec.Code)
}
}
func TestWithCORSAcceptsSameOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil)
req.Host = "example.com"
req.Header.Set("Origin", "http://example.com")
rec := httptest.NewRecorder()
srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rec.Code)
}
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Fatalf("unexpected allow-origin header %q", got)
}
}
func TestWithCORSAcceptsCrossOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com/api/config", nil)
req.Host = "example.com"
req.Header.Set("Origin", "https://web.example")
rec := httptest.NewRecorder()
srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rec.Code)
}
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "https://web.example" {
t.Fatalf("unexpected allow-origin header %q", got)
}
if got := rec.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("expected allow credentials, got %q", got)
}
}
func TestHandleNodeConnectRejectsInvalidOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nodes.NewManager())
mux := http.NewServeMux()
mux.HandleFunc("/nodes/connect", srv.handleNodeConnect)
httpSrv := httptest.NewServer(mux)
defer httpSrv.Close()
wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/nodes/connect"
dialer := websocket.Dialer{}
headers := http.Header{"Origin": []string{"javascript:alert(1)"}}
conn, resp, err := dialer.Dial(wsURL, headers)
if err == nil {
conn.Close()
t.Fatalf("expected websocket handshake to fail")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected 403 response, got %#v", resp)
}
}
func TestHandleNodeConnectAcceptsCrossOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nodes.NewManager())
mux := http.NewServeMux()
mux.HandleFunc("/nodes/connect", srv.handleNodeConnect)
httpSrv := httptest.NewServer(mux)
defer httpSrv.Close()
wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/nodes/connect"
dialer := websocket.Dialer{}
headers := http.Header{"Origin": []string{"https://web.example"}}
conn, resp, err := dialer.Dial(wsURL, headers)
if err != nil {
t.Fatalf("expected websocket handshake to succeed, resp=%#v err=%v", resp, err)
}
_ = conn.Close()
}
func TestHandleWebUISetsCookieForBearerOnly(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "secret-token", nil)
bearerReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
bearerReq.Header.Set("Authorization", "Bearer secret-token")
bearerRec := httptest.NewRecorder()
srv.handleWebUI(bearerRec, bearerReq)
if len(bearerRec.Result().Cookies()) == 0 {
t.Fatalf("expected bearer-authenticated UI request to set cookie")
}
cookieReq := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
cookieReq.AddCookie(&http.Cookie{Name: "clawgo_webui_token", Value: "secret-token"})
cookieRec := httptest.NewRecorder()
srv.handleWebUI(cookieRec, cookieReq)
if len(cookieRec.Result().Cookies()) != 0 {
t.Fatalf("expected cookie-authenticated UI request not to reset cookie")
}
}
func TestHandleWebUIAuthSessionSetsCrossSiteCookieForAllowedOrigin(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "secret-token", nil)
req := httptest.NewRequest(http.MethodPost, "http://gateway.example/api/auth/session", nil)
req.Host = "gateway.example"
req.Header.Set("Origin", "https://web.example")
req.Header.Set("Authorization", "Bearer secret-token")
rec := httptest.NewRecorder()
srv.withCORS(http.HandlerFunc(srv.handleWebUIAuthSession)).ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
}
cookies := rec.Result().Cookies()
if len(cookies) != 1 {
t.Fatalf("expected one cookie, got %d", len(cookies))
}
if cookies[0].SameSite != http.SameSiteNoneMode {
t.Fatalf("expected SameSite=None for cross-site session, got %v", cookies[0].SameSite)
}
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "https://web.example" {
t.Fatalf("unexpected allow-origin header %q", got)
}
}
func TestHandleWebUIUploadDoesNotExposeAbsolutePath(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
var form bytes.Buffer
mw := multipartWriter(t, &form, "file", "demo.txt", []byte("upload-body"))
req := httptest.NewRequest(http.MethodPost, "/api/upload", &form)
req.Header.Set("Content-Type", mw.FormDataContentType())
rec := httptest.NewRecorder()
srv.handleWebUIUpload(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
}
var payload map[string]interface{}
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if _, ok := payload["path"]; ok {
t.Fatalf("expected upload response to omit absolute path: %+v", payload)
}
if strings.TrimSpace(payload["media"].(string)) == "" {
t.Fatalf("expected media handle in response: %+v", payload)
}
}
func multipartWriter(t *testing.T, dst *bytes.Buffer, fieldName, filename string, body []byte) *multipart.Writer {
t.Helper()
mw := multipart.NewWriter(dst)
part, err := mw.CreateFormFile(fieldName, filename)
if err != nil {
t.Fatalf("create form file: %v", err)
}
if _, err := part.Write(body); err != nil {
t.Fatalf("write form file: %v", err)
}
if err := mw.Close(); err != nil {
t.Fatalf("close multipart writer: %v", err)
}
return mw
}
func TestHandleWebUIProviderOAuthStartRejectsGet(t *testing.T) {
t.Parallel()
srv := NewServer("127.0.0.1", 0, "", nil)
req := httptest.NewRequest(http.MethodGet, "/api/provider/oauth/start?provider=openai", nil)
rec := httptest.NewRecorder()
srv.handleWebUIProviderOAuthStart(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected 405, got %d", rec.Code)
}
}