Files
2026-04-02 19:18:42 +02:00

379 lines
11 KiB
Go

package restapi
import (
"bytes"
"errors"
"io"
"math"
"net"
"net/http"
"path"
"strconv"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
const (
errorCodeRequestTooLarge = "request_too_large"
errorCodeRateLimited = "rate_limited"
publicRESTIPBucketKeySegment = "/ip="
)
var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit")
// PublicMalformedRequestReason identifies the stable malformed-request counter
// dimension recorded by the public REST anti-abuse middleware.
type PublicMalformedRequestReason string
const (
// PublicMalformedRequestReasonEmptyBody records a missing request body.
PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body"
// PublicMalformedRequestReasonMalformedJSON records syntactically malformed
// JSON.
PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json"
// PublicMalformedRequestReasonInvalidJSONValue records JSON values whose
// types do not match the expected request schema.
PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value"
// PublicMalformedRequestReasonUnknownField records JSON objects with fields
// outside the documented schema.
PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field"
// PublicMalformedRequestReasonMultipleJSONObjects records requests that
// contain more than one JSON object.
PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects"
// PublicMalformedRequestReasonOversizedBody records requests whose bodies
// exceed the configured class limit.
PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body"
)
// PublicRateLimitDecision describes the outcome returned by a public REST
// limiter for one request bucket reservation attempt.
type PublicRateLimitDecision struct {
// Allowed reports whether the request may proceed immediately.
Allowed bool
// RetryAfter is the minimum delay the client should wait before retrying
// when Allowed is false.
RetryAfter time.Duration
}
// PublicRequestLimiter applies public REST rate-limit policy to a concrete
// bucket key.
type PublicRequestLimiter interface {
// Reserve evaluates key under policy and returns whether the request may
// proceed immediately.
Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision
}
// PublicRequestObserver captures low-cardinality public REST anti-abuse
// telemetry.
type PublicRequestObserver interface {
// RecordMalformedRequest records one malformed request in class for reason.
RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason)
}
type noopPublicRequestObserver struct{}
func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) {
}
type inMemoryPublicRequestLimiter struct {
now func() time.Time
cleanupInterval time.Duration
mu sync.Mutex
entries map[string]*publicRateLimiterEntry
nextCleanup time.Time
}
type publicRateLimiterEntry struct {
limiter *rate.Limiter
limit rate.Limit
burst int
expiresAt time.Time
}
func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter {
return &inMemoryPublicRequestLimiter{
now: time.Now,
cleanupInterval: time.Minute,
entries: make(map[string]*publicRateLimiterEntry),
}
}
func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision {
now := l.now()
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
l.mu.Lock()
defer l.mu.Unlock()
l.cleanupExpiredBucketsLocked(now)
entry, ok := l.entries[key]
if !ok || entry.limit != limit || entry.burst != policy.Burst {
entry = &publicRateLimiterEntry{
limiter: rate.NewLimiter(limit, policy.Burst),
limit: limit,
burst: policy.Burst,
}
l.entries[key] = entry
}
entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window))
reservation := entry.limiter.ReserveN(now, 1)
if !reservation.OK() {
return PublicRateLimitDecision{
Allowed: false,
RetryAfter: policy.Window,
}
}
retryAfter := reservation.DelayFrom(now)
if retryAfter > 0 {
return PublicRateLimitDecision{
Allowed: false,
RetryAfter: retryAfter,
}
}
return PublicRateLimitDecision{Allowed: true}
}
func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) {
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
return
}
for key, entry := range l.entries {
if !entry.expiresAt.After(now) {
delete(l.entries, key)
}
}
l.nextCleanup = now.Add(l.cleanupInterval)
}
func publicRateLimiterEntryTTL(window time.Duration) time.Duration {
if window < time.Minute {
return time.Minute
}
return 2 * window
}
func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc {
return func(c *gin.Context) {
class, ok := PublicRouteClassFromContext(c.Request.Context())
if !ok {
class = PublicRouteClassPublicMisc
}
allowedMethods := allowedMethodsForRequestShape(c.Request)
if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) {
c.Header("Allow", strings.Join(allowedMethods, ", "))
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
return
}
classPolicy := publicRoutePolicyForClass(policy, class)
bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes)
if err != nil {
switch {
case errors.Is(err, errRequestBodyTooLarge):
observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody)
abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit")
default:
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
}
return
}
clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr)
if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed {
abortRateLimited(c, decision.RetryAfter)
return
}
identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes)
switch {
case err == nil:
identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy)
if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed {
abortRateLimited(c, decision.RetryAfter)
return
}
case errors.Is(err, errPublicAuthIdentityNotApplicable):
default:
if reason, malformed := malformedRequestReasonFromError(err); malformed {
observer.RecordMalformedRequest(class, reason)
}
}
c.Next()
}
}
func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig {
switch class.Normalized() {
case PublicRouteClassPublicAuth:
return policy.PublicAuth
case PublicRouteClassBrowserBootstrap:
return policy.BrowserBootstrap
case PublicRouteClassBrowserAsset:
return policy.BrowserAsset
default:
return policy.PublicMisc
}
}
func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig {
switch requestPath {
case "/api/v1/public/auth/send-email-code":
return policy.SendEmailCodeIdentity
case "/api/v1/public/auth/confirm-email-code":
return policy.ConfirmEmailCodeIdentity
default:
return config.PublicAuthIdentityPolicyConfig{}
}
}
func allowedMethodsForRequestShape(r *http.Request) []string {
switch {
case isPublicAuthPath(r.URL.Path):
return []string{http.MethodPost}
case isProbePath(r.URL.Path):
return []string{http.MethodGet}
case matchesBrowserAssetRequestShape(r):
return []string{http.MethodGet, http.MethodHead}
case matchesBrowserBootstrapRequestShape(r):
return []string{http.MethodGet, http.MethodHead}
default:
return nil
}
}
func isAllowedMethod(method string, allowedMethods []string) bool {
for _, allowedMethod := range allowedMethods {
if method == allowedMethod {
return true
}
}
return false
}
func isPublicAuthPath(requestPath string) bool {
switch requestPath {
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
return true
default:
return false
}
}
func isProbePath(requestPath string) bool {
switch requestPath {
case "/healthz", "/readyz":
return true
default:
return false
}
}
func matchesBrowserBootstrapRequestShape(r *http.Request) bool {
if r.URL.Path == "/" {
return true
}
return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html")
}
func matchesBrowserAssetRequestShape(r *http.Request) bool {
if strings.HasPrefix(r.URL.Path, "/assets/") {
return true
}
switch strings.ToLower(path.Ext(r.URL.Path)) {
case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest":
return true
default:
return false
}
}
func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) {
if r == nil {
return nil, nil
}
if r.Body == nil {
r.Body = io.NopCloser(bytes.NewReader(nil))
return nil, nil
}
bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1))
closeErr := r.Body.Close()
if err != nil {
return nil, err
}
if closeErr != nil {
return nil, closeErr
}
if int64(len(bodyBytes)) > maxBodyBytes {
return nil, errRequestBodyTooLarge
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return bodyBytes, nil
}
func abortRateLimited(c *gin.Context, retryAfter time.Duration) {
c.Header("Retry-After", retryAfterHeaderValue(retryAfter))
abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded")
}
func retryAfterHeaderValue(delay time.Duration) string {
seconds := int64(math.Ceil(delay.Seconds()))
if seconds < 1 {
seconds = 1
}
return strconv.FormatInt(seconds, 10)
}
func clientIPFromRemoteAddr(remoteAddr string) string {
host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr))
if err == nil {
return host
}
remoteAddr = strings.TrimSpace(remoteAddr)
if remoteAddr == "" {
return "unknown"
}
return remoteAddr
}
func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string {
return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP
}
func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string {
return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue
}