379 lines
11 KiB
Go
379 lines
11 KiB
Go
package restapi
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/config"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
const (
|
|
errorCodeRequestTooLarge = "request_too_large"
|
|
errorCodeRateLimited = "rate_limited"
|
|
|
|
publicRESTIPBucketKeySegment = "/ip="
|
|
)
|
|
|
|
var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit")
|
|
|
|
// PublicMalformedRequestReason identifies the stable malformed-request counter
|
|
// dimension recorded by the public REST anti-abuse middleware.
|
|
type PublicMalformedRequestReason string
|
|
|
|
const (
|
|
// PublicMalformedRequestReasonEmptyBody records a missing request body.
|
|
PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body"
|
|
|
|
// PublicMalformedRequestReasonMalformedJSON records syntactically malformed
|
|
// JSON.
|
|
PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json"
|
|
|
|
// PublicMalformedRequestReasonInvalidJSONValue records JSON values whose
|
|
// types do not match the expected request schema.
|
|
PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value"
|
|
|
|
// PublicMalformedRequestReasonUnknownField records JSON objects with fields
|
|
// outside the documented schema.
|
|
PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field"
|
|
|
|
// PublicMalformedRequestReasonMultipleJSONObjects records requests that
|
|
// contain more than one JSON object.
|
|
PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects"
|
|
|
|
// PublicMalformedRequestReasonOversizedBody records requests whose bodies
|
|
// exceed the configured class limit.
|
|
PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body"
|
|
)
|
|
|
|
// PublicRateLimitDecision describes the outcome returned by a public REST
|
|
// limiter for one request bucket reservation attempt.
|
|
type PublicRateLimitDecision struct {
|
|
// Allowed reports whether the request may proceed immediately.
|
|
Allowed bool
|
|
|
|
// RetryAfter is the minimum delay the client should wait before retrying
|
|
// when Allowed is false.
|
|
RetryAfter time.Duration
|
|
}
|
|
|
|
// PublicRequestLimiter applies public REST rate-limit policy to a concrete
|
|
// bucket key.
|
|
type PublicRequestLimiter interface {
|
|
// Reserve evaluates key under policy and returns whether the request may
|
|
// proceed immediately.
|
|
Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision
|
|
}
|
|
|
|
// PublicRequestObserver captures low-cardinality public REST anti-abuse
|
|
// telemetry.
|
|
type PublicRequestObserver interface {
|
|
// RecordMalformedRequest records one malformed request in class for reason.
|
|
RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason)
|
|
}
|
|
|
|
type noopPublicRequestObserver struct{}
|
|
|
|
func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) {
|
|
}
|
|
|
|
type inMemoryPublicRequestLimiter struct {
|
|
now func() time.Time
|
|
cleanupInterval time.Duration
|
|
|
|
mu sync.Mutex
|
|
entries map[string]*publicRateLimiterEntry
|
|
nextCleanup time.Time
|
|
}
|
|
|
|
type publicRateLimiterEntry struct {
|
|
limiter *rate.Limiter
|
|
limit rate.Limit
|
|
burst int
|
|
expiresAt time.Time
|
|
}
|
|
|
|
func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter {
|
|
return &inMemoryPublicRequestLimiter{
|
|
now: time.Now,
|
|
cleanupInterval: time.Minute,
|
|
entries: make(map[string]*publicRateLimiterEntry),
|
|
}
|
|
}
|
|
|
|
func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision {
|
|
now := l.now()
|
|
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
|
|
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
l.cleanupExpiredBucketsLocked(now)
|
|
|
|
entry, ok := l.entries[key]
|
|
if !ok || entry.limit != limit || entry.burst != policy.Burst {
|
|
entry = &publicRateLimiterEntry{
|
|
limiter: rate.NewLimiter(limit, policy.Burst),
|
|
limit: limit,
|
|
burst: policy.Burst,
|
|
}
|
|
l.entries[key] = entry
|
|
}
|
|
|
|
entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window))
|
|
|
|
reservation := entry.limiter.ReserveN(now, 1)
|
|
if !reservation.OK() {
|
|
return PublicRateLimitDecision{
|
|
Allowed: false,
|
|
RetryAfter: policy.Window,
|
|
}
|
|
}
|
|
|
|
retryAfter := reservation.DelayFrom(now)
|
|
if retryAfter > 0 {
|
|
return PublicRateLimitDecision{
|
|
Allowed: false,
|
|
RetryAfter: retryAfter,
|
|
}
|
|
}
|
|
|
|
return PublicRateLimitDecision{Allowed: true}
|
|
}
|
|
|
|
func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) {
|
|
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
|
|
return
|
|
}
|
|
|
|
for key, entry := range l.entries {
|
|
if !entry.expiresAt.After(now) {
|
|
delete(l.entries, key)
|
|
}
|
|
}
|
|
|
|
l.nextCleanup = now.Add(l.cleanupInterval)
|
|
}
|
|
|
|
func publicRateLimiterEntryTTL(window time.Duration) time.Duration {
|
|
if window < time.Minute {
|
|
return time.Minute
|
|
}
|
|
|
|
return 2 * window
|
|
}
|
|
|
|
func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
|
if !ok {
|
|
class = PublicRouteClassPublicMisc
|
|
}
|
|
|
|
allowedMethods := allowedMethodsForRequestShape(c.Request)
|
|
if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) {
|
|
c.Header("Allow", strings.Join(allowedMethods, ", "))
|
|
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
|
|
return
|
|
}
|
|
|
|
classPolicy := publicRoutePolicyForClass(policy, class)
|
|
bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, errRequestBodyTooLarge):
|
|
observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody)
|
|
abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit")
|
|
default:
|
|
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
|
}
|
|
return
|
|
}
|
|
|
|
clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr)
|
|
if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed {
|
|
abortRateLimited(c, decision.RetryAfter)
|
|
return
|
|
}
|
|
|
|
identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes)
|
|
switch {
|
|
case err == nil:
|
|
identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy)
|
|
if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed {
|
|
abortRateLimited(c, decision.RetryAfter)
|
|
return
|
|
}
|
|
case errors.Is(err, errPublicAuthIdentityNotApplicable):
|
|
default:
|
|
if reason, malformed := malformedRequestReasonFromError(err); malformed {
|
|
observer.RecordMalformedRequest(class, reason)
|
|
}
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig {
|
|
switch class.Normalized() {
|
|
case PublicRouteClassPublicAuth:
|
|
return policy.PublicAuth
|
|
case PublicRouteClassBrowserBootstrap:
|
|
return policy.BrowserBootstrap
|
|
case PublicRouteClassBrowserAsset:
|
|
return policy.BrowserAsset
|
|
default:
|
|
return policy.PublicMisc
|
|
}
|
|
}
|
|
|
|
func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig {
|
|
switch requestPath {
|
|
case "/api/v1/public/auth/send-email-code":
|
|
return policy.SendEmailCodeIdentity
|
|
case "/api/v1/public/auth/confirm-email-code":
|
|
return policy.ConfirmEmailCodeIdentity
|
|
default:
|
|
return config.PublicAuthIdentityPolicyConfig{}
|
|
}
|
|
}
|
|
|
|
func allowedMethodsForRequestShape(r *http.Request) []string {
|
|
switch {
|
|
case isPublicAuthPath(r.URL.Path):
|
|
return []string{http.MethodPost}
|
|
case isProbePath(r.URL.Path):
|
|
return []string{http.MethodGet}
|
|
case matchesBrowserAssetRequestShape(r):
|
|
return []string{http.MethodGet, http.MethodHead}
|
|
case matchesBrowserBootstrapRequestShape(r):
|
|
return []string{http.MethodGet, http.MethodHead}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func isAllowedMethod(method string, allowedMethods []string) bool {
|
|
for _, allowedMethod := range allowedMethods {
|
|
if method == allowedMethod {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isPublicAuthPath(requestPath string) bool {
|
|
switch requestPath {
|
|
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isProbePath(requestPath string) bool {
|
|
switch requestPath {
|
|
case "/healthz", "/readyz":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func matchesBrowserBootstrapRequestShape(r *http.Request) bool {
|
|
if r.URL.Path == "/" {
|
|
return true
|
|
}
|
|
|
|
return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html")
|
|
}
|
|
|
|
func matchesBrowserAssetRequestShape(r *http.Request) bool {
|
|
if strings.HasPrefix(r.URL.Path, "/assets/") {
|
|
return true
|
|
}
|
|
|
|
switch strings.ToLower(path.Ext(r.URL.Path)) {
|
|
case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) {
|
|
if r == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
if r.Body == nil {
|
|
r.Body = io.NopCloser(bytes.NewReader(nil))
|
|
return nil, nil
|
|
}
|
|
|
|
bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1))
|
|
closeErr := r.Body.Close()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if closeErr != nil {
|
|
return nil, closeErr
|
|
}
|
|
if int64(len(bodyBytes)) > maxBodyBytes {
|
|
return nil, errRequestBodyTooLarge
|
|
}
|
|
|
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
return bodyBytes, nil
|
|
}
|
|
|
|
func abortRateLimited(c *gin.Context, retryAfter time.Duration) {
|
|
c.Header("Retry-After", retryAfterHeaderValue(retryAfter))
|
|
abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded")
|
|
}
|
|
|
|
func retryAfterHeaderValue(delay time.Duration) string {
|
|
seconds := int64(math.Ceil(delay.Seconds()))
|
|
if seconds < 1 {
|
|
seconds = 1
|
|
}
|
|
|
|
return strconv.FormatInt(seconds, 10)
|
|
}
|
|
|
|
func clientIPFromRemoteAddr(remoteAddr string) string {
|
|
host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr))
|
|
if err == nil {
|
|
return host
|
|
}
|
|
|
|
remoteAddr = strings.TrimSpace(remoteAddr)
|
|
if remoteAddr == "" {
|
|
return "unknown"
|
|
}
|
|
|
|
return remoteAddr
|
|
}
|
|
|
|
func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string {
|
|
return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP
|
|
}
|
|
|
|
func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string {
|
|
return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue
|
|
}
|