feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
route := c.FullPath()
|
||||
if route == "" {
|
||||
route = c.Request.URL.Path
|
||||
}
|
||||
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
if !ok {
|
||||
class = PublicRouteClassPublicMisc
|
||||
}
|
||||
|
||||
errorCode, _ := c.Get(publicErrorCodeContextKey)
|
||||
errorCodeValue, _ := errorCode.(string)
|
||||
|
||||
outcome := telemetry.OutcomeFromPublicErrorCode(statusCode, errorCodeValue)
|
||||
rejectReason := telemetry.RejectReason(outcome)
|
||||
duration := time.Since(start)
|
||||
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("route_class", string(class)),
|
||||
attribute.String("route", route),
|
||||
attribute.String("method", c.Request.Method),
|
||||
attribute.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if rejectReason != "" {
|
||||
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
|
||||
}
|
||||
metrics.RecordPublicRequest(c.Request.Context(), attrs, duration)
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "public_http"),
|
||||
zap.String("transport", "http"),
|
||||
zap.String("route", route),
|
||||
zap.String("route_class", string(class)),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
|
||||
zap.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if rejectReason != "" {
|
||||
fields = append(fields, zap.String("reject_reason", rejectReason))
|
||||
}
|
||||
fields = append(fields, logging.TraceFieldsFromContext(c.Request.Context())...)
|
||||
|
||||
switch outcome {
|
||||
case telemetry.EdgeOutcomeSuccess:
|
||||
logger.Info("public request completed", fields...)
|
||||
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeInternalError:
|
||||
logger.Error("public request failed", fields...)
|
||||
default:
|
||||
logger.Warn("public request rejected", fields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicOpenAPISpecValidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, thisFile, _, ok := runtime.Caller(0)
|
||||
require.True(t, ok)
|
||||
|
||||
specPath := filepath.Join(filepath.Dir(thisFile), "..", "..", "openapi.yaml")
|
||||
ctx := context.Background()
|
||||
|
||||
loader := openapi3.NewLoader()
|
||||
doc, err := loader.LoadFromFile(specPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, doc)
|
||||
require.NotNil(t, doc.Info)
|
||||
require.Equal(t, "v1", doc.Info.Version)
|
||||
|
||||
require.NoError(t, doc.Validate(ctx))
|
||||
}
|
||||
@@ -0,0 +1,378 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
errorCodeRequestTooLarge = "request_too_large"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
|
||||
publicRESTIPBucketKeySegment = "/ip="
|
||||
)
|
||||
|
||||
var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit")
|
||||
|
||||
// PublicMalformedRequestReason identifies the stable malformed-request counter
|
||||
// dimension recorded by the public REST anti-abuse middleware.
|
||||
type PublicMalformedRequestReason string
|
||||
|
||||
const (
|
||||
// PublicMalformedRequestReasonEmptyBody records a missing request body.
|
||||
PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body"
|
||||
|
||||
// PublicMalformedRequestReasonMalformedJSON records syntactically malformed
|
||||
// JSON.
|
||||
PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json"
|
||||
|
||||
// PublicMalformedRequestReasonInvalidJSONValue records JSON values whose
|
||||
// types do not match the expected request schema.
|
||||
PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value"
|
||||
|
||||
// PublicMalformedRequestReasonUnknownField records JSON objects with fields
|
||||
// outside the documented schema.
|
||||
PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field"
|
||||
|
||||
// PublicMalformedRequestReasonMultipleJSONObjects records requests that
|
||||
// contain more than one JSON object.
|
||||
PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects"
|
||||
|
||||
// PublicMalformedRequestReasonOversizedBody records requests whose bodies
|
||||
// exceed the configured class limit.
|
||||
PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body"
|
||||
)
|
||||
|
||||
// PublicRateLimitDecision describes the outcome returned by a public REST
|
||||
// limiter for one request bucket reservation attempt.
|
||||
type PublicRateLimitDecision struct {
|
||||
// Allowed reports whether the request may proceed immediately.
|
||||
Allowed bool
|
||||
|
||||
// RetryAfter is the minimum delay the client should wait before retrying
|
||||
// when Allowed is false.
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
// PublicRequestLimiter applies public REST rate-limit policy to a concrete
|
||||
// bucket key.
|
||||
type PublicRequestLimiter interface {
|
||||
// Reserve evaluates key under policy and returns whether the request may
|
||||
// proceed immediately.
|
||||
Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision
|
||||
}
|
||||
|
||||
// PublicRequestObserver captures low-cardinality public REST anti-abuse
|
||||
// telemetry.
|
||||
type PublicRequestObserver interface {
|
||||
// RecordMalformedRequest records one malformed request in class for reason.
|
||||
RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason)
|
||||
}
|
||||
|
||||
type noopPublicRequestObserver struct{}
|
||||
|
||||
func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) {
|
||||
}
|
||||
|
||||
type inMemoryPublicRequestLimiter struct {
|
||||
now func() time.Time
|
||||
cleanupInterval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
entries map[string]*publicRateLimiterEntry
|
||||
nextCleanup time.Time
|
||||
}
|
||||
|
||||
type publicRateLimiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
limit rate.Limit
|
||||
burst int
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter {
|
||||
return &inMemoryPublicRequestLimiter{
|
||||
now: time.Now,
|
||||
cleanupInterval: time.Minute,
|
||||
entries: make(map[string]*publicRateLimiterEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision {
|
||||
now := l.now()
|
||||
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.cleanupExpiredBucketsLocked(now)
|
||||
|
||||
entry, ok := l.entries[key]
|
||||
if !ok || entry.limit != limit || entry.burst != policy.Burst {
|
||||
entry = &publicRateLimiterEntry{
|
||||
limiter: rate.NewLimiter(limit, policy.Burst),
|
||||
limit: limit,
|
||||
burst: policy.Burst,
|
||||
}
|
||||
l.entries[key] = entry
|
||||
}
|
||||
|
||||
entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window))
|
||||
|
||||
reservation := entry.limiter.ReserveN(now, 1)
|
||||
if !reservation.OK() {
|
||||
return PublicRateLimitDecision{
|
||||
Allowed: false,
|
||||
RetryAfter: policy.Window,
|
||||
}
|
||||
}
|
||||
|
||||
retryAfter := reservation.DelayFrom(now)
|
||||
if retryAfter > 0 {
|
||||
return PublicRateLimitDecision{
|
||||
Allowed: false,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
return PublicRateLimitDecision{Allowed: true}
|
||||
}
|
||||
|
||||
func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) {
|
||||
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
|
||||
return
|
||||
}
|
||||
|
||||
for key, entry := range l.entries {
|
||||
if !entry.expiresAt.After(now) {
|
||||
delete(l.entries, key)
|
||||
}
|
||||
}
|
||||
|
||||
l.nextCleanup = now.Add(l.cleanupInterval)
|
||||
}
|
||||
|
||||
func publicRateLimiterEntryTTL(window time.Duration) time.Duration {
|
||||
if window < time.Minute {
|
||||
return time.Minute
|
||||
}
|
||||
|
||||
return 2 * window
|
||||
}
|
||||
|
||||
func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
if !ok {
|
||||
class = PublicRouteClassPublicMisc
|
||||
}
|
||||
|
||||
allowedMethods := allowedMethodsForRequestShape(c.Request)
|
||||
if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) {
|
||||
c.Header("Allow", strings.Join(allowedMethods, ", "))
|
||||
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
|
||||
return
|
||||
}
|
||||
|
||||
classPolicy := publicRoutePolicyForClass(policy, class)
|
||||
bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errRequestBodyTooLarge):
|
||||
observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody)
|
||||
abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit")
|
||||
default:
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr)
|
||||
if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed {
|
||||
abortRateLimited(c, decision.RetryAfter)
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes)
|
||||
switch {
|
||||
case err == nil:
|
||||
identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy)
|
||||
if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed {
|
||||
abortRateLimited(c, decision.RetryAfter)
|
||||
return
|
||||
}
|
||||
case errors.Is(err, errPublicAuthIdentityNotApplicable):
|
||||
default:
|
||||
if reason, malformed := malformedRequestReasonFromError(err); malformed {
|
||||
observer.RecordMalformedRequest(class, reason)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig {
|
||||
switch class.Normalized() {
|
||||
case PublicRouteClassPublicAuth:
|
||||
return policy.PublicAuth
|
||||
case PublicRouteClassBrowserBootstrap:
|
||||
return policy.BrowserBootstrap
|
||||
case PublicRouteClassBrowserAsset:
|
||||
return policy.BrowserAsset
|
||||
default:
|
||||
return policy.PublicMisc
|
||||
}
|
||||
}
|
||||
|
||||
func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig {
|
||||
switch requestPath {
|
||||
case "/api/v1/public/auth/send-email-code":
|
||||
return policy.SendEmailCodeIdentity
|
||||
case "/api/v1/public/auth/confirm-email-code":
|
||||
return policy.ConfirmEmailCodeIdentity
|
||||
default:
|
||||
return config.PublicAuthIdentityPolicyConfig{}
|
||||
}
|
||||
}
|
||||
|
||||
func allowedMethodsForRequestShape(r *http.Request) []string {
|
||||
switch {
|
||||
case isPublicAuthPath(r.URL.Path):
|
||||
return []string{http.MethodPost}
|
||||
case isProbePath(r.URL.Path):
|
||||
return []string{http.MethodGet}
|
||||
case matchesBrowserAssetRequestShape(r):
|
||||
return []string{http.MethodGet, http.MethodHead}
|
||||
case matchesBrowserBootstrapRequestShape(r):
|
||||
return []string{http.MethodGet, http.MethodHead}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isAllowedMethod(method string, allowedMethods []string) bool {
|
||||
for _, allowedMethod := range allowedMethods {
|
||||
if method == allowedMethod {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isPublicAuthPath(requestPath string) bool {
|
||||
switch requestPath {
|
||||
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isProbePath(requestPath string) bool {
|
||||
switch requestPath {
|
||||
case "/healthz", "/readyz":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func matchesBrowserBootstrapRequestShape(r *http.Request) bool {
|
||||
if r.URL.Path == "/" {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html")
|
||||
}
|
||||
|
||||
func matchesBrowserAssetRequestShape(r *http.Request) bool {
|
||||
if strings.HasPrefix(r.URL.Path, "/assets/") {
|
||||
return true
|
||||
}
|
||||
|
||||
switch strings.ToLower(path.Ext(r.URL.Path)) {
|
||||
case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) {
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if r.Body == nil {
|
||||
r.Body = io.NopCloser(bytes.NewReader(nil))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1))
|
||||
closeErr := r.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if closeErr != nil {
|
||||
return nil, closeErr
|
||||
}
|
||||
if int64(len(bodyBytes)) > maxBodyBytes {
|
||||
return nil, errRequestBodyTooLarge
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
func abortRateLimited(c *gin.Context, retryAfter time.Duration) {
|
||||
c.Header("Retry-After", retryAfterHeaderValue(retryAfter))
|
||||
abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded")
|
||||
}
|
||||
|
||||
func retryAfterHeaderValue(delay time.Duration) string {
|
||||
seconds := int64(math.Ceil(delay.Seconds()))
|
||||
if seconds < 1 {
|
||||
seconds = 1
|
||||
}
|
||||
|
||||
return strconv.FormatInt(seconds, 10)
|
||||
}
|
||||
|
||||
func clientIPFromRemoteAddr(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr))
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
|
||||
remoteAddr = strings.TrimSpace(remoteAddr)
|
||||
if remoteAddr == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
return remoteAddr
|
||||
}
|
||||
|
||||
func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string {
|
||||
return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP
|
||||
}
|
||||
|
||||
func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string {
|
||||
return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue
|
||||
}
|
||||
@@ -0,0 +1,455 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicAntiAbuseRejectsOversizedBodies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oversizedJSONBody := `{"email":"` + strings.Repeat("a", 8200) + `@example.com"}`
|
||||
oversizedConfirmJSONBody := `{"challenge_id":"` + strings.Repeat("c", 8300) + `","code":"123456","client_public_key":"key"}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
wantClass PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "send email",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: oversizedJSONBody,
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "confirm email",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: oversizedConfirmJSONBody,
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "healthz body",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
body: `x`,
|
||||
wantClass: PublicRouteClassPublicMisc,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
observer := &recordingPublicRequestObserver{}
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{ChallengeID: "challenge-123"},
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{
|
||||
DeviceSessionID: "device-session-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
|
||||
AuthService: authService,
|
||||
Observer: observer,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, `{"error":{"code":"request_too_large","message":"request body exceeds the configured limit"}}`, recorder.Body.String())
|
||||
assert.Equal(t, 0, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, []malformedObservation{{
|
||||
class: tt.wantClass,
|
||||
reason: PublicMalformedRequestReasonOversizedBody,
|
||||
}}, observer.snapshot())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseRejectsInvalidMethodsForBrowserShapes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
accept string
|
||||
wantAllow string
|
||||
}{
|
||||
{
|
||||
name: "asset path",
|
||||
method: http.MethodPost,
|
||||
target: "/assets/app.js",
|
||||
wantAllow: "GET, HEAD",
|
||||
},
|
||||
{
|
||||
name: "bootstrap request",
|
||||
method: http.MethodPost,
|
||||
target: "/",
|
||||
accept: "text/html",
|
||||
wantAllow: "GET, HEAD",
|
||||
},
|
||||
{
|
||||
name: "head probe rejected",
|
||||
method: http.MethodHead,
|
||||
target: "/healthz",
|
||||
wantAllow: http.MethodGet,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, nil)
|
||||
if tt.accept != "" {
|
||||
req.Header.Set("Accept", tt.accept)
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
|
||||
assert.Equal(t, `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseBrowserClassBucketsStayIsolatedFromPublicAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
burstRequest *http.Request
|
||||
}{
|
||||
{
|
||||
name: "browser asset",
|
||||
burstRequest: httptest.NewRequest(http.MethodGet, "/assets/app.js", nil),
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap",
|
||||
burstRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept", "text/html")
|
||||
return req
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AntiAbuse.BrowserAsset.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
cfg.AntiAbuse.BrowserBootstrap.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{
|
||||
ChallengeID: "challenge-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
tt.burstRequest.RemoteAddr = "192.0.2.10:1234"
|
||||
|
||||
firstBurst := httptest.NewRecorder()
|
||||
handler.ServeHTTP(firstBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
|
||||
|
||||
secondBurst := httptest.NewRecorder()
|
||||
handler.ServeHTTP(secondBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
|
||||
|
||||
authReq := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(`{"email":"pilot@example.com"}`))
|
||||
authReq.Header.Set("Content-Type", "application/json")
|
||||
authReq.RemoteAddr = "192.0.2.10:1234"
|
||||
authResp := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(authResp, authReq)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, firstBurst.Code)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondBurst.Code)
|
||||
assert.Equal(t, http.StatusOK, authResp.Code)
|
||||
assert.Equal(t, `{"challenge_id":"challenge-123"}`, authResp.Body.String())
|
||||
assert.Equal(t, 1, authService.sendEmailCodeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseSendEmailIdentityThrottle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{
|
||||
ChallengeID: "challenge-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
first := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
|
||||
second := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
|
||||
third := sendEmailCodeRequest(`{"email":"other@example.com"}`)
|
||||
|
||||
firstResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(firstResp, first)
|
||||
|
||||
secondResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(secondResp, second)
|
||||
|
||||
thirdResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(thirdResp, third)
|
||||
|
||||
assert.Equal(t, http.StatusOK, firstResp.Code)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
|
||||
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
|
||||
assert.Equal(t, http.StatusOK, thirdResp.Code)
|
||||
assert.Equal(t, 2, authService.sendEmailCodeCalls)
|
||||
thirdInput := authService.sendEmailCodeInput
|
||||
assert.Equal(t, "other@example.com", thirdInput.Email)
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseConfirmEmailIdentityThrottle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.ConfirmEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{
|
||||
DeviceSessionID: "device-session-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
first := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
|
||||
second := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
|
||||
third := confirmEmailCodeRequest(`{"challenge_id":"challenge-456","code":"123456","client_public_key":"public-key-material"}`)
|
||||
|
||||
firstResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(firstResp, first)
|
||||
|
||||
secondResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(secondResp, second)
|
||||
|
||||
thirdResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(thirdResp, third)
|
||||
|
||||
assert.Equal(t, http.StatusOK, firstResp.Code)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
|
||||
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
|
||||
assert.Equal(t, http.StatusOK, thirdResp.Code)
|
||||
assert.Equal(t, 2, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, "challenge-456", authService.confirmEmailCodeInput.ChallengeID)
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseMalformedTelemetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantReason PublicMalformedRequestReason
|
||||
wantRecords int
|
||||
}{
|
||||
{
|
||||
name: "empty body",
|
||||
body: ``,
|
||||
wantReason: PublicMalformedRequestReasonEmptyBody,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
body: `{"email":`,
|
||||
wantReason: PublicMalformedRequestReasonMalformedJSON,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "invalid json value",
|
||||
body: `{"email":123}`,
|
||||
wantReason: PublicMalformedRequestReasonInvalidJSONValue,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
body: `{"email":"pilot@example.com","extra":"x"}`,
|
||||
wantReason: PublicMalformedRequestReasonUnknownField,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple objects",
|
||||
body: `{"email":"pilot@example.com"}{"email":"pilot@example.com"}`,
|
||||
wantReason: PublicMalformedRequestReasonMultipleJSONObjects,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "validation error does not count as malformed",
|
||||
body: `{"email":"not-an-email"}`,
|
||||
wantRecords: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
observer := &recordingPublicRequestObserver{}
|
||||
authService := &recordingAuthServiceClient{}
|
||||
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
|
||||
AuthService: authService,
|
||||
Observer: observer,
|
||||
})
|
||||
|
||||
req := sendEmailCodeRequest(tt.body)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
assert.Equal(t, tt.wantRecords, len(observer.snapshot()))
|
||||
assert.Equal(t, 0, authService.sendEmailCodeCalls)
|
||||
if tt.wantRecords == 1 {
|
||||
assert.Equal(t, malformedObservation{
|
||||
class: PublicRouteClassPublicAuth,
|
||||
reason: tt.wantReason,
|
||||
}, observer.snapshot()[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryPublicRequestLimiterCleansExpiredBuckets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1000, 0)
|
||||
limiter := newInMemoryPublicRequestLimiter()
|
||||
limiter.now = func() time.Time {
|
||||
return now
|
||||
}
|
||||
limiter.cleanupInterval = time.Second
|
||||
|
||||
policy := config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Minute,
|
||||
Burst: 1,
|
||||
}
|
||||
|
||||
firstDecision := limiter.Reserve("bucket-1", policy)
|
||||
secondDecision := limiter.Reserve("bucket-2", policy)
|
||||
require.True(t, firstDecision.Allowed)
|
||||
require.True(t, secondDecision.Allowed)
|
||||
require.Len(t, limiter.entries, 2)
|
||||
|
||||
now = now.Add(3 * time.Minute)
|
||||
|
||||
thirdDecision := limiter.Reserve("bucket-3", policy)
|
||||
require.True(t, thirdDecision.Allowed)
|
||||
assert.Len(t, limiter.entries, 1)
|
||||
_, exists := limiter.entries["bucket-3"]
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func sendEmailCodeRequest(body string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "192.0.2.10:1234"
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func confirmEmailCodeRequest(body string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/confirm-email-code", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "192.0.2.10:1234"
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
type malformedObservation struct {
|
||||
class PublicRouteClass
|
||||
reason PublicMalformedRequestReason
|
||||
}
|
||||
|
||||
type recordingPublicRequestObserver struct {
|
||||
mu sync.Mutex
|
||||
observations []malformedObservation
|
||||
}
|
||||
|
||||
func (o *recordingPublicRequestObserver) RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.observations = append(o.observations, malformedObservation{
|
||||
class: class,
|
||||
reason: reason,
|
||||
})
|
||||
}
|
||||
|
||||
func (o *recordingPublicRequestObserver) snapshot() []malformedObservation {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
return append([]malformedObservation(nil), o.observations...)
|
||||
}
|
||||
@@ -0,0 +1,446 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var errPublicAuthIdentityNotApplicable = errors.New("public auth identity does not apply to this route")
|
||||
|
||||
type malformedJSONRequestError struct {
|
||||
message string
|
||||
reason PublicMalformedRequestReason
|
||||
}
|
||||
|
||||
func (e *malformedJSONRequestError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return e.message
|
||||
}
|
||||
|
||||
type publicAuthIdentity struct {
|
||||
kind string
|
||||
value string
|
||||
}
|
||||
|
||||
// AuthServiceClient defines the consumer-side contract used by public auth
|
||||
// REST handlers to delegate unauthenticated authentication commands to the
|
||||
// Auth / Session Service.
|
||||
type AuthServiceClient interface {
|
||||
// SendEmailCode starts a login challenge for input.Email and returns the
|
||||
// challenge identifier that the client must later confirm.
|
||||
SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error)
|
||||
|
||||
// ConfirmEmailCode completes a previously issued challenge, registers
|
||||
// input.ClientPublicKey for the new device session, and returns the created
|
||||
// device session identifier.
|
||||
ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error)
|
||||
}
|
||||
|
||||
// SendEmailCodeInput describes the public REST and adapter payload used to
|
||||
// request a login code for a single e-mail address.
|
||||
type SendEmailCodeInput struct {
|
||||
// Email is the single client e-mail address that should receive the login
|
||||
// code challenge.
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// SendEmailCodeResult describes the public REST and adapter payload returned
|
||||
// after the Auth / Session Service creates a login challenge.
|
||||
type SendEmailCodeResult struct {
|
||||
// ChallengeID identifies the issued challenge that must be confirmed by the
|
||||
// client in the next public auth step.
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
}
|
||||
|
||||
// ConfirmEmailCodeInput describes the public REST and adapter payload used to
|
||||
// complete a previously issued login challenge.
|
||||
type ConfirmEmailCodeInput struct {
|
||||
// ChallengeID identifies the challenge previously returned by
|
||||
// SendEmailCode.
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
|
||||
// Code is the verification code delivered to the client by the Auth /
|
||||
// Session Service.
|
||||
Code string `json:"code"`
|
||||
|
||||
// ClientPublicKey is the standard base64-encoded raw 32-byte Ed25519 public
|
||||
// key that should be registered for the created device session.
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
}
|
||||
|
||||
// ConfirmEmailCodeResult describes the public REST and adapter payload
|
||||
// returned after the Auth / Session Service creates a device session.
|
||||
type ConfirmEmailCodeResult struct {
|
||||
// DeviceSessionID is the stable identifier of the created device session.
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
|
||||
// AuthServiceError allows an auth adapter to project a stable public REST
|
||||
// error without teaching the gateway transport layer about upstream business
|
||||
// rules.
|
||||
type AuthServiceError struct {
|
||||
// StatusCode is the HTTP status that the public REST handler should expose.
|
||||
StatusCode int
|
||||
|
||||
// Code is the stable edge-level error code written into the JSON envelope.
|
||||
Code string
|
||||
|
||||
// Message is the human-readable client-safe error description.
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns a readable representation of the projected auth service error.
|
||||
func (e *AuthServiceError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(e.Code) == "" && strings.TrimSpace(e.Message) == "":
|
||||
return http.StatusText(e.normalizedStatusCode())
|
||||
case strings.TrimSpace(e.Code) == "":
|
||||
return e.Message
|
||||
case strings.TrimSpace(e.Message) == "":
|
||||
return e.Code
|
||||
default:
|
||||
return e.Code + ": " + e.Message
|
||||
}
|
||||
}
|
||||
|
||||
func (e *AuthServiceError) normalizedStatusCode() int {
|
||||
if e == nil || e.StatusCode < 400 || e.StatusCode > 599 {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
return e.StatusCode
|
||||
}
|
||||
|
||||
func (e *AuthServiceError) normalizedCode() string {
|
||||
if e == nil {
|
||||
return errorCodeInternalError
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(e.Code)
|
||||
if code == "" {
|
||||
switch e.normalizedStatusCode() {
|
||||
case http.StatusServiceUnavailable:
|
||||
return errorCodeServiceUnavailable
|
||||
case http.StatusBadRequest:
|
||||
return errorCodeInvalidRequest
|
||||
default:
|
||||
return errorCodeInternalError
|
||||
}
|
||||
}
|
||||
|
||||
return code
|
||||
}
|
||||
|
||||
func (e *AuthServiceError) normalizedMessage() string {
|
||||
if e == nil {
|
||||
return "internal server error"
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(e.Message)
|
||||
if message == "" {
|
||||
switch e.normalizedStatusCode() {
|
||||
case http.StatusServiceUnavailable:
|
||||
return "auth service is unavailable"
|
||||
case http.StatusBadRequest:
|
||||
return "request is invalid"
|
||||
default:
|
||||
return "internal server error"
|
||||
}
|
||||
}
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
// unavailableAuthServiceClient keeps the public auth surface mounted until a
|
||||
// concrete upstream adapter is wired into the gateway process.
|
||||
type unavailableAuthServiceClient struct{}
|
||||
|
||||
func (unavailableAuthServiceClient) SendEmailCode(context.Context, SendEmailCodeInput) (SendEmailCodeResult, error) {
|
||||
return SendEmailCodeResult{}, &AuthServiceError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: errorCodeServiceUnavailable,
|
||||
Message: "auth service is unavailable",
|
||||
}
|
||||
}
|
||||
|
||||
func (unavailableAuthServiceClient) ConfirmEmailCode(context.Context, ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
|
||||
return ConfirmEmailCodeResult{}, &AuthServiceError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: errorCodeServiceUnavailable,
|
||||
Message: "auth service is unavailable",
|
||||
}
|
||||
}
|
||||
|
||||
func handleSendEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var input SendEmailCodeInput
|
||||
if err := decodeJSONRequest(c.Request, &input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
if err := validateSendEmailCodeInput(&input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := authService.SendEmailCode(callCtx, input)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
|
||||
return
|
||||
}
|
||||
abortWithAuthServiceError(c, err)
|
||||
return
|
||||
}
|
||||
if err := validateSendEmailCodeResult(&result); err != nil {
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
func handleConfirmEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var input ConfirmEmailCodeInput
|
||||
if err := decodeJSONRequest(c.Request, &input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
if err := validateConfirmEmailCodeInput(&input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := authService.ConfirmEmailCode(callCtx, input)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
|
||||
return
|
||||
}
|
||||
abortWithAuthServiceError(c, err)
|
||||
return
|
||||
}
|
||||
if err := validateConfirmEmailCodeResult(&result); err != nil {
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
func abortInvalidRequest(c *gin.Context, message string) {
|
||||
abortWithError(c, http.StatusBadRequest, errorCodeInvalidRequest, message)
|
||||
}
|
||||
|
||||
func abortWithAuthServiceError(c *gin.Context, err error) {
|
||||
var authErr *AuthServiceError
|
||||
if errors.As(err, &authErr) {
|
||||
abortWithError(c, authErr.normalizedStatusCode(), authErr.normalizedCode(), authErr.normalizedMessage())
|
||||
return
|
||||
}
|
||||
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
}
|
||||
|
||||
func decodeJSONRequest(r *http.Request, target any) error {
|
||||
if r == nil || r.Body == nil {
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must not be empty",
|
||||
reason: PublicMalformedRequestReasonEmptyBody,
|
||||
}
|
||||
}
|
||||
|
||||
return decodeJSONReader(r.Body, target)
|
||||
}
|
||||
|
||||
func decodeJSONBytes(bodyBytes []byte, target any) error {
|
||||
return decodeJSONReader(bytes.NewReader(bodyBytes), target)
|
||||
}
|
||||
|
||||
func decodeJSONReader(reader io.Reader, target any) error {
|
||||
decoder := json.NewDecoder(reader)
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return describeJSONDecodeError(err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(&struct{}{}); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must contain a single JSON object",
|
||||
reason: PublicMalformedRequestReasonMultipleJSONObjects,
|
||||
}
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must contain a single JSON object",
|
||||
reason: PublicMalformedRequestReasonMultipleJSONObjects,
|
||||
}
|
||||
}
|
||||
|
||||
func describeJSONDecodeError(err error) error {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must not be empty",
|
||||
reason: PublicMalformedRequestReasonEmptyBody,
|
||||
}
|
||||
case errors.As(err, &syntaxErr):
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains malformed JSON",
|
||||
reason: PublicMalformedRequestReasonMalformedJSON,
|
||||
}
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains malformed JSON",
|
||||
reason: PublicMalformedRequestReasonMalformedJSON,
|
||||
}
|
||||
case errors.As(err, &typeErr):
|
||||
if strings.TrimSpace(typeErr.Field) != "" {
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
|
||||
reason: PublicMalformedRequestReasonInvalidJSONValue,
|
||||
}
|
||||
}
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains an invalid JSON value",
|
||||
reason: PublicMalformedRequestReasonInvalidJSONValue,
|
||||
}
|
||||
case strings.HasPrefix(err.Error(), "json: unknown field "):
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
|
||||
reason: PublicMalformedRequestReasonUnknownField,
|
||||
}
|
||||
default:
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains invalid JSON",
|
||||
reason: PublicMalformedRequestReasonMalformedJSON,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateSendEmailCodeInput(input *SendEmailCodeInput) error {
|
||||
input.Email = strings.TrimSpace(input.Email)
|
||||
if input.Email == "" {
|
||||
return errors.New("email must not be empty")
|
||||
}
|
||||
|
||||
parsedAddress, err := mail.ParseAddress(input.Email)
|
||||
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != input.Email {
|
||||
return errors.New("email must be a single valid email address")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSendEmailCodeResult(result *SendEmailCodeResult) error {
|
||||
result.ChallengeID = strings.TrimSpace(result.ChallengeID)
|
||||
if result.ChallengeID == "" {
|
||||
return errors.New("auth service returned an empty challenge_id")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfirmEmailCodeInput(input *ConfirmEmailCodeInput) error {
|
||||
input.ChallengeID = strings.TrimSpace(input.ChallengeID)
|
||||
if input.ChallengeID == "" {
|
||||
return errors.New("challenge_id must not be empty")
|
||||
}
|
||||
|
||||
input.Code = strings.TrimSpace(input.Code)
|
||||
if input.Code == "" {
|
||||
return errors.New("code must not be empty")
|
||||
}
|
||||
|
||||
input.ClientPublicKey = strings.TrimSpace(input.ClientPublicKey)
|
||||
if input.ClientPublicKey == "" {
|
||||
return errors.New("client_public_key must not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfirmEmailCodeResult(result *ConfirmEmailCodeResult) error {
|
||||
result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID)
|
||||
if result.DeviceSessionID == "" {
|
||||
return errors.New("auth service returned an empty device_session_id")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func malformedRequestReasonFromError(err error) (PublicMalformedRequestReason, bool) {
|
||||
var malformedErr *malformedJSONRequestError
|
||||
if !errors.As(err, &malformedErr) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return malformedErr.reason, true
|
||||
}
|
||||
|
||||
func extractPublicAuthIdentity(requestPath string, bodyBytes []byte) (publicAuthIdentity, error) {
|
||||
switch requestPath {
|
||||
case "/api/v1/public/auth/send-email-code":
|
||||
var input SendEmailCodeInput
|
||||
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
if err := validateSendEmailCodeInput(&input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
|
||||
return publicAuthIdentity{
|
||||
kind: "email",
|
||||
value: input.Email,
|
||||
}, nil
|
||||
case "/api/v1/public/auth/confirm-email-code":
|
||||
var input ConfirmEmailCodeInput
|
||||
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
if err := validateConfirmEmailCodeInput(&input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
|
||||
return publicAuthIdentity{
|
||||
kind: "challenge",
|
||||
value: input.ChallengeID,
|
||||
}, nil
|
||||
default:
|
||||
return publicAuthIdentity{}, errPublicAuthIdentityNotApplicable
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSendEmailCodeHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{
|
||||
ChallengeID: "challenge-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
strings.NewReader(`{"email":" pilot@example.com "}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String())
|
||||
assert.Equal(t, 1, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, SendEmailCodeInput{Email: "pilot@example.com"}, authService.sendEmailCodeInput)
|
||||
assert.True(t, authService.sendEmailCodeRouteClassOK)
|
||||
assert.Equal(t, PublicRouteClassPublicAuth, authService.sendEmailCodeRouteClass)
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{
|
||||
DeviceSessionID: "device-session-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
strings.NewReader(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String())
|
||||
assert.Equal(t, 0, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, 1, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key-material",
|
||||
}, authService.confirmEmailCodeInput)
|
||||
assert.True(t, authService.confirmEmailCodeRouteClassOK)
|
||||
assert.Equal(t, PublicRouteClassPublicAuth, authService.confirmEmailCodeRouteClass)
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
body string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantSendCalls int
|
||||
wantConfirmCalls int
|
||||
}{
|
||||
{
|
||||
name: "send email malformed json",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
|
||||
wantSendCalls: 0,
|
||||
wantConfirmCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "send email validation error",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"not-an-email"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`,
|
||||
wantSendCalls: 0,
|
||||
wantConfirmCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "confirm email empty code",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`,
|
||||
wantSendCalls: 0,
|
||||
wantConfirmCalls: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{}
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
assert.Equal(t, tt.wantSendCalls, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, tt.wantConfirmCalls, authService.confirmEmailCodeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlersMapAdapterErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
body string
|
||||
authClient *recordingAuthServiceClient
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "auth service projected bad request",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
confirmEmailCodeErr: &AuthServiceError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Code: errorCodeInvalidRequest,
|
||||
Message: "confirmation code is invalid",
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"confirmation code is invalid"}}`,
|
||||
},
|
||||
{
|
||||
name: "auth service projected custom too many requests",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
sendEmailCodeErr: &AuthServiceError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Code: "upstream_rate_limited",
|
||||
Message: "too many attempts for this email",
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusTooManyRequests,
|
||||
wantBody: `{"error":{"code":"upstream_rate_limited","message":"too many attempts for this email"}}`,
|
||||
},
|
||||
{
|
||||
name: "auth service projected gateway normalizes blank gateway error fields",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
confirmEmailCodeErr: &AuthServiceError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
|
||||
},
|
||||
{
|
||||
name: "unexpected auth service error",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
sendEmailCodeErr: errors.New("boom"),
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: tt.authClient})
|
||||
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAuthServiceReturnsServiceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "send email code",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
wantStatus: http.StatusServiceUnavailable,
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm email code",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
wantStatus: http.StatusServiceUnavailable,
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
|
||||
},
|
||||
{
|
||||
name: "healthz remains available",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: `{"status":"ok"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
|
||||
if tt.body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeErr: context.DeadlineExceeded,
|
||||
}
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AuthUpstreamTimeout = 5 * time.Millisecond
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
strings.NewReader(`{"email":"pilot@example.com"}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
assert.Equal(t, `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, buffer := testutil.NewObservedLogger(t)
|
||||
handler := newPublicHandler(ServerDependencies{
|
||||
Logger: logger,
|
||||
AuthService: &recordingAuthServiceClient{
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"},
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
strings.NewReader(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
logOutput := buffer.String()
|
||||
assert.NotContains(t, logOutput, "challenge-123")
|
||||
assert.NotContains(t, logOutput, "123456")
|
||||
assert.NotContains(t, logOutput, "public-key-material")
|
||||
assert.NotContains(t, logOutput, "pilot@example.com")
|
||||
}
|
||||
|
||||
// recordingAuthServiceClient captures handler inputs and route classification
|
||||
// so tests can assert the exact adapter delegation contract.
|
||||
type recordingAuthServiceClient struct {
|
||||
sendEmailCodeResult SendEmailCodeResult
|
||||
sendEmailCodeErr error
|
||||
sendEmailCodeInput SendEmailCodeInput
|
||||
sendEmailCodeRouteClass PublicRouteClass
|
||||
sendEmailCodeRouteClassOK bool
|
||||
sendEmailCodeCalls int
|
||||
|
||||
confirmEmailCodeResult ConfirmEmailCodeResult
|
||||
confirmEmailCodeErr error
|
||||
confirmEmailCodeInput ConfirmEmailCodeInput
|
||||
confirmEmailCodeRouteClass PublicRouteClass
|
||||
confirmEmailCodeRouteClassOK bool
|
||||
confirmEmailCodeCalls int
|
||||
}
|
||||
|
||||
func (c *recordingAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
|
||||
c.sendEmailCodeCalls++
|
||||
c.sendEmailCodeInput = input
|
||||
|
||||
c.sendEmailCodeRouteClass, c.sendEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
|
||||
|
||||
return c.sendEmailCodeResult, c.sendEmailCodeErr
|
||||
}
|
||||
|
||||
func (c *recordingAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
|
||||
c.confirmEmailCodeCalls++
|
||||
c.confirmEmailCodeInput = input
|
||||
|
||||
c.confirmEmailCodeRouteClass, c.confirmEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
|
||||
|
||||
return c.confirmEmailCodeResult, c.confirmEmailCodeErr
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
// Package restapi exposes the unauthenticated public REST surface of the
|
||||
// gateway.
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonContentType = "application/json; charset=utf-8"
|
||||
|
||||
errorCodeInvalidRequest = "invalid_request"
|
||||
errorCodeNotFound = "not_found"
|
||||
errorCodeMethodNotAllowed = "method_not_allowed"
|
||||
errorCodeInternalError = "internal_error"
|
||||
errorCodeServiceUnavailable = "service_unavailable"
|
||||
|
||||
publicRESTBaseBucketKeyPrefix = "public_rest/class="
|
||||
)
|
||||
|
||||
// PublicRouteClass identifies the public traffic class assigned to an incoming
|
||||
// REST request before route handling and edge policy evaluation.
|
||||
type PublicRouteClass string
|
||||
|
||||
const (
|
||||
// PublicRouteClassPublicAuth identifies public authentication commands.
|
||||
PublicRouteClassPublicAuth PublicRouteClass = "public_auth"
|
||||
|
||||
// PublicRouteClassBrowserBootstrap identifies browser bootstrap traffic such
|
||||
// as the main document request.
|
||||
PublicRouteClassBrowserBootstrap PublicRouteClass = "browser_bootstrap"
|
||||
|
||||
// PublicRouteClassBrowserAsset identifies browser asset requests.
|
||||
PublicRouteClassBrowserAsset PublicRouteClass = "browser_asset"
|
||||
|
||||
// PublicRouteClassPublicMisc identifies public traffic that does not match a
|
||||
// more specific class.
|
||||
PublicRouteClassPublicMisc PublicRouteClass = "public_misc"
|
||||
)
|
||||
|
||||
var configureGinModeOnce sync.Once
|
||||
|
||||
// Normalized returns c when it belongs to the stable public route class set.
|
||||
// Unknown or empty values collapse to PublicRouteClassPublicMisc so edge policy
|
||||
// code can rely on a fixed anti-abuse namespace.
|
||||
func (c PublicRouteClass) Normalized() PublicRouteClass {
|
||||
switch c {
|
||||
case PublicRouteClassPublicAuth,
|
||||
PublicRouteClassBrowserBootstrap,
|
||||
PublicRouteClassBrowserAsset,
|
||||
PublicRouteClassPublicMisc:
|
||||
return c
|
||||
default:
|
||||
return PublicRouteClassPublicMisc
|
||||
}
|
||||
}
|
||||
|
||||
// BaseBucketKey returns the canonical base rate-limit namespace for c. The key
|
||||
// stays scoped only by the normalized public route class; callers may append
|
||||
// subject dimensions such as IP or identity without redefining the class
|
||||
// namespace.
|
||||
func (c PublicRouteClass) BaseBucketKey() string {
|
||||
return publicRESTBaseBucketKeyPrefix + string(c.Normalized())
|
||||
}
|
||||
|
||||
// PublicTrafficClassifier maps public REST requests to the public anti-abuse
|
||||
// class used by the gateway edge. The server normalizes classifier outputs to
|
||||
// the stable class set before storing them in request context.
|
||||
type PublicTrafficClassifier interface {
|
||||
Classify(*http.Request) PublicRouteClass
|
||||
}
|
||||
|
||||
// ServerDependencies describes the optional collaborators used by the public
|
||||
// REST server. The zero value is valid and keeps the process runnable with the
|
||||
// built-in defaults.
|
||||
type ServerDependencies struct {
|
||||
// Classifier assigns the public anti-abuse class before route handling.
|
||||
// When nil, the gateway default classifier is used.
|
||||
Classifier PublicTrafficClassifier
|
||||
|
||||
// AuthService delegates public auth commands to the Auth / Session Service.
|
||||
// When nil, public auth routes remain mounted and return a stable
|
||||
// service-unavailable response.
|
||||
AuthService AuthServiceClient
|
||||
|
||||
// Limiter applies the public REST rate-limit policy. When nil, a default
|
||||
// process-local in-memory limiter is used.
|
||||
Limiter PublicRequestLimiter
|
||||
|
||||
// Observer records malformed-request telemetry for the public REST layer.
|
||||
// When nil, a no-op observer is used.
|
||||
Observer PublicRequestObserver
|
||||
|
||||
// Logger writes structured transport logs for public REST traffic. When nil,
|
||||
// a no-op logger is used.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Telemetry records low-cardinality edge metrics. When nil, metrics are
|
||||
// disabled.
|
||||
Telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// Server owns the public unauthenticated REST listener exposed by the gateway.
|
||||
type Server struct {
|
||||
cfg config.PublicHTTPConfig
|
||||
|
||||
handler http.Handler
|
||||
logger *zap.Logger
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs a public REST server for the supplied listener
|
||||
// configuration and dependency bundle. Nil dependencies are replaced with safe
|
||||
// defaults so the gateway can still expose the documented public surface.
|
||||
func NewServer(cfg config.PublicHTTPConfig, deps ServerDependencies) *Server {
|
||||
deps = normalizeServerDependencies(deps)
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
handler: newPublicHandlerWithConfig(cfg, deps),
|
||||
logger: deps.Logger.Named("public_http"),
|
||||
}
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the public REST surface until
|
||||
// Shutdown closes the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run public REST server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run public REST server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: s.handler,
|
||||
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
|
||||
ReadTimeout: s.cfg.ReadTimeout,
|
||||
IdleTimeout: s.cfg.IdleTimeout,
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = server
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("public REST server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = server.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, http.ErrServerClosed):
|
||||
s.logger.Info("public REST server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run public REST server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the public REST server within ctx.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown public REST server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("shutdown public REST server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublicRouteClassFromContext returns the previously classified normalized
|
||||
// public route class stored in ctx.
|
||||
func PublicRouteClassFromContext(ctx context.Context) (PublicRouteClass, bool) {
|
||||
if ctx == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
class, ok := ctx.Value(publicRouteClassContextKey{}).(PublicRouteClass)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return class.Normalized(), true
|
||||
}
|
||||
|
||||
type publicRouteClassContextKey struct{}
|
||||
|
||||
type defaultPublicTrafficClassifier struct{}
|
||||
|
||||
// Classify maps the incoming request into a stable public route class that can
|
||||
// later drive anti-abuse policy and rate limiting.
|
||||
func (defaultPublicTrafficClassifier) Classify(r *http.Request) PublicRouteClass {
|
||||
switch {
|
||||
case isPublicAuthRequest(r):
|
||||
return PublicRouteClassPublicAuth
|
||||
case isBrowserBootstrapRequest(r):
|
||||
return PublicRouteClassBrowserBootstrap
|
||||
case isBrowserAssetRequest(r):
|
||||
return PublicRouteClassBrowserAsset
|
||||
default:
|
||||
return PublicRouteClassPublicMisc
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
|
||||
if deps.Classifier == nil {
|
||||
deps.Classifier = defaultPublicTrafficClassifier{}
|
||||
}
|
||||
if deps.AuthService == nil {
|
||||
deps.AuthService = unavailableAuthServiceClient{}
|
||||
}
|
||||
if deps.Limiter == nil {
|
||||
deps.Limiter = newInMemoryPublicRequestLimiter()
|
||||
}
|
||||
if deps.Observer == nil {
|
||||
deps.Observer = noopPublicRequestObserver{}
|
||||
}
|
||||
if deps.Logger == nil {
|
||||
deps.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return deps
|
||||
}
|
||||
|
||||
func newPublicHandler(deps ServerDependencies) http.Handler {
|
||||
return newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), deps)
|
||||
}
|
||||
|
||||
func newPublicHandlerWithConfig(cfg config.PublicHTTPConfig, deps ServerDependencies) http.Handler {
|
||||
configureGinModeOnce.Do(func() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
})
|
||||
|
||||
deps = normalizeServerDependencies(deps)
|
||||
|
||||
router := gin.New()
|
||||
router.HandleMethodNotAllowed = true
|
||||
router.Use(gin.CustomRecovery(func(c *gin.Context, _ any) {
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
}))
|
||||
router.Use(otelgin.Middleware("galaxy-edge-gateway-public"))
|
||||
router.Use(withPublicObservability(deps.Logger.Named("public_http"), deps.Telemetry))
|
||||
router.Use(withPublicRouteClass(deps.Classifier))
|
||||
router.Use(withPublicAntiAbuse(cfg.AntiAbuse, deps.Limiter, deps.Observer))
|
||||
|
||||
router.GET("/healthz", handleHealthz)
|
||||
router.GET("/readyz", handleReadyz)
|
||||
router.POST("/api/v1/public/auth/send-email-code", handleSendEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
|
||||
router.POST("/api/v1/public/auth/confirm-email-code", handleConfirmEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
|
||||
|
||||
router.NoMethod(func(c *gin.Context) {
|
||||
allowMethods := allowedMethodsForPath(c.Request.URL.Path)
|
||||
if allowMethods != "" {
|
||||
c.Header("Allow", allowMethods)
|
||||
}
|
||||
|
||||
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
|
||||
})
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
abortWithError(c, http.StatusNotFound, errorCodeNotFound, "resource was not found")
|
||||
})
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func handleHealthz(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func handleReadyz(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ready"})
|
||||
}
|
||||
|
||||
func withPublicRouteClass(classifier PublicTrafficClassifier) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
class := classifier.Classify(c.Request).Normalized()
|
||||
ctx := context.WithValue(c.Request.Context(), publicRouteClassContextKey{}, class)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isPublicAuthRequest(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && isPublicAuthPath(r.URL.Path)
|
||||
}
|
||||
|
||||
func isBrowserBootstrapRequest(r *http.Request) bool {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||
return true
|
||||
}
|
||||
|
||||
return matchesBrowserBootstrapRequestShape(r)
|
||||
}
|
||||
|
||||
func isBrowserAssetRequest(r *http.Request) bool {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchesBrowserAssetRequestShape(r)
|
||||
}
|
||||
|
||||
type statusResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorBody `json:"error"`
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func abortWithError(c *gin.Context, statusCode int, code string, message string) {
|
||||
if c != nil {
|
||||
c.Set(publicErrorCodeContextKey, code)
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, errorResponse{
|
||||
Error: errorBody{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const publicErrorCodeContextKey = "public_error_code"
|
||||
|
||||
func allowedMethodsForPath(requestPath string) string {
|
||||
switch requestPath {
|
||||
case "/healthz", "/readyz":
|
||||
return http.MethodGet
|
||||
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
|
||||
return http.MethodPost
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenAddr() string {
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
@@ -0,0 +1,459 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicHandlerHealthEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
wantStatus int
|
||||
wantType string
|
||||
wantBody string
|
||||
wantAllow string
|
||||
}{
|
||||
{
|
||||
name: "healthz",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
wantStatus: http.StatusOK,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "readyz",
|
||||
method: http.MethodGet,
|
||||
target: "/readyz",
|
||||
wantStatus: http.StatusOK,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"status":"ready"}`,
|
||||
},
|
||||
{
|
||||
name: "wrong method on known route",
|
||||
method: http.MethodPost,
|
||||
target: "/healthz",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
|
||||
wantAllow: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "unknown route",
|
||||
method: http.MethodGet,
|
||||
target: "/unknown",
|
||||
wantStatus: http.StatusNotFound,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"error":{"code":"not_found","message":"resource was not found"}}`,
|
||||
},
|
||||
{
|
||||
name: "wrong method on public auth route",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
|
||||
wantAllow: http.MethodPost,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, tt.wantType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPublicTrafficClassifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
classifier := defaultPublicTrafficClassifier{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
accept string
|
||||
wantClass PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "public auth route",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "public auth confirm route",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap route",
|
||||
method: http.MethodGet,
|
||||
target: "/",
|
||||
wantClass: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
{
|
||||
name: "browser asset route",
|
||||
method: http.MethodGet,
|
||||
target: "/assets/app.js",
|
||||
wantClass: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "browser asset head request",
|
||||
method: http.MethodHead,
|
||||
target: "/assets/app.js",
|
||||
wantClass: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "browser asset extension request",
|
||||
method: http.MethodGet,
|
||||
target: "/manifest.webmanifest",
|
||||
wantClass: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "public misc route",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/unknown",
|
||||
wantClass: PublicRouteClassPublicMisc,
|
||||
},
|
||||
{
|
||||
name: "html accept bootstrap route",
|
||||
method: http.MethodGet,
|
||||
target: "/app",
|
||||
accept: "application/json, text/html;q=0.9",
|
||||
wantClass: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
{
|
||||
name: "public auth wins over browser accept header",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
accept: "text/html",
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "probe with html accept is bootstrap",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
accept: "text/html",
|
||||
wantClass: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, nil)
|
||||
if tt.accept != "" {
|
||||
req.Header.Set("Accept", tt.accept)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantClass, classifier.Classify(req))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicRouteClassNormalized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input PublicRouteClass
|
||||
want PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "public auth",
|
||||
input: PublicRouteClassPublicAuth,
|
||||
want: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap",
|
||||
input: PublicRouteClassBrowserBootstrap,
|
||||
want: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
{
|
||||
name: "browser asset",
|
||||
input: PublicRouteClassBrowserAsset,
|
||||
want: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "public misc",
|
||||
input: PublicRouteClassPublicMisc,
|
||||
want: PublicRouteClassPublicMisc,
|
||||
},
|
||||
{
|
||||
name: "unknown collapses to misc",
|
||||
input: PublicRouteClass("unexpected"),
|
||||
want: PublicRouteClassPublicMisc,
|
||||
},
|
||||
{
|
||||
name: "empty collapses to misc",
|
||||
input: PublicRouteClass(""),
|
||||
want: PublicRouteClassPublicMisc,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.want, tt.input.Normalized())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicRouteClassBaseBucketKeyIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
class PublicRouteClass
|
||||
wantKey string
|
||||
}{
|
||||
{
|
||||
name: "public auth",
|
||||
class: PublicRouteClassPublicAuth,
|
||||
wantKey: "public_rest/class=public_auth",
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap",
|
||||
class: PublicRouteClassBrowserBootstrap,
|
||||
wantKey: "public_rest/class=browser_bootstrap",
|
||||
},
|
||||
{
|
||||
name: "browser asset",
|
||||
class: PublicRouteClassBrowserAsset,
|
||||
wantKey: "public_rest/class=browser_asset",
|
||||
},
|
||||
{
|
||||
name: "public misc",
|
||||
class: PublicRouteClassPublicMisc,
|
||||
wantKey: "public_rest/class=public_misc",
|
||||
},
|
||||
{
|
||||
name: "unknown collapses to misc namespace",
|
||||
class: PublicRouteClass("unexpected"),
|
||||
wantKey: "public_rest/class=public_misc",
|
||||
},
|
||||
}
|
||||
|
||||
seenKeys := make(map[string]PublicRouteClass, len(tests))
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.wantKey, tt.class.BaseBucketKey())
|
||||
})
|
||||
|
||||
normalizedClass := tt.class.Normalized()
|
||||
if normalizedClass == PublicRouteClassPublicMisc && tt.class != PublicRouteClassPublicMisc {
|
||||
continue
|
||||
}
|
||||
|
||||
if previousClass, exists := seenKeys[tt.wantKey]; exists {
|
||||
require.FailNowf(t, "bucket key collision", "class %q collides with %q on key %q", tt.class, previousClass, tt.wantKey)
|
||||
}
|
||||
|
||||
seenKeys[tt.wantKey] = tt.class
|
||||
}
|
||||
|
||||
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserBootstrap.BaseBucketKey())
|
||||
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserAsset.BaseBucketKey())
|
||||
}
|
||||
|
||||
func TestWithPublicRouteClassStoresClassInContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(withPublicRouteClass(staticClassifier{class: PublicRouteClassBrowserAsset}))
|
||||
router.GET("/assets/app.js", func(c *gin.Context) {
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, PublicRouteClassBrowserAsset, class)
|
||||
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/assets/app.js", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestWithPublicRouteClassNormalizesUnsupportedClassToPublicMisc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
class PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "unknown class",
|
||||
class: PublicRouteClass("unexpected"),
|
||||
},
|
||||
{
|
||||
name: "empty class",
|
||||
class: PublicRouteClass(""),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(withPublicRouteClass(staticClassifier{class: tt.class}))
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, PublicRouteClassPublicMisc, class)
|
||||
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
PublicHTTP: func() config.PublicHTTPConfig {
|
||||
publicHTTPCfg := config.DefaultPublicHTTPConfig()
|
||||
publicHTTPCfg.Addr = "127.0.0.1:0"
|
||||
publicHTTPCfg.AntiAbuse.PublicMisc.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1000,
|
||||
Window: time.Minute,
|
||||
Burst: 1000,
|
||||
}
|
||||
return publicHTTPCfg
|
||||
}(),
|
||||
}
|
||||
|
||||
server := NewServer(cfg.PublicHTTP, ServerDependencies{})
|
||||
application := app.New(cfg, server)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
waitForHealthResponse(t, addr)
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
require.FailNow(t, "Run() did not return after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
type staticClassifier struct {
|
||||
class PublicRouteClass
|
||||
}
|
||||
|
||||
func (c staticClassifier) Classify(*http.Request) PublicRouteClass {
|
||||
return c.class
|
||||
}
|
||||
|
||||
func waitForListenAddr(t *testing.T, server *Server) string {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if addr := server.listenAddr(); addr != "" {
|
||||
return addr
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
require.FailNow(t, "server did not start listening")
|
||||
return ""
|
||||
}
|
||||
|
||||
func waitForHealthResponse(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 100 * time.Millisecond}
|
||||
url := "http://" + addr + "/healthz"
|
||||
deadline := time.Now().Add(time.Second)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
closeErr := resp.Body.Close()
|
||||
require.NoError(t, readErr)
|
||||
require.NoError(t, closeErr)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, `{"status":"ok"}`, strings.TrimSpace(string(body)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.FailNowf(t, "health check timed out", "url=%s", url)
|
||||
}
|
||||
Reference in New Issue
Block a user