feat: edge gateway service

This commit is contained in:
Ilia Denisov
2026-04-02 19:18:42 +02:00
committed by GitHub
parent 8cde99936c
commit 436c97a38b
95 changed files with 20504 additions and 57 deletions
+76
View File
@@ -0,0 +1,76 @@
package restapi
import (
"time"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
if logger == nil {
logger = zap.NewNop()
}
return func(c *gin.Context) {
start := time.Now()
c.Next()
statusCode := c.Writer.Status()
route := c.FullPath()
if route == "" {
route = c.Request.URL.Path
}
class, ok := PublicRouteClassFromContext(c.Request.Context())
if !ok {
class = PublicRouteClassPublicMisc
}
errorCode, _ := c.Get(publicErrorCodeContextKey)
errorCodeValue, _ := errorCode.(string)
outcome := telemetry.OutcomeFromPublicErrorCode(statusCode, errorCodeValue)
rejectReason := telemetry.RejectReason(outcome)
duration := time.Since(start)
attrs := []attribute.KeyValue{
attribute.String("route_class", string(class)),
attribute.String("route", route),
attribute.String("method", c.Request.Method),
attribute.String("edge_outcome", string(outcome)),
}
if rejectReason != "" {
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
}
metrics.RecordPublicRequest(c.Request.Context(), attrs, duration)
fields := []zap.Field{
zap.String("component", "public_http"),
zap.String("transport", "http"),
zap.String("route", route),
zap.String("route_class", string(class)),
zap.String("method", c.Request.Method),
zap.Int("status_code", statusCode),
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
zap.String("edge_outcome", string(outcome)),
}
if rejectReason != "" {
fields = append(fields, zap.String("reject_reason", rejectReason))
}
fields = append(fields, logging.TraceFieldsFromContext(c.Request.Context())...)
switch outcome {
case telemetry.EdgeOutcomeSuccess:
logger.Info("public request completed", fields...)
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeInternalError:
logger.Error("public request failed", fields...)
default:
logger.Warn("public request rejected", fields...)
}
}
}
+30
View File
@@ -0,0 +1,30 @@
package restapi
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/getkin/kin-openapi/openapi3"
"github.com/stretchr/testify/require"
)
func TestPublicOpenAPISpecValidates(t *testing.T) {
t.Parallel()
_, thisFile, _, ok := runtime.Caller(0)
require.True(t, ok)
specPath := filepath.Join(filepath.Dir(thisFile), "..", "..", "openapi.yaml")
ctx := context.Background()
loader := openapi3.NewLoader()
doc, err := loader.LoadFromFile(specPath)
require.NoError(t, err)
require.NotNil(t, doc)
require.NotNil(t, doc.Info)
require.Equal(t, "v1", doc.Info.Version)
require.NoError(t, doc.Validate(ctx))
}
@@ -0,0 +1,378 @@
package restapi
import (
"bytes"
"errors"
"io"
"math"
"net"
"net/http"
"path"
"strconv"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
const (
errorCodeRequestTooLarge = "request_too_large"
errorCodeRateLimited = "rate_limited"
publicRESTIPBucketKeySegment = "/ip="
)
var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit")
// PublicMalformedRequestReason identifies the stable malformed-request counter
// dimension recorded by the public REST anti-abuse middleware.
type PublicMalformedRequestReason string
const (
// PublicMalformedRequestReasonEmptyBody records a missing request body.
PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body"
// PublicMalformedRequestReasonMalformedJSON records syntactically malformed
// JSON.
PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json"
// PublicMalformedRequestReasonInvalidJSONValue records JSON values whose
// types do not match the expected request schema.
PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value"
// PublicMalformedRequestReasonUnknownField records JSON objects with fields
// outside the documented schema.
PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field"
// PublicMalformedRequestReasonMultipleJSONObjects records requests that
// contain more than one JSON object.
PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects"
// PublicMalformedRequestReasonOversizedBody records requests whose bodies
// exceed the configured class limit.
PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body"
)
// PublicRateLimitDecision describes the outcome returned by a public REST
// limiter for one request bucket reservation attempt.
type PublicRateLimitDecision struct {
// Allowed reports whether the request may proceed immediately.
Allowed bool
// RetryAfter is the minimum delay the client should wait before retrying
// when Allowed is false.
RetryAfter time.Duration
}
// PublicRequestLimiter applies public REST rate-limit policy to a concrete
// bucket key.
type PublicRequestLimiter interface {
// Reserve evaluates key under policy and returns whether the request may
// proceed immediately.
Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision
}
// PublicRequestObserver captures low-cardinality public REST anti-abuse
// telemetry.
type PublicRequestObserver interface {
// RecordMalformedRequest records one malformed request in class for reason.
RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason)
}
type noopPublicRequestObserver struct{}
func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) {
}
type inMemoryPublicRequestLimiter struct {
now func() time.Time
cleanupInterval time.Duration
mu sync.Mutex
entries map[string]*publicRateLimiterEntry
nextCleanup time.Time
}
type publicRateLimiterEntry struct {
limiter *rate.Limiter
limit rate.Limit
burst int
expiresAt time.Time
}
func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter {
return &inMemoryPublicRequestLimiter{
now: time.Now,
cleanupInterval: time.Minute,
entries: make(map[string]*publicRateLimiterEntry),
}
}
func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision {
now := l.now()
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
l.mu.Lock()
defer l.mu.Unlock()
l.cleanupExpiredBucketsLocked(now)
entry, ok := l.entries[key]
if !ok || entry.limit != limit || entry.burst != policy.Burst {
entry = &publicRateLimiterEntry{
limiter: rate.NewLimiter(limit, policy.Burst),
limit: limit,
burst: policy.Burst,
}
l.entries[key] = entry
}
entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window))
reservation := entry.limiter.ReserveN(now, 1)
if !reservation.OK() {
return PublicRateLimitDecision{
Allowed: false,
RetryAfter: policy.Window,
}
}
retryAfter := reservation.DelayFrom(now)
if retryAfter > 0 {
return PublicRateLimitDecision{
Allowed: false,
RetryAfter: retryAfter,
}
}
return PublicRateLimitDecision{Allowed: true}
}
func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) {
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
return
}
for key, entry := range l.entries {
if !entry.expiresAt.After(now) {
delete(l.entries, key)
}
}
l.nextCleanup = now.Add(l.cleanupInterval)
}
func publicRateLimiterEntryTTL(window time.Duration) time.Duration {
if window < time.Minute {
return time.Minute
}
return 2 * window
}
func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc {
return func(c *gin.Context) {
class, ok := PublicRouteClassFromContext(c.Request.Context())
if !ok {
class = PublicRouteClassPublicMisc
}
allowedMethods := allowedMethodsForRequestShape(c.Request)
if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) {
c.Header("Allow", strings.Join(allowedMethods, ", "))
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
return
}
classPolicy := publicRoutePolicyForClass(policy, class)
bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes)
if err != nil {
switch {
case errors.Is(err, errRequestBodyTooLarge):
observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody)
abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit")
default:
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
}
return
}
clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr)
if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed {
abortRateLimited(c, decision.RetryAfter)
return
}
identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes)
switch {
case err == nil:
identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy)
if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed {
abortRateLimited(c, decision.RetryAfter)
return
}
case errors.Is(err, errPublicAuthIdentityNotApplicable):
default:
if reason, malformed := malformedRequestReasonFromError(err); malformed {
observer.RecordMalformedRequest(class, reason)
}
}
c.Next()
}
}
func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig {
switch class.Normalized() {
case PublicRouteClassPublicAuth:
return policy.PublicAuth
case PublicRouteClassBrowserBootstrap:
return policy.BrowserBootstrap
case PublicRouteClassBrowserAsset:
return policy.BrowserAsset
default:
return policy.PublicMisc
}
}
func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig {
switch requestPath {
case "/api/v1/public/auth/send-email-code":
return policy.SendEmailCodeIdentity
case "/api/v1/public/auth/confirm-email-code":
return policy.ConfirmEmailCodeIdentity
default:
return config.PublicAuthIdentityPolicyConfig{}
}
}
func allowedMethodsForRequestShape(r *http.Request) []string {
switch {
case isPublicAuthPath(r.URL.Path):
return []string{http.MethodPost}
case isProbePath(r.URL.Path):
return []string{http.MethodGet}
case matchesBrowserAssetRequestShape(r):
return []string{http.MethodGet, http.MethodHead}
case matchesBrowserBootstrapRequestShape(r):
return []string{http.MethodGet, http.MethodHead}
default:
return nil
}
}
func isAllowedMethod(method string, allowedMethods []string) bool {
for _, allowedMethod := range allowedMethods {
if method == allowedMethod {
return true
}
}
return false
}
func isPublicAuthPath(requestPath string) bool {
switch requestPath {
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
return true
default:
return false
}
}
func isProbePath(requestPath string) bool {
switch requestPath {
case "/healthz", "/readyz":
return true
default:
return false
}
}
func matchesBrowserBootstrapRequestShape(r *http.Request) bool {
if r.URL.Path == "/" {
return true
}
return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html")
}
func matchesBrowserAssetRequestShape(r *http.Request) bool {
if strings.HasPrefix(r.URL.Path, "/assets/") {
return true
}
switch strings.ToLower(path.Ext(r.URL.Path)) {
case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest":
return true
default:
return false
}
}
func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) {
if r == nil {
return nil, nil
}
if r.Body == nil {
r.Body = io.NopCloser(bytes.NewReader(nil))
return nil, nil
}
bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1))
closeErr := r.Body.Close()
if err != nil {
return nil, err
}
if closeErr != nil {
return nil, closeErr
}
if int64(len(bodyBytes)) > maxBodyBytes {
return nil, errRequestBodyTooLarge
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return bodyBytes, nil
}
func abortRateLimited(c *gin.Context, retryAfter time.Duration) {
c.Header("Retry-After", retryAfterHeaderValue(retryAfter))
abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded")
}
func retryAfterHeaderValue(delay time.Duration) string {
seconds := int64(math.Ceil(delay.Seconds()))
if seconds < 1 {
seconds = 1
}
return strconv.FormatInt(seconds, 10)
}
func clientIPFromRemoteAddr(remoteAddr string) string {
host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr))
if err == nil {
return host
}
remoteAddr = strings.TrimSpace(remoteAddr)
if remoteAddr == "" {
return "unknown"
}
return remoteAddr
}
func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string {
return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP
}
func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string {
return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue
}
@@ -0,0 +1,455 @@
package restapi
import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPublicAntiAbuseRejectsOversizedBodies(t *testing.T) {
t.Parallel()
oversizedJSONBody := `{"email":"` + strings.Repeat("a", 8200) + `@example.com"}`
oversizedConfirmJSONBody := `{"challenge_id":"` + strings.Repeat("c", 8300) + `","code":"123456","client_public_key":"key"}`
tests := []struct {
name string
method string
target string
body string
wantClass PublicRouteClass
}{
{
name: "send email",
method: http.MethodPost,
target: "/api/v1/public/auth/send-email-code",
body: oversizedJSONBody,
wantClass: PublicRouteClassPublicAuth,
},
{
name: "confirm email",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
body: oversizedConfirmJSONBody,
wantClass: PublicRouteClassPublicAuth,
},
{
name: "healthz body",
method: http.MethodGet,
target: "/healthz",
body: `x`,
wantClass: PublicRouteClassPublicMisc,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
observer := &recordingPublicRequestObserver{}
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{ChallengeID: "challenge-123"},
confirmEmailCodeResult: ConfirmEmailCodeResult{
DeviceSessionID: "device-session-123",
},
}
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
AuthService: authService,
Observer: observer,
})
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, `{"error":{"code":"request_too_large","message":"request body exceeds the configured limit"}}`, recorder.Body.String())
assert.Equal(t, 0, authService.sendEmailCodeCalls)
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
assert.Equal(t, []malformedObservation{{
class: tt.wantClass,
reason: PublicMalformedRequestReasonOversizedBody,
}}, observer.snapshot())
})
}
}
func TestPublicAntiAbuseRejectsInvalidMethodsForBrowserShapes(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{})
tests := []struct {
name string
method string
target string
accept string
wantAllow string
}{
{
name: "asset path",
method: http.MethodPost,
target: "/assets/app.js",
wantAllow: "GET, HEAD",
},
{
name: "bootstrap request",
method: http.MethodPost,
target: "/",
accept: "text/html",
wantAllow: "GET, HEAD",
},
{
name: "head probe rejected",
method: http.MethodHead,
target: "/healthz",
wantAllow: http.MethodGet,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, nil)
if tt.accept != "" {
req.Header.Set("Accept", tt.accept)
}
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusMethodNotAllowed, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
assert.Equal(t, `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, recorder.Body.String())
})
}
}
func TestPublicAntiAbuseBrowserClassBucketsStayIsolatedFromPublicAuth(t *testing.T) {
t.Parallel()
tests := []struct {
name string
burstRequest *http.Request
}{
{
name: "browser asset",
burstRequest: httptest.NewRequest(http.MethodGet, "/assets/app.js", nil),
},
{
name: "browser bootstrap",
burstRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Accept", "text/html")
return req
}(),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
cfg := config.DefaultPublicHTTPConfig()
cfg.AntiAbuse.BrowserAsset.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
cfg.AntiAbuse.BrowserBootstrap.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{
ChallengeID: "challenge-123",
},
}
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
tt.burstRequest.RemoteAddr = "192.0.2.10:1234"
firstBurst := httptest.NewRecorder()
handler.ServeHTTP(firstBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
secondBurst := httptest.NewRecorder()
handler.ServeHTTP(secondBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
authReq := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(`{"email":"pilot@example.com"}`))
authReq.Header.Set("Content-Type", "application/json")
authReq.RemoteAddr = "192.0.2.10:1234"
authResp := httptest.NewRecorder()
handler.ServeHTTP(authResp, authReq)
assert.Equal(t, http.StatusNotFound, firstBurst.Code)
assert.Equal(t, http.StatusTooManyRequests, secondBurst.Code)
assert.Equal(t, http.StatusOK, authResp.Code)
assert.Equal(t, `{"challenge_id":"challenge-123"}`, authResp.Body.String())
assert.Equal(t, 1, authService.sendEmailCodeCalls)
})
}
}
func TestPublicAntiAbuseSendEmailIdentityThrottle(t *testing.T) {
t.Parallel()
cfg := config.DefaultPublicHTTPConfig()
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{
ChallengeID: "challenge-123",
},
}
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
first := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
second := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
third := sendEmailCodeRequest(`{"email":"other@example.com"}`)
firstResp := httptest.NewRecorder()
handler.ServeHTTP(firstResp, first)
secondResp := httptest.NewRecorder()
handler.ServeHTTP(secondResp, second)
thirdResp := httptest.NewRecorder()
handler.ServeHTTP(thirdResp, third)
assert.Equal(t, http.StatusOK, firstResp.Code)
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
assert.Equal(t, http.StatusOK, thirdResp.Code)
assert.Equal(t, 2, authService.sendEmailCodeCalls)
thirdInput := authService.sendEmailCodeInput
assert.Equal(t, "other@example.com", thirdInput.Email)
}
func TestPublicAntiAbuseConfirmEmailIdentityThrottle(t *testing.T) {
t.Parallel()
cfg := config.DefaultPublicHTTPConfig()
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.ConfirmEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
authService := &recordingAuthServiceClient{
confirmEmailCodeResult: ConfirmEmailCodeResult{
DeviceSessionID: "device-session-123",
},
}
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
first := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
second := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
third := confirmEmailCodeRequest(`{"challenge_id":"challenge-456","code":"123456","client_public_key":"public-key-material"}`)
firstResp := httptest.NewRecorder()
handler.ServeHTTP(firstResp, first)
secondResp := httptest.NewRecorder()
handler.ServeHTTP(secondResp, second)
thirdResp := httptest.NewRecorder()
handler.ServeHTTP(thirdResp, third)
assert.Equal(t, http.StatusOK, firstResp.Code)
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
assert.Equal(t, http.StatusOK, thirdResp.Code)
assert.Equal(t, 2, authService.confirmEmailCodeCalls)
assert.Equal(t, "challenge-456", authService.confirmEmailCodeInput.ChallengeID)
}
func TestPublicAntiAbuseMalformedTelemetry(t *testing.T) {
t.Parallel()
tests := []struct {
name string
body string
wantReason PublicMalformedRequestReason
wantRecords int
}{
{
name: "empty body",
body: ``,
wantReason: PublicMalformedRequestReasonEmptyBody,
wantRecords: 1,
},
{
name: "malformed json",
body: `{"email":`,
wantReason: PublicMalformedRequestReasonMalformedJSON,
wantRecords: 1,
},
{
name: "invalid json value",
body: `{"email":123}`,
wantReason: PublicMalformedRequestReasonInvalidJSONValue,
wantRecords: 1,
},
{
name: "unknown field",
body: `{"email":"pilot@example.com","extra":"x"}`,
wantReason: PublicMalformedRequestReasonUnknownField,
wantRecords: 1,
},
{
name: "multiple objects",
body: `{"email":"pilot@example.com"}{"email":"pilot@example.com"}`,
wantReason: PublicMalformedRequestReasonMultipleJSONObjects,
wantRecords: 1,
},
{
name: "validation error does not count as malformed",
body: `{"email":"not-an-email"}`,
wantRecords: 0,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
observer := &recordingPublicRequestObserver{}
authService := &recordingAuthServiceClient{}
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
AuthService: authService,
Observer: observer,
})
req := sendEmailCodeRequest(tt.body)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Equal(t, tt.wantRecords, len(observer.snapshot()))
assert.Equal(t, 0, authService.sendEmailCodeCalls)
if tt.wantRecords == 1 {
assert.Equal(t, malformedObservation{
class: PublicRouteClassPublicAuth,
reason: tt.wantReason,
}, observer.snapshot()[0])
}
})
}
}
func TestInMemoryPublicRequestLimiterCleansExpiredBuckets(t *testing.T) {
t.Parallel()
now := time.Unix(1000, 0)
limiter := newInMemoryPublicRequestLimiter()
limiter.now = func() time.Time {
return now
}
limiter.cleanupInterval = time.Second
policy := config.PublicRateLimitConfig{
Requests: 1,
Window: time.Minute,
Burst: 1,
}
firstDecision := limiter.Reserve("bucket-1", policy)
secondDecision := limiter.Reserve("bucket-2", policy)
require.True(t, firstDecision.Allowed)
require.True(t, secondDecision.Allowed)
require.Len(t, limiter.entries, 2)
now = now.Add(3 * time.Minute)
thirdDecision := limiter.Reserve("bucket-3", policy)
require.True(t, thirdDecision.Allowed)
assert.Len(t, limiter.entries, 1)
_, exists := limiter.entries["bucket-3"]
assert.True(t, exists)
}
func sendEmailCodeRequest(body string) *http.Request {
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "192.0.2.10:1234"
return req
}
func confirmEmailCodeRequest(body string) *http.Request {
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/confirm-email-code", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "192.0.2.10:1234"
return req
}
type malformedObservation struct {
class PublicRouteClass
reason PublicMalformedRequestReason
}
type recordingPublicRequestObserver struct {
mu sync.Mutex
observations []malformedObservation
}
func (o *recordingPublicRequestObserver) RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason) {
o.mu.Lock()
defer o.mu.Unlock()
o.observations = append(o.observations, malformedObservation{
class: class,
reason: reason,
})
}
func (o *recordingPublicRequestObserver) snapshot() []malformedObservation {
o.mu.Lock()
defer o.mu.Unlock()
return append([]malformedObservation(nil), o.observations...)
}
+446
View File
@@ -0,0 +1,446 @@
package restapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/mail"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var errPublicAuthIdentityNotApplicable = errors.New("public auth identity does not apply to this route")
type malformedJSONRequestError struct {
message string
reason PublicMalformedRequestReason
}
func (e *malformedJSONRequestError) Error() string {
if e == nil {
return ""
}
return e.message
}
type publicAuthIdentity struct {
kind string
value string
}
// AuthServiceClient defines the consumer-side contract used by public auth
// REST handlers to delegate unauthenticated authentication commands to the
// Auth / Session Service.
type AuthServiceClient interface {
// SendEmailCode starts a login challenge for input.Email and returns the
// challenge identifier that the client must later confirm.
SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error)
// ConfirmEmailCode completes a previously issued challenge, registers
// input.ClientPublicKey for the new device session, and returns the created
// device session identifier.
ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error)
}
// SendEmailCodeInput describes the public REST and adapter payload used to
// request a login code for a single e-mail address.
type SendEmailCodeInput struct {
// Email is the single client e-mail address that should receive the login
// code challenge.
Email string `json:"email"`
}
// SendEmailCodeResult describes the public REST and adapter payload returned
// after the Auth / Session Service creates a login challenge.
type SendEmailCodeResult struct {
// ChallengeID identifies the issued challenge that must be confirmed by the
// client in the next public auth step.
ChallengeID string `json:"challenge_id"`
}
// ConfirmEmailCodeInput describes the public REST and adapter payload used to
// complete a previously issued login challenge.
type ConfirmEmailCodeInput struct {
// ChallengeID identifies the challenge previously returned by
// SendEmailCode.
ChallengeID string `json:"challenge_id"`
// Code is the verification code delivered to the client by the Auth /
// Session Service.
Code string `json:"code"`
// ClientPublicKey is the standard base64-encoded raw 32-byte Ed25519 public
// key that should be registered for the created device session.
ClientPublicKey string `json:"client_public_key"`
}
// ConfirmEmailCodeResult describes the public REST and adapter payload
// returned after the Auth / Session Service creates a device session.
type ConfirmEmailCodeResult struct {
// DeviceSessionID is the stable identifier of the created device session.
DeviceSessionID string `json:"device_session_id"`
}
// AuthServiceError allows an auth adapter to project a stable public REST
// error without teaching the gateway transport layer about upstream business
// rules.
type AuthServiceError struct {
// StatusCode is the HTTP status that the public REST handler should expose.
StatusCode int
// Code is the stable edge-level error code written into the JSON envelope.
Code string
// Message is the human-readable client-safe error description.
Message string
}
// Error returns a readable representation of the projected auth service error.
func (e *AuthServiceError) Error() string {
if e == nil {
return ""
}
switch {
case strings.TrimSpace(e.Code) == "" && strings.TrimSpace(e.Message) == "":
return http.StatusText(e.normalizedStatusCode())
case strings.TrimSpace(e.Code) == "":
return e.Message
case strings.TrimSpace(e.Message) == "":
return e.Code
default:
return e.Code + ": " + e.Message
}
}
func (e *AuthServiceError) normalizedStatusCode() int {
if e == nil || e.StatusCode < 400 || e.StatusCode > 599 {
return http.StatusInternalServerError
}
return e.StatusCode
}
func (e *AuthServiceError) normalizedCode() string {
if e == nil {
return errorCodeInternalError
}
code := strings.TrimSpace(e.Code)
if code == "" {
switch e.normalizedStatusCode() {
case http.StatusServiceUnavailable:
return errorCodeServiceUnavailable
case http.StatusBadRequest:
return errorCodeInvalidRequest
default:
return errorCodeInternalError
}
}
return code
}
func (e *AuthServiceError) normalizedMessage() string {
if e == nil {
return "internal server error"
}
message := strings.TrimSpace(e.Message)
if message == "" {
switch e.normalizedStatusCode() {
case http.StatusServiceUnavailable:
return "auth service is unavailable"
case http.StatusBadRequest:
return "request is invalid"
default:
return "internal server error"
}
}
return message
}
// unavailableAuthServiceClient keeps the public auth surface mounted until a
// concrete upstream adapter is wired into the gateway process.
type unavailableAuthServiceClient struct{}
func (unavailableAuthServiceClient) SendEmailCode(context.Context, SendEmailCodeInput) (SendEmailCodeResult, error) {
return SendEmailCodeResult{}, &AuthServiceError{
StatusCode: http.StatusServiceUnavailable,
Code: errorCodeServiceUnavailable,
Message: "auth service is unavailable",
}
}
func (unavailableAuthServiceClient) ConfirmEmailCode(context.Context, ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
return ConfirmEmailCodeResult{}, &AuthServiceError{
StatusCode: http.StatusServiceUnavailable,
Code: errorCodeServiceUnavailable,
Message: "auth service is unavailable",
}
}
func handleSendEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var input SendEmailCodeInput
if err := decodeJSONRequest(c.Request, &input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
if err := validateSendEmailCodeInput(&input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := authService.SendEmailCode(callCtx, input)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
return
}
abortWithAuthServiceError(c, err)
return
}
if err := validateSendEmailCodeResult(&result); err != nil {
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
return
}
c.JSON(http.StatusOK, result)
}
}
func handleConfirmEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var input ConfirmEmailCodeInput
if err := decodeJSONRequest(c.Request, &input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
if err := validateConfirmEmailCodeInput(&input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := authService.ConfirmEmailCode(callCtx, input)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
return
}
abortWithAuthServiceError(c, err)
return
}
if err := validateConfirmEmailCodeResult(&result); err != nil {
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
return
}
c.JSON(http.StatusOK, result)
}
}
func abortInvalidRequest(c *gin.Context, message string) {
abortWithError(c, http.StatusBadRequest, errorCodeInvalidRequest, message)
}
func abortWithAuthServiceError(c *gin.Context, err error) {
var authErr *AuthServiceError
if errors.As(err, &authErr) {
abortWithError(c, authErr.normalizedStatusCode(), authErr.normalizedCode(), authErr.normalizedMessage())
return
}
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
}
func decodeJSONRequest(r *http.Request, target any) error {
if r == nil || r.Body == nil {
return &malformedJSONRequestError{
message: "request body must not be empty",
reason: PublicMalformedRequestReasonEmptyBody,
}
}
return decodeJSONReader(r.Body, target)
}
func decodeJSONBytes(bodyBytes []byte, target any) error {
return decodeJSONReader(bytes.NewReader(bodyBytes), target)
}
func decodeJSONReader(reader io.Reader, target any) error {
decoder := json.NewDecoder(reader)
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return describeJSONDecodeError(err)
}
if err := decoder.Decode(&struct{}{}); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return &malformedJSONRequestError{
message: "request body must contain a single JSON object",
reason: PublicMalformedRequestReasonMultipleJSONObjects,
}
}
return &malformedJSONRequestError{
message: "request body must contain a single JSON object",
reason: PublicMalformedRequestReasonMultipleJSONObjects,
}
}
func describeJSONDecodeError(err error) error {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
switch {
case errors.Is(err, io.EOF):
return &malformedJSONRequestError{
message: "request body must not be empty",
reason: PublicMalformedRequestReasonEmptyBody,
}
case errors.As(err, &syntaxErr):
return &malformedJSONRequestError{
message: "request body contains malformed JSON",
reason: PublicMalformedRequestReasonMalformedJSON,
}
case errors.Is(err, io.ErrUnexpectedEOF):
return &malformedJSONRequestError{
message: "request body contains malformed JSON",
reason: PublicMalformedRequestReasonMalformedJSON,
}
case errors.As(err, &typeErr):
if strings.TrimSpace(typeErr.Field) != "" {
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
reason: PublicMalformedRequestReasonInvalidJSONValue,
}
}
return &malformedJSONRequestError{
message: "request body contains an invalid JSON value",
reason: PublicMalformedRequestReasonInvalidJSONValue,
}
case strings.HasPrefix(err.Error(), "json: unknown field "):
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
reason: PublicMalformedRequestReasonUnknownField,
}
default:
return &malformedJSONRequestError{
message: "request body contains invalid JSON",
reason: PublicMalformedRequestReasonMalformedJSON,
}
}
}
func validateSendEmailCodeInput(input *SendEmailCodeInput) error {
input.Email = strings.TrimSpace(input.Email)
if input.Email == "" {
return errors.New("email must not be empty")
}
parsedAddress, err := mail.ParseAddress(input.Email)
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != input.Email {
return errors.New("email must be a single valid email address")
}
return nil
}
func validateSendEmailCodeResult(result *SendEmailCodeResult) error {
result.ChallengeID = strings.TrimSpace(result.ChallengeID)
if result.ChallengeID == "" {
return errors.New("auth service returned an empty challenge_id")
}
return nil
}
func validateConfirmEmailCodeInput(input *ConfirmEmailCodeInput) error {
input.ChallengeID = strings.TrimSpace(input.ChallengeID)
if input.ChallengeID == "" {
return errors.New("challenge_id must not be empty")
}
input.Code = strings.TrimSpace(input.Code)
if input.Code == "" {
return errors.New("code must not be empty")
}
input.ClientPublicKey = strings.TrimSpace(input.ClientPublicKey)
if input.ClientPublicKey == "" {
return errors.New("client_public_key must not be empty")
}
return nil
}
func validateConfirmEmailCodeResult(result *ConfirmEmailCodeResult) error {
result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID)
if result.DeviceSessionID == "" {
return errors.New("auth service returned an empty device_session_id")
}
return nil
}
func malformedRequestReasonFromError(err error) (PublicMalformedRequestReason, bool) {
var malformedErr *malformedJSONRequestError
if !errors.As(err, &malformedErr) {
return "", false
}
return malformedErr.reason, true
}
func extractPublicAuthIdentity(requestPath string, bodyBytes []byte) (publicAuthIdentity, error) {
switch requestPath {
case "/api/v1/public/auth/send-email-code":
var input SendEmailCodeInput
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
return publicAuthIdentity{}, err
}
if err := validateSendEmailCodeInput(&input); err != nil {
return publicAuthIdentity{}, err
}
return publicAuthIdentity{
kind: "email",
value: input.Email,
}, nil
case "/api/v1/public/auth/confirm-email-code":
var input ConfirmEmailCodeInput
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
return publicAuthIdentity{}, err
}
if err := validateConfirmEmailCodeInput(&input); err != nil {
return publicAuthIdentity{}, err
}
return publicAuthIdentity{
kind: "challenge",
value: input.ChallengeID,
}, nil
default:
return publicAuthIdentity{}, errPublicAuthIdentityNotApplicable
}
}
@@ -0,0 +1,377 @@
package restapi
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSendEmailCodeHandlerSuccess(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{
ChallengeID: "challenge-123",
},
}
handler := newPublicHandler(ServerDependencies{AuthService: authService})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/send-email-code",
strings.NewReader(`{"email":" pilot@example.com "}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String())
assert.Equal(t, 1, authService.sendEmailCodeCalls)
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
assert.Equal(t, SendEmailCodeInput{Email: "pilot@example.com"}, authService.sendEmailCodeInput)
assert.True(t, authService.sendEmailCodeRouteClassOK)
assert.Equal(t, PublicRouteClassPublicAuth, authService.sendEmailCodeRouteClass)
}
func TestConfirmEmailCodeHandlerSuccess(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{
confirmEmailCodeResult: ConfirmEmailCodeResult{
DeviceSessionID: "device-session-123",
},
}
handler := newPublicHandler(ServerDependencies{AuthService: authService})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/confirm-email-code",
strings.NewReader(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String())
assert.Equal(t, 0, authService.sendEmailCodeCalls)
assert.Equal(t, 1, authService.confirmEmailCodeCalls)
assert.Equal(t, ConfirmEmailCodeInput{
ChallengeID: "challenge-123",
Code: "123456",
ClientPublicKey: "public-key-material",
}, authService.confirmEmailCodeInput)
assert.True(t, authService.confirmEmailCodeRouteClassOK)
assert.Equal(t, PublicRouteClassPublicAuth, authService.confirmEmailCodeRouteClass)
}
func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) {
t.Parallel()
tests := []struct {
name string
target string
body string
wantStatus int
wantBody string
wantSendCalls int
wantConfirmCalls int
}{
{
name: "send email malformed json",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
wantSendCalls: 0,
wantConfirmCalls: 0,
},
{
name: "send email validation error",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"not-an-email"}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`,
wantSendCalls: 0,
wantConfirmCalls: 0,
},
{
name: "confirm email empty code",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`,
wantSendCalls: 0,
wantConfirmCalls: 0,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{}
handler := newPublicHandler(ServerDependencies{AuthService: authService})
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
assert.Equal(t, tt.wantSendCalls, authService.sendEmailCodeCalls)
assert.Equal(t, tt.wantConfirmCalls, authService.confirmEmailCodeCalls)
})
}
}
func TestPublicAuthHandlersMapAdapterErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
target string
body string
authClient *recordingAuthServiceClient
wantStatus int
wantBody string
}{
{
name: "auth service projected bad request",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
authClient: &recordingAuthServiceClient{
confirmEmailCodeErr: &AuthServiceError{
StatusCode: http.StatusBadRequest,
Code: errorCodeInvalidRequest,
Message: "confirmation code is invalid",
},
},
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"confirmation code is invalid"}}`,
},
{
name: "auth service projected custom too many requests",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
authClient: &recordingAuthServiceClient{
sendEmailCodeErr: &AuthServiceError{
StatusCode: http.StatusTooManyRequests,
Code: "upstream_rate_limited",
Message: "too many attempts for this email",
},
},
wantStatus: http.StatusTooManyRequests,
wantBody: `{"error":{"code":"upstream_rate_limited","message":"too many attempts for this email"}}`,
},
{
name: "auth service projected gateway normalizes blank gateway error fields",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
authClient: &recordingAuthServiceClient{
confirmEmailCodeErr: &AuthServiceError{
StatusCode: http.StatusBadGateway,
},
},
wantStatus: http.StatusBadGateway,
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
},
{
name: "unexpected auth service error",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
authClient: &recordingAuthServiceClient{
sendEmailCodeErr: errors.New("boom"),
},
wantStatus: http.StatusInternalServerError,
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{AuthService: tt.authClient})
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestDefaultAuthServiceReturnsServiceUnavailable(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{})
tests := []struct {
name string
method string
target string
body string
wantStatus int
wantBody string
}{
{
name: "send email code",
method: http.MethodPost,
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
wantStatus: http.StatusServiceUnavailable,
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
},
{
name: "confirm email code",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
wantStatus: http.StatusServiceUnavailable,
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
},
{
name: "healthz remains available",
method: http.MethodGet,
target: "/healthz",
wantStatus: http.StatusOK,
wantBody: `{"status":"ok"}`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
if tt.body != "" {
req.Header.Set("Content-Type", "application/json")
}
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{
sendEmailCodeErr: context.DeadlineExceeded,
}
cfg := config.DefaultPublicHTTPConfig()
cfg.AuthUpstreamTimeout = 5 * time.Millisecond
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/send-email-code",
strings.NewReader(`{"email":"pilot@example.com"}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.Equal(t, `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`, recorder.Body.String())
}
func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) {
t.Parallel()
logger, buffer := testutil.NewObservedLogger(t)
handler := newPublicHandler(ServerDependencies{
Logger: logger,
AuthService: &recordingAuthServiceClient{
confirmEmailCodeResult: ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"},
},
})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/confirm-email-code",
strings.NewReader(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
logOutput := buffer.String()
assert.NotContains(t, logOutput, "challenge-123")
assert.NotContains(t, logOutput, "123456")
assert.NotContains(t, logOutput, "public-key-material")
assert.NotContains(t, logOutput, "pilot@example.com")
}
// recordingAuthServiceClient captures handler inputs and route classification
// so tests can assert the exact adapter delegation contract.
type recordingAuthServiceClient struct {
sendEmailCodeResult SendEmailCodeResult
sendEmailCodeErr error
sendEmailCodeInput SendEmailCodeInput
sendEmailCodeRouteClass PublicRouteClass
sendEmailCodeRouteClassOK bool
sendEmailCodeCalls int
confirmEmailCodeResult ConfirmEmailCodeResult
confirmEmailCodeErr error
confirmEmailCodeInput ConfirmEmailCodeInput
confirmEmailCodeRouteClass PublicRouteClass
confirmEmailCodeRouteClassOK bool
confirmEmailCodeCalls int
}
func (c *recordingAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
c.sendEmailCodeCalls++
c.sendEmailCodeInput = input
c.sendEmailCodeRouteClass, c.sendEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
return c.sendEmailCodeResult, c.sendEmailCodeErr
}
func (c *recordingAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
c.confirmEmailCodeCalls++
c.confirmEmailCodeInput = input
c.confirmEmailCodeRouteClass, c.confirmEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
return c.confirmEmailCodeResult, c.confirmEmailCodeErr
}
+388
View File
@@ -0,0 +1,388 @@
// Package restapi exposes the unauthenticated public REST surface of the
// gateway.
package restapi
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
"go.uber.org/zap"
)
const (
jsonContentType = "application/json; charset=utf-8"
errorCodeInvalidRequest = "invalid_request"
errorCodeNotFound = "not_found"
errorCodeMethodNotAllowed = "method_not_allowed"
errorCodeInternalError = "internal_error"
errorCodeServiceUnavailable = "service_unavailable"
publicRESTBaseBucketKeyPrefix = "public_rest/class="
)
// PublicRouteClass identifies the public traffic class assigned to an incoming
// REST request before route handling and edge policy evaluation.
type PublicRouteClass string
const (
// PublicRouteClassPublicAuth identifies public authentication commands.
PublicRouteClassPublicAuth PublicRouteClass = "public_auth"
// PublicRouteClassBrowserBootstrap identifies browser bootstrap traffic such
// as the main document request.
PublicRouteClassBrowserBootstrap PublicRouteClass = "browser_bootstrap"
// PublicRouteClassBrowserAsset identifies browser asset requests.
PublicRouteClassBrowserAsset PublicRouteClass = "browser_asset"
// PublicRouteClassPublicMisc identifies public traffic that does not match a
// more specific class.
PublicRouteClassPublicMisc PublicRouteClass = "public_misc"
)
var configureGinModeOnce sync.Once
// Normalized returns c when it belongs to the stable public route class set.
// Unknown or empty values collapse to PublicRouteClassPublicMisc so edge policy
// code can rely on a fixed anti-abuse namespace.
func (c PublicRouteClass) Normalized() PublicRouteClass {
switch c {
case PublicRouteClassPublicAuth,
PublicRouteClassBrowserBootstrap,
PublicRouteClassBrowserAsset,
PublicRouteClassPublicMisc:
return c
default:
return PublicRouteClassPublicMisc
}
}
// BaseBucketKey returns the canonical base rate-limit namespace for c. The key
// stays scoped only by the normalized public route class; callers may append
// subject dimensions such as IP or identity without redefining the class
// namespace.
func (c PublicRouteClass) BaseBucketKey() string {
return publicRESTBaseBucketKeyPrefix + string(c.Normalized())
}
// PublicTrafficClassifier maps public REST requests to the public anti-abuse
// class used by the gateway edge. The server normalizes classifier outputs to
// the stable class set before storing them in request context.
type PublicTrafficClassifier interface {
Classify(*http.Request) PublicRouteClass
}
// ServerDependencies describes the optional collaborators used by the public
// REST server. The zero value is valid and keeps the process runnable with the
// built-in defaults.
type ServerDependencies struct {
// Classifier assigns the public anti-abuse class before route handling.
// When nil, the gateway default classifier is used.
Classifier PublicTrafficClassifier
// AuthService delegates public auth commands to the Auth / Session Service.
// When nil, public auth routes remain mounted and return a stable
// service-unavailable response.
AuthService AuthServiceClient
// Limiter applies the public REST rate-limit policy. When nil, a default
// process-local in-memory limiter is used.
Limiter PublicRequestLimiter
// Observer records malformed-request telemetry for the public REST layer.
// When nil, a no-op observer is used.
Observer PublicRequestObserver
// Logger writes structured transport logs for public REST traffic. When nil,
// a no-op logger is used.
Logger *zap.Logger
// Telemetry records low-cardinality edge metrics. When nil, metrics are
// disabled.
Telemetry *telemetry.Runtime
}
// Server owns the public unauthenticated REST listener exposed by the gateway.
type Server struct {
cfg config.PublicHTTPConfig
handler http.Handler
logger *zap.Logger
stateMu sync.RWMutex
server *http.Server
listener net.Listener
}
// NewServer constructs a public REST server for the supplied listener
// configuration and dependency bundle. Nil dependencies are replaced with safe
// defaults so the gateway can still expose the documented public surface.
func NewServer(cfg config.PublicHTTPConfig, deps ServerDependencies) *Server {
deps = normalizeServerDependencies(deps)
return &Server{
cfg: cfg,
handler: newPublicHandlerWithConfig(cfg, deps),
logger: deps.Logger.Named("public_http"),
}
}
// Run binds the configured listener and serves the public REST surface until
// Shutdown closes the server.
func (s *Server) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run public REST server: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
listener, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("run public REST server: listen on %q: %w", s.cfg.Addr, err)
}
server := &http.Server{
Handler: s.handler,
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
ReadTimeout: s.cfg.ReadTimeout,
IdleTimeout: s.cfg.IdleTimeout,
}
s.stateMu.Lock()
s.server = server
s.listener = listener
s.stateMu.Unlock()
s.logger.Info("public REST server started", zap.String("addr", listener.Addr().String()))
defer func() {
s.stateMu.Lock()
s.server = nil
s.listener = nil
s.stateMu.Unlock()
}()
err = server.Serve(listener)
switch {
case err == nil:
return nil
case errors.Is(err, http.ErrServerClosed):
s.logger.Info("public REST server stopped")
return nil
default:
return fmt.Errorf("run public REST server: serve on %q: %w", s.cfg.Addr, err)
}
}
// Shutdown gracefully stops the public REST server within ctx.
func (s *Server) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown public REST server: nil context")
}
s.stateMu.RLock()
server := s.server
s.stateMu.RUnlock()
if server == nil {
return nil
}
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("shutdown public REST server: %w", err)
}
return nil
}
// PublicRouteClassFromContext returns the previously classified normalized
// public route class stored in ctx.
func PublicRouteClassFromContext(ctx context.Context) (PublicRouteClass, bool) {
if ctx == nil {
return "", false
}
class, ok := ctx.Value(publicRouteClassContextKey{}).(PublicRouteClass)
if !ok {
return "", false
}
return class.Normalized(), true
}
type publicRouteClassContextKey struct{}
type defaultPublicTrafficClassifier struct{}
// Classify maps the incoming request into a stable public route class that can
// later drive anti-abuse policy and rate limiting.
func (defaultPublicTrafficClassifier) Classify(r *http.Request) PublicRouteClass {
switch {
case isPublicAuthRequest(r):
return PublicRouteClassPublicAuth
case isBrowserBootstrapRequest(r):
return PublicRouteClassBrowserBootstrap
case isBrowserAssetRequest(r):
return PublicRouteClassBrowserAsset
default:
return PublicRouteClassPublicMisc
}
}
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
if deps.Classifier == nil {
deps.Classifier = defaultPublicTrafficClassifier{}
}
if deps.AuthService == nil {
deps.AuthService = unavailableAuthServiceClient{}
}
if deps.Limiter == nil {
deps.Limiter = newInMemoryPublicRequestLimiter()
}
if deps.Observer == nil {
deps.Observer = noopPublicRequestObserver{}
}
if deps.Logger == nil {
deps.Logger = zap.NewNop()
}
return deps
}
func newPublicHandler(deps ServerDependencies) http.Handler {
return newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), deps)
}
func newPublicHandlerWithConfig(cfg config.PublicHTTPConfig, deps ServerDependencies) http.Handler {
configureGinModeOnce.Do(func() {
gin.SetMode(gin.ReleaseMode)
})
deps = normalizeServerDependencies(deps)
router := gin.New()
router.HandleMethodNotAllowed = true
router.Use(gin.CustomRecovery(func(c *gin.Context, _ any) {
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
}))
router.Use(otelgin.Middleware("galaxy-edge-gateway-public"))
router.Use(withPublicObservability(deps.Logger.Named("public_http"), deps.Telemetry))
router.Use(withPublicRouteClass(deps.Classifier))
router.Use(withPublicAntiAbuse(cfg.AntiAbuse, deps.Limiter, deps.Observer))
router.GET("/healthz", handleHealthz)
router.GET("/readyz", handleReadyz)
router.POST("/api/v1/public/auth/send-email-code", handleSendEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
router.POST("/api/v1/public/auth/confirm-email-code", handleConfirmEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
router.NoMethod(func(c *gin.Context) {
allowMethods := allowedMethodsForPath(c.Request.URL.Path)
if allowMethods != "" {
c.Header("Allow", allowMethods)
}
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
})
router.NoRoute(func(c *gin.Context) {
abortWithError(c, http.StatusNotFound, errorCodeNotFound, "resource was not found")
})
return router
}
func handleHealthz(c *gin.Context) {
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
}
func handleReadyz(c *gin.Context) {
c.JSON(http.StatusOK, statusResponse{Status: "ready"})
}
func withPublicRouteClass(classifier PublicTrafficClassifier) gin.HandlerFunc {
return func(c *gin.Context) {
class := classifier.Classify(c.Request).Normalized()
ctx := context.WithValue(c.Request.Context(), publicRouteClassContextKey{}, class)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
func isPublicAuthRequest(r *http.Request) bool {
return r.Method == http.MethodPost && isPublicAuthPath(r.URL.Path)
}
func isBrowserBootstrapRequest(r *http.Request) bool {
if r.Method == http.MethodGet && r.URL.Path == "/" {
return true
}
return matchesBrowserBootstrapRequestShape(r)
}
func isBrowserAssetRequest(r *http.Request) bool {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
return false
}
return matchesBrowserAssetRequestShape(r)
}
type statusResponse struct {
Status string `json:"status"`
}
type errorResponse struct {
Error errorBody `json:"error"`
}
type errorBody struct {
Code string `json:"code"`
Message string `json:"message"`
}
func abortWithError(c *gin.Context, statusCode int, code string, message string) {
if c != nil {
c.Set(publicErrorCodeContextKey, code)
}
c.AbortWithStatusJSON(statusCode, errorResponse{
Error: errorBody{
Code: code,
Message: message,
},
})
}
const publicErrorCodeContextKey = "public_error_code"
func allowedMethodsForPath(requestPath string) string {
switch requestPath {
case "/healthz", "/readyz":
return http.MethodGet
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
return http.MethodPost
default:
return ""
}
}
func (s *Server) listenAddr() string {
s.stateMu.RLock()
defer s.stateMu.RUnlock()
if s.listener == nil {
return ""
}
return s.listener.Addr().String()
}
+459
View File
@@ -0,0 +1,459 @@
package restapi
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPublicHandlerHealthEndpoints(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{})
tests := []struct {
name string
method string
target string
wantStatus int
wantType string
wantBody string
wantAllow string
}{
{
name: "healthz",
method: http.MethodGet,
target: "/healthz",
wantStatus: http.StatusOK,
wantType: jsonContentType,
wantBody: `{"status":"ok"}`,
},
{
name: "readyz",
method: http.MethodGet,
target: "/readyz",
wantStatus: http.StatusOK,
wantType: jsonContentType,
wantBody: `{"status":"ready"}`,
},
{
name: "wrong method on known route",
method: http.MethodPost,
target: "/healthz",
wantStatus: http.StatusMethodNotAllowed,
wantType: jsonContentType,
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
wantAllow: http.MethodGet,
},
{
name: "unknown route",
method: http.MethodGet,
target: "/unknown",
wantStatus: http.StatusNotFound,
wantType: jsonContentType,
wantBody: `{"error":{"code":"not_found","message":"resource was not found"}}`,
},
{
name: "wrong method on public auth route",
method: http.MethodGet,
target: "/api/v1/public/auth/send-email-code",
wantStatus: http.StatusMethodNotAllowed,
wantType: jsonContentType,
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
wantAllow: http.MethodPost,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, tt.wantType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
})
}
}
func TestDefaultPublicTrafficClassifier(t *testing.T) {
t.Parallel()
classifier := defaultPublicTrafficClassifier{}
tests := []struct {
name string
method string
target string
accept string
wantClass PublicRouteClass
}{
{
name: "public auth route",
method: http.MethodPost,
target: "/api/v1/public/auth/send-email-code",
wantClass: PublicRouteClassPublicAuth,
},
{
name: "public auth confirm route",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
wantClass: PublicRouteClassPublicAuth,
},
{
name: "browser bootstrap route",
method: http.MethodGet,
target: "/",
wantClass: PublicRouteClassBrowserBootstrap,
},
{
name: "browser asset route",
method: http.MethodGet,
target: "/assets/app.js",
wantClass: PublicRouteClassBrowserAsset,
},
{
name: "browser asset head request",
method: http.MethodHead,
target: "/assets/app.js",
wantClass: PublicRouteClassBrowserAsset,
},
{
name: "browser asset extension request",
method: http.MethodGet,
target: "/manifest.webmanifest",
wantClass: PublicRouteClassBrowserAsset,
},
{
name: "public misc route",
method: http.MethodPost,
target: "/api/v1/public/unknown",
wantClass: PublicRouteClassPublicMisc,
},
{
name: "html accept bootstrap route",
method: http.MethodGet,
target: "/app",
accept: "application/json, text/html;q=0.9",
wantClass: PublicRouteClassBrowserBootstrap,
},
{
name: "public auth wins over browser accept header",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
accept: "text/html",
wantClass: PublicRouteClassPublicAuth,
},
{
name: "probe with html accept is bootstrap",
method: http.MethodGet,
target: "/healthz",
accept: "text/html",
wantClass: PublicRouteClassBrowserBootstrap,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, nil)
if tt.accept != "" {
req.Header.Set("Accept", tt.accept)
}
assert.Equal(t, tt.wantClass, classifier.Classify(req))
})
}
}
func TestPublicRouteClassNormalized(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input PublicRouteClass
want PublicRouteClass
}{
{
name: "public auth",
input: PublicRouteClassPublicAuth,
want: PublicRouteClassPublicAuth,
},
{
name: "browser bootstrap",
input: PublicRouteClassBrowserBootstrap,
want: PublicRouteClassBrowserBootstrap,
},
{
name: "browser asset",
input: PublicRouteClassBrowserAsset,
want: PublicRouteClassBrowserAsset,
},
{
name: "public misc",
input: PublicRouteClassPublicMisc,
want: PublicRouteClassPublicMisc,
},
{
name: "unknown collapses to misc",
input: PublicRouteClass("unexpected"),
want: PublicRouteClassPublicMisc,
},
{
name: "empty collapses to misc",
input: PublicRouteClass(""),
want: PublicRouteClassPublicMisc,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, tt.input.Normalized())
})
}
}
func TestPublicRouteClassBaseBucketKeyIsolation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
class PublicRouteClass
wantKey string
}{
{
name: "public auth",
class: PublicRouteClassPublicAuth,
wantKey: "public_rest/class=public_auth",
},
{
name: "browser bootstrap",
class: PublicRouteClassBrowserBootstrap,
wantKey: "public_rest/class=browser_bootstrap",
},
{
name: "browser asset",
class: PublicRouteClassBrowserAsset,
wantKey: "public_rest/class=browser_asset",
},
{
name: "public misc",
class: PublicRouteClassPublicMisc,
wantKey: "public_rest/class=public_misc",
},
{
name: "unknown collapses to misc namespace",
class: PublicRouteClass("unexpected"),
wantKey: "public_rest/class=public_misc",
},
}
seenKeys := make(map[string]PublicRouteClass, len(tests))
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.wantKey, tt.class.BaseBucketKey())
})
normalizedClass := tt.class.Normalized()
if normalizedClass == PublicRouteClassPublicMisc && tt.class != PublicRouteClassPublicMisc {
continue
}
if previousClass, exists := seenKeys[tt.wantKey]; exists {
require.FailNowf(t, "bucket key collision", "class %q collides with %q on key %q", tt.class, previousClass, tt.wantKey)
}
seenKeys[tt.wantKey] = tt.class
}
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserBootstrap.BaseBucketKey())
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserAsset.BaseBucketKey())
}
func TestWithPublicRouteClassStoresClassInContext(t *testing.T) {
t.Parallel()
router := gin.New()
router.Use(withPublicRouteClass(staticClassifier{class: PublicRouteClassBrowserAsset}))
router.GET("/assets/app.js", func(c *gin.Context) {
class, ok := PublicRouteClassFromContext(c.Request.Context())
require.True(t, ok)
assert.Equal(t, PublicRouteClassBrowserAsset, class)
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/assets/app.js", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
}
func TestWithPublicRouteClassNormalizesUnsupportedClassToPublicMisc(t *testing.T) {
t.Parallel()
tests := []struct {
name string
class PublicRouteClass
}{
{
name: "unknown class",
class: PublicRouteClass("unexpected"),
},
{
name: "empty class",
class: PublicRouteClass(""),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
router := gin.New()
router.Use(withPublicRouteClass(staticClassifier{class: tt.class}))
router.GET("/", func(c *gin.Context) {
class, ok := PublicRouteClassFromContext(c.Request.Context())
require.True(t, ok)
assert.Equal(t, PublicRouteClassPublicMisc, class)
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
})
}
}
func TestServerLifecycle(t *testing.T) {
t.Parallel()
cfg := config.Config{
ShutdownTimeout: time.Second,
PublicHTTP: func() config.PublicHTTPConfig {
publicHTTPCfg := config.DefaultPublicHTTPConfig()
publicHTTPCfg.Addr = "127.0.0.1:0"
publicHTTPCfg.AntiAbuse.PublicMisc.RateLimit = config.PublicRateLimitConfig{
Requests: 1000,
Window: time.Minute,
Burst: 1000,
}
return publicHTTPCfg
}(),
}
server := NewServer(cfg.PublicHTTP, ServerDependencies{})
application := app.New(cfg, server)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
addr := waitForListenAddr(t, server)
waitForHealthResponse(t, addr)
cancel()
select {
case err := <-resultCh:
require.NoError(t, err)
case <-time.After(2 * time.Second):
require.FailNow(t, "Run() did not return after cancellation")
}
}
type staticClassifier struct {
class PublicRouteClass
}
func (c staticClassifier) Classify(*http.Request) PublicRouteClass {
return c.class
}
func waitForListenAddr(t *testing.T, server *Server) string {
t.Helper()
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if addr := server.listenAddr(); addr != "" {
return addr
}
time.Sleep(10 * time.Millisecond)
}
require.FailNow(t, "server did not start listening")
return ""
}
func waitForHealthResponse(t *testing.T, addr string) {
t.Helper()
client := &http.Client{Timeout: 100 * time.Millisecond}
url := "http://" + addr + "/healthz"
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
resp, err := client.Get(url)
if err != nil {
time.Sleep(10 * time.Millisecond)
continue
}
body, readErr := io.ReadAll(resp.Body)
closeErr := resp.Body.Close()
require.NoError(t, readErr)
require.NoError(t, closeErr)
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, `{"status":"ok"}`, strings.TrimSpace(string(body)))
return
}
require.FailNowf(t, "health check timed out", "url=%s", url)
}