package restapi import ( "context" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "galaxy/gateway/internal/app" "galaxy/gateway/internal/config" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestPublicHandlerHealthEndpoints(t *testing.T) { t.Parallel() handler := newPublicHandler(ServerDependencies{}) tests := []struct { name string method string target string wantStatus int wantType string wantBody string wantAllow string }{ { name: "healthz", method: http.MethodGet, target: "/healthz", wantStatus: http.StatusOK, wantType: jsonContentType, wantBody: `{"status":"ok"}`, }, { name: "readyz", method: http.MethodGet, target: "/readyz", wantStatus: http.StatusOK, wantType: jsonContentType, wantBody: `{"status":"ready"}`, }, { name: "wrong method on known route", method: http.MethodPost, target: "/healthz", wantStatus: http.StatusMethodNotAllowed, wantType: jsonContentType, wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, wantAllow: http.MethodGet, }, { name: "unknown route", method: http.MethodGet, target: "/unknown", wantStatus: http.StatusNotFound, wantType: jsonContentType, wantBody: `{"error":{"code":"not_found","message":"resource was not found"}}`, }, { name: "wrong method on public auth route", method: http.MethodGet, target: "/api/v1/public/auth/send-email-code", wantStatus: http.StatusMethodNotAllowed, wantType: jsonContentType, wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, wantAllow: http.MethodPost, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() req := httptest.NewRequest(tt.method, tt.target, nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, req) assert.Equal(t, tt.wantStatus, recorder.Code) assert.Equal(t, tt.wantType, recorder.Header().Get("Content-Type")) assert.Equal(t, tt.wantBody, recorder.Body.String()) assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow")) }) } } func TestDefaultPublicTrafficClassifier(t *testing.T) { t.Parallel() classifier := defaultPublicTrafficClassifier{} tests := []struct { name string method string target string accept string wantClass PublicRouteClass }{ { name: "public auth route", method: http.MethodPost, target: "/api/v1/public/auth/send-email-code", wantClass: PublicRouteClassPublicAuth, }, { name: "public auth confirm route", method: http.MethodPost, target: "/api/v1/public/auth/confirm-email-code", wantClass: PublicRouteClassPublicAuth, }, { name: "browser bootstrap route", method: http.MethodGet, target: "/", wantClass: PublicRouteClassBrowserBootstrap, }, { name: "browser asset route", method: http.MethodGet, target: "/assets/app.js", wantClass: PublicRouteClassBrowserAsset, }, { name: "browser asset head request", method: http.MethodHead, target: "/assets/app.js", wantClass: PublicRouteClassBrowserAsset, }, { name: "browser asset extension request", method: http.MethodGet, target: "/manifest.webmanifest", wantClass: PublicRouteClassBrowserAsset, }, { name: "public misc route", method: http.MethodPost, target: "/api/v1/public/unknown", wantClass: PublicRouteClassPublicMisc, }, { name: "html accept bootstrap route", method: http.MethodGet, target: "/app", accept: "application/json, text/html;q=0.9", wantClass: PublicRouteClassBrowserBootstrap, }, { name: "public auth wins over browser accept header", method: http.MethodPost, target: "/api/v1/public/auth/confirm-email-code", accept: "text/html", wantClass: PublicRouteClassPublicAuth, }, { name: "probe with html accept is bootstrap", method: http.MethodGet, target: "/healthz", accept: "text/html", wantClass: PublicRouteClassBrowserBootstrap, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() req := httptest.NewRequest(tt.method, tt.target, nil) if tt.accept != "" { req.Header.Set("Accept", tt.accept) } assert.Equal(t, tt.wantClass, classifier.Classify(req)) }) } } func TestPublicRouteClassNormalized(t *testing.T) { t.Parallel() tests := []struct { name string input PublicRouteClass want PublicRouteClass }{ { name: "public auth", input: PublicRouteClassPublicAuth, want: PublicRouteClassPublicAuth, }, { name: "browser bootstrap", input: PublicRouteClassBrowserBootstrap, want: PublicRouteClassBrowserBootstrap, }, { name: "browser asset", input: PublicRouteClassBrowserAsset, want: PublicRouteClassBrowserAsset, }, { name: "public misc", input: PublicRouteClassPublicMisc, want: PublicRouteClassPublicMisc, }, { name: "unknown collapses to misc", input: PublicRouteClass("unexpected"), want: PublicRouteClassPublicMisc, }, { name: "empty collapses to misc", input: PublicRouteClass(""), want: PublicRouteClassPublicMisc, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() assert.Equal(t, tt.want, tt.input.Normalized()) }) } } func TestPublicRouteClassBaseBucketKeyIsolation(t *testing.T) { t.Parallel() tests := []struct { name string class PublicRouteClass wantKey string }{ { name: "public auth", class: PublicRouteClassPublicAuth, wantKey: "public_rest/class=public_auth", }, { name: "browser bootstrap", class: PublicRouteClassBrowserBootstrap, wantKey: "public_rest/class=browser_bootstrap", }, { name: "browser asset", class: PublicRouteClassBrowserAsset, wantKey: "public_rest/class=browser_asset", }, { name: "public misc", class: PublicRouteClassPublicMisc, wantKey: "public_rest/class=public_misc", }, { name: "unknown collapses to misc namespace", class: PublicRouteClass("unexpected"), wantKey: "public_rest/class=public_misc", }, } seenKeys := make(map[string]PublicRouteClass, len(tests)) for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() assert.Equal(t, tt.wantKey, tt.class.BaseBucketKey()) }) normalizedClass := tt.class.Normalized() if normalizedClass == PublicRouteClassPublicMisc && tt.class != PublicRouteClassPublicMisc { continue } if previousClass, exists := seenKeys[tt.wantKey]; exists { require.FailNowf(t, "bucket key collision", "class %q collides with %q on key %q", tt.class, previousClass, tt.wantKey) } seenKeys[tt.wantKey] = tt.class } assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserBootstrap.BaseBucketKey()) assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserAsset.BaseBucketKey()) } func TestWithPublicRouteClassStoresClassInContext(t *testing.T) { t.Parallel() router := gin.New() router.Use(withPublicRouteClass(staticClassifier{class: PublicRouteClassBrowserAsset})) router.GET("/assets/app.js", func(c *gin.Context) { class, ok := PublicRouteClassFromContext(c.Request.Context()) require.True(t, ok) assert.Equal(t, PublicRouteClassBrowserAsset, class) c.JSON(http.StatusOK, statusResponse{Status: "ok"}) }) req := httptest.NewRequest(http.MethodGet, "/assets/app.js", nil) recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, `{"status":"ok"}`, recorder.Body.String()) } func TestWithPublicRouteClassNormalizesUnsupportedClassToPublicMisc(t *testing.T) { t.Parallel() tests := []struct { name string class PublicRouteClass }{ { name: "unknown class", class: PublicRouteClass("unexpected"), }, { name: "empty class", class: PublicRouteClass(""), }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() router := gin.New() router.Use(withPublicRouteClass(staticClassifier{class: tt.class})) router.GET("/", func(c *gin.Context) { class, ok := PublicRouteClassFromContext(c.Request.Context()) require.True(t, ok) assert.Equal(t, PublicRouteClassPublicMisc, class) c.JSON(http.StatusOK, statusResponse{Status: "ok"}) }) req := httptest.NewRequest(http.MethodGet, "/", nil) recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, `{"status":"ok"}`, recorder.Body.String()) }) } } func TestServerLifecycle(t *testing.T) { t.Parallel() cfg := config.Config{ ShutdownTimeout: time.Second, PublicHTTP: func() config.PublicHTTPConfig { publicHTTPCfg := config.DefaultPublicHTTPConfig() publicHTTPCfg.Addr = "127.0.0.1:0" publicHTTPCfg.AntiAbuse.PublicMisc.RateLimit = config.PublicRateLimitConfig{ Requests: 1000, Window: time.Minute, Burst: 1000, } return publicHTTPCfg }(), } server := NewServer(cfg.PublicHTTP, ServerDependencies{}) application := app.New(cfg, server) ctx, cancel := context.WithCancel(context.Background()) defer cancel() resultCh := make(chan error, 1) go func() { resultCh <- application.Run(ctx) }() addr := waitForListenAddr(t, server) waitForHealthResponse(t, addr) cancel() select { case err := <-resultCh: require.NoError(t, err) case <-time.After(2 * time.Second): require.FailNow(t, "Run() did not return after cancellation") } } type staticClassifier struct { class PublicRouteClass } func (c staticClassifier) Classify(*http.Request) PublicRouteClass { return c.class } func waitForListenAddr(t *testing.T, server *Server) string { t.Helper() deadline := time.Now().Add(time.Second) for time.Now().Before(deadline) { if addr := server.listenAddr(); addr != "" { return addr } time.Sleep(10 * time.Millisecond) } require.FailNow(t, "server did not start listening") return "" } func waitForHealthResponse(t *testing.T, addr string) { t.Helper() client := &http.Client{Timeout: 100 * time.Millisecond} url := "http://" + addr + "/healthz" deadline := time.Now().Add(time.Second) for time.Now().Before(deadline) { resp, err := client.Get(url) if err != nil { time.Sleep(10 * time.Millisecond) continue } body, readErr := io.ReadAll(resp.Body) closeErr := resp.Body.Close() require.NoError(t, readErr) require.NoError(t, closeErr) require.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, `{"status":"ok"}`, strings.TrimSpace(string(body))) return } require.FailNowf(t, "health check timed out", "url=%s", url) }