package restapi import ( "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func init() { gin.SetMode(gin.TestMode) } func newCORSRouter(allowedOrigins []string) *gin.Engine { router := gin.New() router.Use(withCORS(allowedOrigins)) router.GET("/api/v1/public/probe", func(c *gin.Context) { c.JSON(http.StatusOK, statusResponse{Status: "ok"}) }) return router } func TestWithCORSAllowsListedOrigin(t *testing.T) { t.Parallel() router := newCORSRouter([]string{"https://www.galaxy.lan"}) req := httptest.NewRequest(http.MethodGet, "/api/v1/public/probe", nil) req.Header.Set("Origin", "https://www.galaxy.lan") recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, "https://www.galaxy.lan", recorder.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "Origin", recorder.Header().Get("Vary")) assert.Equal(t, "true", recorder.Header().Get("Access-Control-Allow-Credentials")) } func TestWithCORSPreflightShortCircuits(t *testing.T) { t.Parallel() router := newCORSRouter([]string{"https://www.galaxy.lan"}) req := httptest.NewRequest(http.MethodOptions, "/api/v1/public/probe", nil) req.Header.Set("Origin", "https://www.galaxy.lan") req.Header.Set("Access-Control-Request-Method", "POST") req.Header.Set("Access-Control-Request-Headers", "Content-Type, X-Galaxy-Trace") recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusNoContent, recorder.Code) assert.Equal(t, "https://www.galaxy.lan", recorder.Header().Get("Access-Control-Allow-Origin")) assert.Contains(t, recorder.Header().Get("Access-Control-Allow-Methods"), "POST") assert.Equal(t, "Content-Type, X-Galaxy-Trace", recorder.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, "3600", recorder.Header().Get("Access-Control-Max-Age")) assert.Empty(t, recorder.Body.String(), "preflight must not return a body") } func TestWithCORSPreflightFallbackHeadersWhenRequestHeadersMissing(t *testing.T) { t.Parallel() router := newCORSRouter([]string{"https://www.galaxy.lan"}) req := httptest.NewRequest(http.MethodOptions, "/api/v1/public/probe", nil) req.Header.Set("Origin", "https://www.galaxy.lan") recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusNoContent, recorder.Code) assert.Equal(t, "Content-Type, Authorization", recorder.Header().Get("Access-Control-Allow-Headers")) } func TestWithCORSRejectsUnknownOrigin(t *testing.T) { t.Parallel() router := newCORSRouter([]string{"https://www.galaxy.lan"}) req := httptest.NewRequest(http.MethodGet, "/api/v1/public/probe", nil) req.Header.Set("Origin", "https://evil.example.com") recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusOK, recorder.Code, "real call must still succeed; the browser is the one that blocks the response") assert.Empty(t, recorder.Header().Get("Access-Control-Allow-Origin"), "no allow-origin header for rejected origin") } func TestWithCORSPassThroughWithoutOriginHeader(t *testing.T) { t.Parallel() router := newCORSRouter([]string{"https://www.galaxy.lan"}) req := httptest.NewRequest(http.MethodGet, "/api/v1/public/probe", nil) recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusOK, recorder.Code) assert.Empty(t, recorder.Header().Get("Access-Control-Allow-Origin")) } func TestWithCORSDisabledByEmptyConfig(t *testing.T) { t.Parallel() router := newCORSRouter(nil) req := httptest.NewRequest(http.MethodGet, "/api/v1/public/probe", nil) req.Header.Set("Origin", "https://www.galaxy.lan") recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusOK, recorder.Code) assert.Empty(t, recorder.Header().Get("Access-Control-Allow-Origin")) }