package restapi import ( "bytes" "errors" "io" "math" "net" "net/http" "path" "strconv" "strings" "sync" "time" "galaxy/gateway/internal/config" "github.com/gin-gonic/gin" "golang.org/x/time/rate" ) const ( errorCodeRequestTooLarge = "request_too_large" errorCodeRateLimited = "rate_limited" publicRESTIPBucketKeySegment = "/ip=" ) var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit") // PublicMalformedRequestReason identifies the stable malformed-request counter // dimension recorded by the public REST anti-abuse middleware. type PublicMalformedRequestReason string const ( // PublicMalformedRequestReasonEmptyBody records a missing request body. PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body" // PublicMalformedRequestReasonMalformedJSON records syntactically malformed // JSON. PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json" // PublicMalformedRequestReasonInvalidJSONValue records JSON values whose // types do not match the expected request schema. PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value" // PublicMalformedRequestReasonUnknownField records JSON objects with fields // outside the documented schema. PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field" // PublicMalformedRequestReasonMultipleJSONObjects records requests that // contain more than one JSON object. PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects" // PublicMalformedRequestReasonOversizedBody records requests whose bodies // exceed the configured class limit. PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body" ) // PublicRateLimitDecision describes the outcome returned by a public REST // limiter for one request bucket reservation attempt. type PublicRateLimitDecision struct { // Allowed reports whether the request may proceed immediately. Allowed bool // RetryAfter is the minimum delay the client should wait before retrying // when Allowed is false. RetryAfter time.Duration } // PublicRequestLimiter applies public REST rate-limit policy to a concrete // bucket key. type PublicRequestLimiter interface { // Reserve evaluates key under policy and returns whether the request may // proceed immediately. Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision } // PublicRequestObserver captures low-cardinality public REST anti-abuse // telemetry. type PublicRequestObserver interface { // RecordMalformedRequest records one malformed request in class for reason. RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason) } type noopPublicRequestObserver struct{} func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) { } type inMemoryPublicRequestLimiter struct { now func() time.Time cleanupInterval time.Duration mu sync.Mutex entries map[string]*publicRateLimiterEntry nextCleanup time.Time } type publicRateLimiterEntry struct { limiter *rate.Limiter limit rate.Limit burst int expiresAt time.Time } func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter { return &inMemoryPublicRequestLimiter{ now: time.Now, cleanupInterval: time.Minute, entries: make(map[string]*publicRateLimiterEntry), } } func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision { now := l.now() limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds()) l.mu.Lock() defer l.mu.Unlock() l.cleanupExpiredBucketsLocked(now) entry, ok := l.entries[key] if !ok || entry.limit != limit || entry.burst != policy.Burst { entry = &publicRateLimiterEntry{ limiter: rate.NewLimiter(limit, policy.Burst), limit: limit, burst: policy.Burst, } l.entries[key] = entry } entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window)) reservation := entry.limiter.ReserveN(now, 1) if !reservation.OK() { return PublicRateLimitDecision{ Allowed: false, RetryAfter: policy.Window, } } retryAfter := reservation.DelayFrom(now) if retryAfter > 0 { return PublicRateLimitDecision{ Allowed: false, RetryAfter: retryAfter, } } return PublicRateLimitDecision{Allowed: true} } func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) { if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) { return } for key, entry := range l.entries { if !entry.expiresAt.After(now) { delete(l.entries, key) } } l.nextCleanup = now.Add(l.cleanupInterval) } func publicRateLimiterEntryTTL(window time.Duration) time.Duration { if window < time.Minute { return time.Minute } return 2 * window } func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc { return func(c *gin.Context) { class, ok := PublicRouteClassFromContext(c.Request.Context()) if !ok { class = PublicRouteClassPublicMisc } allowedMethods := allowedMethodsForRequestShape(c.Request) if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) { c.Header("Allow", strings.Join(allowedMethods, ", ")) abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route") return } classPolicy := publicRoutePolicyForClass(policy, class) bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes) if err != nil { switch { case errors.Is(err, errRequestBodyTooLarge): observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody) abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit") default: abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error") } return } clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr) if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed { abortRateLimited(c, decision.RetryAfter) return } identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes) switch { case err == nil: identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy) if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed { abortRateLimited(c, decision.RetryAfter) return } case errors.Is(err, errPublicAuthIdentityNotApplicable): default: if reason, malformed := malformedRequestReasonFromError(err); malformed { observer.RecordMalformedRequest(class, reason) } } c.Next() } } func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig { switch class.Normalized() { case PublicRouteClassPublicAuth: return policy.PublicAuth case PublicRouteClassBrowserBootstrap: return policy.BrowserBootstrap case PublicRouteClassBrowserAsset: return policy.BrowserAsset default: return policy.PublicMisc } } func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig { switch requestPath { case "/api/v1/public/auth/send-email-code": return policy.SendEmailCodeIdentity case "/api/v1/public/auth/confirm-email-code": return policy.ConfirmEmailCodeIdentity default: return config.PublicAuthIdentityPolicyConfig{} } } func allowedMethodsForRequestShape(r *http.Request) []string { switch { case isPublicAuthPath(r.URL.Path): return []string{http.MethodPost} case isProbePath(r.URL.Path): return []string{http.MethodGet} case matchesBrowserAssetRequestShape(r): return []string{http.MethodGet, http.MethodHead} case matchesBrowserBootstrapRequestShape(r): return []string{http.MethodGet, http.MethodHead} default: return nil } } func isAllowedMethod(method string, allowedMethods []string) bool { for _, allowedMethod := range allowedMethods { if method == allowedMethod { return true } } return false } func isPublicAuthPath(requestPath string) bool { switch requestPath { case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code": return true default: return false } } func isProbePath(requestPath string) bool { switch requestPath { case "/healthz", "/readyz": return true default: return false } } func matchesBrowserBootstrapRequestShape(r *http.Request) bool { if r.URL.Path == "/" { return true } return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html") } func matchesBrowserAssetRequestShape(r *http.Request) bool { if strings.HasPrefix(r.URL.Path, "/assets/") { return true } switch strings.ToLower(path.Ext(r.URL.Path)) { case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest": return true default: return false } } func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) { if r == nil { return nil, nil } if r.Body == nil { r.Body = io.NopCloser(bytes.NewReader(nil)) return nil, nil } bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1)) closeErr := r.Body.Close() if err != nil { return nil, err } if closeErr != nil { return nil, closeErr } if int64(len(bodyBytes)) > maxBodyBytes { return nil, errRequestBodyTooLarge } r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) return bodyBytes, nil } func abortRateLimited(c *gin.Context, retryAfter time.Duration) { c.Header("Retry-After", retryAfterHeaderValue(retryAfter)) abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded") } func retryAfterHeaderValue(delay time.Duration) string { seconds := int64(math.Ceil(delay.Seconds())) if seconds < 1 { seconds = 1 } return strconv.FormatInt(seconds, 10) } func clientIPFromRemoteAddr(remoteAddr string) string { host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr)) if err == nil { return host } remoteAddr = strings.TrimSpace(remoteAddr) if remoteAddr == "" { return "unknown" } return remoteAddr } func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string { return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP } func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string { return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue }