460 lines
11 KiB
Go
460 lines
11 KiB
Go
package restapi
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/app"
|
|
"galaxy/gateway/internal/config"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestPublicHandlerHealthEndpoints(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handler := newPublicHandler(ServerDependencies{})
|
|
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
target string
|
|
wantStatus int
|
|
wantType string
|
|
wantBody string
|
|
wantAllow string
|
|
}{
|
|
{
|
|
name: "healthz",
|
|
method: http.MethodGet,
|
|
target: "/healthz",
|
|
wantStatus: http.StatusOK,
|
|
wantType: jsonContentType,
|
|
wantBody: `{"status":"ok"}`,
|
|
},
|
|
{
|
|
name: "readyz",
|
|
method: http.MethodGet,
|
|
target: "/readyz",
|
|
wantStatus: http.StatusOK,
|
|
wantType: jsonContentType,
|
|
wantBody: `{"status":"ready"}`,
|
|
},
|
|
{
|
|
name: "wrong method on known route",
|
|
method: http.MethodPost,
|
|
target: "/healthz",
|
|
wantStatus: http.StatusMethodNotAllowed,
|
|
wantType: jsonContentType,
|
|
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
|
|
wantAllow: http.MethodGet,
|
|
},
|
|
{
|
|
name: "unknown route",
|
|
method: http.MethodGet,
|
|
target: "/unknown",
|
|
wantStatus: http.StatusNotFound,
|
|
wantType: jsonContentType,
|
|
wantBody: `{"error":{"code":"not_found","message":"resource was not found"}}`,
|
|
},
|
|
{
|
|
name: "wrong method on public auth route",
|
|
method: http.MethodGet,
|
|
target: "/api/v1/public/auth/send-email-code",
|
|
wantStatus: http.StatusMethodNotAllowed,
|
|
wantType: jsonContentType,
|
|
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
|
|
wantAllow: http.MethodPost,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
req := httptest.NewRequest(tt.method, tt.target, nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(recorder, req)
|
|
|
|
assert.Equal(t, tt.wantStatus, recorder.Code)
|
|
assert.Equal(t, tt.wantType, recorder.Header().Get("Content-Type"))
|
|
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
|
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultPublicTrafficClassifier(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
classifier := defaultPublicTrafficClassifier{}
|
|
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
target string
|
|
accept string
|
|
wantClass PublicRouteClass
|
|
}{
|
|
{
|
|
name: "public auth route",
|
|
method: http.MethodPost,
|
|
target: "/api/v1/public/auth/send-email-code",
|
|
wantClass: PublicRouteClassPublicAuth,
|
|
},
|
|
{
|
|
name: "public auth confirm route",
|
|
method: http.MethodPost,
|
|
target: "/api/v1/public/auth/confirm-email-code",
|
|
wantClass: PublicRouteClassPublicAuth,
|
|
},
|
|
{
|
|
name: "browser bootstrap route",
|
|
method: http.MethodGet,
|
|
target: "/",
|
|
wantClass: PublicRouteClassBrowserBootstrap,
|
|
},
|
|
{
|
|
name: "browser asset route",
|
|
method: http.MethodGet,
|
|
target: "/assets/app.js",
|
|
wantClass: PublicRouteClassBrowserAsset,
|
|
},
|
|
{
|
|
name: "browser asset head request",
|
|
method: http.MethodHead,
|
|
target: "/assets/app.js",
|
|
wantClass: PublicRouteClassBrowserAsset,
|
|
},
|
|
{
|
|
name: "browser asset extension request",
|
|
method: http.MethodGet,
|
|
target: "/manifest.webmanifest",
|
|
wantClass: PublicRouteClassBrowserAsset,
|
|
},
|
|
{
|
|
name: "public misc route",
|
|
method: http.MethodPost,
|
|
target: "/api/v1/public/unknown",
|
|
wantClass: PublicRouteClassPublicMisc,
|
|
},
|
|
{
|
|
name: "html accept bootstrap route",
|
|
method: http.MethodGet,
|
|
target: "/app",
|
|
accept: "application/json, text/html;q=0.9",
|
|
wantClass: PublicRouteClassBrowserBootstrap,
|
|
},
|
|
{
|
|
name: "public auth wins over browser accept header",
|
|
method: http.MethodPost,
|
|
target: "/api/v1/public/auth/confirm-email-code",
|
|
accept: "text/html",
|
|
wantClass: PublicRouteClassPublicAuth,
|
|
},
|
|
{
|
|
name: "probe with html accept is bootstrap",
|
|
method: http.MethodGet,
|
|
target: "/healthz",
|
|
accept: "text/html",
|
|
wantClass: PublicRouteClassBrowserBootstrap,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
req := httptest.NewRequest(tt.method, tt.target, nil)
|
|
if tt.accept != "" {
|
|
req.Header.Set("Accept", tt.accept)
|
|
}
|
|
|
|
assert.Equal(t, tt.wantClass, classifier.Classify(req))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPublicRouteClassNormalized(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
input PublicRouteClass
|
|
want PublicRouteClass
|
|
}{
|
|
{
|
|
name: "public auth",
|
|
input: PublicRouteClassPublicAuth,
|
|
want: PublicRouteClassPublicAuth,
|
|
},
|
|
{
|
|
name: "browser bootstrap",
|
|
input: PublicRouteClassBrowserBootstrap,
|
|
want: PublicRouteClassBrowserBootstrap,
|
|
},
|
|
{
|
|
name: "browser asset",
|
|
input: PublicRouteClassBrowserAsset,
|
|
want: PublicRouteClassBrowserAsset,
|
|
},
|
|
{
|
|
name: "public misc",
|
|
input: PublicRouteClassPublicMisc,
|
|
want: PublicRouteClassPublicMisc,
|
|
},
|
|
{
|
|
name: "unknown collapses to misc",
|
|
input: PublicRouteClass("unexpected"),
|
|
want: PublicRouteClassPublicMisc,
|
|
},
|
|
{
|
|
name: "empty collapses to misc",
|
|
input: PublicRouteClass(""),
|
|
want: PublicRouteClassPublicMisc,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, tt.want, tt.input.Normalized())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPublicRouteClassBaseBucketKeyIsolation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
class PublicRouteClass
|
|
wantKey string
|
|
}{
|
|
{
|
|
name: "public auth",
|
|
class: PublicRouteClassPublicAuth,
|
|
wantKey: "public_rest/class=public_auth",
|
|
},
|
|
{
|
|
name: "browser bootstrap",
|
|
class: PublicRouteClassBrowserBootstrap,
|
|
wantKey: "public_rest/class=browser_bootstrap",
|
|
},
|
|
{
|
|
name: "browser asset",
|
|
class: PublicRouteClassBrowserAsset,
|
|
wantKey: "public_rest/class=browser_asset",
|
|
},
|
|
{
|
|
name: "public misc",
|
|
class: PublicRouteClassPublicMisc,
|
|
wantKey: "public_rest/class=public_misc",
|
|
},
|
|
{
|
|
name: "unknown collapses to misc namespace",
|
|
class: PublicRouteClass("unexpected"),
|
|
wantKey: "public_rest/class=public_misc",
|
|
},
|
|
}
|
|
|
|
seenKeys := make(map[string]PublicRouteClass, len(tests))
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, tt.wantKey, tt.class.BaseBucketKey())
|
|
})
|
|
|
|
normalizedClass := tt.class.Normalized()
|
|
if normalizedClass == PublicRouteClassPublicMisc && tt.class != PublicRouteClassPublicMisc {
|
|
continue
|
|
}
|
|
|
|
if previousClass, exists := seenKeys[tt.wantKey]; exists {
|
|
require.FailNowf(t, "bucket key collision", "class %q collides with %q on key %q", tt.class, previousClass, tt.wantKey)
|
|
}
|
|
|
|
seenKeys[tt.wantKey] = tt.class
|
|
}
|
|
|
|
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserBootstrap.BaseBucketKey())
|
|
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserAsset.BaseBucketKey())
|
|
}
|
|
|
|
func TestWithPublicRouteClassStoresClassInContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
router := gin.New()
|
|
router.Use(withPublicRouteClass(staticClassifier{class: PublicRouteClassBrowserAsset}))
|
|
router.GET("/assets/app.js", func(c *gin.Context) {
|
|
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
|
require.True(t, ok)
|
|
assert.Equal(t, PublicRouteClassBrowserAsset, class)
|
|
|
|
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/assets/app.js", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
|
|
}
|
|
|
|
func TestWithPublicRouteClassNormalizesUnsupportedClassToPublicMisc(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
class PublicRouteClass
|
|
}{
|
|
{
|
|
name: "unknown class",
|
|
class: PublicRouteClass("unexpected"),
|
|
},
|
|
{
|
|
name: "empty class",
|
|
class: PublicRouteClass(""),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
router := gin.New()
|
|
router.Use(withPublicRouteClass(staticClassifier{class: tt.class}))
|
|
router.GET("/", func(c *gin.Context) {
|
|
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
|
require.True(t, ok)
|
|
assert.Equal(t, PublicRouteClassPublicMisc, class)
|
|
|
|
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestServerLifecycle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cfg := config.Config{
|
|
ShutdownTimeout: time.Second,
|
|
PublicHTTP: func() config.PublicHTTPConfig {
|
|
publicHTTPCfg := config.DefaultPublicHTTPConfig()
|
|
publicHTTPCfg.Addr = "127.0.0.1:0"
|
|
publicHTTPCfg.AntiAbuse.PublicMisc.RateLimit = config.PublicRateLimitConfig{
|
|
Requests: 1000,
|
|
Window: time.Minute,
|
|
Burst: 1000,
|
|
}
|
|
return publicHTTPCfg
|
|
}(),
|
|
}
|
|
|
|
server := NewServer(cfg.PublicHTTP, ServerDependencies{})
|
|
application := app.New(cfg, server)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
resultCh := make(chan error, 1)
|
|
go func() {
|
|
resultCh <- application.Run(ctx)
|
|
}()
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
waitForHealthResponse(t, addr)
|
|
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-resultCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(2 * time.Second):
|
|
require.FailNow(t, "Run() did not return after cancellation")
|
|
}
|
|
}
|
|
|
|
type staticClassifier struct {
|
|
class PublicRouteClass
|
|
}
|
|
|
|
func (c staticClassifier) Classify(*http.Request) PublicRouteClass {
|
|
return c.class
|
|
}
|
|
|
|
func waitForListenAddr(t *testing.T, server *Server) string {
|
|
t.Helper()
|
|
|
|
deadline := time.Now().Add(time.Second)
|
|
for time.Now().Before(deadline) {
|
|
if addr := server.listenAddr(); addr != "" {
|
|
return addr
|
|
}
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
require.FailNow(t, "server did not start listening")
|
|
return ""
|
|
}
|
|
|
|
func waitForHealthResponse(t *testing.T, addr string) {
|
|
t.Helper()
|
|
|
|
client := &http.Client{Timeout: 100 * time.Millisecond}
|
|
url := "http://" + addr + "/healthz"
|
|
deadline := time.Now().Add(time.Second)
|
|
|
|
for time.Now().Before(deadline) {
|
|
resp, err := client.Get(url)
|
|
if err != nil {
|
|
time.Sleep(10 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
body, readErr := io.ReadAll(resp.Body)
|
|
closeErr := resp.Body.Close()
|
|
require.NoError(t, readErr)
|
|
require.NoError(t, closeErr)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, `{"status":"ok"}`, strings.TrimSpace(string(body)))
|
|
|
|
return
|
|
}
|
|
|
|
require.FailNowf(t, "health check timed out", "url=%s", url)
|
|
}
|