diff --git a/pkg/api/server.go b/pkg/api/server.go index 27a91fc..353380e 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -238,10 +238,25 @@ func (s *Server) withCORS(next http.Handler) http.Handler { next = http.NotFoundHandler() } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With") + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin != "" { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Add("Vary", "Origin") + } else { + w.Header().Set("Access-Control-Allow-Origin", "*") + } + allowMethods := strings.TrimSpace(r.Header.Get("Access-Control-Request-Method")) + if allowMethods == "" { + allowMethods = "GET, POST, PUT, PATCH, DELETE, OPTIONS" + } + w.Header().Set("Access-Control-Allow-Methods", allowMethods) + allowHeaders := strings.TrimSpace(r.Header.Get("Access-Control-Request-Headers")) + if allowHeaders == "" { + allowHeaders = "Authorization, Content-Type, X-Requested-With, Accept, Origin, Cache-Control, Pragma" + } + w.Header().Set("Access-Control-Allow-Headers", allowHeaders) w.Header().Set("Access-Control-Expose-Headers", "*") + w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index c7e34ec..c8d691a 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -172,6 +172,35 @@ func TestHandleWebUIConfigPostSavesNormalizedConfig(t *testing.T) { } } +func TestWithCORSEchoesPreflightHeaders(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "") + handler := srv.withCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodOptions, "/api/config", nil) + req.Header.Set("Origin", "https://dash.clawgo.dev") + req.Header.Set("Access-Control-Request-Method", "POST") + req.Header.Set("Access-Control-Request-Headers", "authorization,content-type,x-clawgo-client") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("expected 204, got %d", rec.Code) + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "https://dash.clawgo.dev" { + t.Fatalf("unexpected allow origin: %q", got) + } + if got := rec.Header().Get("Access-Control-Allow-Methods"); got != "POST" { + t.Fatalf("unexpected allow methods: %q", got) + } + if got := rec.Header().Get("Access-Control-Allow-Headers"); got != "authorization,content-type,x-clawgo-client" { + t.Fatalf("unexpected allow headers: %q", got) + } +} + func TestHandleWebUISessionsHidesInternalSessionsByDefault(t *testing.T) { t.Parallel()