diff --git a/.gitignore b/.gitignore index f3ec25c..242e02f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +.codex .vscode/ artifacts/ \ No newline at end of file diff --git a/README.md b/README.md index 39da9bc..ab1c3e1 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,27 @@ It is the starting point for implementing the external edge layer, authenticatio - Internal business services are **not reachable directly from outside**. - Any external command, except public auth commands, must be authenticated before it is routed further. - Gateway handles only edge concerns. Business validation and domain rules remain inside business services. -- Push / long-polling delivery is also handled by the gateway. +- Gateway owns external delivery channels; the v1 implementation uses + authenticated gRPC server-streaming push, while long-polling remains out of + scope. + +```mermaid +flowchart LR + Client["Clients\n(native and browser)"] + Gateway["Edge Gateway\npublic REST + authenticated gRPC"] + Auth["Auth / Session Service"] + Business["Business Services"] + Redis["Redis\nsession cache + replay keys + event streams"] + Telemetry["Telemetry Backends\nPrometheus / OTLP"] + + Client --> Gateway + Gateway --> Auth + Gateway --> Business + Gateway --> Redis + Gateway --> Telemetry + Auth --> Redis + Business --> Redis +``` ## Main Components @@ -33,7 +53,7 @@ Responsibilities: - rate limiting and abuse protection - command routing - basic policy enforcement -- long-polling / push connection handling +- authenticated gRPC server-streaming push connection handling - delivery of client-facing events from pub/sub The gateway must not implement domain-specific business logic. @@ -158,21 +178,26 @@ Flow: 7. gateway verifies anti-replay constraints 8. gateway applies rate limits and basic policy checks 9. gateway extracts authenticated context, including `user_id` -10. gateway routes the request to the target business service based on `command_type` +10. gateway routes the request to the target business service based on `message_type` No business service should receive an unauthenticated external request. -### Push / Long-Polling Flow +### Push Flow -The gateway owns external push / long-polling connections. +The gateway owns external delivery connections. +The v1 gateway uses authenticated gRPC server-streaming push. +Long-polling remains out of scope for the implemented gateway. Flow: -1. client opens authenticated push / long-polling connection through gateway +1. client opens authenticated push connection through gateway 2. gateway binds connection to `user_id` and `device_session_id` -3. gateway may send current server time for clock offset calculation -4. internal services publish client-facing events to pub/sub -5. gateway consumes those events and delivers them to the proper client connections +3. gateway starts the channel with a signed service event that includes the + current server time for clock offset calculation +4. internal services publish client-facing events to pub/sub targeted by + `user_id` and optionally by `device_session_id` +5. gateway consumes those events and delivers them to the proper client + connections Gateway is a delivery layer, not the source of business events. @@ -184,7 +209,7 @@ Typical internal authenticated context: - `user_id` - `device_session_id` -- `command_type` +- `message_type` - verified payload bytes - transport `request_id` - optional command id / trace id @@ -218,7 +243,7 @@ When a device session is revoked: - auth/session service publishes revoke/invalidation event - gateway updates or invalidates session cache - gateway rejects further requests for that session -- gateway closes active push / long-polling connections bound to that session, if applicable +- gateway closes active authenticated push streams bound to that session, if applicable ## Non-Goals diff --git a/README_security.md b/SECURITY.md similarity index 53% rename from README_security.md rename to SECURITY.md index d02fcb9..a9133df 100644 --- a/README_security.md +++ b/SECURITY.md @@ -15,13 +15,34 @@ It is the starting point for implementing authenticated device sessions, signed - Responses are authenticated by server-side signatures. - Transport integrity and freshness are verified before payload is processed. +```mermaid +sequenceDiagram + participant Client + participant Gateway + participant SessionCache + participant ReplayStore + participant Business + + Client->>Gateway: ExecuteCommand / SubscribeEvents\n(protocol_version, device_session_id,\nmessage_type, timestamp_ms, request_id,\npayload_hash, signature) + Gateway->>SessionCache: lookup(device_session_id) + SessionCache-->>Gateway: user_id, client_public_key, status + Gateway->>Gateway: verify payload_hash, signature,\nfreshness window + Gateway->>ReplayStore: reserve(device_session_id, request_id, ttl) + ReplayStore-->>Gateway: accepted / duplicate + Gateway->>Business: verified command context + Business-->>Gateway: response payload + Gateway-->>Client: signed response + Gateway-->>Client: signed push events on SubscribeEvents +``` + ## Device Session Model After successful login through e-mail code: 1. client generates an asymmetric key pair 2. private key remains on the client device -3. public key is registered on the server +3. public key is registered on the server as the standard base64-encoded raw + 32-byte Ed25519 public key 4. server creates a persistent `device_session` 5. client stores: - `device_session_id` @@ -31,7 +52,7 @@ The server stores at least: - `device_session_id` - `user_id` -- client public key +- base64-encoded raw 32-byte Ed25519 client public key - session status - revoke metadata @@ -66,11 +87,18 @@ Minimal required fields: - `request_id` - `payload_hash` +The supported request `protocol_version` literal for the v1 gateway transport +is `v1`. +The v1 authenticated request signature scheme is Ed25519. +The stored client public key is the standard base64-encoded raw 32-byte +Ed25519 public key, and the request `signature` field carries the raw +64-byte Ed25519 signature bytes. + ### Request Signing Input The client signs canonical bytes built from: -- request domain marker, for example `myapp-request-v1` +- request domain marker `galaxy-request-v1` - `protocol_version` - `device_session_id` - `message_type` @@ -78,7 +106,16 @@ The client signs canonical bytes built from: - `request_id` - `payload_hash` -`payload_hash` should be computed from raw `payload_bytes`. +The canonical v1 request signing input uses this binary encoding: + +- each `string` and `bytes` field is encoded as `uvarint(len(field_bytes))` + followed by raw bytes +- `timestamp_ms` is encoded as an 8-byte big-endian unsigned integer +- fields are appended in the exact order listed above + +`payload_hash` is the raw 32-byte SHA-256 digest computed from raw +`payload_bytes`. +Empty payloads still use the SHA-256 digest of the empty byte slice. The goal is to bind the signature to: @@ -109,15 +146,80 @@ Minimal required fields: The server signs canonical bytes built from: -- response domain marker, for example `myapp-response-v1` +- response domain marker `galaxy-response-v1` - `protocol_version` - `request_id` - `timestamp_ms` - `result_code` - `payload_hash` +The current gateway v1 response signature scheme is Ed25519. +The canonical v1 response signing input uses this binary encoding: + +- each `string` and `bytes` field is encoded as `uvarint(len(field_bytes))` + followed by raw bytes +- `timestamp_ms` is encoded as an 8-byte big-endian unsigned integer +- fields are appended in the exact order listed above + +The gateway server loads the response signing key from a PKCS#8 PEM-encoded +Ed25519 private key. The client verifies the signature using a trusted server public key. +## Event Structure + +Each server push event logically contains: + +- `payload_bytes` +- `event_envelope` +- `signature` + +### Event Envelope + +Minimal required fields: + +- `event_type` +- `event_id` +- `timestamp_ms` +- `payload_hash` + +Optional fields: + +- `request_id` +- `trace_id` + +The current gateway v1 stream-event signature scheme is Ed25519. +The gateway currently signs unary responses and stream events with the same +PKCS#8 PEM-encoded Ed25519 private key. +The bootstrap event implemented for `SubscribeEvents` uses +`event_type = gateway.server_time`, reuses the opening subscribe `request_id` +as `event_id`, and encodes `server_time_ms` in a FlatBuffers +`gateway.ServerTimeEvent` payload. +Later client-facing push events are sourced from internal pub/sub with target +metadata `user_id` and optional `device_session_id`, plus `event_type`, +`event_id`, `payload_bytes`, and optional `request_id` / `trace_id`. +The gateway derives `timestamp_ms`, recomputes `payload_hash`, signs the +event at delivery time, and only then forwards it to the matching active +streams. + +### Event Signing Input + +The server signs canonical bytes built from: + +- event domain marker `galaxy-event-v1` +- `event_type` +- `event_id` +- `timestamp_ms` +- `request_id` +- `trace_id` +- `payload_hash` + +The canonical v1 event signing input uses this binary encoding: + +- each `string` and `bytes` field is encoded as `uvarint(len(field_bytes))` + followed by raw bytes +- `timestamp_ms` is encoded as an 8-byte big-endian unsigned integer +- fields are appended in the exact order listed above + ## Verification Order on Server Before processing payload, the server/gateway must: @@ -140,6 +242,14 @@ Before accepting response payload, the client must: 4. verify timestamp freshness if applicable 5. only then accept the response payload +Before accepting push-event payload, the client must: + +1. verify server event signature +2. verify `payload_hash` +3. verify `request_id` when the event is correlated to the opening request +4. verify timestamp freshness if applicable +5. only then accept the event payload + ## Anti-Replay Model Transport anti-replay uses: @@ -148,7 +258,11 @@ Transport anti-replay uses: - `request_id` The server accepts requests only inside an allowed time window. +The current gateway v1 freshness window is symmetric `±5 minutes` around +server time. Recently seen `request_id` values must be tracked for the corresponding session and rejected on reuse. +Replay reservations should remain active until `timestamp_ms + freshness_window` +so future-skewed but still valid requests stay protected after acceptance. This protects transport freshness. It does not replace business idempotency. @@ -159,12 +273,13 @@ Clients use server time offset instead of trusting local clock directly. Expected approach: -- client establishes authenticated long-polling / push connection +- client establishes an authenticated `SubscribeEvents` gRPC stream - server provides current server time - client computes local offset - subsequent signed requests use adjusted time -No extra sync request is required if push / long-polling already exists. +No extra sync request is required when the authenticated push stream is already +open. ## TLS and MITM Considerations diff --git a/gateway/.env.example b/gateway/.env.example new file mode 100644 index 0000000..004718f --- /dev/null +++ b/gateway/.env.example @@ -0,0 +1,27 @@ +# Required startup settings. +GATEWAY_SESSION_CACHE_REDIS_ADDR=127.0.0.1:6379 +GATEWAY_SESSION_EVENTS_REDIS_STREAM=gateway:session-events +GATEWAY_CLIENT_EVENTS_REDIS_STREAM=gateway:client-events +GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH=./secrets/response-signer.pem + +# Main listeners. +GATEWAY_PUBLIC_HTTP_ADDR=127.0.0.1:8080 +GATEWAY_AUTHENTICATED_GRPC_ADDR=127.0.0.1:9090 + +# Optional admin listener. +# GATEWAY_ADMIN_HTTP_ADDR=127.0.0.1:9091 + +# Optional Redis tuning. +# GATEWAY_SESSION_CACHE_REDIS_DB=0 +# GATEWAY_SESSION_CACHE_REDIS_KEY_PREFIX=gateway:session: +# GATEWAY_REPLAY_REDIS_KEY_PREFIX=gateway:replay: +# GATEWAY_SESSION_CACHE_REDIS_TLS_ENABLED=false + +# Optional public-auth integration. Without an injected adapter the routes stay +# mounted and return 503 service_unavailable. +# GATEWAY_PUBLIC_AUTH_UPSTREAM_TIMEOUT=3s + +# Optional shutdown and telemetry tuning. +# GATEWAY_SHUTDOWN_TIMEOUT=5s +# GATEWAY_LOG_LEVEL=info +# OTEL_TRACES_EXPORTER=none diff --git a/gateway/PLAN.md b/gateway/PLAN.md index 3731c97..fc66e88 100644 --- a/gateway/PLAN.md +++ b/gateway/PLAN.md @@ -21,11 +21,19 @@ The intended v1 architecture is: - `protocol_version` covers transport and envelope compatibility, not business payload schema compatibility. - FlatBuffers are used for business payload bytes only. +- Phase 3 public auth uses a challenge-token REST flow: + `send-email-code(email) -> challenge_id` and + `confirm-email-code(challenge_id, code, client_public_key) -> device_session_id`. +- Phase 3 uses a consumer-side `AuthServiceClient` inside `gateway`; the + default process wiring keeps public auth routes mounted and returns + `503 service_unavailable` until a concrete upstream adapter is added. - Browser bootstrap and asset traffic are within gateway scope, even when backed by a pluggable proxy or handler. - Long-polling is out of scope for v1. -## Phase 1. Module Skeleton +## ~~Phase 1.~~ Module Skeleton + +Status: implemented. Goal: create the runnable gateway process skeleton. @@ -49,7 +57,9 @@ Targeted tests: - startup with valid config; - shutdown without leaked goroutines. -## Phase 2. Public REST Server +## ~~Phase 2.~~ Public REST Server + +Status: implemented. Goal: add the unauthenticated HTTP server shell. @@ -73,7 +83,9 @@ Targeted tests: - health endpoint responses; - request classification smoke tests. -## Phase 3. Public Auth REST Handlers +## ~~Phase 3.~~ Public Auth REST Handlers + +Status: implemented. Goal: expose unauthenticated auth commands through REST/JSON. @@ -96,7 +108,9 @@ Targeted tests: - success and validation errors for both routes; - no session lookup on public auth paths. -## Phase 4. Public Traffic Classification +## ~~Phase 4.~~ Public Traffic Classification + +Status: implemented. Goal: isolate public traffic into stable anti-abuse classes. @@ -118,7 +132,9 @@ Targeted tests: - per-class routing tests; - bucket isolation tests. -## Phase 5. Public REST Anti-Abuse +## ~~Phase 5.~~ Public REST Anti-Abuse + +Status: implemented. Goal: add coarse protection to unauthenticated REST traffic. @@ -142,7 +158,9 @@ Targeted tests: - bootstrap burst stays outside auth abuse counters; - invalid methods and oversized bodies are rejected. -## Phase 6. gRPC Server and Public Contracts +## ~~Phase 6.~~ gRPC Server and Public Contracts + +Status: implemented. Goal: bring up authenticated transport over gRPC and HTTP/2. @@ -165,7 +183,9 @@ Targeted tests: - unary transport smoke test; - stream transport smoke test. -## Phase 7. Envelope Parsing and Protocol Gate +## ~~Phase 7.~~ Envelope Parsing and Protocol Gate + +Status: implemented. Goal: validate the gRPC control envelope before security checks continue. @@ -186,7 +206,9 @@ Targeted tests: - missing field rejection; - unsupported `protocol_version` rejection. -## Phase 8. Session Cache Lookup +## ~~Phase 8.~~ Session Cache Lookup + +Status: implemented. Goal: resolve authenticated identity from cache. @@ -208,7 +230,9 @@ Targeted tests: - cache miss reject; - revoked session reject. -## Phase 9. Payload Hash and Signing Input +## ~~Phase 9.~~ Payload Hash and Signing Input + +Status: implemented. Goal: verify payload integrity before signature verification. @@ -228,7 +252,9 @@ Targeted tests: - payload hash mismatch reject; - canonical bytes differ when signed fields change. -## Phase 10. Client Signature Verification +## ~~Phase 10.~~ Client Signature Verification + +Status: implemented. Goal: authenticate the request origin using the session public key. @@ -249,7 +275,9 @@ Targeted tests: - bad signature reject; - wrong-key reject. -## Phase 11. Freshness and Anti-Replay +## ~~Phase 11.~~ Freshness and Anti-Replay + +Status: implemented. Goal: enforce transport freshness and replay protection. @@ -271,7 +299,9 @@ Targeted tests: - replay reject for same session and request ID; - distinct sessions do not collide. -## Phase 12. Authenticated Rate Limits and Policy +## ~~Phase 12.~~ Authenticated Rate Limits and Policy + +Status: implemented. Goal: apply edge policy after transport authenticity is established. @@ -291,7 +321,10 @@ Targeted tests: - per-dimension throttling; - bucket isolation from public traffic. -## Phase 13. Internal Authenticated Command and Routing +## ~~Phase 13.~~ Internal Authenticated Command and Routing + +Status: implemented. +Note: delivered together with Phase 14 signed unary responses. Goal: forward only verified context to downstream services. @@ -313,7 +346,9 @@ Targeted tests: - route selection by `message_type`; - downstream receives the expected authenticated context. -## Phase 14. Signed Unary Responses +## ~~Phase 14.~~ Signed Unary Responses + +Status: implemented as part of Phase 13 delivery. Goal: return verifiable server responses to authenticated clients. @@ -335,7 +370,9 @@ Targeted tests: - response correlation test; - server signature generation test. -## Phase 15. Session Update and Revocation Events +## ~~Phase 15.~~ Session Update and Revocation Events + +Status: implemented. Goal: keep gateway session state current without synchronous hot-path lookups. @@ -357,7 +394,9 @@ Targeted tests: - cache update from event; - revocation event invalidates cached session. -## Phase 16. Authenticated Push Stream +## ~~Phase 16.~~ Authenticated Push Stream + +Status: implemented. Goal: open a verified server-streaming channel for client-facing delivery. @@ -379,7 +418,9 @@ Targeted tests: - rejected stream open for invalid session; - first event contains server time. -## Phase 17. Event Fan-Out +## ~~Phase 17.~~ Event Fan-Out + +Status: implemented. Goal: deliver client-facing events from internal pub/sub to active streams. @@ -401,7 +442,9 @@ Targeted tests: - multi-device delivery for one user; - unrelated sessions do not receive the event. -## Phase 18. Revocation-Driven Stream Teardown +## ~~Phase 18.~~ Revocation-Driven Stream Teardown + +Status: implemented. Goal: terminate active delivery channels when a session is revoked. @@ -422,7 +465,12 @@ Targeted tests: - revoke closes active stream; - revoked session cannot reopen the stream. -## Phase 19. Observability and Shutdown Hardening +## ~~Phase 19.~~ Observability and Shutdown Hardening + +Status: implemented. +Note: delivered with `zap` structured logging, OpenTelemetry tracing and +metrics, the optional private admin `/metrics` listener, timeout budgets, and +shutdown-driven push-stream teardown. Goal: make the service operable in production. @@ -446,7 +494,12 @@ Targeted tests: - shutdown closes listeners and active streams; - secret and signature values are not logged. -## Phase 20. Acceptance Pass +## ~~Phase 20.~~ Acceptance Pass + +Status: implemented. +Note: acceptance pass reconciled README/OpenAPI/root architecture +documentation, fixed the documented public-auth projected-error contract, and +added focused regression coverage including OpenAPI validation. Goal: reconcile implementation, documentation, and regression coverage. diff --git a/gateway/README.md b/gateway/README.md index dddb9af..414842a 100644 --- a/gateway/README.md +++ b/gateway/README.md @@ -1,5 +1,46 @@ # Edge Gateway +## Run and Dependencies + +`cmd/gateway` starts with built-in listener defaults, but it still requires: + +- one reachable Redis deployment for session lookup, replay reservations, and + both internal event streams; +- one configured session event stream via `GATEWAY_SESSION_EVENTS_REDIS_STREAM`; +- one configured client event stream via `GATEWAY_CLIENT_EVENTS_REDIS_STREAM`; +- one PKCS#8 PEM-encoded Ed25519 response-signer key referenced by + `GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH`. + +Required startup environment variables: + +- `GATEWAY_SESSION_CACHE_REDIS_ADDR` +- `GATEWAY_SESSION_EVENTS_REDIS_STREAM` +- `GATEWAY_CLIENT_EVENTS_REDIS_STREAM` +- `GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH` + +Optional integrations: + +- `GATEWAY_ADMIN_HTTP_ADDR` enables the private `/metrics` listener; +- an injected `AuthServiceClient` enables real public auth handling; +- injected downstream routes are required for successful `ExecuteCommand`. + +Operational caveats: + +- public auth routes stay mounted and return `503 service_unavailable` until an + auth adapter is wired; +- authenticated gRPC starts without downstream routes, but `ExecuteCommand` + returns gRPC `UNIMPLEMENTED` until routing is configured. + +Additional module docs: + +- [Public REST contract](openapi.yaml) +- [Documentation index](docs/README.md) +- [Runtime and components](docs/runtime.md) +- [Request and push flows](docs/flows.md) +- [Operator runbook](docs/runbook.md) +- [Configuration and contract examples](docs/examples.md) +- [Example `.env`](.env.example) + ## Purpose `Edge Gateway` is the only public ingress for Galaxy Plus clients. @@ -40,29 +81,97 @@ The gateway exposes two external transport classes. | Transport | Audience | Authentication | Payload format | Primary use | | --- | --- | --- | --- | --- | -| REST/JSON | Public, unauthenticated traffic | No device session auth | JSON | Public auth commands, health checks, browser/bootstrap traffic | +| REST/JSON | Public, unauthenticated traffic | No device session auth | JSON | Health checks, public auth commands, and browser/bootstrap traffic | | gRPC over HTTP/2 | Authenticated clients only | Required | FlatBuffers payload inside protobuf control envelope | Verified commands and push delivery | ### Public REST Surface The public REST surface is used for commands that must work before a device session exists and for browser-originated traffic that may share the same edge. +It covers the probe endpoints, public auth routes, and coarse public +anti-abuse. -Stable public endpoints: +Currently implemented public endpoints: -- `POST /api/v1/public/auth/send-email-code` -- `POST /api/v1/public/auth/confirm-email-code` - `GET /healthz` - `GET /readyz` +- `POST /api/v1/public/auth/send-email-code` +- `POST /api/v1/public/auth/confirm-email-code` + +The implemented REST contract is documented in [`openapi.yaml`](openapi.yaml). +The listener address is configured by `GATEWAY_PUBLIC_HTTP_ADDR`. +The public REST listener read budgets are configured by: + +- `GATEWAY_PUBLIC_HTTP_READ_HEADER_TIMEOUT` with default `2s`; +- `GATEWAY_PUBLIC_HTTP_READ_TIMEOUT` with default `10s`; +- `GATEWAY_PUBLIC_HTTP_IDLE_TIMEOUT` with default `1m`. + +The public auth JSON contract uses a challenge-token flow: + +- `send-email-code` accepts `email` and returns `challenge_id`; +- `confirm-email-code` accepts `challenge_id`, `code`, and + `client_public_key`, then returns `device_session_id`. + +`client_public_key` is the standard base64-encoded raw 32-byte Ed25519 public +key for the device session being created. + +These routes remain unauthenticated and delegate only through an injected +`AuthServiceClient`. +The default wiring used by `cmd/gateway` keeps the routes mounted and returns +`503 service_unavailable` until a concrete upstream auth adapter is supplied. +Public auth adapter calls are wrapped in +`GATEWAY_PUBLIC_AUTH_UPSTREAM_TIMEOUT`, which defaults to `3s`. +When that timeout expires, the gateway preserves the public REST contract and +returns `503 service_unavailable`. +When an injected auth adapter returns `*AuthServiceError`, the gateway projects +that client-safe `4xx/5xx` status, `code`, and `message` back to the caller +after normalizing blank or invalid fields. Unexpected non-`AuthServiceError` +adapter failures fail closed as `500 internal_error`. + +Public anti-abuse is process-local and in-memory. +Per-IP buckets are derived only from the TCP peer `RemoteAddr`. +Forwarded proxy headers such as `X-Forwarded-For` and `Forwarded` are +intentionally ignored. +Oversized public REST bodies are rejected with `413 request_too_large`. +Rate-limited requests are rejected with `429 rate_limited` and a +`Retry-After` header. In addition to the fixed endpoints above, the gateway may front browser bootstrap or asset traffic through a pluggable public handler or proxy. That traffic belongs to dedicated public route classes and must not share rate limit buckets or abuse counters with the public auth API. +### Operational Admin Surface + +The gateway may expose one private operational HTTP listener used for metrics. + +The admin listener is disabled by default and is enabled only when +`GATEWAY_ADMIN_HTTP_ADDR` is non-empty. +When enabled, it serves: + +- `GET /metrics` + +The admin listener read budgets are configured by: + +- `GATEWAY_ADMIN_HTTP_READ_HEADER_TIMEOUT` with default `2s`; +- `GATEWAY_ADMIN_HTTP_READ_TIMEOUT` with default `10s`; +- `GATEWAY_ADMIN_HTTP_IDLE_TIMEOUT` with default `1m`. + +`/metrics` is intentionally not mounted on the public REST ingress. +It is also intentionally excluded from [`openapi.yaml`](openapi.yaml), because +that specification covers only the public REST ingress. +The endpoint exposes metrics in the Prometheus text exposition format described +in the official Prometheus documentation: +. + ### Authenticated gRPC Surface All authenticated client requests use HTTP/2 and gRPC. +The listener address is configured by `GATEWAY_AUTHENTICATED_GRPC_ADDR`. +Inbound authenticated gRPC connection setup is bounded by +`GATEWAY_AUTHENTICATED_GRPC_CONNECTION_TIMEOUT`, which defaults to `5s`. +The accepted client timestamp skew is configured by +`GATEWAY_AUTHENTICATED_GRPC_FRESHNESS_WINDOW` and defaults to `5m`. The public gRPC service exposes two methods: @@ -72,10 +181,133 @@ The public gRPC service exposes two methods: `ExecuteCommand` is a generic unary RPC. The gateway routes the request downstream by `message_type` after transport verification succeeds. +Downstream unary execution is bounded by +`GATEWAY_AUTHENTICATED_DOWNSTREAM_TIMEOUT`, which defaults to `5s`. +When that timeout expires, the gateway preserves the authenticated gRPC +contract and returns gRPC `UNAVAILABLE` with message +`downstream service is unavailable`. `SubscribeEvents` is an authenticated server-streaming RPC. It binds the stream to `user_id` and `device_session_id` and starts by sending -a service event that includes the current server time in milliseconds. +a signed service event that includes the current server time in milliseconds. + +The v1 protobuf contract lives in +`proto/galaxy/gateway/v1/edge_gateway.proto` under package +`galaxy.gateway.v1` and service `EdgeGateway`. +Generated Go bindings are committed under `proto/galaxy/gateway/v1/` and are +regenerated with: + +```bash +buf generate +``` + +The gateway validates the request envelope, device-session +cache lookup, `payload_hash`, the client Ed25519 signature, timestamp +freshness, replay reservation, authenticated rate limits, and the +authenticated policy hook before any later routing or push step runs. +Malformed envelopes are rejected with gRPC `INVALID_ARGUMENT`. +Requests with a non-empty but unsupported `protocol_version` are rejected with +gRPC `FAILED_PRECONDITION`. +The supported request `protocol_version` literal is `v1`. +Requests with an unknown `device_session_id` are rejected with gRPC +`UNAUTHENTICATED`. +Requests for revoked sessions are rejected with gRPC `FAILED_PRECONDITION`. +SessionCache backend failures, including Redis lookup or record-decode +failures, are rejected with gRPC `UNAVAILABLE`. +Requests with a `payload_hash` that is not a 32-byte SHA-256 digest or does +not match `payload_bytes` are rejected with gRPC `INVALID_ARGUMENT`. +Requests with an invalid client signature or a signature created by a +different key are rejected with gRPC `UNAUTHENTICATED` and message +`invalid request signature`. +Requests with malformed cached `client_public_key` material fail closed as +gRPC `UNAVAILABLE`. +Requests with a `timestamp_ms` outside the symmetric freshness window around +current server time are rejected with gRPC `FAILED_PRECONDITION` and message +`request timestamp is outside the freshness window`. +Requests that reuse the same `request_id` for the same `device_session_id` +inside the active replay window are rejected with gRPC +`FAILED_PRECONDITION` and message `request replay detected`. +ReplayStore backend failures fail closed with gRPC `UNAVAILABLE` and message +`replay store is unavailable`. +Authenticated rate limits are enforced independently by transport peer IP, +authenticated `device_session_id`, authenticated `user_id`, and authenticated +message class. The gateway uses the full verified `message_type` literal as the +stable v1 message-class key because the transport does not yet define a +coarser authenticated class taxonomy. The peer IP is derived only from the +gRPC transport peer address; if it is missing or cannot be parsed, the +request falls back to the stable `unknown` IP bucket. +Requests that exceed any authenticated rate-limit bucket are rejected with +gRPC `RESOURCE_EXHAUSTED` and message +`authenticated request rate limit exceeded`. +The authenticated edge policy hook runs after those rate limits and defaults +to allow-all until a concrete policy evaluator is wired into the process. +`ExecuteCommand` builds an internal authenticated command context, +resolves one exact-match downstream route by the full verified `message_type` +literal, executes the downstream unary client, and signs the response before +it is returned to the caller. When no exact downstream route is registered, +`ExecuteCommand` is rejected with gRPC `UNIMPLEMENTED` and message +`message_type is not routed`. Downstream availability failures are rejected +with gRPC `UNAVAILABLE` and message `downstream service is unavailable`. +Unexpected downstream route-resolution or execution failures are rejected with +gRPC `INTERNAL`. Successful unary responses preserve the original +`request_id`, carry a SHA-256 `payload_hash` of the returned `payload_bytes`, +and are signed with the configured server Ed25519 response signer. +The default `cmd/gateway` wiring currently installs an empty static +downstream router, so verified `ExecuteCommand` requests still return gRPC +`UNIMPLEMENTED` until concrete downstream routes are injected. +`SubscribeEvents` applies the full authenticated ingress pipeline, binds +the stream to the verified `user_id` and `device_session_id`, sends one +signed `gateway.server_time` bootstrap event whose FlatBuffers payload carries +`server_time_ms`, registers the active stream in the in-memory `PushHub`, and +then forwards signed client-facing events consumed from the configured client +event Redis stream. User-targeted events fan out to every active stream for +that user. Session-targeted events fan out only to streams whose +`user_id` and `device_session_id` both match the event target. Each active +stream uses a bounded in-memory queue; when that queue overflows, only the +affected stream is closed with gRPC `RESOURCE_EXHAUSTED` and message +`push stream overflowed`. When the session lifecycle stream reports that the +same `device_session_id` was revoked, every active `SubscribeEvents` stream +bound to that exact session is closed with gRPC `FAILED_PRECONDITION` and +message `device session is revoked`. During gateway shutdown, the in-memory +push hub is closed before gRPC graceful stop, and every active +`SubscribeEvents` stream is terminated with gRPC `UNAVAILABLE` and message +`gateway is shutting down`. +Authenticated anti-abuse budgets are configured by the +`GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_*` environment variables. + +Current authenticated gRPC defaults: + +- per-IP: `120 requests / minute`, `burst=40`; +- per-session: `60 requests / minute`, `burst=20`; +- per-user: `120 requests / minute`, `burst=40`; +- per-message-class: `60 requests / minute`, `burst=20`. + +Authenticated anti-abuse configuration surface: + +- per-IP: + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_IP_RATE_LIMIT_REQUESTS` default + `120`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_IP_RATE_LIMIT_WINDOW` default `1m`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_IP_RATE_LIMIT_BURST` default `40`; +- per-session: + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_SESSION_RATE_LIMIT_REQUESTS` default + `60`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_SESSION_RATE_LIMIT_WINDOW` default + `1m`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_SESSION_RATE_LIMIT_BURST` default + `20`; +- per-user: + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_USER_RATE_LIMIT_REQUESTS` default + `120`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_USER_RATE_LIMIT_WINDOW` default `1m`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_USER_RATE_LIMIT_BURST` default `40`; +- per-message-class: + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_REQUESTS` + default `60`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_WINDOW` + default `1m`, + `GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_BURST` + default `20`. ## Envelope and Payload Model @@ -86,10 +318,25 @@ The authenticated transport uses a split contract: - signatures are computed over canonical envelope fields and a hash of raw FlatBuffers bytes. -The gateway treats `payload_bytes` as opaque business data. +The gateway treats authenticated request `payload_bytes` as opaque business +data. It verifies integrity and forwards verified bytes downstream without rewriting them. +The request envelope version literal is `v1`. +`payload_hash` is the raw 32-byte SHA-256 digest of `payload_bytes`. +`ExecuteCommand` hashes the raw FlatBuffers payload bytes exactly as sent, +while `SubscribeEvents` with an empty payload still requires +`sha256([]byte{})` rather than a special-case value. +The v1 request signature scheme is Ed25519. +`client_public_key` is the standard base64-encoded raw 32-byte Ed25519 public +key registered during `confirm-email-code`. +`signature` carries the raw 64-byte Ed25519 signature computed over the +canonical request signing input. + +The v1 stream bootstrap payload uses the shared FlatBuffers schema +`pkg/schema/fbs/gateway.fbs` with root table `gateway.ServerTimeEvent`. + ### ExecuteCommandRequest Required fields: @@ -119,6 +366,22 @@ Required fields: - `payload_hash` - `signature` +The v1 unary response signature scheme is Ed25519 with response +domain marker `galaxy-response-v1`. +The response signing input uses the same canonical binary encoding shape as +the request signer: + +- each `string` and `bytes` field is encoded as `uvarint(len(field_bytes))` + followed by raw bytes; +- `timestamp_ms` is encoded as an 8-byte big-endian unsigned integer; +- the signed field order is `galaxy-response-v1`, `protocol_version`, + `request_id`, `timestamp_ms`, `result_code`, `payload_hash`. + +`cmd/gateway` loads the unary response signer from +`GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH`, which must point to a PKCS#8 +PEM-encoded Ed25519 private key. Startup fails when the file is absent, +unreadable, not strict PEM, not PKCS#8, or not Ed25519. + ### SubscribeEventsRequest The stream open request reuses the authenticated request model. @@ -158,6 +421,33 @@ Optional fields: - `request_id` - `trace_id` +The v1 stream-event signature scheme is Ed25519 with event domain +marker `galaxy-event-v1`. +The event signing input uses the same canonical binary encoding shape as the +request and unary response signers: + +- each `string` and `bytes` field is encoded as `uvarint(len(field_bytes))` + followed by raw bytes; +- `timestamp_ms` is encoded as an 8-byte big-endian unsigned integer; +- the signed field order is `galaxy-event-v1`, `event_type`, `event_id`, + `timestamp_ms`, `request_id`, `trace_id`, `payload_hash`. + +The bootstrap event uses: + +- `event_type = "gateway.server_time"`; +- `event_id = request_id` from the opening `SubscribeEvents` request; +- `payload_bytes` encoded as FlatBuffers `gateway.ServerTimeEvent` with + `server_time_ms`; +- the same loaded Ed25519 signer configured by + `GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH`. + +Client-facing fan-out events are sourced from the internal client +event stream. Internal publishers provide the event target and business +payload only: `user_id`, optional `device_session_id`, `event_type`, +`event_id`, `payload_bytes`, and optional `request_id` / `trace_id`. The +gateway derives `timestamp_ms`, recomputes `payload_hash`, signs the event, +and only then forwards it to the matching `SubscribeEvents` streams. + ## Verification and Routing Pipeline The gateway applies the same strict verification order for authenticated gRPC @@ -178,6 +468,38 @@ ingress. No downstream business service should receive a request that has not passed this full verification pipeline. +`ExecuteCommand` enforces steps 1 through 11 and +signs the successful unary response afterward. `SubscribeEvents` enforces +steps 1 through 9, binds the verified stream identity, sends the initial +signed server-time bootstrap event, and then keeps the stream open for push +delivery. +Malformed envelopes fail with gRPC `INVALID_ARGUMENT`. +Unsupported non-empty `protocol_version` values fail with gRPC +`FAILED_PRECONDITION`. +Unknown sessions fail with gRPC `UNAUTHENTICATED`. +Revoked sessions fail with gRPC `FAILED_PRECONDITION`. +SessionCache backend failures fail with gRPC `UNAVAILABLE`. +`payload_hash` values that are not raw 32-byte SHA-256 digests fail with gRPC +`INVALID_ARGUMENT` and message `payload_hash must be a 32-byte SHA-256 digest`. +`payload_hash` values that do not match `payload_bytes` fail with gRPC +`INVALID_ARGUMENT` and message `payload_hash does not match payload_bytes`. +Invalid request signatures fail with gRPC `UNAUTHENTICATED` and message +`invalid request signature`. +Malformed cached `client_public_key` values fail closed with gRPC +`UNAVAILABLE` and message `session cache is unavailable`. +Requests with a `timestamp_ms` outside the accepted freshness window fail with +gRPC `FAILED_PRECONDITION` and message +`request timestamp is outside the freshness window`. +Requests that reuse the same `request_id` for the same `device_session_id` +inside the active replay window fail with gRPC `FAILED_PRECONDITION` and +message `request replay detected`. +ReplayStore backend failures fail with gRPC `UNAVAILABLE` and message +`replay store is unavailable`. +Unrouted exact-match `message_type` values fail with gRPC `UNIMPLEMENTED` and +message `message_type is not routed`. +Downstream availability failures fail with gRPC `UNAVAILABLE` and message +`downstream service is unavailable`. + ## Internal Authenticated Contract Downstream services should receive an internal authenticated command rather than @@ -206,7 +528,7 @@ Expected session fields available to the gateway: - `device_session_id` - `user_id` -- client public key +- base64-encoded raw 32-byte Ed25519 client public key - session status - revoke metadata - optional client metadata @@ -217,12 +539,189 @@ Expected session fields available to the gateway: - session existence checks; - `device_session_id -> user_id`; -- access to the client public key used for signature verification; +- access to the base64-encoded raw Ed25519 client public key used for + signature verification; - revoked versus active status checks. Cache updates are event-driven. TTL is allowed only as a safety net and must not replace invalidation events. +The gateway keeps a process-local in-memory snapshot +cache in front of the Redis fallback backend. Authenticated requests read the +local snapshot first. A local miss performs one bounded Redis lookup and seeds +the local snapshot so later requests for the same session avoid another Redis +round-trip unless a later session event changes the cached state. + +The local snapshot cache intentionally has no TTL and no size-based +eviction policy. Session lifecycle events are the authoritative mechanism for +keeping the hot path current, while Redis fallback remains the safety net for +cold misses and process restarts. + +The Redis fallback implementation uses `go-redis/v9`. +`cmd/gateway` requires the Redis fallback backend during startup, issues a +bounded `PING`, and refuses to start when Redis is misconfigured or +unavailable. + +Required environment variable: + +- `GATEWAY_SESSION_CACHE_REDIS_ADDR` + +Optional environment variables: + +- `GATEWAY_SESSION_CACHE_REDIS_USERNAME` +- `GATEWAY_SESSION_CACHE_REDIS_PASSWORD` +- `GATEWAY_SESSION_CACHE_REDIS_DB` with default `0` +- `GATEWAY_SESSION_CACHE_REDIS_KEY_PREFIX` with default `gateway:session:` +- `GATEWAY_SESSION_CACHE_REDIS_LOOKUP_TIMEOUT` with default `250ms` +- `GATEWAY_SESSION_CACHE_REDIS_TLS_ENABLED` with default `false` + +The Redis key format is: + +- `` + +The Redis value is one strict JSON object: + +- `device_session_id` +- `user_id` +- `client_public_key` +- `status` +- optional `revoked_at_ms` + +`client_public_key` stores the standard base64-encoded raw 32-byte Ed25519 +public key registered for the device session. + +Malformed JSON, missing required fields, unsupported `status`, or a +`device_session_id` mismatch between the Redis value and the lookup key are +treated as SessionCache backend failures rather than as valid session states. + +### Session Event Stream + +The gateway keeps the process-local session snapshot cache synchronized from one +Redis Stream consumed through `go-redis/v9`. + +`cmd/gateway` requires the session event stream configuration during startup, +issues a bounded `PING` against the same Redis deployment used for +`SessionCache`, and refuses to start when that Redis backend is unavailable. + +Required environment variable: + +- `GATEWAY_SESSION_EVENTS_REDIS_STREAM` + +Optional environment variable: + +- `GATEWAY_SESSION_EVENTS_REDIS_READ_BLOCK_TIMEOUT` with default `1s` + +The subscriber reuses the same Redis address, ACL credentials, logical +database, timeout, and TLS settings configured for `SessionCache`. + +Each gateway replica keeps its own in-memory last-seen stream ID and consumes +the stream with plain `XREAD`, not a shared consumer group. +On startup the replica resolves the current stream tail and begins from that +point, which preserves the same fresh-process semantics as Redis `$` while +avoiding a race before the first blocking read. + +The session event payload is one strict full snapshot with these +fields: + +- `device_session_id` +- `user_id` +- `client_public_key` +- `status` +- optional `revoked_at_ms` + +Valid active and revoked snapshots upsert or replace the local session state. +Later stream entries win. +Malformed events are skipped without stopping the subscriber; when +`device_session_id` can still be extracted, the gateway evicts the local +snapshot for that session so it cannot continue using stale state. + +Session event publishers must keep the stream bounded by using +`XADD ... MAXLEN ~ ` or an equivalent retention policy. +The gateway intentionally does not trim the stream from the consumer side, +because consumer-side trimming could drop updates that another gateway replica +has not read yet. + +### Client Event Stream + +The gateway delivers client-facing push events from one dedicated Redis Stream +consumed through `go-redis/v9`. + +`cmd/gateway` requires the client event stream configuration during startup, +issues a bounded `PING` against the same Redis deployment used for +`SessionCache`, and refuses to start when that Redis backend is unavailable. + +Required environment variable: + +- `GATEWAY_CLIENT_EVENTS_REDIS_STREAM` + +Optional environment variable: + +- `GATEWAY_CLIENT_EVENTS_REDIS_READ_BLOCK_TIMEOUT` with default `1s` + +The subscriber reuses the same Redis address, ACL credentials, logical +database, timeout, and TLS settings configured for `SessionCache`. + +Each gateway replica keeps its own in-memory last-seen stream ID and consumes +the stream with plain `XREAD`, not a shared consumer group. +On startup the replica resolves the current stream tail and begins from that +point, which preserves the same fresh-process semantics as Redis `$` while +avoiding a race before the first blocking read. + +The client event payload is one strict target-plus-payload entry with +these fields: + +- `user_id` +- optional `device_session_id` +- `event_type` +- `event_id` +- `payload_bytes` +- optional `request_id` +- optional `trace_id` + +`payload_bytes` carries the raw binary-safe business payload bytes for the +outbound client event. +When `device_session_id` is absent or blank, the gateway fans the event out to +every active stream for `user_id`. +When `device_session_id` is present, the gateway fans the event out only to +active streams whose `user_id` and `device_session_id` both match. +Malformed client event entries are skipped without stopping the subscriber or +delivering partial data to clients. + +Client event publishers must keep the stream bounded by using +`XADD ... MAXLEN ~ ` or an equivalent retention policy. +The gateway intentionally does not trim the stream from the consumer side, +because consumer-side trimming could drop updates that another gateway replica +has not read yet. + +### Replay Store + +`ReplayStore` provides the hot-path anti-replay reservation for: + +- duplicate detection by `device_session_id + request_id`; +- bounded replay protection for the authenticated freshness window. + +The ReplayStore uses Redis through `go-redis/v9`. +`cmd/gateway` requires the ReplayStore backend during startup, issues a +bounded `PING`, and refuses to start when Redis is misconfigured or +unavailable. + +The ReplayStore reuses the same Redis deployment settings as `SessionCache` +and adds two replay-specific environment variables: + +- `GATEWAY_REPLAY_REDIS_KEY_PREFIX` with default `gateway:replay:` +- `GATEWAY_REPLAY_REDIS_RESERVE_TIMEOUT` with default `250ms` + +Replay keys use this format: + +- `:` + +For each accepted request, the replay reservation TTL is computed as: + +- `timestamp_ms + freshness_window - now` + +The TTL is clamped to a minimum positive duration so requests accepted exactly +on the freshness boundary still reserve their replay key. + ### Revocation Behavior When a device session is revoked: @@ -231,7 +730,9 @@ When a device session is revoked: 2. it publishes a session update or revoke event; 3. the gateway invalidates or updates `SessionCache`; 4. new unary gRPC requests for that session are rejected; -5. active `SubscribeEvents` streams for that session are closed. +5. active `SubscribeEvents` streams for that exact `device_session_id` are + closed with gRPC `FAILED_PRECONDITION` and message + `device session is revoked`. ## Public Anti-Abuse Model @@ -245,9 +746,15 @@ The gateway uses these public route classes: - `browser_asset` - `public_misc` +Any classifier result outside this fixed set is normalized to `public_misc` +before the class is stored in request context or used for policy derivation. +The canonical base bucket namespace for public REST policy is +`public_rest/class=`. + ### Public Auth -`public_auth` includes `send-email-code` and `confirm-email-code`. +`public_auth` is the stable route class for `send-email-code` and +`confirm-email-code`. This class uses stricter limits and abuse scoring because it directly touches account and session creation flows. @@ -259,6 +766,36 @@ Controls include: - malformed request counters; - elevated logging and security telemetry for repeated failures. +Current defaults: + +- per-IP: `30 requests / minute`, `burst=10`; +- `send-email-code` identity buckets: `3 requests / 10 minutes`, `burst=1`, + keyed by normalized `email`; +- `confirm-email-code` identity buckets: `6 requests / 10 minutes`, + `burst=2`, keyed by normalized `challenge_id`; +- maximum request body size: `8192` bytes; +- only `POST` is accepted for public auth routes. + +Configuration surface: + +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_MAX_BODY_BYTES` default `8192`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_RATE_LIMIT_REQUESTS` default + `30`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_RATE_LIMIT_WINDOW` default `1m`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_RATE_LIMIT_BURST` default `10`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_SEND_EMAIL_CODE_IDENTITY_RATE_LIMIT_REQUESTS` + default `3`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_SEND_EMAIL_CODE_IDENTITY_RATE_LIMIT_WINDOW` + default `10m`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_SEND_EMAIL_CODE_IDENTITY_RATE_LIMIT_BURST` + default `1`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_CONFIRM_EMAIL_CODE_IDENTITY_RATE_LIMIT_REQUESTS` + default `6`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_CONFIRM_EMAIL_CODE_IDENTITY_RATE_LIMIT_WINDOW` + default `10m`; +- `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_CONFIRM_EMAIL_CODE_IDENTITY_RATE_LIMIT_BURST` + default `2`. + ### Browser Bootstrap and Asset Traffic `browser_bootstrap` and `browser_asset` use separate coarse-grained budgets. @@ -275,6 +812,40 @@ This traffic is still constrained by: The gateway must not merge these buckets or counters with `public_auth`. +Current defaults: + +- `browser_bootstrap`: `60 requests / minute`, `burst=20`, `GET` and `HEAD` + only, and no request body; +- `browser_asset`: `300 requests / minute`, `burst=80`, `GET` and `HEAD` + only, and no request body; +- `public_misc`: `30 requests / minute`, `burst=10`, and no request body. + +Configuration surface: + +- `browser_bootstrap`: + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_MAX_BODY_BYTES` default + `0`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_RATE_LIMIT_REQUESTS` + default `60`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_RATE_LIMIT_WINDOW` default + `1m`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_RATE_LIMIT_BURST` default + `20`; +- `browser_asset`: + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_MAX_BODY_BYTES` default `0`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_RATE_LIMIT_REQUESTS` default + `300`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_RATE_LIMIT_WINDOW` default + `1m`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_RATE_LIMIT_BURST` default + `80`; +- `public_misc`: + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_MAX_BODY_BYTES` default `0`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_RATE_LIMIT_REQUESTS` default + `30`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_RATE_LIMIT_WINDOW` default `1m`, + `GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_RATE_LIMIT_BURST` default `10`. + ## Push Delivery Model The v1 push channel is a gRPC server stream. @@ -285,15 +856,34 @@ Expected stream behavior: 1. the client opens `SubscribeEvents`; 2. the gateway applies the full authenticated ingress verification pipeline; 3. the stream is bound to `user_id` and `device_session_id`; -4. the first service event includes `server_time_ms`; -5. client-facing events from internal pub/sub are fanned out to matching active - streams; -6. revoke events close affected streams. +4. the first signed service event is `gateway.server_time` and its + FlatBuffers payload includes `server_time_ms`; +5. after that bootstrap event, the stream is registered in `PushHub` and + remains open until client cancellation, server shutdown, queue overflow, + session revoke for the same `device_session_id`, or a later send failure; +6. internal pub/sub may target all active streams for one `user_id` or only + one `device_session_id` within that user; +7. the current per-stream in-memory queue capacity is `64` events and + overflow closes only the affected stream; +8. session revoke closes only streams bound to the same exact + `device_session_id` and returns gRPC `FAILED_PRECONDITION` with message + `device session is revoked`. + +## Lifecycle and Shutdown + +Gateway process shutdown is coordinated across the public REST listener, +authenticated gRPC listener, optional admin listener, internal Redis +subscribers, and telemetry runtime. + +`GATEWAY_SHUTDOWN_TIMEOUT` configures the per-component graceful shutdown +budget and defaults to `5s`. +During authenticated gRPC shutdown, the in-memory `PushHub` closes active +streams before gRPC graceful stop, so active `SubscribeEvents` calls terminate +with gRPC `UNAVAILABLE` and message `gateway is shutting down`. ## Recommended Package Layout -The initial package layout should keep transport, policy, and downstream -adapters separate: +The package layout keeps transport, policy, and downstream adapters separate: - `cmd/gateway` - `internal/app` @@ -317,11 +907,17 @@ The gateway should be built around explicit consumer-side interfaces. Provides cached session lookup by `device_session_id`. Returns enough data to verify signatures and identify the authenticated user. +The current production implementation is a process-local read-through cache in +front of a Redis fallback adapter that uses strict JSON records under a +configurable key prefix. ### ReplayStore Tracks recently seen `request_id` values per device session and rejects replayed requests inside the accepted freshness window. +The current production adapter is Redis-backed, uses a dedicated configurable +key prefix, and reserves keys with a TTL derived from +`timestamp_ms + freshness_window - now`. ### RateLimiter @@ -333,24 +929,44 @@ Applies independent policies for: - authenticated gRPC requests by user; - authenticated gRPC requests by message class. +The current rate limiter is process-local and in-memory. +Public REST keys stay under the `public_rest/...` namespace, while +authenticated gRPC keys stay under `authenticated_grpc/...`, so both traffic +surfaces keep independent buckets even when they share the same limiter +backend. + ### PublicTrafficClassifier Maps incoming public REST requests to one of the public route classes so that limits and anti-abuse counters remain isolated. +The gateway normalizes any unsupported or empty classifier output to +`public_misc`, and public policy code derives the base bucket namespace from +the normalized class as `public_rest/class=`. ### AuthServiceClient Handles public auth commands and session-related updates exchanged with the Auth / Session Service. +The gateway contract is: + +- `SendEmailCode(email) -> challenge_id` +- `ConfirmEmailCode(challenge_id, code, client_public_key) -> device_session_id` + +When no concrete implementation is wired, the gateway keeps the public routes +available and returns a stable `503 service_unavailable` response instead of +failing process startup. ### DownstreamRouter -Resolves the target downstream service or adapter by `message_type`. +Resolves the target downstream service or adapter by the full exact-match +`message_type` literal. ### DownstreamClient Executes a verified authenticated command against a downstream internal service -and returns response payload bytes plus a stable result code. +and returns response payload bytes plus a stable opaque result code. +An empty or whitespace-only result code is treated as an internal downstream +contract violation. ### EventSubscriber @@ -360,15 +976,25 @@ Subscribes to internal pub/sub topics used for: - revocations; - client-facing event delivery. +The implementation consumes two Redis Streams with replica-safe plain +`XREAD`: one strict full-session snapshot stream for the process-local session +cache and one client-facing event stream for live push fan-out. + ### PushHub Tracks active `SubscribeEvents` streams, binds them to authenticated identities, and delivers events to the correct connections. +The implementation uses one bounded in-memory queue per stream with a +default capacity of `64` events; overflowing one queue closes only that stream +and leaves the remaining streams active. ### ResponseSigner Signs unary responses and stream events so clients can verify server-originated messages. +The implementation uses one Ed25519 signer loaded from +`GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH`, which must reference a PKCS#8 +PEM-encoded private key. ### Clock @@ -382,6 +1008,7 @@ internal implementation details. Minimum error categories: - malformed request; +- request too large; - unsupported protocol; - unknown session; - revoked session; @@ -389,7 +1016,10 @@ Minimum error categories: - stale request; - replay detected; - rate limited; +- policy denied; - downstream unavailable; +- backend unavailable; +- gateway shutting down; - internal error. Observability requirements: @@ -400,6 +1030,51 @@ Observability requirements: - metrics keyed by route class, message type, result code, and reject reason; - no logging of secrets, raw private material, or raw signatures. +The service uses: + +- `go.uber.org/zap` for structured JSON logs; +- `otelgin` for the public REST listener; +- `otelgrpc` for the authenticated gRPC listener; +- OpenTelemetry metrics exported through Prometheus on the optional admin + `/metrics` listener. + +Current custom metric families: + +- `gateway.public_http.requests` +- `gateway.public_http.duration` +- `gateway.authenticated_grpc.requests` +- `gateway.authenticated_grpc.duration` +- `gateway.push.active_streams` +- `gateway.push.stream_closures` +- `gateway.internal_event_drops` + +The process-wide log level is configured by `GATEWAY_LOG_LEVEL` and +defaults to `info`. +The default OpenTelemetry resource uses `service.name=galaxy-edge-gateway` +when `OTEL_SERVICE_NAME` is unset. +If `OTEL_TRACES_EXPORTER` is unset or set to `none`, the gateway keeps tracing +runtime enabled but installs no external trace exporter. +If `OTEL_TRACES_EXPORTER=otlp`, the gateway uses the standard +`OTEL_EXPORTER_OTLP_*` environment variables to configure the OTLP trace +exporter protocol and endpoint. +The protocol selection specifically honors +`OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` first and falls back to +`OTEL_EXPORTER_OTLP_PROTOCOL` when the trace-specific variable is unset. +Supported values are `http/protobuf` and `grpc`; when both variables are +unset, the gateway defaults to `http/protobuf`. + +Structured logs intentionally omit: + +- public auth e-mail addresses, login codes, and challenge IDs; +- client public keys; +- raw payload bytes and payload hashes; +- raw request or response signatures; +- response-signer private key material and Redis credentials. + +Malformed internal session and client-event stream entries are no longer +silently dropped: the gateway logs the drop and increments +`gateway.internal_event_drops`. + ## Non-Goals The gateway is not a business authorization layer and must not grow into a diff --git a/gateway/buf.gen.yaml b/gateway/buf.gen.yaml new file mode 100644 index 0000000..e576cda --- /dev/null +++ b/gateway/buf.gen.yaml @@ -0,0 +1,11 @@ +version: v2 + +plugins: + - remote: buf.build/protocolbuffers/go:v1.36.11 + out: proto + opt: + - paths=source_relative + - remote: buf.build/grpc/go:v1.6.1 + out: proto + opt: + - paths=source_relative diff --git a/gateway/buf.lock b/gateway/buf.lock new file mode 100644 index 0000000..d15a117 --- /dev/null +++ b/gateway/buf.lock @@ -0,0 +1,6 @@ +# Generated by buf. DO NOT EDIT. +version: v2 +deps: + - name: buf.build/bufbuild/protovalidate + commit: 80ab13bee0bf4272b6161a72bf7034e0 + digest: b5:1aa6a965be5d02d64e1d81954fa2e78ef9d1e33a0c30f92bc2626039006a94deb3a5b05f14ed8893f5c3ffce444ac008f7e968188ad225c4c29c813aa5f2daa1 diff --git a/gateway/buf.yaml b/gateway/buf.yaml new file mode 100644 index 0000000..641797b --- /dev/null +++ b/gateway/buf.yaml @@ -0,0 +1,15 @@ +version: v2 + +modules: + - path: proto + +deps: + - buf.build/bufbuild/protovalidate + +lint: + use: + - STANDARD + +breaking: + use: + - FILE diff --git a/gateway/cmd/gateway/main.go b/gateway/cmd/gateway/main.go new file mode 100644 index 0000000..61ca2a0 --- /dev/null +++ b/gateway/cmd/gateway/main.go @@ -0,0 +1,209 @@ +package main + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + + "galaxy/gateway/internal/adminapi" + "galaxy/gateway/internal/app" + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/downstream" + "galaxy/gateway/internal/events" + "galaxy/gateway/internal/grpcapi" + "galaxy/gateway/internal/logging" + "galaxy/gateway/internal/push" + "galaxy/gateway/internal/replay" + "galaxy/gateway/internal/restapi" + "galaxy/gateway/internal/session" + "galaxy/gateway/internal/telemetry" + + "go.uber.org/zap" +) + +// main loads the gateway configuration, runs the process lifecycle, and exits +// with a non-zero status when startup or runtime fails. +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + if err := run(ctx); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run(ctx context.Context) (err error) { + cfg, err := config.LoadFromEnv() + if err != nil { + return err + } + + logger, err := logging.New(cfg.Logging) + if err != nil { + return fmt.Errorf("build gateway logger: %w", err) + } + + telemetryRuntime, err := telemetry.New(ctx, logger) + if err != nil { + return fmt.Errorf("build gateway telemetry: %w", err) + } + + grpcDeps, components, cleanup, err := newAuthenticatedGRPCDependencies(ctx, cfg, logger, telemetryRuntime) + if err != nil { + return err + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout) + defer cancel() + + err = errors.Join( + err, + cleanup(), + telemetryRuntime.Shutdown(shutdownCtx), + logging.Sync(logger), + ) + }() + + restServer := restapi.NewServer(cfg.PublicHTTP, restapi.ServerDependencies{ + Logger: logger, + Telemetry: telemetryRuntime, + }) + grpcServer := grpcapi.NewServer(cfg.AuthenticatedGRPC, grpcDeps) + + applicationComponents := []app.Component{ + restServer, + grpcServer, + } + if adminServer := adminapi.NewServer(cfg.AdminHTTP, telemetryRuntime.Handler(), logger); adminServer.Enabled() { + applicationComponents = append(applicationComponents, adminServer) + } + applicationComponents = append(applicationComponents, components...) + + logger.Info("gateway application starting", + zap.String("public_http_addr", cfg.PublicHTTP.Addr), + zap.String("authenticated_grpc_addr", cfg.AuthenticatedGRPC.Addr), + zap.String("admin_http_addr", cfg.AdminHTTP.Addr), + ) + + application := app.New(cfg, applicationComponents...) + + err = application.Run(ctx) + return err +} + +func newAuthenticatedGRPCDependencies(ctx context.Context, cfg config.Config, logger *zap.Logger, telemetryRuntime *telemetry.Runtime) (grpcapi.ServerDependencies, []app.Component, func() error, error) { + responseSigner, err := authn.LoadEd25519ResponseSignerFromPEMFile(cfg.ResponseSigner.PrivateKeyPEMPath) + if err != nil { + return grpcapi.ServerDependencies{}, nil, nil, fmt.Errorf("build authenticated grpc dependencies: load response signer: %w", err) + } + + fallbackSessionCache, err := session.NewRedisCache(cfg.SessionCacheRedis) + if err != nil { + return grpcapi.ServerDependencies{}, nil, nil, fmt.Errorf("build authenticated grpc dependencies: %w", err) + } + + replayStore, err := replay.NewRedisStore(cfg.SessionCacheRedis, cfg.ReplayRedis) + if err != nil { + closeErr := fallbackSessionCache.Close() + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + localSessionCache := session.NewMemoryCache() + sessionCache, err := session.NewReadThroughCache(localSessionCache, fallbackSessionCache) + if err != nil { + closeErr := errors.Join( + fallbackSessionCache.Close(), + replayStore.Close(), + ) + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + pushHub := push.NewHubWithObserver(0, telemetry.NewPushObserver(telemetryRuntime)) + sessionSubscriber, err := events.NewRedisSessionSubscriberWithObservability(cfg.SessionCacheRedis, cfg.SessionEventsRedis, localSessionCache, pushHub, logger, telemetryRuntime) + if err != nil { + closeErr := errors.Join( + fallbackSessionCache.Close(), + replayStore.Close(), + ) + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + clientEventSubscriber, err := events.NewRedisClientEventSubscriberWithObservability(cfg.SessionCacheRedis, cfg.ClientEventsRedis, pushHub, logger, telemetryRuntime) + if err != nil { + closeErr := errors.Join( + fallbackSessionCache.Close(), + replayStore.Close(), + sessionSubscriber.Close(), + ) + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + cleanup := func() error { + return errors.Join( + fallbackSessionCache.Close(), + replayStore.Close(), + sessionSubscriber.Close(), + clientEventSubscriber.Close(), + ) + } + + if err := fallbackSessionCache.Ping(ctx); err != nil { + closeErr := cleanup() + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + if err := replayStore.Ping(ctx); err != nil { + closeErr := cleanup() + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + if err := sessionSubscriber.Ping(ctx); err != nil { + closeErr := cleanup() + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + if err := clientEventSubscriber.Ping(ctx); err != nil { + closeErr := cleanup() + return grpcapi.ServerDependencies{}, nil, nil, errors.Join( + fmt.Errorf("build authenticated grpc dependencies: %w", err), + closeErr, + ) + } + + return grpcapi.ServerDependencies{ + Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, nil, logger), + Router: downstream.NewStaticRouter(nil), + ResponseSigner: responseSigner, + SessionCache: sessionCache, + ReplayStore: replayStore, + Logger: logger, + Telemetry: telemetryRuntime, + PushHub: pushHub, + }, []app.Component{sessionSubscriber, clientEventSubscriber}, cleanup, nil +} diff --git a/gateway/cmd/gateway/main_test.go b/gateway/cmd/gateway/main_test.go new file mode 100644 index 0000000..60ac7f6 --- /dev/null +++ b/gateway/cmd/gateway/main_test.go @@ -0,0 +1,275 @@ +package main + +import ( + "context" + "crypto/ed25519" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "net" + "os" + "path/filepath" + "testing" + "time" + + "galaxy/gateway/internal/config" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestNewAuthenticatedGRPCDependencies(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + responseSignerPEMPath := writeTestResponseSignerPEMFile(t) + + tests := []struct { + name string + cfg config.Config + wantErr string + }{ + { + name: "success", + cfg: config.Config{ + SessionCacheRedis: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + KeyPrefix: "gateway:session:", + LookupTimeout: 250 * time.Millisecond, + }, + ReplayRedis: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + SessionEventsRedis: config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: time.Second, + }, + ClientEventsRedis: config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: time.Second, + }, + ResponseSigner: config.ResponseSignerConfig{ + PrivateKeyPEMPath: responseSignerPEMPath, + }, + }, + }, + { + name: "invalid redis config", + cfg: config.Config{ + SessionCacheRedis: config.SessionCacheRedisConfig{ + LookupTimeout: 250 * time.Millisecond, + }, + ReplayRedis: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + SessionEventsRedis: config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: time.Second, + }, + ClientEventsRedis: config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: time.Second, + }, + ResponseSigner: config.ResponseSignerConfig{ + PrivateKeyPEMPath: responseSignerPEMPath, + }, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "startup ping failure", + cfg: config.Config{ + SessionCacheRedis: config.SessionCacheRedisConfig{ + Addr: unusedTCPAddr(t), + KeyPrefix: "gateway:session:", + LookupTimeout: 100 * time.Millisecond, + }, + ReplayRedis: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 100 * time.Millisecond, + }, + SessionEventsRedis: config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: time.Second, + }, + ClientEventsRedis: config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: time.Second, + }, + ResponseSigner: config.ResponseSignerConfig{ + PrivateKeyPEMPath: responseSignerPEMPath, + }, + }, + wantErr: "ping redis session cache", + }, + { + name: "invalid replay config", + cfg: config.Config{ + SessionCacheRedis: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + KeyPrefix: "gateway:session:", + LookupTimeout: 250 * time.Millisecond, + }, + ReplayRedis: config.ReplayRedisConfig{ + ReserveTimeout: 250 * time.Millisecond, + }, + SessionEventsRedis: config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: time.Second, + }, + ClientEventsRedis: config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: time.Second, + }, + ResponseSigner: config.ResponseSignerConfig{ + PrivateKeyPEMPath: responseSignerPEMPath, + }, + }, + wantErr: "replay key prefix must not be empty", + }, + { + name: "invalid client event config", + cfg: config.Config{ + SessionCacheRedis: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + KeyPrefix: "gateway:session:", + LookupTimeout: 250 * time.Millisecond, + }, + ReplayRedis: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + SessionEventsRedis: config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: time.Second, + }, + ClientEventsRedis: config.ClientEventsRedisConfig{ + ReadBlockTimeout: time.Second, + }, + ResponseSigner: config.ResponseSignerConfig{ + PrivateKeyPEMPath: responseSignerPEMPath, + }, + }, + wantErr: "client event subscriber: stream must not be empty", + }, + { + name: "missing response signer path", + cfg: config.Config{ + SessionCacheRedis: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + KeyPrefix: "gateway:session:", + LookupTimeout: 250 * time.Millisecond, + }, + ReplayRedis: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + SessionEventsRedis: config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: time.Second, + }, + ClientEventsRedis: config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: time.Second, + }, + }, + wantErr: "load response signer", + }, + { + name: "invalid response signer pem", + cfg: config.Config{ + SessionCacheRedis: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + KeyPrefix: "gateway:session:", + LookupTimeout: 250 * time.Millisecond, + }, + ReplayRedis: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + SessionEventsRedis: config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: time.Second, + }, + ClientEventsRedis: config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: time.Second, + }, + ResponseSigner: config.ResponseSignerConfig{ + PrivateKeyPEMPath: writeInvalidPEMFile(t), + }, + }, + wantErr: "response signer private key", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + deps, components, cleanup, err := newAuthenticatedGRPCDependencies(context.Background(), tt.cfg, zap.NewNop(), nil) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + require.NotNil(t, deps.SessionCache) + require.NotNil(t, deps.ReplayStore) + require.NotNil(t, deps.ResponseSigner) + require.NotNil(t, deps.Router) + require.NotNil(t, deps.Service) + require.Len(t, components, 2) + require.NotNil(t, cleanup) + assert.NoError(t, cleanup()) + }) + } +} + +func unusedTCPAddr(t *testing.T) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + addr := listener.Addr().String() + require.NoError(t, listener.Close()) + + return addr +} + +func writeTestResponseSignerPEMFile(t *testing.T) string { + t.Helper() + + seed := sha256.Sum256([]byte("gateway-main-test-response-signer")) + privateKey := ed25519.NewKeyFromSeed(seed[:]) + + encodedPrivateKey, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + path := filepath.Join(t.TempDir(), "response-signer.pem") + err = os.WriteFile(path, pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: encodedPrivateKey, + }), 0o600) + require.NoError(t, err) + + return path +} + +func writeInvalidPEMFile(t *testing.T) string { + t.Helper() + + path := filepath.Join(t.TempDir(), "invalid-response-signer.pem") + err := os.WriteFile(path, []byte("not a valid pem"), 0o600) + require.NoError(t, err) + + return path +} diff --git a/gateway/docs/README.md b/gateway/docs/README.md new file mode 100644 index 0000000..f9f1584 --- /dev/null +++ b/gateway/docs/README.md @@ -0,0 +1,20 @@ +# Edge Gateway Docs + +This directory keeps service-local documentation that is too detailed for the +root architecture documents and too diagram-heavy for the module README. + +Sections: + +- [Runtime and components](runtime.md) +- [Public auth, command, and push flows](flows.md) +- [Operator runbook](runbook.md) +- [Configuration and contract examples](examples.md) +- [Example `.env`](../.env.example) + +Primary references: + +- [`../README.md`](../README.md) for service scope, contracts, configuration, + and operational behavior +- [`../openapi.yaml`](../openapi.yaml) for the public REST contract +- [`../../README.md`](../../README.md) for workspace-level architecture +- [`../../SECURITY.md`](../../SECURITY.md) for the transport security model diff --git a/gateway/docs/examples.md b/gateway/docs/examples.md new file mode 100644 index 0000000..8651347 --- /dev/null +++ b/gateway/docs/examples.md @@ -0,0 +1,179 @@ +# Configuration And Contract Examples + +The examples below are illustrative. Values such as signatures, payload hashes, +and FlatBuffers payload bytes are placeholders unless explicitly stated +otherwise. + +## Example `.env` + +The repository also includes a ready-to-copy sample file: + +- [`../.env.example`](../.env.example) + +The sample keeps all secrets blank and shows only the settings needed to boot +the process and expose the main listeners. + +## Public Auth HTTP Examples + +Start an e-mail challenge: + +```bash +curl -X POST http://127.0.0.1:8080/api/v1/public/auth/send-email-code \ + -H 'Content-Type: application/json' \ + -d '{"email":"pilot@example.com"}' +``` + +Example response: + +```json +{ + "challenge_id": "challenge-123" +} +``` + +Confirm the challenge and register the device public key: + +```bash +curl -X POST http://127.0.0.1:8080/api/v1/public/auth/confirm-email-code \ + -H 'Content-Type: application/json' \ + -d '{ + "challenge_id": "challenge-123", + "code": "123456", + "client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=" + }' +``` + +Example response: + +```json +{ + "device_session_id": "device-session-123" +} +``` + +## Authenticated gRPC Envelope Examples + +The authenticated transport is gRPC/protobuf, not JSON over HTTP. The examples +below use protobuf-style JSON only to make the logical envelope readable. +`bytes` fields are shown as base64 strings, matching the standard protobuf JSON +mapping. + +Example `ExecuteCommandRequest`: + +```json +{ + "protocolVersion": "v1", + "deviceSessionId": "device-session-123", + "messageType": "fleet.move", + "timestampMs": "1775121600000", + "requestId": "request-123", + "payloadBytes": "RkxBVEJVRkZFUlNfUEFZTE9BRA==", + "payloadHash": "5fY6Q8V9mK8x2B7v6v0V0m0i1rQ2QF0rQ8V1Yt1r8Ys=", + "signature": "3o4v8f3h0Y6I0x1bS7zY+8m0bV1Lk4D3yq8J2n8F1rD7yK9v8M1Q0w2s4a6f8d0Q0m3L6y8R1t5w7x9z0a2cA==", + "traceId": "trace-123" +} +``` + +Example `ExecuteCommandResponse`: + +```json +{ + "protocolVersion": "v1", + "requestId": "request-123", + "timestampMs": "1775121600123", + "resultCode": "ok", + "payloadBytes": "RkxBVEJVRkZFUlNfUkVTUE9OU0U=", + "payloadHash": "wL4n8H1aR2x3M4b5C6d7E8f9G0h1J2k3L4m5N6o7P8Q=", + "signature": "2Xb7l9m0n1p2q3r4s5t6u7v8w9x0y1z2A3B4C5D6E7F8G9H0J1K2L3M4N5O6P7Q8R9S0T1U2V3W4X5Y6Z7a8b9cQ==" +} +``` + +Example bootstrap `GatewayEvent` sent after `SubscribeEvents` opens: + +```json +{ + "eventType": "gateway.server_time", + "eventId": "request-123", + "timestampMs": "1775121600456", + "payloadBytes": "RkxBVEJVRkZFUlNfU0VSVkVSX1RJTUU=", + "payloadHash": "2b1U3m4N5p6Q7r8S9t0U1v2W3x4Y5z6A7b8C9d0E1f2=", + "signature": "4Nf8k2p6s0w4y8A2d6g0j4m8p2t6w0z4C8F2I6L0O4R8U2X6a0d4g8j2m6p0s4v8yA2d6g0j4m8p2t6w0z4C8F2I6A==", + "requestId": "request-123", + "traceId": "trace-123" +} +``` + +## Redis Examples + +### Session Cache Record + +Example Redis key and JSON value used by the fallback session cache: + +```text +gateway:session:device-session-123 +``` + +```json +{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=", + "status": "active" +} +``` + +### Session Event Stream Entry + +Example session snapshot entry: + +```bash +redis-cli XADD gateway:session-events '*' \ + device_session_id device-session-123 \ + user_id user-123 \ + client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \ + status active +``` + +Revocation entry: + +```bash +redis-cli XADD gateway:session-events '*' \ + device_session_id device-session-123 \ + user_id user-123 \ + client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \ + status revoked \ + revoked_at_ms 1775121700000 +``` + +### Client Event Stream Entry + +User-wide event: + +```bash +redis-cli XADD gateway:client-events '*' \ + user_id user-123 \ + event_type fleet.updated \ + event_id event-123 \ + payload_bytes payload-v1 +``` + +Session-targeted event with correlation: + +```bash +redis-cli XADD gateway:client-events '*' \ + user_id user-123 \ + device_session_id device-session-123 \ + event_type fleet.updated \ + event_id event-124 \ + payload_bytes payload-v2 \ + request_id request-123 \ + trace_id trace-123 +``` + +Notes: + +- `payload_bytes` in Redis Stream entries must be binary-safe payload data; +- the gateway derives `timestamp_ms`, recomputes `payload_hash`, and signs the + outgoing event at delivery time; +- each gateway replica consumes streams with plain `XREAD`, so publishers must + keep retention bounded with `MAXLEN`. diff --git a/gateway/docs/flows.md b/gateway/docs/flows.md new file mode 100644 index 0000000..de39e00 --- /dev/null +++ b/gateway/docs/flows.md @@ -0,0 +1,86 @@ +# Request and Push Flows + +## Public Auth Flow + +```mermaid +sequenceDiagram + participant Client + participant Gateway + participant Limiter as Public anti-abuse + participant Auth as AuthServiceClient + + Client->>Gateway: POST /api/v1/public/auth/send-email-code + Gateway->>Limiter: classify + rate-limit + body checks + Limiter-->>Gateway: allowed + Gateway->>Auth: SendEmailCode(email) + Auth-->>Gateway: challenge_id + Gateway-->>Client: 200 {challenge_id} + + Client->>Gateway: POST /api/v1/public/auth/confirm-email-code + Gateway->>Limiter: classify + rate-limit + body checks + Limiter-->>Gateway: allowed + Gateway->>Auth: ConfirmEmailCode(challenge_id, code, client_public_key) + Auth-->>Gateway: device_session_id + Gateway-->>Client: 200 {device_session_id} +``` + +## Authenticated ExecuteCommand Flow + +```mermaid +sequenceDiagram + participant Client + participant Gateway + participant Cache as SessionCache + participant Replay as ReplayStore + participant Policy as Rate limit / policy + participant Downstream + + Client->>Gateway: ExecuteCommand(envelope, payload_bytes, signature) + Gateway->>Gateway: validate envelope + protocol_version + Gateway->>Cache: lookup(device_session_id) + Cache-->>Gateway: session record + Gateway->>Gateway: verify payload_hash + Gateway->>Gateway: verify Ed25519 signature + Gateway->>Gateway: verify freshness window + Gateway->>Replay: reserve(device_session_id, request_id, ttl) + Replay-->>Gateway: accepted + Gateway->>Policy: apply IP/session/user/message_type budgets + Policy-->>Gateway: allowed + Gateway->>Downstream: verified authenticated command + Downstream-->>Gateway: result_code + payload_bytes + Gateway->>Gateway: hash payload + sign response + Gateway-->>Client: ExecuteCommandResponse + signature +``` + +## SubscribeEvents Lifecycle + +```mermaid +sequenceDiagram + participant Client + participant Gateway + participant Cache as SessionCache + participant Replay as ReplayStore + participant Hub as PushHub + participant Stream as Client event stream + participant Sess as Session event stream + + Client->>Gateway: SubscribeEvents(envelope, signature) + Gateway->>Gateway: validate envelope + verify request + Gateway->>Cache: lookup(device_session_id) + Cache-->>Gateway: session record + Gateway->>Replay: reserve(device_session_id, request_id, ttl) + Replay-->>Gateway: accepted + Gateway->>Client: gateway.server_time event + Gateway->>Hub: register(user_id, device_session_id) + + Stream-->>Gateway: client-facing event for user_id / device_session_id + Gateway->>Hub: publish signed event + Hub-->>Client: matching event delivery + + Sess-->>Gateway: revoked session snapshot + Gateway->>Hub: revoke(device_session_id) + Hub-->>Client: stream closes with FAILED_PRECONDITION + + Note over Gateway,Hub: During shutdown the gateway closes PushHub before gRPC graceful stop. + Hub-->>Client: stream closes with UNAVAILABLE +``` diff --git a/gateway/docs/runbook.md b/gateway/docs/runbook.md new file mode 100644 index 0000000..beae48c --- /dev/null +++ b/gateway/docs/runbook.md @@ -0,0 +1,143 @@ +# Operator Runbook + +This runbook covers the checks that matter most during startup, steady-state +readiness, shutdown, and push or revoke incidents. + +## Startup Checks + +Before starting the process, confirm: + +- `GATEWAY_SESSION_CACHE_REDIS_ADDR` points to the Redis deployment used for + session lookup and both internal event streams. +- `GATEWAY_SESSION_EVENTS_REDIS_STREAM` and + `GATEWAY_CLIENT_EVENTS_REDIS_STREAM` reference existing Redis Stream keys or + the names publishers will use. +- `GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH` points to a readable PKCS#8 + PEM-encoded Ed25519 private key. +- the configured Redis ACL, DB, TLS, and key-prefix settings match the target + environment. + +At startup the process performs bounded `PING` checks for: + +- the Redis-backed session cache adapter; +- the replay store; +- the session event subscriber; +- the client event subscriber. + +Startup fails fast if any of those checks fail or if the signer key cannot be +loaded. + +Expected listener state after a healthy start: + +- public HTTP is enabled on `GATEWAY_PUBLIC_HTTP_ADDR` or its default `:8080`; +- authenticated gRPC is enabled on + `GATEWAY_AUTHENTICATED_GRPC_ADDR` or its default `:9090`; +- admin HTTP is enabled only when `GATEWAY_ADMIN_HTTP_ADDR` is non-empty. + +Known startup caveats: + +- public auth routes stay mounted without an upstream adapter and return + `503 service_unavailable`; +- authenticated gRPC starts with an empty static router, so `ExecuteCommand` + returns gRPC `UNIMPLEMENTED` until downstream routes are injected. + +## Readiness + +Use the probes according to what they actually guarantee: + +- `GET /healthz` confirms that the public HTTP listener is alive; +- `GET /readyz` confirms that the current process is ready to serve public HTTP + traffic; +- `GET /metrics` is available only on the optional admin listener. + +`/readyz` is process-local. It does not confirm: + +- downstream business-service reachability; +- auth upstream adapter reachability; +- Redis health after startup; +- push fan-out health. + +For a practical readiness check in production: + +1. confirm the process emitted startup logs for the public and authenticated + listeners; +2. check `GET /healthz`; +3. check `GET /readyz`; +4. if admin HTTP is enabled, scrape `GET /metrics`; +5. verify the expected Redis deployment and stream names from config. + +## Shutdown + +The process handles `SIGINT` and `SIGTERM`. + +Shutdown behavior: + +- the per-component shutdown budget is controlled by + `GATEWAY_SHUTDOWN_TIMEOUT`; +- internal subscribers are stopped as part of application shutdown; +- the in-memory `PushHub` is closed before gRPC graceful stop; +- active `SubscribeEvents` streams terminate with gRPC `UNAVAILABLE` and + message `gateway is shutting down`. + +During planned restarts: + +1. send `SIGTERM`; +2. wait for listener shutdown and component-stop logs; +3. expect connected clients to reconnect after the gateway closes the stream; +4. investigate only if shutdown exceeds `GATEWAY_SHUTDOWN_TIMEOUT` or streams + remain open unexpectedly. + +## Revoke And Push Failure Triage + +### Revocation Does Not Take Effect + +If a revoked session still sends traffic or keeps an active stream: + +1. verify that the auth/session side published a session snapshot with the + same `device_session_id` and `status=revoked`; +2. verify that the event was written to + `GATEWAY_SESSION_EVENTS_REDIS_STREAM`; +3. verify the gateway is connected to the same Redis address, DB, and stream; +4. confirm the snapshot fields are complete and well-formed; +5. check that a later active snapshot did not overwrite the revoked one. + +Expected gateway behavior after the revoke snapshot is consumed: + +- new authenticated requests for that `device_session_id` fail with gRPC + `FAILED_PRECONDITION`; +- active `SubscribeEvents` streams for that exact `device_session_id` close + with the same status. + +### Push Events Are Not Delivered + +If a client reports missing push events: + +1. confirm that the client successfully opened `SubscribeEvents`; +2. confirm the stream received the initial `gateway.server_time` bootstrap + event; +3. confirm the gateway consumed the expected entry from + `GATEWAY_CLIENT_EVENTS_REDIS_STREAM`; +4. verify `user_id` and optional `device_session_id` in the stream entry match + the intended target; +5. confirm the event payload fields are well-formed and not dropped as + malformed; +6. check whether the stream was closed earlier because of revoke, shutdown, or + overflow. + +### Stream Closed Unexpectedly + +Use the terminal gRPC status first: + +- `FAILED_PRECONDITION` with `device session is revoked` means the session was + revoked; +- `RESOURCE_EXHAUSTED` with `push stream overflowed` means that stream stopped + consuming fast enough and its in-memory queue overflowed; +- `UNAVAILABLE` with `gateway is shutting down` means normal process shutdown; +- client-side cancellation or transport errors should be investigated on the + client or network side. + +For overflow incidents: + +- treat the issue as stream-local, not a global push outage; +- inspect client receive behavior and reconnect logic; +- look at push metrics and logs around the affected user/session. diff --git a/gateway/docs/runtime.md b/gateway/docs/runtime.md new file mode 100644 index 0000000..3400417 --- /dev/null +++ b/gateway/docs/runtime.md @@ -0,0 +1,59 @@ +# Runtime and Components + +The diagram below focuses on the deployed `galaxy/gateway` process and its +runtime dependencies. + +```mermaid +flowchart LR + subgraph Clients + Public["Public REST clients"] + Authd["Authenticated gRPC clients"] + end + + subgraph Gateway["Edge Gateway process"] + PublicHTTP["Public HTTP listener\n/healthz /readyz /api/v1/public/auth/*"] + AuthGRPC["Authenticated gRPC listener\nExecuteCommand / SubscribeEvents"] + AdminHTTP["Optional admin HTTP listener\n/metrics"] + SessionSnap["In-memory session snapshot cache"] + Replay["Replay reservation client"] + PushHub["PushHub"] + SessSub["Session event subscriber"] + ClientSub["Client event subscriber"] + Telemetry["Logs, traces, metrics"] + end + + Public --> PublicHTTP + Authd --> AuthGRPC + AuthGRPC --> SessionSnap + AuthGRPC --> Replay + AuthGRPC --> PushHub + SessSub --> SessionSnap + SessSub --> PushHub + ClientSub --> PushHub + PublicHTTP --> Telemetry + AuthGRPC --> Telemetry + AdminHTTP --> Telemetry + + Redis["Redis\nsession records + replay keys + streams"] + AuthSvc["Auth / Session Service"] + Downstream["Downstream business services"] + Metrics["Prometheus / OTLP collectors"] + + PublicHTTP -. public auth adapter .-> AuthSvc + SessionSnap --> Redis + Replay --> Redis + SessSub --> Redis + ClientSub --> Redis + AuthGRPC --> Downstream + Telemetry --> Metrics +``` + +Notes: + +- `cmd/gateway` refuses startup when Redis connectivity or the response signer + is misconfigured. +- The admin listener is optional and serves only Prometheus text metrics. +- Public auth routing stays available without an upstream adapter, but returns + `503 service_unavailable`. +- Authenticated gRPC starts with an empty static router; `ExecuteCommand` + remains `UNIMPLEMENTED` until downstream routes are injected. diff --git a/gateway/go.mod b/gateway/go.mod index 63d35cb..f8afb6d 100644 --- a/gateway/go.mod +++ b/gateway/go.mod @@ -1,3 +1,98 @@ module galaxy/gateway go 1.26.0 + +require ( + buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20260209202127-80ab13bee0bf.1 + buf.build/go/protovalidate v1.1.3 + github.com/alicebob/miniredis/v2 v2.37.0 + github.com/getkin/kin-openapi v0.134.0 + github.com/gin-gonic/gin v1.12.0 + github.com/google/flatbuffers v25.12.19+incompatible + github.com/prometheus/client_golang v1.23.2 + github.com/redis/go-redis/v9 v9.18.0 + github.com/stretchr/testify v1.11.1 + go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 + go.opentelemetry.io/otel v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 + go.opentelemetry.io/otel/exporters/prometheus v0.64.0 + go.opentelemetry.io/otel/metric v1.42.0 + go.opentelemetry.io/otel/sdk v1.42.0 + go.opentelemetry.io/otel/sdk/metric v1.42.0 + go.opentelemetry.io/otel/trace v1.42.0 + go.uber.org/zap v1.27.1 + golang.org/x/time v0.15.0 + google.golang.org/grpc v1.80.0 + google.golang.org/protobuf v1.36.11 +) + +require ( + cel.dev/expr v0.25.1 // indirect + github.com/antlr4-go/antlr/v4 v4.13.1 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/gabriel-vasile/mimetype v1.4.13 // indirect + github.com/gin-contrib/sse v1.1.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.30.1 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/goccy/go-yaml v1.19.2 // indirect + github.com/google/cel-go v0.27.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c // indirect + github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.67.5 // indirect + github.com/prometheus/otlptranslator v1.0.0 // indirect + github.com/prometheus/procfs v0.19.2 // indirect + github.com/quic-go/qpack v0.6.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.3.1 // indirect + github.com/woodsbury/decimal128 v1.3.0 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect + go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 // indirect + go.opentelemetry.io/proto/otlp v1.9.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect + golang.org/x/arch v0.24.0 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 // indirect + golang.org/x/net v0.51.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/gateway/go.sum b/gateway/go.sum index e69de29..bbb0ce6 100644 --- a/gateway/go.sum +++ b/gateway/go.sum @@ -0,0 +1,236 @@ +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20260209202127-80ab13bee0bf.1 h1:PMmTMyvHScV9Mn8wc6ASge9uRcHy0jtqPd+fM35LmsQ= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20260209202127-80ab13bee0bf.1/go.mod h1:tvtbpgaVXZX4g6Pn+AnzFycuRK3MOz5HJfEGeEllXYM= +buf.build/go/protovalidate v1.1.3 h1:m2GVEgQWd7rk+vIoAZ+f0ygGjvQTuqPQapBBdcpWVPE= +buf.build/go/protovalidate v1.1.3/go.mod h1:9XIuohWz+kj+9JVn3WQneHA5LZP50mjvneZMnbLkiIE= +cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= +cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68= +github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= +github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= +github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/brianvoe/gofakeit/v6 v6.28.0 h1:Xib46XXuQfmlLS2EXRuJpqcw8St6qSZz75OUo0tgAW4= +github.com/brianvoe/gofakeit/v6 v6.28.0/go.mod h1:Xj58BMSnFqcn/fAQeSK+/PLtC5kSb7FJIq4JyGa8vEs= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= +github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/getkin/kin-openapi v0.134.0 h1:/L5+1+kfe6dXh8Ot/wqiTgUkjOIEJiC0bbYVziHB8rU= +github.com/getkin/kin-openapi v0.134.0/go.mod h1:wK6ZLG/VgoETO9pcLJ/VmAtIcl/DNlMayNTb716EUxE= +github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= +github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= +github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= +github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= +github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= +github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/cel-go v0.27.0 h1:e7ih85+4qVrBuqQWTW4FKSqZYokVuc3HnhH5keboFTo= +github.com/google/cel-go v0.27.0/go.mod h1:tTJ11FWqnhw5KKpnWpvW9CJC3Y9GK4EIS0WXnBbebzw= +github.com/google/flatbuffers v25.12.19+incompatible h1:haMV2JRRJCe1998HeW/p0X9UaMTK6SDo0ffLn2+DbLs= +github.com/google/flatbuffers v25.12.19+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c h1:7ACFcSaQsrWtrH4WHHfUqE1C+f8r2uv8KGaW0jTNjus= +github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c/go.mod h1:JKox4Gszkxt57kj27u7rvi7IFoIULvCZHUsBTUmQM/s= +github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b h1:vivRhVUAa9t1q0Db4ZmezBP8pWQWnXHFokZj0AOea2g= +github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= +github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= +github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos= +github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= +github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= +github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/rodaine/protogofakeit v0.1.1 h1:ZKouljuRM3A+TArppfBqnH8tGZHOwM/pjvtXe9DaXH8= +github.com/rodaine/protogofakeit v0.1.1/go.mod h1:pXn/AstBYMaSfc1/RqH3N82pBuxtWgejz1AlYpY1mI0= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= +github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +github.com/woodsbury/decimal128 v1.3.0 h1:8pffMNWIlC0O5vbyHWFZAt5yWvWcrHA+3ovIIjVWss0= +github.com/woodsbury/decimal128 v1.3.0/go.mod h1:C5UTmyTjW3JftjUFzOVhC20BEQa2a4ZKOB5I6Zjb+ds= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= +go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0 h1:E7DmskpIO7ZR6QI6zKSEKIDNUYoKw9oHXP23gzbCdU0= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0/go.mod h1:WB2cS9y+AwqqKhoo9gw6/ZxlSjFBUQGZ8BQOaD3FVXM= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= +go.opentelemetry.io/contrib/propagators/b3 v1.42.0 h1:B2Pew5ufEtgkjLF+tSkXjgYZXQr9m7aCm1wLKB0URbU= +go.opentelemetry.io/contrib/propagators/b3 v1.42.0/go.mod h1:iPgUcSEF5DORW6+yNbdw/YevUy+QqJ508ncjhrRSCjc= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 h1:THuZiwpQZuHPul65w4WcwEnkX2QIuMT+UFoOrygtoJw= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0/go.mod h1:J2pvYM5NGHofZ2/Ru6zw/TNWnEQp5crgyDeSrYpXkAw= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 h1:zWWrB1U6nqhS/k6zYB74CjRpuiitRtLLi68VcgmOEto= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0/go.mod h1:2qXPNBX1OVRC0IwOnfo1ljoid+RD0QK3443EaqVlsOU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 h1:uLXP+3mghfMf7XmV4PkGfFhFKuNWoCvvx5wP/wOXo0o= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0/go.mod h1:v0Tj04armyT59mnURNUJf7RCKcKzq+lgJs6QSjHjaTc= +go.opentelemetry.io/otel/exporters/prometheus v0.64.0 h1:g0LRDXMX/G1SEZtK8zl8Chm4K6GBwRkjPKE36LxiTYs= +go.opentelemetry.io/otel/exporters/prometheus v0.64.0/go.mod h1:UrgcjnarfdlBDP3GjDIJWe6HTprwSazNjwsI+Ru6hro= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0 h1:s/1iRkCKDfhlh1JF26knRneorus8aOwVIDhvYx9WoDw= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0/go.mod h1:UI3wi0FXg1Pofb8ZBiBLhtMzgoTm1TYkMvn71fAqDzs= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= +go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= +go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= +golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 h1:SbTAbRFnd5kjQXbczszQ0hdk3ctwYf3qBNH9jIsGclE= +golang.org/x/exp v0.0.0-20250813145105-42675adae3e6/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0= +google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gateway/internal/adminapi/server.go b/gateway/internal/adminapi/server.go new file mode 100644 index 0000000..a4ea170 --- /dev/null +++ b/gateway/internal/adminapi/server.go @@ -0,0 +1,133 @@ +// Package adminapi exposes the optional private admin HTTP listener used for +// operational endpoints such as Prometheus metrics. +package adminapi + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + + "galaxy/gateway/internal/config" + + "go.uber.org/zap" +) + +// Server owns the optional admin HTTP listener exposed by the gateway. +type Server struct { + cfg config.AdminHTTPConfig + handler http.Handler + logger *zap.Logger + + stateMu sync.RWMutex + server *http.Server + listener net.Listener +} + +// NewServer constructs an admin HTTP server for cfg and handler. +func NewServer(cfg config.AdminHTTPConfig, handler http.Handler, logger *zap.Logger) *Server { + if handler == nil { + handler = http.NotFoundHandler() + } + if logger == nil { + logger = zap.NewNop() + } + + return &Server{ + cfg: cfg, + handler: handler, + logger: logger.Named("admin_http"), + } +} + +// Enabled reports whether the admin listener should run. +func (s *Server) Enabled() bool { + return s != nil && s.cfg.Addr != "" +} + +// Run binds the configured listener and serves the admin HTTP surface until +// Shutdown closes the server. A disabled admin server returns when ctx is +// canceled. +func (s *Server) Run(ctx context.Context) error { + if ctx == nil { + return errors.New("run admin HTTP server: nil context") + } + if err := ctx.Err(); err != nil { + return err + } + if !s.Enabled() { + <-ctx.Done() + return nil + } + + listener, err := net.Listen("tcp", s.cfg.Addr) + if err != nil { + return fmt.Errorf("run admin HTTP server: listen on %q: %w", s.cfg.Addr, err) + } + + server := &http.Server{ + Handler: s.handler, + ReadHeaderTimeout: s.cfg.ReadHeaderTimeout, + ReadTimeout: s.cfg.ReadTimeout, + IdleTimeout: s.cfg.IdleTimeout, + } + + s.stateMu.Lock() + s.server = server + s.listener = listener + s.stateMu.Unlock() + + s.logger.Info("admin HTTP server started", zap.String("addr", listener.Addr().String())) + + defer func() { + s.stateMu.Lock() + s.server = nil + s.listener = nil + s.stateMu.Unlock() + }() + + err = server.Serve(listener) + switch { + case err == nil: + return nil + case errors.Is(err, http.ErrServerClosed): + s.logger.Info("admin HTTP server stopped") + return nil + default: + return fmt.Errorf("run admin HTTP server: serve on %q: %w", s.cfg.Addr, err) + } +} + +// Shutdown gracefully stops the admin HTTP server within ctx. +func (s *Server) Shutdown(ctx context.Context) error { + if ctx == nil { + return errors.New("shutdown admin HTTP server: nil context") + } + + s.stateMu.RLock() + server := s.server + s.stateMu.RUnlock() + + if server == nil { + return nil + } + + if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("shutdown admin HTTP server: %w", err) + } + + return nil +} + +func (s *Server) listenAddr() string { + s.stateMu.RLock() + defer s.stateMu.RUnlock() + + if s.listener == nil { + return "" + } + + return s.listener.Addr().String() +} diff --git a/gateway/internal/adminapi/server_test.go b/gateway/internal/adminapi/server_test.go new file mode 100644 index 0000000..732c214 --- /dev/null +++ b/gateway/internal/adminapi/server_test.go @@ -0,0 +1,102 @@ +package adminapi + +import ( + "context" + "net" + "net/http" + "testing" + "time" + + "galaxy/gateway/internal/app" + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/restapi" + "galaxy/gateway/internal/testutil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMetricsAreReachableOnlyOnAdminListener(t *testing.T) { + t.Parallel() + + logger, _ := testutil.NewObservedLogger(t) + telemetryRuntime := testutil.NewTelemetryRuntime(t, logger) + + publicAddr := unusedTCPAddr(t) + adminAddr := unusedTCPAddr(t) + + publicCfg := config.DefaultPublicHTTPConfig() + publicCfg.Addr = publicAddr + adminCfg := config.DefaultAdminHTTPConfig() + adminCfg.Addr = adminAddr + + restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{ + Logger: logger, + Telemetry: telemetryRuntime, + }) + adminServer := NewServer(adminCfg, telemetryRuntime.Handler(), logger) + + application := app.New( + config.Config{ + ShutdownTimeout: time.Second, + PublicHTTP: publicCfg, + AdminHTTP: adminCfg, + AuthenticatedGRPC: config.DefaultAuthenticatedGRPCConfig(), + }, + restServer, + adminServer, + ) + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + defer func() { + cancel() + select { + case err := <-resultCh: + require.NoError(t, err) + case <-time.After(time.Second): + require.FailNow(t, "application did not stop") + } + }() + + waitForHTTPStatus(t, "http://"+publicAddr+"/healthz", http.StatusOK) + waitForHTTPStatus(t, "http://"+adminAddr+"/metrics", http.StatusOK) + + publicMetricsResp, err := http.Get("http://" + publicAddr + "/metrics") + require.NoError(t, err) + defer func() { + require.NoError(t, publicMetricsResp.Body.Close()) + }() + assert.Equal(t, http.StatusNotFound, publicMetricsResp.StatusCode) +} + +func waitForHTTPStatus(t *testing.T, rawURL string, wantStatus int) { + t.Helper() + + require.Eventually(t, func() bool { + resp, err := http.Get(rawURL) + if err != nil { + return false + } + defer func() { + _ = resp.Body.Close() + }() + + return resp.StatusCode == wantStatus + }, time.Second, 10*time.Millisecond) +} + +func unusedTCPAddr(t *testing.T) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + addr := listener.Addr().String() + require.NoError(t, listener.Close()) + + return addr +} diff --git a/gateway/internal/app/app.go b/gateway/internal/app/app.go new file mode 100644 index 0000000..ee70b5a --- /dev/null +++ b/gateway/internal/app/app.go @@ -0,0 +1,178 @@ +// Package app wires the gateway process lifecycle and coordinates component +// startup and graceful shutdown. +package app + +import ( + "context" + "errors" + "fmt" + "sync" + + "galaxy/gateway/internal/config" +) + +// Component is a long-lived gateway subsystem that participates in coordinated +// startup and graceful shutdown. +type Component interface { + // Run starts the component and blocks until it stops. + Run(context.Context) error + + // Shutdown stops the component within the provided timeout-bounded context. + Shutdown(context.Context) error +} + +// App owns the process-level lifecycle of the gateway and its registered +// components. +type App struct { + cfg config.Config + components []Component +} + +// New constructs an App with a defensive copy of the supplied components. +func New(cfg config.Config, components ...Component) *App { + clonedComponents := append([]Component(nil), components...) + + return &App{ + cfg: cfg, + components: clonedComponents, + } +} + +// Run starts all configured components, waits for cancellation or the first +// component failure, and then executes best-effort graceful shutdown for every +// component. +func (a *App) Run(ctx context.Context) error { + if ctx == nil { + return errors.New("run gateway app: nil context") + } + if err := a.validate(); err != nil { + return err + } + if len(a.components) == 0 { + <-ctx.Done() + return nil + } + + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + + results := make(chan componentResult, len(a.components)) + var runWG sync.WaitGroup + + for idx, component := range a.components { + runWG.Add(1) + + go func(index int, component Component) { + defer runWG.Done() + results <- componentResult{ + index: index, + err: component.Run(runCtx), + } + }(idx, component) + } + + var runErr error + + select { + case <-ctx.Done(): + case result := <-results: + runErr = classifyComponentResult(ctx, result) + } + + cancel() + + shutdownErr := a.shutdownComponents() + waitErr := a.waitForComponents(&runWG) + + return errors.Join(runErr, shutdownErr, waitErr) +} + +// componentResult captures the first observed exit from a running component. +type componentResult struct { + index int + err error +} + +// validate confirms that the App has a safe shutdown budget and no nil +// components before goroutines are started. +func (a *App) validate() error { + if a.cfg.ShutdownTimeout <= 0 { + return fmt.Errorf("run gateway app: shutdown timeout must be positive, got %s", a.cfg.ShutdownTimeout) + } + + for idx, component := range a.components { + if component == nil { + return fmt.Errorf("run gateway app: component %d is nil", idx) + } + } + + return nil +} + +// classifyComponentResult maps the first component exit into the error that +// should control the application result. +func classifyComponentResult(parentCtx context.Context, result componentResult) error { + switch { + case result.err == nil: + if parentCtx.Err() != nil { + return nil + } + return fmt.Errorf("run gateway app: component %d exited without error before shutdown", result.index) + case errors.Is(result.err, context.Canceled) && parentCtx.Err() != nil: + return nil + default: + return fmt.Errorf("run gateway app: component %d: %w", result.index, result.err) + } +} + +// shutdownComponents calls Shutdown on every registered component using a fresh +// timeout-bounded context per component and joins any shutdown failures. +func (a *App) shutdownComponents() error { + var shutdownWG sync.WaitGroup + errs := make(chan error, len(a.components)) + + for idx, component := range a.components { + shutdownWG.Add(1) + + go func(index int, component Component) { + defer shutdownWG.Done() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout) + defer cancel() + + if err := component.Shutdown(shutdownCtx); err != nil { + errs <- fmt.Errorf("shutdown gateway component %d: %w", index, err) + } + }(idx, component) + } + + shutdownWG.Wait() + close(errs) + + var joined error + for err := range errs { + joined = errors.Join(joined, err) + } + + return joined +} + +// waitForComponents waits for running components to return after shutdown and +// reports when they outlive the configured shutdown budget. +func (a *App) waitForComponents(runWG *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + runWG.Wait() + close(done) + }() + + waitCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout) + defer cancel() + + select { + case <-done: + return nil + case <-waitCtx.Done(): + return fmt.Errorf("wait for gateway components: %w", waitCtx.Err()) + } +} diff --git a/gateway/internal/app/app_test.go b/gateway/internal/app/app_test.go new file mode 100644 index 0000000..a50d02b --- /dev/null +++ b/gateway/internal/app/app_test.go @@ -0,0 +1,268 @@ +package app + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "galaxy/gateway/internal/config" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAppRunWaitsForCancellationWithoutComponents(t *testing.T) { + t.Parallel() + + application := New(config.Config{ShutdownTimeout: 50 * time.Millisecond}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + + select { + case err := <-resultCh: + require.FailNowf(t, "Run() returned early", "error=%v", err) + case <-time.After(50 * time.Millisecond): + } + + cancel() + + select { + case err := <-resultCh: + require.NoError(t, err) + case <-time.After(time.Second): + require.FailNow(t, "Run() did not return after cancellation") + } +} + +func TestAppRunCancelsComponentsAndCallsShutdownOnce(t *testing.T) { + t.Parallel() + + first := newLifecycleComponent() + second := newLifecycleComponent() + + application := New( + config.Config{ShutdownTimeout: time.Second}, + first, + second, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + + first.waitStarted(t) + second.waitStarted(t) + + cancel() + + select { + case err := <-resultCh: + require.NoError(t, err) + case <-time.After(time.Second): + require.FailNow(t, "Run() did not return after cancellation") + } + + first.waitRunExited(t) + second.waitRunExited(t) + + assert.Equal(t, 1, first.shutdownCalls()) + assert.Equal(t, 1, second.shutdownCalls()) +} + +func TestAppRunReturnsComponentErrorAndStillShutsDown(t *testing.T) { + t.Parallel() + + runErr := errors.New("boom") + failing := newFailingComponent(runErr) + blocking := newLifecycleComponent() + + application := New( + config.Config{ShutdownTimeout: time.Second}, + failing, + blocking, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + + failing.waitStarted(t) + blocking.waitStarted(t) + failing.releaseRun() + + select { + case err := <-resultCh: + require.Error(t, err) + assert.ErrorIs(t, err, runErr) + case <-time.After(time.Second): + require.FailNow(t, "Run() did not return after component failure") + } + + failing.waitRunExited(t) + blocking.waitRunExited(t) + + assert.Equal(t, 1, failing.shutdownCalls()) + assert.Equal(t, 1, blocking.shutdownCalls()) +} + +// lifecycleComponent blocks in Run until the application calls Shutdown. +type lifecycleComponent struct { + startedCh chan struct{} + runDoneCh chan struct{} + stopCh chan struct{} + shutdownMu sync.Mutex + shutdownCnt int +} + +// newLifecycleComponent builds a component that exits Run only after Shutdown +// signals its stop channel. +func newLifecycleComponent() *lifecycleComponent { + return &lifecycleComponent{ + startedCh: make(chan struct{}), + runDoneCh: make(chan struct{}), + stopCh: make(chan struct{}), + } +} + +// Run marks the component as started, waits for cancellation, and then blocks +// until Shutdown releases the stop channel. +func (c *lifecycleComponent) Run(ctx context.Context) error { + close(c.startedCh) + defer close(c.runDoneCh) + + <-ctx.Done() + <-c.stopCh + return nil +} + +// Shutdown records the call and releases the run loop. +func (c *lifecycleComponent) Shutdown(context.Context) error { + c.shutdownMu.Lock() + defer c.shutdownMu.Unlock() + + c.shutdownCnt++ + if c.shutdownCnt == 1 { + close(c.stopCh) + } + + return nil +} + +// waitStarted blocks until Run has started or fails the test on timeout. +func (c *lifecycleComponent) waitStarted(t *testing.T) { + t.Helper() + + select { + case <-c.startedCh: + case <-time.After(time.Second): + require.FailNow(t, "component did not start") + } +} + +// waitRunExited blocks until Run exits or fails the test on timeout. +func (c *lifecycleComponent) waitRunExited(t *testing.T) { + t.Helper() + + select { + case <-c.runDoneCh: + case <-time.After(time.Second): + require.FailNow(t, "component run did not exit") + } +} + +// shutdownCalls returns the number of observed Shutdown invocations. +func (c *lifecycleComponent) shutdownCalls() int { + c.shutdownMu.Lock() + defer c.shutdownMu.Unlock() + + return c.shutdownCnt +} + +// failingComponent returns a predefined error once released by the test and +// still tracks shutdown calls. +type failingComponent struct { + startedCh chan struct{} + releaseCh chan struct{} + runDoneCh chan struct{} + shutdownMu sync.Mutex + shutdownCnt int + err error +} + +// newFailingComponent builds a component whose Run returns err after release. +func newFailingComponent(err error) *failingComponent { + return &failingComponent{ + startedCh: make(chan struct{}), + releaseCh: make(chan struct{}), + runDoneCh: make(chan struct{}), + err: err, + } +} + +// Run waits until the test releases it and then returns the configured error. +func (c *failingComponent) Run(context.Context) error { + close(c.startedCh) + defer close(c.runDoneCh) + + <-c.releaseCh + return c.err +} + +// Shutdown records that the application attempted graceful shutdown. +func (c *failingComponent) Shutdown(context.Context) error { + c.shutdownMu.Lock() + defer c.shutdownMu.Unlock() + + c.shutdownCnt++ + return nil +} + +// waitStarted blocks until Run has started or fails the test on timeout. +func (c *failingComponent) waitStarted(t *testing.T) { + t.Helper() + + select { + case <-c.startedCh: + case <-time.After(time.Second): + require.FailNow(t, "failing component did not start") + } +} + +// releaseRun allows Run to return its configured error. +func (c *failingComponent) releaseRun() { + close(c.releaseCh) +} + +// waitRunExited blocks until Run exits or fails the test on timeout. +func (c *failingComponent) waitRunExited(t *testing.T) { + t.Helper() + + select { + case <-c.runDoneCh: + case <-time.After(time.Second): + require.FailNow(t, "failing component run did not exit") + } +} + +// shutdownCalls returns the number of observed Shutdown invocations. +func (c *failingComponent) shutdownCalls() int { + c.shutdownMu.Lock() + defer c.shutdownMu.Unlock() + + return c.shutdownCnt +} diff --git a/gateway/internal/authn/event.go b/gateway/internal/authn/event.go new file mode 100644 index 0000000..2e0edb7 --- /dev/null +++ b/gateway/internal/authn/event.go @@ -0,0 +1,80 @@ +package authn + +import ( + "crypto/ed25519" + "encoding/binary" + "errors" +) + +const ( + // EventDomainMarkerV1 binds the v1 server event signature to the Galaxy + // gateway transport contract. + EventDomainMarkerV1 = "galaxy-event-v1" +) + +var ( + // ErrInvalidEventSignature reports that a gateway stream event signature is + // not a raw Ed25519 signature for the canonical event signing input. + ErrInvalidEventSignature = errors.New("invalid event signature") +) + +// EventSigningFields contains the canonical v1 stream-event fields that are +// bound into the server signing input. +type EventSigningFields struct { + // EventType identifies the stable client-facing event category. + EventType string + + // EventID is the stable event correlation identifier. + EventID string + + // TimestampMS carries the server event timestamp in milliseconds. + TimestampMS int64 + + // RequestID optionally correlates the event to the opening client request. + RequestID string + + // TraceID optionally carries the client-supplied tracing correlation value. + TraceID string + + // PayloadHash is the raw SHA-256 digest of event payload bytes. + PayloadHash []byte +} + +// BuildEventSigningInput returns the canonical byte sequence the v1 gateway +// stream-event signature covers. String and byte fields are length-prefixed +// with uvarint(len(field)) followed by raw bytes, while TimestampMS is +// appended as an 8-byte big-endian uint64. +func BuildEventSigningInput(fields EventSigningFields) []byte { + size := len(EventDomainMarkerV1) + + len(fields.EventType) + + len(fields.EventID) + + len(fields.RequestID) + + len(fields.TraceID) + + len(fields.PayloadHash) + + (6 * binary.MaxVarintLen64) + + 8 + + buf := make([]byte, 0, size) + buf = appendLengthPrefixedString(buf, EventDomainMarkerV1) + buf = appendLengthPrefixedString(buf, fields.EventType) + buf = appendLengthPrefixedString(buf, fields.EventID) + buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS)) + buf = appendLengthPrefixedString(buf, fields.RequestID) + buf = appendLengthPrefixedString(buf, fields.TraceID) + buf = appendLengthPrefixedBytes(buf, fields.PayloadHash) + + return buf +} + +// VerifyEventSignature verifies that signature authenticates fields under +// publicKey using the canonical v1 event signing input. +func VerifyEventSignature(publicKey ed25519.PublicKey, signature []byte, fields EventSigningFields) error { + if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize { + return ErrInvalidEventSignature + } + if !ed25519.Verify(publicKey, BuildEventSigningInput(fields), signature) { + return ErrInvalidEventSignature + } + + return nil +} diff --git a/gateway/internal/authn/event_test.go b/gateway/internal/authn/event_test.go new file mode 100644 index 0000000..9b610c1 --- /dev/null +++ b/gateway/internal/authn/event_test.go @@ -0,0 +1,111 @@ +package authn + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildEventSigningInputChangesWhenSignedFieldChanges(t *testing.T) { + t.Parallel() + + base := EventSigningFields{ + EventType: "gateway.server_time", + EventID: "request-123", + TimestampMS: 123456789, + RequestID: "request-123", + TraceID: "trace-123", + PayloadHash: mustSHA256([]byte("payload")), + } + + baseInput := BuildEventSigningInput(base) + + tests := []struct { + name string + mutate func(EventSigningFields) EventSigningFields + }{ + { + name: "event type", + mutate: func(fields EventSigningFields) EventSigningFields { + fields.EventType = "gateway.other" + return fields + }, + }, + { + name: "event id", + mutate: func(fields EventSigningFields) EventSigningFields { + fields.EventID = "request-456" + return fields + }, + }, + { + name: "timestamp", + mutate: func(fields EventSigningFields) EventSigningFields { + fields.TimestampMS++ + return fields + }, + }, + { + name: "request id", + mutate: func(fields EventSigningFields) EventSigningFields { + fields.RequestID = "request-456" + return fields + }, + }, + { + name: "trace id", + mutate: func(fields EventSigningFields) EventSigningFields { + fields.TraceID = "trace-456" + return fields + }, + }, + { + name: "payload hash", + mutate: func(fields EventSigningFields) EventSigningFields { + fields.PayloadHash = mustSHA256([]byte("other")) + return fields + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mutated := BuildEventSigningInput(tt.mutate(base)) + assert.False(t, bytes.Equal(baseInput, mutated)) + }) + } +} + +func TestSignAndVerifyEventSignature(t *testing.T) { + t.Parallel() + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + signer, err := NewEd25519ResponseSigner(privateKey) + require.NoError(t, err) + + fields := EventSigningFields{ + EventType: "gateway.server_time", + EventID: "request-123", + TimestampMS: 123456789, + RequestID: "request-123", + TraceID: "trace-123", + PayloadHash: mustSHA256([]byte("payload")), + } + + signature, err := signer.SignEvent(fields) + require.NoError(t, err) + require.NoError(t, VerifyEventSignature(signer.PublicKey(), signature, fields)) + + fields.TraceID = "changed" + require.ErrorIs(t, VerifyEventSignature(signer.PublicKey(), signature, fields), ErrInvalidEventSignature) +} diff --git a/gateway/internal/authn/request.go b/gateway/internal/authn/request.go new file mode 100644 index 0000000..d7c6e23 --- /dev/null +++ b/gateway/internal/authn/request.go @@ -0,0 +1,101 @@ +// Package authn defines authenticated transport helpers shared by the gateway +// edge verification pipeline. +package authn + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "errors" +) + +const ( + // RequestDomainMarkerV1 binds the v1 client request signature to the Galaxy + // gateway transport contract. + RequestDomainMarkerV1 = "galaxy-request-v1" +) + +var ( + // ErrInvalidPayloadHash reports that payloadHash is not a raw SHA-256 digest. + ErrInvalidPayloadHash = errors.New("payload_hash must be a 32-byte SHA-256 digest") + + // ErrPayloadHashMismatch reports that payloadHash does not match payloadBytes. + ErrPayloadHashMismatch = errors.New("payload_hash does not match payload_bytes") +) + +// RequestSigningFields contains the canonical v1 request fields that are bound +// into the client signing input after the gateway validates and normalizes the +// request envelope. +type RequestSigningFields struct { + // ProtocolVersion identifies the transport envelope version. + ProtocolVersion string + + // DeviceSessionID identifies the authenticated device session bound to the + // request. + DeviceSessionID string + + // MessageType is the stable downstream routing key. + MessageType string + + // TimestampMS carries the client request timestamp in milliseconds. + TimestampMS int64 + + // RequestID is the transport correlation and anti-replay identifier. + RequestID string + + // PayloadHash is the raw SHA-256 digest of payload bytes. + PayloadHash []byte +} + +// BuildRequestSigningInput returns the canonical byte sequence the v1 client +// request signature covers. String and byte fields are length-prefixed with +// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as +// an 8-byte big-endian uint64. The caller is expected to pass fields that have +// already passed earlier envelope validation. +func BuildRequestSigningInput(fields RequestSigningFields) []byte { + size := len(RequestDomainMarkerV1) + + len(fields.ProtocolVersion) + + len(fields.DeviceSessionID) + + len(fields.MessageType) + + len(fields.RequestID) + + len(fields.PayloadHash) + + (6 * binary.MaxVarintLen64) + + 8 + + buf := make([]byte, 0, size) + buf = appendLengthPrefixedString(buf, RequestDomainMarkerV1) + buf = appendLengthPrefixedString(buf, fields.ProtocolVersion) + buf = appendLengthPrefixedString(buf, fields.DeviceSessionID) + buf = appendLengthPrefixedString(buf, fields.MessageType) + buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS)) + buf = appendLengthPrefixedString(buf, fields.RequestID) + buf = appendLengthPrefixedBytes(buf, fields.PayloadHash) + + return buf +} + +// VerifyPayloadHash checks that payloadHash is the raw SHA-256 digest of +// payloadBytes. Empty payloadBytes are valid and must use sha256.Sum256(nil). +func VerifyPayloadHash(payloadBytes, payloadHash []byte) error { + if len(payloadHash) != sha256.Size { + return ErrInvalidPayloadHash + } + + sum := sha256.Sum256(payloadBytes) + if !bytes.Equal(sum[:], payloadHash) { + return ErrPayloadHashMismatch + } + + return nil +} + +func appendLengthPrefixedString(dst []byte, value string) []byte { + return appendLengthPrefixedBytes(dst, []byte(value)) +} + +func appendLengthPrefixedBytes(dst []byte, value []byte) []byte { + dst = binary.AppendUvarint(dst, uint64(len(value))) + dst = append(dst, value...) + + return dst +} diff --git a/gateway/internal/authn/request_test.go b/gateway/internal/authn/request_test.go new file mode 100644 index 0000000..bd57bf5 --- /dev/null +++ b/gateway/internal/authn/request_test.go @@ -0,0 +1,163 @@ +package authn + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestVerifyPayloadHash(t *testing.T) { + t.Parallel() + + payloadSum := sha256.Sum256([]byte("payload")) + emptySum := sha256.Sum256(nil) + otherSum := sha256.Sum256([]byte("other")) + + tests := []struct { + name string + payload []byte + payloadHash []byte + wantErr error + }{ + { + name: "matches non-empty payload", + payload: []byte("payload"), + payloadHash: payloadSum[:], + }, + { + name: "matches empty payload", + payload: nil, + payloadHash: emptySum[:], + }, + { + name: "rejects digest with invalid length", + payload: []byte("payload"), + payloadHash: []byte("short"), + wantErr: ErrInvalidPayloadHash, + }, + { + name: "rejects digest mismatch", + payload: []byte("payload"), + payloadHash: otherSum[:], + wantErr: ErrPayloadHashMismatch, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := VerifyPayloadHash(tt.payload, tt.payloadHash) + if tt.wantErr == nil { + require.NoError(t, err) + return + } + + require.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestBuildRequestSigningInput(t *testing.T) { + t.Parallel() + + fields := RequestSigningFields{ + ProtocolVersion: "v1", + DeviceSessionID: "device-session-123", + MessageType: "fleet.move", + TimestampMS: 123456789, + RequestID: "request-123", + PayloadHash: mustSHA256([]byte("payload")), + } + + got := BuildRequestSigningInput(fields) + + want, err := hex.DecodeString("1167616c6178792d726571756573742d7631027631126465766963652d73657373696f6e2d3132330a666c6565742e6d6f766500000000075bcd150b726571756573742d31323320239f59ed55e737c77147cf55ad0c1b030b6d7ee748a7426952f9b852d5a935e5") + require.NoError(t, err) + assert.Equal(t, want, got) +} + +func TestBuildRequestSigningInputChangesWhenSignedFieldChanges(t *testing.T) { + t.Parallel() + + base := RequestSigningFields{ + ProtocolVersion: "v1", + DeviceSessionID: "device-session-123", + MessageType: "fleet.move", + TimestampMS: 123456789, + RequestID: "request-123", + PayloadHash: mustSHA256([]byte("payload")), + } + + baseInput := BuildRequestSigningInput(base) + + tests := []struct { + name string + mutate func(RequestSigningFields) RequestSigningFields + }{ + { + name: "protocol version", + mutate: func(fields RequestSigningFields) RequestSigningFields { + fields.ProtocolVersion = "v2" + return fields + }, + }, + { + name: "device session id", + mutate: func(fields RequestSigningFields) RequestSigningFields { + fields.DeviceSessionID = "device-session-456" + return fields + }, + }, + { + name: "message type", + mutate: func(fields RequestSigningFields) RequestSigningFields { + fields.MessageType = "fleet.attack" + return fields + }, + }, + { + name: "timestamp", + mutate: func(fields RequestSigningFields) RequestSigningFields { + fields.TimestampMS++ + return fields + }, + }, + { + name: "request id", + mutate: func(fields RequestSigningFields) RequestSigningFields { + fields.RequestID = "request-456" + return fields + }, + }, + { + name: "payload hash", + mutate: func(fields RequestSigningFields) RequestSigningFields { + fields.PayloadHash = mustSHA256([]byte("other")) + return fields + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mutated := BuildRequestSigningInput(tt.mutate(base)) + assert.False(t, bytes.Equal(baseInput, mutated)) + }) + } +} + +func mustSHA256(payload []byte) []byte { + sum := sha256.Sum256(payload) + return sum[:] +} diff --git a/gateway/internal/authn/response.go b/gateway/internal/authn/response.go new file mode 100644 index 0000000..44f7cd6 --- /dev/null +++ b/gateway/internal/authn/response.go @@ -0,0 +1,189 @@ +package authn + +import ( + "bytes" + "crypto/ed25519" + "crypto/x509" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "os" +) + +const ( + // ResponseDomainMarkerV1 binds the v1 server response signature to the + // Galaxy gateway transport contract. + ResponseDomainMarkerV1 = "galaxy-response-v1" +) + +var ( + // ErrInvalidResponsePrivateKeyPEM reports that the configured response + // signer private key is not a strict PKCS#8 PEM-encoded private key. + ErrInvalidResponsePrivateKeyPEM = errors.New("response signer private key is not a valid PKCS#8 PEM block") + + // ErrInvalidResponsePrivateKey reports that the configured response signer + // private key is not an Ed25519 private key. + ErrInvalidResponsePrivateKey = errors.New("response signer private key must be an Ed25519 PKCS#8 private key") + + // ErrInvalidResponseSignature reports that a server response signature is + // not a raw Ed25519 signature for the canonical response signing input. + ErrInvalidResponseSignature = errors.New("invalid response signature") +) + +// ResponseSigningFields contains the canonical v1 response fields that are +// bound into the server signing input. +type ResponseSigningFields struct { + // ProtocolVersion identifies the transport envelope version. + ProtocolVersion string + + // RequestID is the transport correlation identifier copied from the + // authenticated request. + RequestID string + + // TimestampMS carries the server response timestamp in milliseconds. + TimestampMS int64 + + // ResultCode is the opaque downstream result code returned to the client. + ResultCode string + + // PayloadHash is the raw SHA-256 digest of response payload bytes. + PayloadHash []byte +} + +// ResponseSigner signs authenticated unary responses and client-facing stream +// events with one server-side key. +type ResponseSigner interface { + // SignResponse returns the raw Ed25519 signature for the canonical response + // signing input built from fields. + SignResponse(fields ResponseSigningFields) ([]byte, error) + + // SignEvent returns the raw Ed25519 signature for the canonical event + // signing input built from fields. + SignEvent(fields EventSigningFields) ([]byte, error) +} + +// Ed25519ResponseSigner signs authenticated responses with one Ed25519 private +// key loaded during process startup. +type Ed25519ResponseSigner struct { + privateKey ed25519.PrivateKey +} + +// NewEd25519ResponseSigner validates privateKey and constructs a signer using +// a defensive key copy. +func NewEd25519ResponseSigner(privateKey ed25519.PrivateKey) (*Ed25519ResponseSigner, error) { + if len(privateKey) != ed25519.PrivateKeySize { + return nil, ErrInvalidResponsePrivateKey + } + + return &Ed25519ResponseSigner{ + privateKey: bytes.Clone(privateKey), + }, nil +} + +// LoadEd25519ResponseSignerFromPEMFile loads a strict PKCS#8 PEM-encoded +// Ed25519 private key from path and constructs a signer. +func LoadEd25519ResponseSignerFromPEMFile(path string) (*Ed25519ResponseSigner, error) { + pemBytes, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read response signer private key PEM: %w", err) + } + + signer, err := ParseEd25519ResponseSignerPEM(pemBytes) + if err != nil { + return nil, err + } + + return signer, nil +} + +// ParseEd25519ResponseSignerPEM parses one strict PKCS#8 PEM-encoded Ed25519 +// private key and constructs a signer from it. +func ParseEd25519ResponseSignerPEM(pemBytes []byte) (*Ed25519ResponseSigner, error) { + block, rest := pem.Decode(pemBytes) + if block == nil || block.Type != "PRIVATE KEY" || len(bytes.TrimSpace(rest)) > 0 { + return nil, ErrInvalidResponsePrivateKeyPEM + } + + parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, ErrInvalidResponsePrivateKeyPEM + } + + privateKey, ok := parsedKey.(ed25519.PrivateKey) + if !ok { + return nil, ErrInvalidResponsePrivateKey + } + + return NewEd25519ResponseSigner(privateKey) +} + +// PublicKey returns the Ed25519 public key that corresponds to the configured +// response signer private key. +func (s *Ed25519ResponseSigner) PublicKey() ed25519.PublicKey { + if s == nil { + return nil + } + + publicKey, _ := s.privateKey.Public().(ed25519.PublicKey) + return bytes.Clone(publicKey) +} + +// SignResponse signs the canonical v1 response signing input built from +// fields. +func (s *Ed25519ResponseSigner) SignResponse(fields ResponseSigningFields) ([]byte, error) { + if s == nil || len(s.privateKey) != ed25519.PrivateKeySize { + return nil, ErrInvalidResponsePrivateKey + } + + signature := ed25519.Sign(s.privateKey, BuildResponseSigningInput(fields)) + return bytes.Clone(signature), nil +} + +// SignEvent signs the canonical v1 stream-event signing input built from +// fields. +func (s *Ed25519ResponseSigner) SignEvent(fields EventSigningFields) ([]byte, error) { + if s == nil || len(s.privateKey) != ed25519.PrivateKeySize { + return nil, ErrInvalidResponsePrivateKey + } + + signature := ed25519.Sign(s.privateKey, BuildEventSigningInput(fields)) + return bytes.Clone(signature), nil +} + +// BuildResponseSigningInput returns the canonical byte sequence the v1 server +// response signature covers. String and byte fields are length-prefixed with +// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as +// an 8-byte big-endian uint64. +func BuildResponseSigningInput(fields ResponseSigningFields) []byte { + size := len(ResponseDomainMarkerV1) + + len(fields.ProtocolVersion) + + len(fields.RequestID) + + len(fields.ResultCode) + + len(fields.PayloadHash) + + (5 * binary.MaxVarintLen64) + + 8 + + buf := make([]byte, 0, size) + buf = appendLengthPrefixedString(buf, ResponseDomainMarkerV1) + buf = appendLengthPrefixedString(buf, fields.ProtocolVersion) + buf = appendLengthPrefixedString(buf, fields.RequestID) + buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS)) + buf = appendLengthPrefixedString(buf, fields.ResultCode) + buf = appendLengthPrefixedBytes(buf, fields.PayloadHash) + + return buf +} + +// VerifyResponseSignature verifies that signature authenticates fields under +// publicKey using the canonical v1 response signing input. +func VerifyResponseSignature(publicKey ed25519.PublicKey, signature []byte, fields ResponseSigningFields) error { + if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize { + return ErrInvalidResponseSignature + } + if !ed25519.Verify(publicKey, BuildResponseSigningInput(fields), signature) { + return ErrInvalidResponseSignature + } + + return nil +} diff --git a/gateway/internal/authn/response_test.go b/gateway/internal/authn/response_test.go new file mode 100644 index 0000000..3efbc83 --- /dev/null +++ b/gateway/internal/authn/response_test.go @@ -0,0 +1,146 @@ +package authn + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildResponseSigningInputChangesWhenSignedFieldChanges(t *testing.T) { + t.Parallel() + + base := ResponseSigningFields{ + ProtocolVersion: "v1", + RequestID: "request-123", + TimestampMS: 123456789, + ResultCode: "ok", + PayloadHash: mustSHA256([]byte("payload")), + } + + baseInput := BuildResponseSigningInput(base) + + tests := []struct { + name string + mutate func(ResponseSigningFields) ResponseSigningFields + }{ + { + name: "protocol version", + mutate: func(fields ResponseSigningFields) ResponseSigningFields { + fields.ProtocolVersion = "v2" + return fields + }, + }, + { + name: "request id", + mutate: func(fields ResponseSigningFields) ResponseSigningFields { + fields.RequestID = "request-456" + return fields + }, + }, + { + name: "timestamp", + mutate: func(fields ResponseSigningFields) ResponseSigningFields { + fields.TimestampMS++ + return fields + }, + }, + { + name: "result code", + mutate: func(fields ResponseSigningFields) ResponseSigningFields { + fields.ResultCode = "denied" + return fields + }, + }, + { + name: "payload hash", + mutate: func(fields ResponseSigningFields) ResponseSigningFields { + fields.PayloadHash = mustSHA256([]byte("other")) + return fields + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mutated := BuildResponseSigningInput(tt.mutate(base)) + assert.False(t, bytes.Equal(baseInput, mutated)) + }) + } +} + +func TestParseEd25519ResponseSignerPEMRejectsMalformedPEM(t *testing.T) { + t.Parallel() + + _, err := ParseEd25519ResponseSignerPEM([]byte("not-pem")) + require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM) +} + +func TestParseEd25519ResponseSignerPEMRejectsNonPKCS8PEM(t *testing.T) { + t.Parallel() + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + block := pem.Block{ + Type: "ED25519 PRIVATE KEY", + Bytes: pemBytes, + } + + _, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&block)) + require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM) +} + +func TestParseEd25519ResponseSignerPEMRejectsNonEd25519Key(t *testing.T) { + t.Parallel() + + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + _, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: pemBytes, + })) + require.ErrorIs(t, err, ErrInvalidResponsePrivateKey) +} + +func TestSignAndVerifyResponseSignature(t *testing.T) { + t.Parallel() + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + signer, err := NewEd25519ResponseSigner(privateKey) + require.NoError(t, err) + + fields := ResponseSigningFields{ + ProtocolVersion: "v1", + RequestID: "request-123", + TimestampMS: 123456789, + ResultCode: "ok", + PayloadHash: mustSHA256([]byte("payload")), + } + + signature, err := signer.SignResponse(fields) + require.NoError(t, err) + require.NoError(t, VerifyResponseSignature(signer.PublicKey(), signature, fields)) + + fields.ResultCode = "changed" + require.ErrorIs(t, VerifyResponseSignature(signer.PublicKey(), signature, fields), ErrInvalidResponseSignature) +} diff --git a/gateway/internal/authn/signature.go b/gateway/internal/authn/signature.go new file mode 100644 index 0000000..baa0557 --- /dev/null +++ b/gateway/internal/authn/signature.go @@ -0,0 +1,47 @@ +package authn + +import ( + "crypto/ed25519" + "encoding/base64" + "errors" +) + +var ( + // ErrInvalidClientPublicKey reports that cached client public key material + // is not a base64-encoded raw Ed25519 public key. + ErrInvalidClientPublicKey = errors.New("client_public_key is not a valid base64-encoded Ed25519 public key") + + // ErrInvalidRequestSignature reports that a request signature is not a raw + // Ed25519 signature for the canonical request signing input. + ErrInvalidRequestSignature = errors.New("invalid request signature") +) + +// VerifyRequestSignature validates the base64-encoded raw Ed25519 public key +// from session cache, builds the canonical v1 signing input from fields, and +// verifies that signature authenticates the request. +func VerifyRequestSignature(clientPublicKey string, signature []byte, fields RequestSigningFields) error { + publicKey, err := decodeClientPublicKey(clientPublicKey) + if err != nil { + return err + } + if len(signature) != ed25519.SignatureSize { + return ErrInvalidRequestSignature + } + if !ed25519.Verify(publicKey, BuildRequestSigningInput(fields), signature) { + return ErrInvalidRequestSignature + } + + return nil +} + +func decodeClientPublicKey(value string) (ed25519.PublicKey, error) { + decoded, err := base64.StdEncoding.Strict().DecodeString(value) + if err != nil { + return nil, ErrInvalidClientPublicKey + } + if len(decoded) != ed25519.PublicKeySize { + return nil, ErrInvalidClientPublicKey + } + + return ed25519.PublicKey(decoded), nil +} diff --git a/gateway/internal/authn/signature_test.go b/gateway/internal/authn/signature_test.go new file mode 100644 index 0000000..062b66e --- /dev/null +++ b/gateway/internal/authn/signature_test.go @@ -0,0 +1,137 @@ +package authn + +import ( + "crypto/ed25519" + "crypto/sha256" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVerifyRequestSignature(t *testing.T) { + t.Parallel() + + clientPrivateKey := newTestPrivateKey("primary") + clientPublicKey := clientPrivateKey.Public().(ed25519.PublicKey) + otherPrivateKey := newTestPrivateKey("other") + + fields := RequestSigningFields{ + ProtocolVersion: "v1", + DeviceSessionID: "device-session-123", + MessageType: "fleet.move", + TimestampMS: 123456789, + RequestID: "request-123", + PayloadHash: mustSHA256([]byte("payload")), + } + + signature := ed25519.Sign(clientPrivateKey, BuildRequestSigningInput(fields)) + + tests := []struct { + name string + clientPublicKey string + signature []byte + fields RequestSigningFields + wantErr error + }{ + { + name: "valid signature", + clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey), + signature: signature, + fields: fields, + }, + { + name: "message type change rejects signature", + clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey), + signature: signature, + fields: func() RequestSigningFields { + mutated := fields + mutated.MessageType = "fleet.attack" + return mutated + }(), + wantErr: ErrInvalidRequestSignature, + }, + { + name: "request id change rejects signature", + clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey), + signature: signature, + fields: func() RequestSigningFields { + mutated := fields + mutated.RequestID = "request-456" + return mutated + }(), + wantErr: ErrInvalidRequestSignature, + }, + { + name: "payload hash change rejects signature", + clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey), + signature: signature, + fields: func() RequestSigningFields { + mutated := fields + mutated.PayloadHash = mustSHA256([]byte("other")) + return mutated + }(), + wantErr: ErrInvalidRequestSignature, + }, + { + name: "wrong key rejects signature", + clientPublicKey: base64.StdEncoding.EncodeToString(otherPrivateKey.Public().(ed25519.PublicKey)), + signature: signature, + fields: fields, + wantErr: ErrInvalidRequestSignature, + }, + { + name: "bit flipped signature rejects", + clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey), + signature: func() []byte { + corrupted := append([]byte(nil), signature...) + corrupted[0] ^= 0xff + return corrupted + }(), + fields: fields, + wantErr: ErrInvalidRequestSignature, + }, + { + name: "invalid signature length rejects", + clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey), + signature: signature[:len(signature)-1], + fields: fields, + wantErr: ErrInvalidRequestSignature, + }, + { + name: "invalid base64 public key rejects", + clientPublicKey: "%%%not-base64%%%", + signature: signature, + fields: fields, + wantErr: ErrInvalidClientPublicKey, + }, + { + name: "invalid public key length rejects", + clientPublicKey: base64.StdEncoding.EncodeToString([]byte("short")), + signature: signature, + fields: fields, + wantErr: ErrInvalidClientPublicKey, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := VerifyRequestSignature(tt.clientPublicKey, tt.signature, tt.fields) + if tt.wantErr == nil { + require.NoError(t, err) + return + } + + require.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func newTestPrivateKey(label string) ed25519.PrivateKey { + seed := sha256.Sum256([]byte("gateway-authn-signature-test-" + label)) + return ed25519.NewKeyFromSeed(seed[:]) +} diff --git a/gateway/internal/clock/clock.go b/gateway/internal/clock/clock.go new file mode 100644 index 0000000..d929186 --- /dev/null +++ b/gateway/internal/clock/clock.go @@ -0,0 +1,20 @@ +// Package clock provides the gateway time source abstraction used by +// authenticated transport checks. +package clock + +import "time" + +// Clock returns current server time for freshness checks and time-dependent +// transport behavior. +type Clock interface { + // Now returns the current server time. + Now() time.Time +} + +// System returns the current process time using the local system clock. +type System struct{} + +// Now returns the current UTC time from the system clock. +func (System) Now() time.Time { + return time.Now().UTC() +} diff --git a/gateway/internal/config/config.go b/gateway/internal/config/config.go new file mode 100644 index 0000000..d35081d --- /dev/null +++ b/gateway/internal/config/config.go @@ -0,0 +1,1333 @@ +// Package config loads process-level gateway configuration from environment +// variables. +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +const ( + // shutdownTimeoutEnvVar names the environment variable that controls the + // maximum time granted to each component shutdown call. + shutdownTimeoutEnvVar = "GATEWAY_SHUTDOWN_TIMEOUT" + + // logLevelEnvVar names the environment variable that configures the process + // log level used by structured JSON logging. + logLevelEnvVar = "GATEWAY_LOG_LEVEL" + + // publicHTTPAddrEnvVar names the environment variable that configures the + // public REST listener address. + publicHTTPAddrEnvVar = "GATEWAY_PUBLIC_HTTP_ADDR" + + // publicHTTPReadHeaderTimeoutEnvVar names the environment variable that + // configures the maximum time allowed to read public REST request headers. + publicHTTPReadHeaderTimeoutEnvVar = "GATEWAY_PUBLIC_HTTP_READ_HEADER_TIMEOUT" + + // publicHTTPReadTimeoutEnvVar names the environment variable that configures + // the maximum time allowed to read the full public REST request. + publicHTTPReadTimeoutEnvVar = "GATEWAY_PUBLIC_HTTP_READ_TIMEOUT" + + // publicHTTPIdleTimeoutEnvVar names the environment variable that configures + // the keep-alive idle timeout for the public REST listener. + publicHTTPIdleTimeoutEnvVar = "GATEWAY_PUBLIC_HTTP_IDLE_TIMEOUT" + + // publicAuthUpstreamTimeoutEnvVar names the environment variable that + // configures the timeout budget used for public auth upstream calls. + publicAuthUpstreamTimeoutEnvVar = "GATEWAY_PUBLIC_AUTH_UPSTREAM_TIMEOUT" + + // adminHTTPAddrEnvVar names the environment variable that configures the + // private admin HTTP listener address. When it is empty, the admin listener + // remains disabled. + adminHTTPAddrEnvVar = "GATEWAY_ADMIN_HTTP_ADDR" + + // adminHTTPReadHeaderTimeoutEnvVar names the environment variable that + // configures the maximum time allowed to read admin listener request + // headers. + adminHTTPReadHeaderTimeoutEnvVar = "GATEWAY_ADMIN_HTTP_READ_HEADER_TIMEOUT" + + // adminHTTPReadTimeoutEnvVar names the environment variable that configures + // the maximum time allowed to read one admin listener request. + adminHTTPReadTimeoutEnvVar = "GATEWAY_ADMIN_HTTP_READ_TIMEOUT" + + // adminHTTPIdleTimeoutEnvVar names the environment variable that configures + // the keep-alive idle timeout for the admin listener. + adminHTTPIdleTimeoutEnvVar = "GATEWAY_ADMIN_HTTP_IDLE_TIMEOUT" + + // authenticatedGRPCAddrEnvVar names the environment variable that configures + // the authenticated gRPC listener address. + authenticatedGRPCAddrEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ADDR" + + // authenticatedGRPCConnectionTimeoutEnvVar names the environment variable + // that configures the inbound connection handshake timeout for the + // authenticated gRPC listener. + authenticatedGRPCConnectionTimeoutEnvVar = "GATEWAY_AUTHENTICATED_GRPC_CONNECTION_TIMEOUT" + + // authenticatedGRPCDownstreamTimeoutEnvVar names the environment variable + // that configures the timeout budget used for downstream unary execution. + authenticatedGRPCDownstreamTimeoutEnvVar = "GATEWAY_AUTHENTICATED_DOWNSTREAM_TIMEOUT" + + // authenticatedGRPCFreshnessWindowEnvVar names the environment variable that + // configures the accepted client timestamp skew window for authenticated + // gRPC requests. + authenticatedGRPCFreshnessWindowEnvVar = "GATEWAY_AUTHENTICATED_GRPC_FRESHNESS_WINDOW" + + // authenticatedGRPCIPRateLimitRequestsEnvVar names the environment + // variable that configures the authenticated gRPC per-IP request budget per + // window. + authenticatedGRPCIPRateLimitRequestsEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_IP_RATE_LIMIT_REQUESTS" + + // authenticatedGRPCIPRateLimitWindowEnvVar names the environment variable + // that configures the authenticated gRPC per-IP rate-limit window. + authenticatedGRPCIPRateLimitWindowEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_IP_RATE_LIMIT_WINDOW" + + // authenticatedGRPCIPRateLimitBurstEnvVar names the environment variable + // that configures the authenticated gRPC per-IP rate-limit burst. + authenticatedGRPCIPRateLimitBurstEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_IP_RATE_LIMIT_BURST" + + // authenticatedGRPCSessionRateLimitRequestsEnvVar names the environment + // variable that configures the authenticated gRPC per-session request + // budget per window. + authenticatedGRPCSessionRateLimitRequestsEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_SESSION_RATE_LIMIT_REQUESTS" + + // authenticatedGRPCSessionRateLimitWindowEnvVar names the environment + // variable that configures the authenticated gRPC per-session rate-limit + // window. + authenticatedGRPCSessionRateLimitWindowEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_SESSION_RATE_LIMIT_WINDOW" + + // authenticatedGRPCSessionRateLimitBurstEnvVar names the environment + // variable that configures the authenticated gRPC per-session rate-limit + // burst. + authenticatedGRPCSessionRateLimitBurstEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_SESSION_RATE_LIMIT_BURST" + + // authenticatedGRPCUserRateLimitRequestsEnvVar names the environment + // variable that configures the authenticated gRPC per-user request budget + // per window. + authenticatedGRPCUserRateLimitRequestsEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_USER_RATE_LIMIT_REQUESTS" + + // authenticatedGRPCUserRateLimitWindowEnvVar names the environment + // variable that configures the authenticated gRPC per-user rate-limit + // window. + authenticatedGRPCUserRateLimitWindowEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_USER_RATE_LIMIT_WINDOW" + + // authenticatedGRPCUserRateLimitBurstEnvVar names the environment variable + // that configures the authenticated gRPC per-user rate-limit burst. + authenticatedGRPCUserRateLimitBurstEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_USER_RATE_LIMIT_BURST" + + // authenticatedGRPCMessageClassRateLimitRequestsEnvVar names the + // environment variable that configures the authenticated gRPC per-message + // class request budget per window. + authenticatedGRPCMessageClassRateLimitRequestsEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_REQUESTS" + + // authenticatedGRPCMessageClassRateLimitWindowEnvVar names the environment + // variable that configures the authenticated gRPC per-message-class + // rate-limit window. + authenticatedGRPCMessageClassRateLimitWindowEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_WINDOW" + + // authenticatedGRPCMessageClassRateLimitBurstEnvVar names the environment + // variable that configures the authenticated gRPC per-message-class + // rate-limit burst. + authenticatedGRPCMessageClassRateLimitBurstEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_BURST" + + // sessionCacheRedisAddrEnvVar names the environment variable that configures + // the Redis address used for SessionCache lookups. + sessionCacheRedisAddrEnvVar = "GATEWAY_SESSION_CACHE_REDIS_ADDR" + + // sessionCacheRedisUsernameEnvVar names the environment variable that + // configures the Redis username used for SessionCache lookups. + sessionCacheRedisUsernameEnvVar = "GATEWAY_SESSION_CACHE_REDIS_USERNAME" + + // sessionCacheRedisPasswordEnvVar names the environment variable that + // configures the Redis password used for SessionCache lookups. + sessionCacheRedisPasswordEnvVar = "GATEWAY_SESSION_CACHE_REDIS_PASSWORD" + + // sessionCacheRedisDBEnvVar names the environment variable that configures + // the Redis logical database used for SessionCache lookups. + sessionCacheRedisDBEnvVar = "GATEWAY_SESSION_CACHE_REDIS_DB" + + // sessionCacheRedisKeyPrefixEnvVar names the environment variable that + // configures the Redis key prefix used for SessionCache records. + sessionCacheRedisKeyPrefixEnvVar = "GATEWAY_SESSION_CACHE_REDIS_KEY_PREFIX" + + // sessionCacheRedisLookupTimeoutEnvVar names the environment variable that + // configures the timeout used for SessionCache Redis lookups and startup + // connectivity checks. + sessionCacheRedisLookupTimeoutEnvVar = "GATEWAY_SESSION_CACHE_REDIS_LOOKUP_TIMEOUT" + + // sessionCacheRedisTLSEnabledEnvVar names the environment variable that + // configures whether SessionCache Redis connections use TLS. + sessionCacheRedisTLSEnabledEnvVar = "GATEWAY_SESSION_CACHE_REDIS_TLS_ENABLED" + + // replayRedisKeyPrefixEnvVar names the environment variable that configures + // the Redis key prefix used for authenticated replay reservations. + replayRedisKeyPrefixEnvVar = "GATEWAY_REPLAY_REDIS_KEY_PREFIX" + + // replayRedisReserveTimeoutEnvVar names the environment variable that + // configures the timeout used for authenticated replay reservations and + // startup connectivity checks. + replayRedisReserveTimeoutEnvVar = "GATEWAY_REPLAY_REDIS_RESERVE_TIMEOUT" + + // sessionEventsRedisStreamEnvVar names the environment variable that + // configures the Redis Stream key consumed for session lifecycle updates. + sessionEventsRedisStreamEnvVar = "GATEWAY_SESSION_EVENTS_REDIS_STREAM" + + // sessionEventsRedisReadBlockTimeoutEnvVar names the environment variable + // that configures the blocking read timeout used by the session event + // subscriber. + sessionEventsRedisReadBlockTimeoutEnvVar = "GATEWAY_SESSION_EVENTS_REDIS_READ_BLOCK_TIMEOUT" + + // clientEventsRedisStreamEnvVar names the environment variable that + // configures the Redis Stream key consumed for client-facing push events. + clientEventsRedisStreamEnvVar = "GATEWAY_CLIENT_EVENTS_REDIS_STREAM" + + // clientEventsRedisReadBlockTimeoutEnvVar names the environment variable + // that configures the blocking read timeout used by the client-event + // subscriber. + clientEventsRedisReadBlockTimeoutEnvVar = "GATEWAY_CLIENT_EVENTS_REDIS_READ_BLOCK_TIMEOUT" + + // responseSignerPrivateKeyPEMPathEnvVar names the environment variable that + // configures the path to the PKCS#8 PEM-encoded Ed25519 private key used to + // sign authenticated unary responses and stream events. + responseSignerPrivateKeyPEMPathEnvVar = "GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH" + + // publicAuthMaxBodyBytesEnvVar names the environment variable that + // configures the maximum accepted request body size for public_auth. + publicAuthMaxBodyBytesEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_MAX_BODY_BYTES" + + // publicAuthRateLimitRequestsEnvVar names the environment variable that + // configures the public_auth request budget per window. + publicAuthRateLimitRequestsEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_RATE_LIMIT_REQUESTS" + + // publicAuthRateLimitWindowEnvVar names the environment variable that + // configures the public_auth rate-limit window. + publicAuthRateLimitWindowEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_RATE_LIMIT_WINDOW" + + // publicAuthRateLimitBurstEnvVar names the environment variable that + // configures the public_auth rate-limit burst. + publicAuthRateLimitBurstEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_AUTH_RATE_LIMIT_BURST" + + // browserBootstrapMaxBodyBytesEnvVar names the environment variable that + // configures the maximum accepted request body size for browser_bootstrap. + browserBootstrapMaxBodyBytesEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_MAX_BODY_BYTES" + + // browserBootstrapRateLimitRequestsEnvVar names the environment variable + // that configures the browser_bootstrap request budget per window. + browserBootstrapRateLimitRequestsEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_RATE_LIMIT_REQUESTS" + + // browserBootstrapRateLimitWindowEnvVar names the environment variable that + // configures the browser_bootstrap rate-limit window. + browserBootstrapRateLimitWindowEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_RATE_LIMIT_WINDOW" + + // browserBootstrapRateLimitBurstEnvVar names the environment variable that + // configures the browser_bootstrap rate-limit burst. + browserBootstrapRateLimitBurstEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_BOOTSTRAP_RATE_LIMIT_BURST" + + // browserAssetMaxBodyBytesEnvVar names the environment variable that + // configures the maximum accepted request body size for browser_asset. + browserAssetMaxBodyBytesEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_MAX_BODY_BYTES" + + // browserAssetRateLimitRequestsEnvVar names the environment variable that + // configures the browser_asset request budget per window. + browserAssetRateLimitRequestsEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_RATE_LIMIT_REQUESTS" + + // browserAssetRateLimitWindowEnvVar names the environment variable that + // configures the browser_asset rate-limit window. + browserAssetRateLimitWindowEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_RATE_LIMIT_WINDOW" + + // browserAssetRateLimitBurstEnvVar names the environment variable that + // configures the browser_asset rate-limit burst. + browserAssetRateLimitBurstEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_BROWSER_ASSET_RATE_LIMIT_BURST" + + // publicMiscMaxBodyBytesEnvVar names the environment variable that + // configures the maximum accepted request body size for public_misc. + publicMiscMaxBodyBytesEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_MAX_BODY_BYTES" + + // publicMiscRateLimitRequestsEnvVar names the environment variable that + // configures the public_misc request budget per window. + publicMiscRateLimitRequestsEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_RATE_LIMIT_REQUESTS" + + // publicMiscRateLimitWindowEnvVar names the environment variable that + // configures the public_misc rate-limit window. + publicMiscRateLimitWindowEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_RATE_LIMIT_WINDOW" + + // publicMiscRateLimitBurstEnvVar names the environment variable that + // configures the public_misc rate-limit burst. + publicMiscRateLimitBurstEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_PUBLIC_MISC_RATE_LIMIT_BURST" + + // sendEmailCodeIdentityRateLimitRequestsEnvVar names the environment + // variable that configures the send-email-code identity request budget per + // window. + sendEmailCodeIdentityRateLimitRequestsEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_SEND_EMAIL_CODE_IDENTITY_RATE_LIMIT_REQUESTS" + + // sendEmailCodeIdentityRateLimitWindowEnvVar names the environment variable + // that configures the send-email-code identity rate-limit window. + sendEmailCodeIdentityRateLimitWindowEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_SEND_EMAIL_CODE_IDENTITY_RATE_LIMIT_WINDOW" + + // sendEmailCodeIdentityRateLimitBurstEnvVar names the environment variable + // that configures the send-email-code identity rate-limit burst. + sendEmailCodeIdentityRateLimitBurstEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_SEND_EMAIL_CODE_IDENTITY_RATE_LIMIT_BURST" + + // confirmEmailCodeIdentityRateLimitRequestsEnvVar names the environment + // variable that configures the confirm-email-code identity request budget + // per window. + confirmEmailCodeIdentityRateLimitRequestsEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_CONFIRM_EMAIL_CODE_IDENTITY_RATE_LIMIT_REQUESTS" + + // confirmEmailCodeIdentityRateLimitWindowEnvVar names the environment + // variable that configures the confirm-email-code identity rate-limit + // window. + confirmEmailCodeIdentityRateLimitWindowEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_CONFIRM_EMAIL_CODE_IDENTITY_RATE_LIMIT_WINDOW" + + // confirmEmailCodeIdentityRateLimitBurstEnvVar names the environment + // variable that configures the confirm-email-code identity rate-limit burst. + confirmEmailCodeIdentityRateLimitBurstEnvVar = "GATEWAY_PUBLIC_HTTP_ANTI_ABUSE_CONFIRM_EMAIL_CODE_IDENTITY_RATE_LIMIT_BURST" + + // defaultShutdownTimeout is applied when shutdownTimeoutEnvVar is absent. + defaultShutdownTimeout = 5 * time.Second + + // defaultLogLevel is applied when logLevelEnvVar is absent. + defaultLogLevel = "info" + + // defaultPublicHTTPAddr is applied when publicHTTPAddrEnvVar is absent. + defaultPublicHTTPAddr = ":8080" + + defaultPublicHTTPReadHeaderTimeout = 2 * time.Second + defaultPublicHTTPReadTimeout = 10 * time.Second + defaultPublicHTTPIdleTimeout = time.Minute + defaultPublicAuthUpstreamTimeout = 3 * time.Second + + defaultAdminHTTPReadHeaderTimeout = 2 * time.Second + defaultAdminHTTPReadTimeout = 10 * time.Second + defaultAdminHTTPIdleTimeout = time.Minute + + // defaultAuthenticatedGRPCAddr is applied when + // authenticatedGRPCAddrEnvVar is absent. + defaultAuthenticatedGRPCAddr = ":9090" + + defaultAuthenticatedGRPCConnectionTimeout = 5 * time.Second + defaultAuthenticatedGRPCDownstreamTimeout = 5 * time.Second + defaultAuthenticatedGRPCFreshnessWindow = 5 * time.Minute + + defaultAuthenticatedGRPCIPRateLimitRequests = 120 + defaultAuthenticatedGRPCIPRateLimitBurst = 40 + + defaultAuthenticatedGRPCSessionRateLimitRequests = 60 + defaultAuthenticatedGRPCSessionRateLimitBurst = 20 + + defaultAuthenticatedGRPCUserRateLimitRequests = 120 + defaultAuthenticatedGRPCUserRateLimitBurst = 40 + + defaultAuthenticatedGRPCMessageClassRateLimitRequests = 60 + defaultAuthenticatedGRPCMessageClassRateLimitBurst = 20 + + defaultSessionCacheRedisDB = 0 + defaultSessionCacheRedisKeyPrefix = "gateway:session:" + defaultSessionCacheRedisLookupTimeout = 250 * time.Millisecond + + defaultReplayRedisKeyPrefix = "gateway:replay:" + defaultReplayRedisReserveTimeout = 250 * time.Millisecond + + defaultSessionEventsRedisReadBlockTimeout = time.Second + defaultClientEventsRedisReadBlockTimeout = time.Second + + defaultPublicAuthMaxBodyBytes = int64(8192) + + defaultPublicAuthRateLimitRequests = 30 + defaultPublicAuthRateLimitBurst = 10 + + defaultBrowserBootstrapRateLimitRequests = 60 + defaultBrowserBootstrapRateLimitBurst = 20 + + defaultBrowserAssetRateLimitRequests = 300 + defaultBrowserAssetRateLimitBurst = 80 + + defaultPublicMiscRateLimitRequests = 30 + defaultPublicMiscRateLimitBurst = 10 + + defaultSendEmailCodeIdentityRateLimitRequests = 3 + defaultSendEmailCodeIdentityRateLimitBurst = 1 + + defaultConfirmEmailCodeIdentityRateLimitRequests = 6 + defaultConfirmEmailCodeIdentityRateLimitBurst = 2 +) + +var ( + defaultClassRateLimitWindow = time.Minute + + defaultIdentityRateLimitWindow = 10 * time.Minute +) + +// RateLimitConfig describes a single rate-limit budget. +type RateLimitConfig struct { + // Requests is the number of accepted requests replenished per Window. + Requests int + + // Window is the interval over which Requests are replenished. + Window time.Duration + + // Burst is the maximum number of immediately available tokens. + Burst int +} + +// PublicRateLimitConfig identifies the generic rate-limit budget shape used by +// public REST policy. +type PublicRateLimitConfig = RateLimitConfig + +// AuthenticatedRateLimitConfig identifies the generic rate-limit budget shape +// used by authenticated gRPC policy. +type AuthenticatedRateLimitConfig = RateLimitConfig + +// PublicRoutePolicyConfig describes the anti-abuse policy enforced for one +// stable public REST traffic class. +type PublicRoutePolicyConfig struct { + // MaxBodyBytes is the maximum accepted request body size. Zero means that + // the request must not carry a body. + MaxBodyBytes int64 + + // RateLimit configures the per-IP budget for the route class. + RateLimit PublicRateLimitConfig +} + +// PublicAuthIdentityPolicyConfig describes the additional identity-based +// limiter applied to one public auth command. +type PublicAuthIdentityPolicyConfig struct { + // RateLimit configures the accepted request budget for one normalized public + // auth identity key. + RateLimit PublicRateLimitConfig +} + +// PublicHTTPAntiAbuseConfig describes the public REST anti-abuse policy used +// before route handling. +type PublicHTTPAntiAbuseConfig struct { + // PublicAuth applies to the stable public_auth route class. + PublicAuth PublicRoutePolicyConfig + + // BrowserBootstrap applies to the stable browser_bootstrap route class. + BrowserBootstrap PublicRoutePolicyConfig + + // BrowserAsset applies to the stable browser_asset route class. + BrowserAsset PublicRoutePolicyConfig + + // PublicMisc applies to the stable public_misc route class. + PublicMisc PublicRoutePolicyConfig + + // SendEmailCodeIdentity applies the additional identity limiter for + // send-email-code. + SendEmailCodeIdentity PublicAuthIdentityPolicyConfig + + // ConfirmEmailCodeIdentity applies the additional identity limiter for + // confirm-email-code. + ConfirmEmailCodeIdentity PublicAuthIdentityPolicyConfig +} + +// AuthenticatedGRPCAntiAbuseConfig describes the authenticated gRPC +// rate-limit budgets enforced after request authenticity has been established. +type AuthenticatedGRPCAntiAbuseConfig struct { + // IP applies to the transport peer IP derived from the gRPC connection. + IP AuthenticatedRateLimitConfig + + // Session applies to the authenticated device_session_id. + Session AuthenticatedRateLimitConfig + + // User applies to the authenticated user_id resolved from SessionCache. + User AuthenticatedRateLimitConfig + + // MessageClass applies to the current authenticated message class. The + // gateway uses the full message_type literal as the stable v1 class key. + MessageClass AuthenticatedRateLimitConfig +} + +// PublicHTTPConfig describes the public unauthenticated REST listener exposed +// by the gateway. +type PublicHTTPConfig struct { + // Addr is the TCP listen address used by the public REST server. + Addr string + + // ReadHeaderTimeout bounds how long the listener may spend reading request + // headers before the gateway rejects the connection. + ReadHeaderTimeout time.Duration + + // ReadTimeout bounds how long the listener may spend reading one public + // request. + ReadTimeout time.Duration + + // IdleTimeout bounds how long the listener keeps an idle keep-alive + // connection open. + IdleTimeout time.Duration + + // AuthUpstreamTimeout bounds one public auth adapter call. + AuthUpstreamTimeout time.Duration + + // AntiAbuse configures the public REST anti-abuse middleware. + AntiAbuse PublicHTTPAntiAbuseConfig +} + +// AdminHTTPConfig describes the private operational HTTP listener used for +// metrics exposure. The listener remains disabled when Addr is empty. +type AdminHTTPConfig struct { + // Addr is the TCP listen address used by the admin HTTP server. An empty + // value disables the listener. + Addr string + + // ReadHeaderTimeout bounds how long the listener may spend reading request + // headers before the gateway rejects the connection. + ReadHeaderTimeout time.Duration + + // ReadTimeout bounds how long the listener may spend reading one admin + // request. + ReadTimeout time.Duration + + // IdleTimeout bounds how long the listener keeps an idle keep-alive + // connection open. + IdleTimeout time.Duration +} + +// AuthenticatedGRPCConfig describes the authenticated gRPC listener exposed by +// the gateway. +type AuthenticatedGRPCConfig struct { + // Addr is the TCP listen address used by the authenticated gRPC server. + Addr string + + // ConnectionTimeout bounds one inbound connection handshake. + ConnectionTimeout time.Duration + + // DownstreamTimeout bounds one downstream unary execution after the request + // has passed the full authenticated ingress pipeline. + DownstreamTimeout time.Duration + + // FreshnessWindow is the accepted skew window around current server time + // used for client request timestamps. + FreshnessWindow time.Duration + + // AntiAbuse configures the authenticated gRPC rate limits enforced after + // the request passes the transport authenticity checks. + AntiAbuse AuthenticatedGRPCAntiAbuseConfig +} + +// SessionCacheRedisConfig describes the Redis connection used for authenticated +// SessionCache lookups. +type SessionCacheRedisConfig struct { + // Addr is the Redis endpoint used for SessionCache requests. + Addr string + + // Username is the optional Redis ACL username used for authentication. + Username string + + // Password is the optional Redis password used for authentication. + Password string + + // DB is the Redis logical database number used for SessionCache keys. + DB int + + // KeyPrefix is prepended to every SessionCache Redis key. + KeyPrefix string + + // LookupTimeout bounds individual SessionCache Redis operations. + LookupTimeout time.Duration + + // TLSEnabled reports whether SessionCache Redis connections should use TLS. + TLSEnabled bool +} + +// ReplayRedisConfig describes the Redis namespace and timeout used for +// authenticated replay reservations. +type ReplayRedisConfig struct { + // KeyPrefix is prepended to every ReplayStore Redis key. + KeyPrefix string + + // ReserveTimeout bounds individual ReplayStore Redis operations. + ReserveTimeout time.Duration +} + +// SessionEventsRedisConfig describes the Redis Stream consumed by the gateway +// to keep the process-local session cache synchronized with session lifecycle +// updates. +type SessionEventsRedisConfig struct { + // Stream is the Redis Stream key carrying full session snapshot events. + Stream string + + // ReadBlockTimeout bounds one blocking XREAD call so shutdown remains + // responsive even when the stream is idle. + ReadBlockTimeout time.Duration +} + +// ClientEventsRedisConfig describes the Redis Stream consumed by the gateway +// to deliver client-facing events to active push streams. +type ClientEventsRedisConfig struct { + // Stream is the Redis Stream key carrying client-facing event entries. + Stream string + + // ReadBlockTimeout bounds one blocking XREAD call so shutdown remains + // responsive even when the stream is idle. + ReadBlockTimeout time.Duration +} + +// ResponseSignerConfig describes the private-key material used to sign +// authenticated unary responses and stream events. +type ResponseSignerConfig struct { + // PrivateKeyPEMPath is the filesystem path to the PKCS#8 PEM-encoded + // Ed25519 private key loaded during startup. + PrivateKeyPEMPath string +} + +// LoggingConfig describes the process-wide structured logging settings. +type LoggingConfig struct { + // Level is the configured minimum log level literal. + Level string +} + +// Config describes process-wide settings required to start and stop the +// gateway safely. +type Config struct { + // ShutdownTimeout limits how long each component may spend in Shutdown + // before the gateway reports a timeout. + ShutdownTimeout time.Duration + + // Logging configures the process-wide structured logger. + Logging LoggingConfig + + // PublicHTTP configures the public unauthenticated REST listener. + PublicHTTP PublicHTTPConfig + + // AdminHTTP configures the optional private admin listener used for metrics + // exposure. + AdminHTTP AdminHTTPConfig + + // AuthenticatedGRPC configures the authenticated gRPC listener. + AuthenticatedGRPC AuthenticatedGRPCConfig + + // SessionCacheRedis configures the Redis-backed authenticated SessionCache. + SessionCacheRedis SessionCacheRedisConfig + + // ReplayRedis configures the Redis-backed authenticated ReplayStore. + ReplayRedis ReplayRedisConfig + + // SessionEventsRedis configures the Redis Stream consumed for session cache + // updates and revocations. + SessionEventsRedis SessionEventsRedisConfig + + // ClientEventsRedis configures the Redis Stream consumed for client-facing + // push delivery. + ClientEventsRedis ClientEventsRedisConfig + + // ResponseSigner configures the authenticated response and event signer + // loaded during startup. + ResponseSigner ResponseSignerConfig +} + +// DefaultPublicHTTPConfig returns the default listener and anti-abuse settings +// for the public REST surface. +func DefaultPublicHTTPConfig() PublicHTTPConfig { + return PublicHTTPConfig{ + Addr: defaultPublicHTTPAddr, + ReadHeaderTimeout: defaultPublicHTTPReadHeaderTimeout, + ReadTimeout: defaultPublicHTTPReadTimeout, + IdleTimeout: defaultPublicHTTPIdleTimeout, + AuthUpstreamTimeout: defaultPublicAuthUpstreamTimeout, + AntiAbuse: PublicHTTPAntiAbuseConfig{ + PublicAuth: PublicRoutePolicyConfig{ + MaxBodyBytes: defaultPublicAuthMaxBodyBytes, + RateLimit: PublicRateLimitConfig{ + Requests: defaultPublicAuthRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultPublicAuthRateLimitBurst, + }, + }, + BrowserBootstrap: PublicRoutePolicyConfig{ + RateLimit: PublicRateLimitConfig{ + Requests: defaultBrowserBootstrapRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultBrowserBootstrapRateLimitBurst, + }, + }, + BrowserAsset: PublicRoutePolicyConfig{ + RateLimit: PublicRateLimitConfig{ + Requests: defaultBrowserAssetRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultBrowserAssetRateLimitBurst, + }, + }, + PublicMisc: PublicRoutePolicyConfig{ + RateLimit: PublicRateLimitConfig{ + Requests: defaultPublicMiscRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultPublicMiscRateLimitBurst, + }, + }, + SendEmailCodeIdentity: PublicAuthIdentityPolicyConfig{ + RateLimit: PublicRateLimitConfig{ + Requests: defaultSendEmailCodeIdentityRateLimitRequests, + Window: defaultIdentityRateLimitWindow, + Burst: defaultSendEmailCodeIdentityRateLimitBurst, + }, + }, + ConfirmEmailCodeIdentity: PublicAuthIdentityPolicyConfig{ + RateLimit: PublicRateLimitConfig{ + Requests: defaultConfirmEmailCodeIdentityRateLimitRequests, + Window: defaultIdentityRateLimitWindow, + Burst: defaultConfirmEmailCodeIdentityRateLimitBurst, + }, + }, + }, + } +} + +// DefaultAdminHTTPConfig returns the default settings for the optional private +// admin listener. The zero address keeps the listener disabled by default. +func DefaultAdminHTTPConfig() AdminHTTPConfig { + return AdminHTTPConfig{ + ReadHeaderTimeout: defaultAdminHTTPReadHeaderTimeout, + ReadTimeout: defaultAdminHTTPReadTimeout, + IdleTimeout: defaultAdminHTTPIdleTimeout, + } +} + +// DefaultAuthenticatedGRPCConfig returns the default listener, freshness, and +// anti-abuse settings for the authenticated gRPC surface. +func DefaultAuthenticatedGRPCConfig() AuthenticatedGRPCConfig { + return AuthenticatedGRPCConfig{ + Addr: defaultAuthenticatedGRPCAddr, + ConnectionTimeout: defaultAuthenticatedGRPCConnectionTimeout, + DownstreamTimeout: defaultAuthenticatedGRPCDownstreamTimeout, + FreshnessWindow: defaultAuthenticatedGRPCFreshnessWindow, + AntiAbuse: AuthenticatedGRPCAntiAbuseConfig{ + IP: AuthenticatedRateLimitConfig{ + Requests: defaultAuthenticatedGRPCIPRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultAuthenticatedGRPCIPRateLimitBurst, + }, + Session: AuthenticatedRateLimitConfig{ + Requests: defaultAuthenticatedGRPCSessionRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultAuthenticatedGRPCSessionRateLimitBurst, + }, + User: AuthenticatedRateLimitConfig{ + Requests: defaultAuthenticatedGRPCUserRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultAuthenticatedGRPCUserRateLimitBurst, + }, + MessageClass: AuthenticatedRateLimitConfig{ + Requests: defaultAuthenticatedGRPCMessageClassRateLimitRequests, + Window: defaultClassRateLimitWindow, + Burst: defaultAuthenticatedGRPCMessageClassRateLimitBurst, + }, + }, + } +} + +// DefaultLoggingConfig returns the default structured logging settings. +func DefaultLoggingConfig() LoggingConfig { + return LoggingConfig{Level: defaultLogLevel} +} + +// DefaultSessionCacheRedisConfig returns the default optional settings for the +// Redis-backed authenticated SessionCache. Addr remains empty and must be +// supplied explicitly. +func DefaultSessionCacheRedisConfig() SessionCacheRedisConfig { + return SessionCacheRedisConfig{ + DB: defaultSessionCacheRedisDB, + KeyPrefix: defaultSessionCacheRedisKeyPrefix, + LookupTimeout: defaultSessionCacheRedisLookupTimeout, + } +} + +// DefaultReplayRedisConfig returns the default Redis key namespace and timeout +// used for authenticated replay reservations. +func DefaultReplayRedisConfig() ReplayRedisConfig { + return ReplayRedisConfig{ + KeyPrefix: defaultReplayRedisKeyPrefix, + ReserveTimeout: defaultReplayRedisReserveTimeout, + } +} + +// DefaultSessionEventsRedisConfig returns the default optional settings for the +// session lifecycle event subscriber. Stream remains empty and must be +// supplied explicitly. +func DefaultSessionEventsRedisConfig() SessionEventsRedisConfig { + return SessionEventsRedisConfig{ + ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout, + } +} + +// DefaultClientEventsRedisConfig returns the default optional settings for the +// client-facing event subscriber. Stream remains empty and must be supplied +// explicitly. +func DefaultClientEventsRedisConfig() ClientEventsRedisConfig { + return ClientEventsRedisConfig{ + ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout, + } +} + +// DefaultResponseSignerConfig returns the default response-signer settings. +// The private key path remains empty and must be supplied explicitly. +func DefaultResponseSignerConfig() ResponseSignerConfig { + return ResponseSignerConfig{} +} + +// LoadFromEnv loads Config from the process environment, applies defaults for +// omitted settings, and validates the resulting values. +func LoadFromEnv() (Config, error) { + cfg := Config{ + ShutdownTimeout: defaultShutdownTimeout, + Logging: DefaultLoggingConfig(), + PublicHTTP: DefaultPublicHTTPConfig(), + AdminHTTP: DefaultAdminHTTPConfig(), + AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(), + SessionCacheRedis: DefaultSessionCacheRedisConfig(), + ReplayRedis: DefaultReplayRedisConfig(), + SessionEventsRedis: DefaultSessionEventsRedisConfig(), + ClientEventsRedis: DefaultClientEventsRedisConfig(), + ResponseSigner: DefaultResponseSignerConfig(), + } + + rawShutdownTimeout, ok := os.LookupEnv(shutdownTimeoutEnvVar) + if ok { + shutdownTimeout, err := time.ParseDuration(rawShutdownTimeout) + if err != nil { + return Config{}, fmt.Errorf("load gateway config: parse %s: %w", shutdownTimeoutEnvVar, err) + } + cfg.ShutdownTimeout = shutdownTimeout + } + + rawLogLevel, ok := os.LookupEnv(logLevelEnvVar) + if ok { + cfg.Logging.Level = rawLogLevel + } + + rawPublicHTTPAddr, ok := os.LookupEnv(publicHTTPAddrEnvVar) + if ok { + cfg.PublicHTTP.Addr = rawPublicHTTPAddr + } + + publicHTTPReadHeaderTimeout, err := loadDurationEnvWithDefault(publicHTTPReadHeaderTimeoutEnvVar, cfg.PublicHTTP.ReadHeaderTimeout) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.ReadHeaderTimeout = publicHTTPReadHeaderTimeout + + publicHTTPReadTimeout, err := loadDurationEnvWithDefault(publicHTTPReadTimeoutEnvVar, cfg.PublicHTTP.ReadTimeout) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.ReadTimeout = publicHTTPReadTimeout + + publicHTTPIdleTimeout, err := loadDurationEnvWithDefault(publicHTTPIdleTimeoutEnvVar, cfg.PublicHTTP.IdleTimeout) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.IdleTimeout = publicHTTPIdleTimeout + + publicAuthUpstreamTimeout, err := loadDurationEnvWithDefault(publicAuthUpstreamTimeoutEnvVar, cfg.PublicHTTP.AuthUpstreamTimeout) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.AuthUpstreamTimeout = publicAuthUpstreamTimeout + + rawAdminHTTPAddr, ok := os.LookupEnv(adminHTTPAddrEnvVar) + if ok { + cfg.AdminHTTP.Addr = rawAdminHTTPAddr + } + + adminHTTPReadHeaderTimeout, err := loadDurationEnvWithDefault(adminHTTPReadHeaderTimeoutEnvVar, cfg.AdminHTTP.ReadHeaderTimeout) + if err != nil { + return Config{}, err + } + cfg.AdminHTTP.ReadHeaderTimeout = adminHTTPReadHeaderTimeout + + adminHTTPReadTimeout, err := loadDurationEnvWithDefault(adminHTTPReadTimeoutEnvVar, cfg.AdminHTTP.ReadTimeout) + if err != nil { + return Config{}, err + } + cfg.AdminHTTP.ReadTimeout = adminHTTPReadTimeout + + adminHTTPIdleTimeout, err := loadDurationEnvWithDefault(adminHTTPIdleTimeoutEnvVar, cfg.AdminHTTP.IdleTimeout) + if err != nil { + return Config{}, err + } + cfg.AdminHTTP.IdleTimeout = adminHTTPIdleTimeout + + rawAuthenticatedGRPCAddr, ok := os.LookupEnv(authenticatedGRPCAddrEnvVar) + if ok { + cfg.AuthenticatedGRPC.Addr = rawAuthenticatedGRPCAddr + } + + authenticatedGRPCConnectionTimeout, err := loadDurationEnvWithDefault(authenticatedGRPCConnectionTimeoutEnvVar, cfg.AuthenticatedGRPC.ConnectionTimeout) + if err != nil { + return Config{}, err + } + cfg.AuthenticatedGRPC.ConnectionTimeout = authenticatedGRPCConnectionTimeout + + authenticatedGRPCDownstreamTimeout, err := loadDurationEnvWithDefault(authenticatedGRPCDownstreamTimeoutEnvVar, cfg.AuthenticatedGRPC.DownstreamTimeout) + if err != nil { + return Config{}, err + } + cfg.AuthenticatedGRPC.DownstreamTimeout = authenticatedGRPCDownstreamTimeout + + authenticatedGRPCFreshnessWindow, err := loadDurationEnvWithDefault(authenticatedGRPCFreshnessWindowEnvVar, cfg.AuthenticatedGRPC.FreshnessWindow) + if err != nil { + return Config{}, err + } + cfg.AuthenticatedGRPC.FreshnessWindow = authenticatedGRPCFreshnessWindow + + authenticatedGRPCIPRateLimit, err := loadRateLimitConfigFromEnv( + cfg.AuthenticatedGRPC.AntiAbuse.IP, + authenticatedGRPCIPRateLimitRequestsEnvVar, + authenticatedGRPCIPRateLimitWindowEnvVar, + authenticatedGRPCIPRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.AuthenticatedGRPC.AntiAbuse.IP = authenticatedGRPCIPRateLimit + + authenticatedGRPCSessionRateLimit, err := loadRateLimitConfigFromEnv( + cfg.AuthenticatedGRPC.AntiAbuse.Session, + authenticatedGRPCSessionRateLimitRequestsEnvVar, + authenticatedGRPCSessionRateLimitWindowEnvVar, + authenticatedGRPCSessionRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.AuthenticatedGRPC.AntiAbuse.Session = authenticatedGRPCSessionRateLimit + + authenticatedGRPCUserRateLimit, err := loadRateLimitConfigFromEnv( + cfg.AuthenticatedGRPC.AntiAbuse.User, + authenticatedGRPCUserRateLimitRequestsEnvVar, + authenticatedGRPCUserRateLimitWindowEnvVar, + authenticatedGRPCUserRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.AuthenticatedGRPC.AntiAbuse.User = authenticatedGRPCUserRateLimit + + messageClassRateLimit, err := loadRateLimitConfigFromEnv( + cfg.AuthenticatedGRPC.AntiAbuse.MessageClass, + authenticatedGRPCMessageClassRateLimitRequestsEnvVar, + authenticatedGRPCMessageClassRateLimitWindowEnvVar, + authenticatedGRPCMessageClassRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.AuthenticatedGRPC.AntiAbuse.MessageClass = messageClassRateLimit + + rawSessionCacheRedisAddr, ok := os.LookupEnv(sessionCacheRedisAddrEnvVar) + if ok { + cfg.SessionCacheRedis.Addr = rawSessionCacheRedisAddr + } + + rawSessionCacheRedisUsername, ok := os.LookupEnv(sessionCacheRedisUsernameEnvVar) + if ok { + cfg.SessionCacheRedis.Username = rawSessionCacheRedisUsername + } + + rawSessionCacheRedisPassword, ok := os.LookupEnv(sessionCacheRedisPasswordEnvVar) + if ok { + cfg.SessionCacheRedis.Password = rawSessionCacheRedisPassword + } + + sessionCacheRedisDB, err := loadIntEnvWithDefault(sessionCacheRedisDBEnvVar, cfg.SessionCacheRedis.DB) + if err != nil { + return Config{}, err + } + cfg.SessionCacheRedis.DB = sessionCacheRedisDB + + rawSessionCacheRedisKeyPrefix, ok := os.LookupEnv(sessionCacheRedisKeyPrefixEnvVar) + if ok { + cfg.SessionCacheRedis.KeyPrefix = rawSessionCacheRedisKeyPrefix + } + + sessionCacheRedisLookupTimeout, err := loadDurationEnvWithDefault(sessionCacheRedisLookupTimeoutEnvVar, cfg.SessionCacheRedis.LookupTimeout) + if err != nil { + return Config{}, err + } + cfg.SessionCacheRedis.LookupTimeout = sessionCacheRedisLookupTimeout + + sessionCacheRedisTLSEnabled, err := loadBoolEnvWithDefault(sessionCacheRedisTLSEnabledEnvVar, cfg.SessionCacheRedis.TLSEnabled) + if err != nil { + return Config{}, err + } + cfg.SessionCacheRedis.TLSEnabled = sessionCacheRedisTLSEnabled + + rawReplayRedisKeyPrefix, ok := os.LookupEnv(replayRedisKeyPrefixEnvVar) + if ok { + cfg.ReplayRedis.KeyPrefix = rawReplayRedisKeyPrefix + } + + replayRedisReserveTimeout, err := loadDurationEnvWithDefault(replayRedisReserveTimeoutEnvVar, cfg.ReplayRedis.ReserveTimeout) + if err != nil { + return Config{}, err + } + cfg.ReplayRedis.ReserveTimeout = replayRedisReserveTimeout + + rawSessionEventsRedisStream, ok := os.LookupEnv(sessionEventsRedisStreamEnvVar) + if ok { + cfg.SessionEventsRedis.Stream = rawSessionEventsRedisStream + } + + sessionEventsRedisReadBlockTimeout, err := loadDurationEnvWithDefault(sessionEventsRedisReadBlockTimeoutEnvVar, cfg.SessionEventsRedis.ReadBlockTimeout) + if err != nil { + return Config{}, err + } + cfg.SessionEventsRedis.ReadBlockTimeout = sessionEventsRedisReadBlockTimeout + + rawClientEventsRedisStream, ok := os.LookupEnv(clientEventsRedisStreamEnvVar) + if ok { + cfg.ClientEventsRedis.Stream = rawClientEventsRedisStream + } + + clientEventsRedisReadBlockTimeout, err := loadDurationEnvWithDefault(clientEventsRedisReadBlockTimeoutEnvVar, cfg.ClientEventsRedis.ReadBlockTimeout) + if err != nil { + return Config{}, err + } + cfg.ClientEventsRedis.ReadBlockTimeout = clientEventsRedisReadBlockTimeout + + rawSignerKeyPath, ok := os.LookupEnv(responseSignerPrivateKeyPEMPathEnvVar) + if ok { + cfg.ResponseSigner.PrivateKeyPEMPath = rawSignerKeyPath + } + + publicAuthPolicy, err := loadPublicRoutePolicyConfigFromEnv( + cfg.PublicHTTP.AntiAbuse.PublicAuth, + publicAuthMaxBodyBytesEnvVar, + publicAuthRateLimitRequestsEnvVar, + publicAuthRateLimitWindowEnvVar, + publicAuthRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.AntiAbuse.PublicAuth = publicAuthPolicy + + browserBootstrapPolicy, err := loadPublicRoutePolicyConfigFromEnv( + cfg.PublicHTTP.AntiAbuse.BrowserBootstrap, + browserBootstrapMaxBodyBytesEnvVar, + browserBootstrapRateLimitRequestsEnvVar, + browserBootstrapRateLimitWindowEnvVar, + browserBootstrapRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.AntiAbuse.BrowserBootstrap = browserBootstrapPolicy + + browserAssetPolicy, err := loadPublicRoutePolicyConfigFromEnv( + cfg.PublicHTTP.AntiAbuse.BrowserAsset, + browserAssetMaxBodyBytesEnvVar, + browserAssetRateLimitRequestsEnvVar, + browserAssetRateLimitWindowEnvVar, + browserAssetRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.AntiAbuse.BrowserAsset = browserAssetPolicy + + publicMiscPolicy, err := loadPublicRoutePolicyConfigFromEnv( + cfg.PublicHTTP.AntiAbuse.PublicMisc, + publicMiscMaxBodyBytesEnvVar, + publicMiscRateLimitRequestsEnvVar, + publicMiscRateLimitWindowEnvVar, + publicMiscRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.AntiAbuse.PublicMisc = publicMiscPolicy + + sendIdentityPolicy, err := loadPublicAuthIdentityPolicyConfigFromEnv( + cfg.PublicHTTP.AntiAbuse.SendEmailCodeIdentity, + sendEmailCodeIdentityRateLimitRequestsEnvVar, + sendEmailCodeIdentityRateLimitWindowEnvVar, + sendEmailCodeIdentityRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.AntiAbuse.SendEmailCodeIdentity = sendIdentityPolicy + + confirmIdentityPolicy, err := loadPublicAuthIdentityPolicyConfigFromEnv( + cfg.PublicHTTP.AntiAbuse.ConfirmEmailCodeIdentity, + confirmEmailCodeIdentityRateLimitRequestsEnvVar, + confirmEmailCodeIdentityRateLimitWindowEnvVar, + confirmEmailCodeIdentityRateLimitBurstEnvVar, + ) + if err != nil { + return Config{}, err + } + cfg.PublicHTTP.AntiAbuse.ConfirmEmailCodeIdentity = confirmIdentityPolicy + + if cfg.ShutdownTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", shutdownTimeoutEnvVar) + } + if err := validateLogLevel(cfg.Logging.Level); err != nil { + return Config{}, fmt.Errorf("load gateway config: %w", err) + } + if strings.TrimSpace(cfg.PublicHTTP.Addr) == "" { + return Config{}, fmt.Errorf("load gateway config: %s must not be empty", publicHTTPAddrEnvVar) + } + if cfg.PublicHTTP.ReadHeaderTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", publicHTTPReadHeaderTimeoutEnvVar) + } + if cfg.PublicHTTP.ReadTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", publicHTTPReadTimeoutEnvVar) + } + if cfg.PublicHTTP.IdleTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", publicHTTPIdleTimeoutEnvVar) + } + if cfg.PublicHTTP.AuthUpstreamTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", publicAuthUpstreamTimeoutEnvVar) + } + if addr := strings.TrimSpace(cfg.AdminHTTP.Addr); addr != "" { + cfg.AdminHTTP.Addr = addr + } + if cfg.AdminHTTP.ReadHeaderTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", adminHTTPReadHeaderTimeoutEnvVar) + } + if cfg.AdminHTTP.ReadTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", adminHTTPReadTimeoutEnvVar) + } + if cfg.AdminHTTP.IdleTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", adminHTTPIdleTimeoutEnvVar) + } + if strings.TrimSpace(cfg.AuthenticatedGRPC.Addr) == "" { + return Config{}, fmt.Errorf("load gateway config: %s must not be empty", authenticatedGRPCAddrEnvVar) + } + if cfg.AuthenticatedGRPC.ConnectionTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", authenticatedGRPCConnectionTimeoutEnvVar) + } + if cfg.AuthenticatedGRPC.DownstreamTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", authenticatedGRPCDownstreamTimeoutEnvVar) + } + if cfg.AuthenticatedGRPC.FreshnessWindow <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", authenticatedGRPCFreshnessWindowEnvVar) + } + if err := validateRateLimitConfig( + cfg.AuthenticatedGRPC.AntiAbuse.IP, + authenticatedGRPCIPRateLimitRequestsEnvVar, + authenticatedGRPCIPRateLimitWindowEnvVar, + authenticatedGRPCIPRateLimitBurstEnvVar, + ); err != nil { + return Config{}, err + } + if err := validateRateLimitConfig( + cfg.AuthenticatedGRPC.AntiAbuse.Session, + authenticatedGRPCSessionRateLimitRequestsEnvVar, + authenticatedGRPCSessionRateLimitWindowEnvVar, + authenticatedGRPCSessionRateLimitBurstEnvVar, + ); err != nil { + return Config{}, err + } + if err := validateRateLimitConfig( + cfg.AuthenticatedGRPC.AntiAbuse.User, + authenticatedGRPCUserRateLimitRequestsEnvVar, + authenticatedGRPCUserRateLimitWindowEnvVar, + authenticatedGRPCUserRateLimitBurstEnvVar, + ); err != nil { + return Config{}, err + } + if err := validateRateLimitConfig( + cfg.AuthenticatedGRPC.AntiAbuse.MessageClass, + authenticatedGRPCMessageClassRateLimitRequestsEnvVar, + authenticatedGRPCMessageClassRateLimitWindowEnvVar, + authenticatedGRPCMessageClassRateLimitBurstEnvVar, + ); err != nil { + return Config{}, err + } + if strings.TrimSpace(cfg.SessionCacheRedis.Addr) == "" { + return Config{}, fmt.Errorf("load gateway config: %s must not be empty", sessionCacheRedisAddrEnvVar) + } + if cfg.SessionCacheRedis.DB < 0 { + return Config{}, fmt.Errorf("load gateway config: %s must not be negative", sessionCacheRedisDBEnvVar) + } + if cfg.SessionCacheRedis.LookupTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", sessionCacheRedisLookupTimeoutEnvVar) + } + if strings.TrimSpace(cfg.ReplayRedis.KeyPrefix) == "" { + return Config{}, fmt.Errorf("load gateway config: %s must not be empty", replayRedisKeyPrefixEnvVar) + } + if cfg.ReplayRedis.ReserveTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", replayRedisReserveTimeoutEnvVar) + } + if strings.TrimSpace(cfg.SessionEventsRedis.Stream) == "" { + return Config{}, fmt.Errorf("load gateway config: %s must not be empty", sessionEventsRedisStreamEnvVar) + } + if cfg.SessionEventsRedis.ReadBlockTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", sessionEventsRedisReadBlockTimeoutEnvVar) + } + if strings.TrimSpace(cfg.ClientEventsRedis.Stream) == "" { + return Config{}, fmt.Errorf("load gateway config: %s must not be empty", clientEventsRedisStreamEnvVar) + } + if cfg.ClientEventsRedis.ReadBlockTimeout <= 0 { + return Config{}, fmt.Errorf("load gateway config: %s must be positive", clientEventsRedisReadBlockTimeoutEnvVar) + } + if strings.TrimSpace(cfg.ResponseSigner.PrivateKeyPEMPath) == "" { + return Config{}, fmt.Errorf("load gateway config: %s must not be empty", responseSignerPrivateKeyPEMPathEnvVar) + } + if err := validatePublicRoutePolicyConfig(cfg.PublicHTTP.AntiAbuse.PublicAuth, publicAuthMaxBodyBytesEnvVar, publicAuthRateLimitRequestsEnvVar, publicAuthRateLimitWindowEnvVar, publicAuthRateLimitBurstEnvVar); err != nil { + return Config{}, err + } + if err := validatePublicRoutePolicyConfig(cfg.PublicHTTP.AntiAbuse.BrowserBootstrap, browserBootstrapMaxBodyBytesEnvVar, browserBootstrapRateLimitRequestsEnvVar, browserBootstrapRateLimitWindowEnvVar, browserBootstrapRateLimitBurstEnvVar); err != nil { + return Config{}, err + } + if err := validatePublicRoutePolicyConfig(cfg.PublicHTTP.AntiAbuse.BrowserAsset, browserAssetMaxBodyBytesEnvVar, browserAssetRateLimitRequestsEnvVar, browserAssetRateLimitWindowEnvVar, browserAssetRateLimitBurstEnvVar); err != nil { + return Config{}, err + } + if err := validatePublicRoutePolicyConfig(cfg.PublicHTTP.AntiAbuse.PublicMisc, publicMiscMaxBodyBytesEnvVar, publicMiscRateLimitRequestsEnvVar, publicMiscRateLimitWindowEnvVar, publicMiscRateLimitBurstEnvVar); err != nil { + return Config{}, err + } + if err := validatePublicAuthIdentityPolicyConfig(cfg.PublicHTTP.AntiAbuse.SendEmailCodeIdentity, sendEmailCodeIdentityRateLimitRequestsEnvVar, sendEmailCodeIdentityRateLimitWindowEnvVar, sendEmailCodeIdentityRateLimitBurstEnvVar); err != nil { + return Config{}, err + } + if err := validatePublicAuthIdentityPolicyConfig(cfg.PublicHTTP.AntiAbuse.ConfirmEmailCodeIdentity, confirmEmailCodeIdentityRateLimitRequestsEnvVar, confirmEmailCodeIdentityRateLimitWindowEnvVar, confirmEmailCodeIdentityRateLimitBurstEnvVar); err != nil { + return Config{}, err + } + + return cfg, nil +} + +func loadPublicRoutePolicyConfigFromEnv(defaults PublicRoutePolicyConfig, maxBodyEnvVar string, requestsEnvVar string, windowEnvVar string, burstEnvVar string) (PublicRoutePolicyConfig, error) { + policy := defaults + + maxBodyBytes, err := loadInt64EnvWithDefault(maxBodyEnvVar, defaults.MaxBodyBytes) + if err != nil { + return PublicRoutePolicyConfig{}, err + } + policy.MaxBodyBytes = maxBodyBytes + + rateLimit, err := loadRateLimitConfigFromEnv(defaults.RateLimit, requestsEnvVar, windowEnvVar, burstEnvVar) + if err != nil { + return PublicRoutePolicyConfig{}, err + } + policy.RateLimit = rateLimit + + return policy, nil +} + +func loadPublicAuthIdentityPolicyConfigFromEnv(defaults PublicAuthIdentityPolicyConfig, requestsEnvVar string, windowEnvVar string, burstEnvVar string) (PublicAuthIdentityPolicyConfig, error) { + rateLimit, err := loadRateLimitConfigFromEnv(defaults.RateLimit, requestsEnvVar, windowEnvVar, burstEnvVar) + if err != nil { + return PublicAuthIdentityPolicyConfig{}, err + } + + return PublicAuthIdentityPolicyConfig{RateLimit: rateLimit}, nil +} + +func loadRateLimitConfigFromEnv(defaults RateLimitConfig, requestsEnvVar string, windowEnvVar string, burstEnvVar string) (RateLimitConfig, error) { + cfg := defaults + + requests, err := loadIntEnvWithDefault(requestsEnvVar, defaults.Requests) + if err != nil { + return RateLimitConfig{}, err + } + cfg.Requests = requests + + window, err := loadDurationEnvWithDefault(windowEnvVar, defaults.Window) + if err != nil { + return RateLimitConfig{}, err + } + cfg.Window = window + + burst, err := loadIntEnvWithDefault(burstEnvVar, defaults.Burst) + if err != nil { + return RateLimitConfig{}, err + } + cfg.Burst = burst + + return cfg, nil +} + +func validateLogLevel(level string) error { + switch strings.ToLower(strings.TrimSpace(level)) { + case "debug", "info", "warn", "error", "dpanic", "panic", "fatal": + return nil + default: + return fmt.Errorf("%s must be one of debug, info, warn, error, dpanic, panic, fatal", logLevelEnvVar) + } +} + +func loadIntEnvWithDefault(envVar string, fallback int) (int, error) { + rawValue, ok := os.LookupEnv(envVar) + if !ok { + return fallback, nil + } + + value, err := strconv.Atoi(rawValue) + if err != nil { + return 0, fmt.Errorf("load gateway config: parse %s: %w", envVar, err) + } + + return value, nil +} + +func loadInt64EnvWithDefault(envVar string, fallback int64) (int64, error) { + rawValue, ok := os.LookupEnv(envVar) + if !ok { + return fallback, nil + } + + value, err := strconv.ParseInt(rawValue, 10, 64) + if err != nil { + return 0, fmt.Errorf("load gateway config: parse %s: %w", envVar, err) + } + + return value, nil +} + +func loadDurationEnvWithDefault(envVar string, fallback time.Duration) (time.Duration, error) { + rawValue, ok := os.LookupEnv(envVar) + if !ok { + return fallback, nil + } + + value, err := time.ParseDuration(rawValue) + if err != nil { + return 0, fmt.Errorf("load gateway config: parse %s: %w", envVar, err) + } + + return value, nil +} + +func loadBoolEnvWithDefault(envVar string, fallback bool) (bool, error) { + rawValue, ok := os.LookupEnv(envVar) + if !ok { + return fallback, nil + } + + value, err := strconv.ParseBool(rawValue) + if err != nil { + return false, fmt.Errorf("load gateway config: parse %s: %w", envVar, err) + } + + return value, nil +} + +func validatePublicRoutePolicyConfig(cfg PublicRoutePolicyConfig, maxBodyEnvVar string, requestsEnvVar string, windowEnvVar string, burstEnvVar string) error { + if cfg.MaxBodyBytes < 0 { + return fmt.Errorf("load gateway config: %s must not be negative", maxBodyEnvVar) + } + + return validateRateLimitConfig(cfg.RateLimit, requestsEnvVar, windowEnvVar, burstEnvVar) +} + +func validatePublicAuthIdentityPolicyConfig(cfg PublicAuthIdentityPolicyConfig, requestsEnvVar string, windowEnvVar string, burstEnvVar string) error { + return validateRateLimitConfig(cfg.RateLimit, requestsEnvVar, windowEnvVar, burstEnvVar) +} + +func validateRateLimitConfig(cfg RateLimitConfig, requestsEnvVar string, windowEnvVar string, burstEnvVar string) error { + if cfg.Requests <= 0 { + return fmt.Errorf("load gateway config: %s must be positive", requestsEnvVar) + } + if cfg.Window <= 0 { + return fmt.Errorf("load gateway config: %s must be positive", windowEnvVar) + } + if cfg.Burst <= 0 { + return fmt.Errorf("load gateway config: %s must be positive", burstEnvVar) + } + + return nil +} diff --git a/gateway/internal/config/config_test.go b/gateway/internal/config/config_test.go new file mode 100644 index 0000000..caf74a1 --- /dev/null +++ b/gateway/internal/config/config_test.go @@ -0,0 +1,1276 @@ +package config + +import ( + "crypto/ed25519" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadFromEnv(t *testing.T) { + customResponseSignerPrivateKeyPEMPath := new(string) + *customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customShutdownTimeout := new(string) + *customShutdownTimeout = "17s" + + customPublicHTTPAddr := new(string) + *customPublicHTTPAddr = "127.0.0.1:9090" + + customAuthenticatedGRPCAddr := new(string) + *customAuthenticatedGRPCAddr = "127.0.0.1:9191" + + customAuthenticatedGRPCFreshnessWindow := new(string) + *customAuthenticatedGRPCFreshnessWindow = "90s" + + customSessionCacheRedisAddr := new(string) + *customSessionCacheRedisAddr = "127.0.0.1:6379" + + customSessionEventsRedisStream := new(string) + *customSessionEventsRedisStream = "gateway:session_events" + + customClientEventsRedisStream := new(string) + *customClientEventsRedisStream = "gateway:client_events" + + emptyPublicHTTPAddr := new(string) + *emptyPublicHTTPAddr = "" + + whitespacePublicHTTPAddr := new(string) + *whitespacePublicHTTPAddr = " " + + emptyAuthenticatedGRPCAddr := new(string) + *emptyAuthenticatedGRPCAddr = "" + + whitespaceAuthenticatedGRPCAddr := new(string) + *whitespaceAuthenticatedGRPCAddr = " " + + emptySessionCacheRedisAddr := new(string) + *emptySessionCacheRedisAddr = "" + + whitespaceSessionCacheRedisAddr := new(string) + *whitespaceSessionCacheRedisAddr = " " + + zeroShutdownTimeout := new(string) + *zeroShutdownTimeout = "0s" + + negativeShutdownTimeout := new(string) + *negativeShutdownTimeout = "-1s" + + invalidShutdownTimeout := new(string) + *invalidShutdownTimeout = "later" + + zeroAuthenticatedGRPCFreshnessWindow := new(string) + *zeroAuthenticatedGRPCFreshnessWindow = "0s" + + invalidAuthenticatedGRPCFreshnessWindow := new(string) + *invalidAuthenticatedGRPCFreshnessWindow = "later" + + tests := []struct { + name string + shutdownTimeout *string + publicHTTPAddr *string + authenticatedGRPCAddr *string + authenticatedGRPCFreshnessWindow *string + sessionCacheRedisAddr *string + responseSignerPrivateKeyPEMPath *string + want Config + wantErr string + }{ + { + name: "required redis address with default optional values", + sessionCacheRedisAddr: customSessionCacheRedisAddr, + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + want: Config{ + ShutdownTimeout: 5 * time.Second, + Logging: DefaultLoggingConfig(), + PublicHTTP: DefaultPublicHTTPConfig(), + AdminHTTP: DefaultAdminHTTPConfig(), + AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(), + SessionCacheRedis: SessionCacheRedisConfig{ + Addr: "127.0.0.1:6379", + DB: defaultSessionCacheRedisDB, + KeyPrefix: defaultSessionCacheRedisKeyPrefix, + LookupTimeout: defaultSessionCacheRedisLookupTimeout, + }, + ReplayRedis: DefaultReplayRedisConfig(), + SessionEventsRedis: SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout, + }, + ClientEventsRedis: ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout, + }, + ResponseSigner: ResponseSignerConfig{ + PrivateKeyPEMPath: *customResponseSignerPrivateKeyPEMPath, + }, + }, + }, + { + name: "custom shutdown timeout", + shutdownTimeout: customShutdownTimeout, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + want: Config{ + ShutdownTimeout: 17 * time.Second, + Logging: DefaultLoggingConfig(), + PublicHTTP: DefaultPublicHTTPConfig(), + AdminHTTP: DefaultAdminHTTPConfig(), + AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(), + SessionCacheRedis: SessionCacheRedisConfig{ + Addr: "127.0.0.1:6379", + DB: defaultSessionCacheRedisDB, + KeyPrefix: defaultSessionCacheRedisKeyPrefix, + LookupTimeout: defaultSessionCacheRedisLookupTimeout, + }, + ReplayRedis: DefaultReplayRedisConfig(), + SessionEventsRedis: SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout, + }, + ClientEventsRedis: ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout, + }, + ResponseSigner: ResponseSignerConfig{ + PrivateKeyPEMPath: *customResponseSignerPrivateKeyPEMPath, + }, + }, + }, + { + name: "custom public http address", + publicHTTPAddr: customPublicHTTPAddr, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + want: Config{ + ShutdownTimeout: 5 * time.Second, + Logging: DefaultLoggingConfig(), + PublicHTTP: func() PublicHTTPConfig { + cfg := DefaultPublicHTTPConfig() + cfg.Addr = "127.0.0.1:9090" + return cfg + }(), + AdminHTTP: DefaultAdminHTTPConfig(), + AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(), + SessionCacheRedis: SessionCacheRedisConfig{ + Addr: "127.0.0.1:6379", + DB: defaultSessionCacheRedisDB, + KeyPrefix: defaultSessionCacheRedisKeyPrefix, + LookupTimeout: defaultSessionCacheRedisLookupTimeout, + }, + ReplayRedis: DefaultReplayRedisConfig(), + SessionEventsRedis: SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout, + }, + ClientEventsRedis: ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout, + }, + ResponseSigner: ResponseSignerConfig{ + PrivateKeyPEMPath: *customResponseSignerPrivateKeyPEMPath, + }, + }, + }, + { + name: "custom authenticated grpc address", + authenticatedGRPCAddr: customAuthenticatedGRPCAddr, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + want: Config{ + ShutdownTimeout: 5 * time.Second, + Logging: DefaultLoggingConfig(), + PublicHTTP: DefaultPublicHTTPConfig(), + AdminHTTP: DefaultAdminHTTPConfig(), + AuthenticatedGRPC: func() AuthenticatedGRPCConfig { + cfg := DefaultAuthenticatedGRPCConfig() + cfg.Addr = "127.0.0.1:9191" + return cfg + }(), + SessionCacheRedis: SessionCacheRedisConfig{ + Addr: "127.0.0.1:6379", + DB: defaultSessionCacheRedisDB, + KeyPrefix: defaultSessionCacheRedisKeyPrefix, + LookupTimeout: defaultSessionCacheRedisLookupTimeout, + }, + ReplayRedis: DefaultReplayRedisConfig(), + SessionEventsRedis: SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout, + }, + ClientEventsRedis: ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout, + }, + ResponseSigner: ResponseSignerConfig{ + PrivateKeyPEMPath: *customResponseSignerPrivateKeyPEMPath, + }, + }, + }, + { + name: "custom authenticated grpc freshness window", + authenticatedGRPCFreshnessWindow: customAuthenticatedGRPCFreshnessWindow, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + want: Config{ + ShutdownTimeout: 5 * time.Second, + Logging: DefaultLoggingConfig(), + PublicHTTP: DefaultPublicHTTPConfig(), + AdminHTTP: DefaultAdminHTTPConfig(), + AuthenticatedGRPC: func() AuthenticatedGRPCConfig { + cfg := DefaultAuthenticatedGRPCConfig() + cfg.FreshnessWindow = 90 * time.Second + return cfg + }(), + SessionCacheRedis: SessionCacheRedisConfig{ + Addr: "127.0.0.1:6379", + DB: defaultSessionCacheRedisDB, + KeyPrefix: defaultSessionCacheRedisKeyPrefix, + LookupTimeout: defaultSessionCacheRedisLookupTimeout, + }, + ReplayRedis: DefaultReplayRedisConfig(), + SessionEventsRedis: SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout, + }, + ClientEventsRedis: ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout, + }, + ResponseSigner: ResponseSignerConfig{ + PrivateKeyPEMPath: *customResponseSignerPrivateKeyPEMPath, + }, + }, + }, + { + name: "zero shutdown timeout", + shutdownTimeout: zeroShutdownTimeout, + wantErr: "must be positive", + }, + { + name: "negative shutdown timeout", + shutdownTimeout: negativeShutdownTimeout, + wantErr: "must be positive", + }, + { + name: "invalid shutdown timeout", + shutdownTimeout: invalidShutdownTimeout, + wantErr: "parse GATEWAY_SHUTDOWN_TIMEOUT", + }, + { + name: "empty public http address", + publicHTTPAddr: emptyPublicHTTPAddr, + wantErr: "GATEWAY_PUBLIC_HTTP_ADDR must not be empty", + }, + { + name: "whitespace public http address", + publicHTTPAddr: whitespacePublicHTTPAddr, + wantErr: "GATEWAY_PUBLIC_HTTP_ADDR must not be empty", + }, + { + name: "empty authenticated grpc address", + authenticatedGRPCAddr: emptyAuthenticatedGRPCAddr, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + wantErr: "GATEWAY_AUTHENTICATED_GRPC_ADDR must not be empty", + }, + { + name: "whitespace authenticated grpc address", + authenticatedGRPCAddr: whitespaceAuthenticatedGRPCAddr, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + wantErr: "GATEWAY_AUTHENTICATED_GRPC_ADDR must not be empty", + }, + { + name: "zero authenticated grpc freshness window", + authenticatedGRPCFreshnessWindow: zeroAuthenticatedGRPCFreshnessWindow, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + wantErr: authenticatedGRPCFreshnessWindowEnvVar + " must be positive", + }, + { + name: "invalid authenticated grpc freshness window", + authenticatedGRPCFreshnessWindow: invalidAuthenticatedGRPCFreshnessWindow, + sessionCacheRedisAddr: customSessionCacheRedisAddr, + wantErr: "parse " + authenticatedGRPCFreshnessWindowEnvVar, + }, + { + name: "missing session cache redis address", + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + wantErr: "GATEWAY_SESSION_CACHE_REDIS_ADDR must not be empty", + }, + { + name: "empty session cache redis address", + sessionCacheRedisAddr: emptySessionCacheRedisAddr, + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + wantErr: "GATEWAY_SESSION_CACHE_REDIS_ADDR must not be empty", + }, + { + name: "whitespace session cache redis address", + sessionCacheRedisAddr: whitespaceSessionCacheRedisAddr, + responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath, + wantErr: "GATEWAY_SESSION_CACHE_REDIS_ADDR must not be empty", + }, + { + name: "missing response signer private key path", + sessionCacheRedisAddr: customSessionCacheRedisAddr, + wantErr: responseSignerPrivateKeyPEMPathEnvVar + " must not be empty", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + restoreEnvs(t, + shutdownTimeoutEnvVar, + publicHTTPAddrEnvVar, + authenticatedGRPCAddrEnvVar, + authenticatedGRPCFreshnessWindowEnvVar, + sessionCacheRedisAddrEnvVar, + sessionEventsRedisStreamEnvVar, + clientEventsRedisStreamEnvVar, + responseSignerPrivateKeyPEMPathEnvVar, + ) + + setEnvValue(t, shutdownTimeoutEnvVar, tt.shutdownTimeout) + setEnvValue(t, publicHTTPAddrEnvVar, tt.publicHTTPAddr) + setEnvValue(t, authenticatedGRPCAddrEnvVar, tt.authenticatedGRPCAddr) + setEnvValue(t, authenticatedGRPCFreshnessWindowEnvVar, tt.authenticatedGRPCFreshnessWindow) + setEnvValue(t, sessionCacheRedisAddrEnvVar, tt.sessionCacheRedisAddr) + setEnvValue(t, sessionEventsRedisStreamEnvVar, customSessionEventsRedisStream) + setEnvValue(t, clientEventsRedisStreamEnvVar, customClientEventsRedisStream) + setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, tt.responseSignerPrivateKeyPEMPath) + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, cfg) + }) + } +} + +func TestLoadFromEnvOperationalSettings(t *testing.T) { + t.Parallel() + + customSessionCacheRedisAddr := new(string) + *customSessionCacheRedisAddr = "127.0.0.1:6379" + + customSessionEventsRedisStream := new(string) + *customSessionEventsRedisStream = "gateway:session_events" + + customClientEventsRedisStream := new(string) + *customClientEventsRedisStream = "gateway:client_events" + + customResponseSignerPrivateKeyPEMPath := new(string) + *customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customLogLevel := new(string) + *customLogLevel = "debug" + + customAdminAddr := new(string) + *customAdminAddr = "127.0.0.1:8081" + + customAdminReadTimeout := new(string) + *customAdminReadTimeout = "4s" + + customPublicReadTimeout := new(string) + *customPublicReadTimeout = "12s" + + customPublicAuthUpstreamTimeout := new(string) + *customPublicAuthUpstreamTimeout = "1500ms" + + customGRPCConnectionTimeout := new(string) + *customGRPCConnectionTimeout = "7s" + + customGRPCDownstreamTimeout := new(string) + *customGRPCDownstreamTimeout = "9s" + + invalidLogLevel := new(string) + *invalidLogLevel = "verbose" + + tests := []struct { + name string + envs map[string]*string + assert func(t *testing.T, cfg Config) + wantErr string + }{ + { + name: "custom operational settings", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + sessionEventsRedisStreamEnvVar: customSessionEventsRedisStream, + clientEventsRedisStreamEnvVar: customClientEventsRedisStream, + responseSignerPrivateKeyPEMPathEnvVar: customResponseSignerPrivateKeyPEMPath, + logLevelEnvVar: customLogLevel, + adminHTTPAddrEnvVar: customAdminAddr, + adminHTTPReadTimeoutEnvVar: customAdminReadTimeout, + publicHTTPReadTimeoutEnvVar: customPublicReadTimeout, + publicAuthUpstreamTimeoutEnvVar: customPublicAuthUpstreamTimeout, + authenticatedGRPCConnectionTimeoutEnvVar: customGRPCConnectionTimeout, + authenticatedGRPCDownstreamTimeoutEnvVar: customGRPCDownstreamTimeout, + }, + assert: func(t *testing.T, cfg Config) { + t.Helper() + assert.Equal(t, "debug", cfg.Logging.Level) + assert.Equal(t, "127.0.0.1:8081", cfg.AdminHTTP.Addr) + assert.Equal(t, 4*time.Second, cfg.AdminHTTP.ReadTimeout) + assert.Equal(t, 12*time.Second, cfg.PublicHTTP.ReadTimeout) + assert.Equal(t, 1500*time.Millisecond, cfg.PublicHTTP.AuthUpstreamTimeout) + assert.Equal(t, 7*time.Second, cfg.AuthenticatedGRPC.ConnectionTimeout) + assert.Equal(t, 9*time.Second, cfg.AuthenticatedGRPC.DownstreamTimeout) + }, + }, + { + name: "invalid log level", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + sessionEventsRedisStreamEnvVar: customSessionEventsRedisStream, + clientEventsRedisStreamEnvVar: customClientEventsRedisStream, + responseSignerPrivateKeyPEMPathEnvVar: customResponseSignerPrivateKeyPEMPath, + logLevelEnvVar: invalidLogLevel, + }, + wantErr: logLevelEnvVar + " must be one of", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + restoreEnvs(t, append( + append( + append( + append(operationalEnvVars(), sessionCacheRedisEnvVars()...), + sessionEventsRedisEnvVars()..., + ), + clientEventsRedisEnvVars()..., + ), + responseSignerPrivateKeyPEMPathEnvVar, + )...) + + for envVar, value := range tt.envs { + setEnvValue(t, envVar, value) + } + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + tt.assert(t, cfg) + }) + } +} + +func TestLoadFromEnvAuthenticatedGRPCAntiAbuse(t *testing.T) { + customSessionCacheRedisAddr := new(string) + *customSessionCacheRedisAddr = "127.0.0.1:6379" + + customSessionEventsRedisStream := new(string) + *customSessionEventsRedisStream = "gateway:session_events" + + customClientEventsRedisStream := new(string) + *customClientEventsRedisStream = "gateway:client_events" + + customResponseSignerPrivateKeyPEMPath := new(string) + *customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customIPRequests := new(string) + *customIPRequests = "240" + + customIPWindow := new(string) + *customIPWindow = "2m" + + customIPBurst := new(string) + *customIPBurst = "60" + + customSessionRequests := new(string) + *customSessionRequests = "120" + + customSessionWindow := new(string) + *customSessionWindow = "90s" + + customSessionBurst := new(string) + *customSessionBurst = "30" + + customUserRequests := new(string) + *customUserRequests = "180" + + customUserWindow := new(string) + *customUserWindow = "3m" + + customUserBurst := new(string) + *customUserBurst = "45" + + customMessageClassRequests := new(string) + *customMessageClassRequests = "75" + + customMessageClassWindow := new(string) + *customMessageClassWindow = "45s" + + customMessageClassBurst := new(string) + *customMessageClassBurst = "15" + + zeroIPRequests := new(string) + *zeroIPRequests = "0" + + tests := []struct { + name string + ipRequests *string + ipWindow *string + ipBurst *string + sessionRequests *string + sessionWindow *string + sessionBurst *string + userRequests *string + userWindow *string + userBurst *string + messageClassRequests *string + messageClassWindow *string + messageClassBurst *string + want AuthenticatedGRPCAntiAbuseConfig + wantErr string + }{ + { + name: "custom authenticated grpc anti abuse config", + ipRequests: customIPRequests, + ipWindow: customIPWindow, + ipBurst: customIPBurst, + sessionRequests: customSessionRequests, + sessionWindow: customSessionWindow, + sessionBurst: customSessionBurst, + userRequests: customUserRequests, + userWindow: customUserWindow, + userBurst: customUserBurst, + messageClassRequests: customMessageClassRequests, + messageClassWindow: customMessageClassWindow, + messageClassBurst: customMessageClassBurst, + want: AuthenticatedGRPCAntiAbuseConfig{ + IP: AuthenticatedRateLimitConfig{ + Requests: 240, + Window: 2 * time.Minute, + Burst: 60, + }, + Session: AuthenticatedRateLimitConfig{ + Requests: 120, + Window: 90 * time.Second, + Burst: 30, + }, + User: AuthenticatedRateLimitConfig{ + Requests: 180, + Window: 3 * time.Minute, + Burst: 45, + }, + MessageClass: AuthenticatedRateLimitConfig{ + Requests: 75, + Window: 45 * time.Second, + Burst: 15, + }, + }, + }, + { + name: "zero authenticated grpc ip requests", + ipRequests: zeroIPRequests, + wantErr: authenticatedGRPCIPRateLimitRequestsEnvVar + " must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + restoreEnvs( + t, + sessionCacheRedisAddrEnvVar, + authenticatedGRPCIPRateLimitRequestsEnvVar, + authenticatedGRPCIPRateLimitWindowEnvVar, + authenticatedGRPCIPRateLimitBurstEnvVar, + authenticatedGRPCSessionRateLimitRequestsEnvVar, + authenticatedGRPCSessionRateLimitWindowEnvVar, + authenticatedGRPCSessionRateLimitBurstEnvVar, + authenticatedGRPCUserRateLimitRequestsEnvVar, + authenticatedGRPCUserRateLimitWindowEnvVar, + authenticatedGRPCUserRateLimitBurstEnvVar, + authenticatedGRPCMessageClassRateLimitRequestsEnvVar, + authenticatedGRPCMessageClassRateLimitWindowEnvVar, + authenticatedGRPCMessageClassRateLimitBurstEnvVar, + sessionEventsRedisStreamEnvVar, + clientEventsRedisStreamEnvVar, + responseSignerPrivateKeyPEMPathEnvVar, + ) + + setEnvValue(t, sessionCacheRedisAddrEnvVar, customSessionCacheRedisAddr) + setEnvValue(t, sessionEventsRedisStreamEnvVar, customSessionEventsRedisStream) + setEnvValue(t, clientEventsRedisStreamEnvVar, customClientEventsRedisStream) + setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, customResponseSignerPrivateKeyPEMPath) + setEnvValue(t, authenticatedGRPCIPRateLimitRequestsEnvVar, tt.ipRequests) + setEnvValue(t, authenticatedGRPCIPRateLimitWindowEnvVar, tt.ipWindow) + setEnvValue(t, authenticatedGRPCIPRateLimitBurstEnvVar, tt.ipBurst) + setEnvValue(t, authenticatedGRPCSessionRateLimitRequestsEnvVar, tt.sessionRequests) + setEnvValue(t, authenticatedGRPCSessionRateLimitWindowEnvVar, tt.sessionWindow) + setEnvValue(t, authenticatedGRPCSessionRateLimitBurstEnvVar, tt.sessionBurst) + setEnvValue(t, authenticatedGRPCUserRateLimitRequestsEnvVar, tt.userRequests) + setEnvValue(t, authenticatedGRPCUserRateLimitWindowEnvVar, tt.userWindow) + setEnvValue(t, authenticatedGRPCUserRateLimitBurstEnvVar, tt.userBurst) + setEnvValue(t, authenticatedGRPCMessageClassRateLimitRequestsEnvVar, tt.messageClassRequests) + setEnvValue(t, authenticatedGRPCMessageClassRateLimitWindowEnvVar, tt.messageClassWindow) + setEnvValue(t, authenticatedGRPCMessageClassRateLimitBurstEnvVar, tt.messageClassBurst) + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, cfg.AuthenticatedGRPC.AntiAbuse) + }) + } +} + +func TestLoadFromEnvSessionCacheRedis(t *testing.T) { + customResponseSignerPrivateKeyPEMPath := new(string) + *customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customSessionEventsRedisStream := new(string) + *customSessionEventsRedisStream = "gateway:session_events" + + customClientEventsRedisStream := new(string) + *customClientEventsRedisStream = "gateway:client_events" + + customRedisAddr := new(string) + *customRedisAddr = "127.0.0.1:6380" + + customRedisUsername := new(string) + *customRedisUsername = "gateway" + + customRedisPassword := new(string) + *customRedisPassword = "secret" + + customRedisDB := new(string) + *customRedisDB = "7" + + customRedisKeyPrefix := new(string) + *customRedisKeyPrefix = "edge:session:" + + customRedisLookupTimeout := new(string) + *customRedisLookupTimeout = "750ms" + + customRedisTLSEnabled := new(string) + *customRedisTLSEnabled = "true" + + negativeRedisDB := new(string) + *negativeRedisDB = "-1" + + invalidRedisLookupTimeout := new(string) + *invalidRedisLookupTimeout = "later" + + invalidRedisTLSEnabled := new(string) + *invalidRedisTLSEnabled = "maybe" + + tests := []struct { + name string + envs map[string]*string + want SessionCacheRedisConfig + wantErr string + }{ + { + name: "custom redis config", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customRedisAddr, + sessionCacheRedisUsernameEnvVar: customRedisUsername, + sessionCacheRedisPasswordEnvVar: customRedisPassword, + sessionCacheRedisDBEnvVar: customRedisDB, + sessionCacheRedisKeyPrefixEnvVar: customRedisKeyPrefix, + sessionCacheRedisLookupTimeoutEnvVar: customRedisLookupTimeout, + sessionCacheRedisTLSEnabledEnvVar: customRedisTLSEnabled, + }, + want: SessionCacheRedisConfig{ + Addr: "127.0.0.1:6380", + Username: "gateway", + Password: "secret", + DB: 7, + KeyPrefix: "edge:session:", + LookupTimeout: 750 * time.Millisecond, + TLSEnabled: true, + }, + }, + { + name: "negative redis db", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customRedisAddr, + sessionCacheRedisDBEnvVar: negativeRedisDB, + }, + wantErr: sessionCacheRedisDBEnvVar + " must not be negative", + }, + { + name: "invalid redis lookup timeout", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customRedisAddr, + sessionCacheRedisLookupTimeoutEnvVar: invalidRedisLookupTimeout, + }, + wantErr: "parse " + sessionCacheRedisLookupTimeoutEnvVar, + }, + { + name: "invalid redis tls flag", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customRedisAddr, + sessionCacheRedisTLSEnabledEnvVar: invalidRedisTLSEnabled, + }, + wantErr: "parse " + sessionCacheRedisTLSEnabledEnvVar, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + restoreEnvs(t, append(append(append(sessionCacheRedisEnvVars(), sessionEventsRedisEnvVars()...), clientEventsRedisEnvVars()...), responseSignerPrivateKeyPEMPathEnvVar)...) + setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, customResponseSignerPrivateKeyPEMPath) + setEnvValue(t, sessionEventsRedisStreamEnvVar, customSessionEventsRedisStream) + setEnvValue(t, clientEventsRedisStreamEnvVar, customClientEventsRedisStream) + + for envVar, value := range tt.envs { + setEnvValue(t, envVar, value) + } + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, cfg.SessionCacheRedis) + }) + } +} + +func TestLoadFromEnvReplayRedis(t *testing.T) { + customSessionCacheRedisAddr := new(string) + *customSessionCacheRedisAddr = "127.0.0.1:6380" + + customSessionEventsRedisStream := new(string) + *customSessionEventsRedisStream = "gateway:session_events" + + customClientEventsRedisStream := new(string) + *customClientEventsRedisStream = "gateway:client_events" + + customResponseSignerPrivateKeyPEMPath := new(string) + *customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customReplayRedisKeyPrefix := new(string) + *customReplayRedisKeyPrefix = "edge:replay:" + + customReplayRedisReserveTimeout := new(string) + *customReplayRedisReserveTimeout = "500ms" + + emptyReplayRedisKeyPrefix := new(string) + *emptyReplayRedisKeyPrefix = "" + + invalidReplayRedisReserveTimeout := new(string) + *invalidReplayRedisReserveTimeout = "later" + + tests := []struct { + name string + envs map[string]*string + want ReplayRedisConfig + wantErr string + }{ + { + name: "custom replay redis config", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + replayRedisKeyPrefixEnvVar: customReplayRedisKeyPrefix, + replayRedisReserveTimeoutEnvVar: customReplayRedisReserveTimeout, + }, + want: ReplayRedisConfig{ + KeyPrefix: "edge:replay:", + ReserveTimeout: 500 * time.Millisecond, + }, + }, + { + name: "empty replay redis key prefix", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + replayRedisKeyPrefixEnvVar: emptyReplayRedisKeyPrefix, + }, + wantErr: replayRedisKeyPrefixEnvVar + " must not be empty", + }, + { + name: "invalid replay redis reserve timeout", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + replayRedisReserveTimeoutEnvVar: invalidReplayRedisReserveTimeout, + }, + wantErr: "parse " + replayRedisReserveTimeoutEnvVar, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + restoreEnvs(t, append(append(append(append(sessionCacheRedisEnvVars(), replayRedisEnvVars()...), sessionEventsRedisEnvVars()...), clientEventsRedisEnvVars()...), responseSignerPrivateKeyPEMPathEnvVar)...) + setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, customResponseSignerPrivateKeyPEMPath) + setEnvValue(t, sessionEventsRedisStreamEnvVar, customSessionEventsRedisStream) + setEnvValue(t, clientEventsRedisStreamEnvVar, customClientEventsRedisStream) + + for envVar, value := range tt.envs { + setEnvValue(t, envVar, value) + } + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, cfg.ReplayRedis) + }) + } +} + +func TestLoadFromEnvSessionEventsRedis(t *testing.T) { + customSessionCacheRedisAddr := new(string) + *customSessionCacheRedisAddr = "127.0.0.1:6380" + + customResponseSignerPrivateKeyPEMPath := new(string) + *customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customClientEventsRedisStream := new(string) + *customClientEventsRedisStream = "gateway:client_events" + + customStream := new(string) + *customStream = "edge:session_events" + + customReadBlockTimeout := new(string) + *customReadBlockTimeout = "1500ms" + + emptyStream := new(string) + *emptyStream = "" + + invalidReadBlockTimeout := new(string) + *invalidReadBlockTimeout = "later" + + tests := []struct { + name string + envs map[string]*string + want SessionEventsRedisConfig + wantErr string + }{ + { + name: "custom session events redis config", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + sessionEventsRedisStreamEnvVar: customStream, + sessionEventsRedisReadBlockTimeoutEnvVar: customReadBlockTimeout, + }, + want: SessionEventsRedisConfig{ + Stream: "edge:session_events", + ReadBlockTimeout: 1500 * time.Millisecond, + }, + }, + { + name: "missing session events redis stream", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + }, + wantErr: sessionEventsRedisStreamEnvVar + " must not be empty", + }, + { + name: "empty session events redis stream", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + sessionEventsRedisStreamEnvVar: emptyStream, + }, + wantErr: sessionEventsRedisStreamEnvVar + " must not be empty", + }, + { + name: "invalid session events read block timeout", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + sessionEventsRedisStreamEnvVar: customStream, + sessionEventsRedisReadBlockTimeoutEnvVar: invalidReadBlockTimeout, + }, + wantErr: "parse " + sessionEventsRedisReadBlockTimeoutEnvVar, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + restoreEnvs(t, append(append(append(sessionCacheRedisEnvVars(), sessionEventsRedisEnvVars()...), clientEventsRedisEnvVars()...), responseSignerPrivateKeyPEMPathEnvVar)...) + setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, customResponseSignerPrivateKeyPEMPath) + setEnvValue(t, clientEventsRedisStreamEnvVar, customClientEventsRedisStream) + + for envVar, value := range tt.envs { + setEnvValue(t, envVar, value) + } + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, cfg.SessionEventsRedis) + }) + } +} + +func TestLoadFromEnvClientEventsRedis(t *testing.T) { + customSessionCacheRedisAddr := new(string) + *customSessionCacheRedisAddr = "127.0.0.1:6380" + + customResponseSignerPrivateKeyPEMPath := new(string) + *customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customSessionEventsRedisStream := new(string) + *customSessionEventsRedisStream = "gateway:session_events" + + customStream := new(string) + *customStream = "edge:client_events" + + customReadBlockTimeout := new(string) + *customReadBlockTimeout = "1500ms" + + emptyStream := new(string) + *emptyStream = "" + + invalidReadBlockTimeout := new(string) + *invalidReadBlockTimeout = "later" + + tests := []struct { + name string + envs map[string]*string + want ClientEventsRedisConfig + wantErr string + }{ + { + name: "custom client events redis config", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + clientEventsRedisStreamEnvVar: customStream, + clientEventsRedisReadBlockTimeoutEnvVar: customReadBlockTimeout, + }, + want: ClientEventsRedisConfig{ + Stream: "edge:client_events", + ReadBlockTimeout: 1500 * time.Millisecond, + }, + }, + { + name: "missing client events redis stream", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + }, + wantErr: clientEventsRedisStreamEnvVar + " must not be empty", + }, + { + name: "empty client events redis stream", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + clientEventsRedisStreamEnvVar: emptyStream, + }, + wantErr: clientEventsRedisStreamEnvVar + " must not be empty", + }, + { + name: "invalid client events read block timeout", + envs: map[string]*string{ + sessionCacheRedisAddrEnvVar: customSessionCacheRedisAddr, + clientEventsRedisStreamEnvVar: customStream, + clientEventsRedisReadBlockTimeoutEnvVar: invalidReadBlockTimeout, + }, + wantErr: "parse " + clientEventsRedisReadBlockTimeoutEnvVar, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + restoreEnvs(t, append(append(append(sessionCacheRedisEnvVars(), sessionEventsRedisEnvVars()...), clientEventsRedisEnvVars()...), responseSignerPrivateKeyPEMPathEnvVar)...) + setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, customResponseSignerPrivateKeyPEMPath) + setEnvValue(t, sessionEventsRedisStreamEnvVar, customSessionEventsRedisStream) + + for envVar, value := range tt.envs { + setEnvValue(t, envVar, value) + } + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, cfg.ClientEventsRedis) + }) + } +} + +func TestLoadFromEnvPublicHTTPAntiAbuse(t *testing.T) { + requiredSessionCacheRedisAddr := new(string) + *requiredSessionCacheRedisAddr = "127.0.0.1:6379" + + requiredSessionEventsRedisStream := new(string) + *requiredSessionEventsRedisStream = "gateway:session_events" + + requiredClientEventsRedisStream := new(string) + *requiredClientEventsRedisStream = "gateway:client_events" + + requiredResponseSignerPrivateKeyPEMPath := new(string) + *requiredResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t) + + customPublicAuthMaxBodyBytes := new(string) + *customPublicAuthMaxBodyBytes = "4096" + + customBrowserAssetRequests := new(string) + *customBrowserAssetRequests = "150" + + customBrowserAssetWindow := new(string) + *customBrowserAssetWindow = "2m" + + customConfirmBurst := new(string) + *customConfirmBurst = "3" + + negativePublicAuthMaxBodyBytes := new(string) + *negativePublicAuthMaxBodyBytes = "-1" + + zeroPublicMiscRequests := new(string) + *zeroPublicMiscRequests = "0" + + invalidSendIdentityWindow := new(string) + *invalidSendIdentityWindow = "later" + + tests := []struct { + name string + envs map[string]*string + want PublicHTTPAntiAbuseConfig + wantErr string + }{ + { + name: "custom anti abuse config", + envs: map[string]*string{ + publicAuthMaxBodyBytesEnvVar: customPublicAuthMaxBodyBytes, + browserAssetRateLimitRequestsEnvVar: customBrowserAssetRequests, + browserAssetRateLimitWindowEnvVar: customBrowserAssetWindow, + confirmEmailCodeIdentityRateLimitBurstEnvVar: customConfirmBurst, + }, + want: func() PublicHTTPAntiAbuseConfig { + cfg := DefaultPublicHTTPConfig().AntiAbuse + cfg.PublicAuth.MaxBodyBytes = 4096 + cfg.BrowserAsset.RateLimit.Requests = 150 + cfg.BrowserAsset.RateLimit.Window = 2 * time.Minute + cfg.ConfirmEmailCodeIdentity.RateLimit.Burst = 3 + return cfg + }(), + }, + { + name: "negative public auth max body bytes", + envs: map[string]*string{ + publicAuthMaxBodyBytesEnvVar: negativePublicAuthMaxBodyBytes, + }, + wantErr: publicAuthMaxBodyBytesEnvVar + " must not be negative", + }, + { + name: "zero public misc requests", + envs: map[string]*string{ + publicMiscRateLimitRequestsEnvVar: zeroPublicMiscRequests, + }, + wantErr: publicMiscRateLimitRequestsEnvVar + " must be positive", + }, + { + name: "invalid send identity window", + envs: map[string]*string{ + sendEmailCodeIdentityRateLimitWindowEnvVar: invalidSendIdentityWindow, + }, + wantErr: "parse " + sendEmailCodeIdentityRateLimitWindowEnvVar, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + restoreEnvs(t, append(append(append(append(publicAntiAbuseEnvVars(), sessionCacheRedisAddrEnvVar), sessionEventsRedisEnvVars()...), clientEventsRedisEnvVars()...), responseSignerPrivateKeyPEMPathEnvVar)...) + setEnvValue(t, sessionCacheRedisAddrEnvVar, requiredSessionCacheRedisAddr) + setEnvValue(t, sessionEventsRedisStreamEnvVar, requiredSessionEventsRedisStream) + setEnvValue(t, clientEventsRedisStreamEnvVar, requiredClientEventsRedisStream) + setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, requiredResponseSignerPrivateKeyPEMPath) + + for envVar, value := range tt.envs { + setEnvValue(t, envVar, value) + } + + cfg, err := LoadFromEnv() + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, cfg.PublicHTTP.AntiAbuse) + }) + } +} + +// restoreEnv resets envVar after the test mutates process-wide environment +// state. +func restoreEnv(t *testing.T, envVar string) { + t.Helper() + + previousValue, hadPreviousValue := os.LookupEnv(envVar) + t.Cleanup(func() { + var err error + if hadPreviousValue { + err = os.Setenv(envVar, previousValue) + } else { + err = os.Unsetenv(envVar) + } + require.NoError(t, err) + }) +} + +// setEnvValue updates envVar to value or unsets it when value is nil. +func setEnvValue(t *testing.T, envVar string, value *string) { + t.Helper() + + var err error + if value == nil { + err = os.Unsetenv(envVar) + } else { + err = os.Setenv(envVar, *value) + } + require.NoError(t, err) +} + +func restoreEnvs(t *testing.T, envVars ...string) { + t.Helper() + + for _, envVar := range envVars { + restoreEnv(t, envVar) + } +} + +func publicAntiAbuseEnvVars() []string { + return []string{ + publicAuthMaxBodyBytesEnvVar, + publicAuthRateLimitRequestsEnvVar, + publicAuthRateLimitWindowEnvVar, + publicAuthRateLimitBurstEnvVar, + browserBootstrapMaxBodyBytesEnvVar, + browserBootstrapRateLimitRequestsEnvVar, + browserBootstrapRateLimitWindowEnvVar, + browserBootstrapRateLimitBurstEnvVar, + browserAssetMaxBodyBytesEnvVar, + browserAssetRateLimitRequestsEnvVar, + browserAssetRateLimitWindowEnvVar, + browserAssetRateLimitBurstEnvVar, + publicMiscMaxBodyBytesEnvVar, + publicMiscRateLimitRequestsEnvVar, + publicMiscRateLimitWindowEnvVar, + publicMiscRateLimitBurstEnvVar, + sendEmailCodeIdentityRateLimitRequestsEnvVar, + sendEmailCodeIdentityRateLimitWindowEnvVar, + sendEmailCodeIdentityRateLimitBurstEnvVar, + confirmEmailCodeIdentityRateLimitRequestsEnvVar, + confirmEmailCodeIdentityRateLimitWindowEnvVar, + confirmEmailCodeIdentityRateLimitBurstEnvVar, + } +} + +func operationalEnvVars() []string { + return []string{ + logLevelEnvVar, + publicHTTPAddrEnvVar, + publicHTTPReadHeaderTimeoutEnvVar, + publicHTTPReadTimeoutEnvVar, + publicHTTPIdleTimeoutEnvVar, + publicAuthUpstreamTimeoutEnvVar, + adminHTTPAddrEnvVar, + adminHTTPReadHeaderTimeoutEnvVar, + adminHTTPReadTimeoutEnvVar, + adminHTTPIdleTimeoutEnvVar, + authenticatedGRPCAddrEnvVar, + authenticatedGRPCConnectionTimeoutEnvVar, + authenticatedGRPCDownstreamTimeoutEnvVar, + authenticatedGRPCFreshnessWindowEnvVar, + } +} + +func sessionCacheRedisEnvVars() []string { + return []string{ + sessionCacheRedisAddrEnvVar, + sessionCacheRedisUsernameEnvVar, + sessionCacheRedisPasswordEnvVar, + sessionCacheRedisDBEnvVar, + sessionCacheRedisKeyPrefixEnvVar, + sessionCacheRedisLookupTimeoutEnvVar, + sessionCacheRedisTLSEnabledEnvVar, + } +} + +func replayRedisEnvVars() []string { + return []string{ + replayRedisKeyPrefixEnvVar, + replayRedisReserveTimeoutEnvVar, + } +} + +func sessionEventsRedisEnvVars() []string { + return []string{ + sessionEventsRedisStreamEnvVar, + sessionEventsRedisReadBlockTimeoutEnvVar, + } +} + +func clientEventsRedisEnvVars() []string { + return []string{ + clientEventsRedisStreamEnvVar, + clientEventsRedisReadBlockTimeoutEnvVar, + } +} + +func writeTestResponseSignerPEMFile(t *testing.T) string { + t.Helper() + + seed := sha256.Sum256([]byte("gateway-config-test-response-signer")) + privateKey := ed25519.NewKeyFromSeed(seed[:]) + + encodedPrivateKey, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + path := filepath.Join(t.TempDir(), "response-signer.pem") + err = os.WriteFile(path, pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: encodedPrivateKey, + }), 0o600) + require.NoError(t, err) + + return path +} diff --git a/gateway/internal/downstream/downstream.go b/gateway/internal/downstream/downstream.go new file mode 100644 index 0000000..b9da9eb --- /dev/null +++ b/gateway/internal/downstream/downstream.go @@ -0,0 +1,108 @@ +// Package downstream defines the verified internal command contract used by the +// gateway after the authenticated edge pipeline succeeds. +package downstream + +import ( + "context" + "errors" +) + +var ( + // ErrRouteNotFound reports that Router does not have an exact-match handler + // for the supplied authenticated message type. + ErrRouteNotFound = errors.New("downstream route not found") + + // ErrDownstreamUnavailable reports that the resolved downstream dependency is + // temporarily unavailable. + ErrDownstreamUnavailable = errors.New("downstream service is unavailable") +) + +// AuthenticatedCommand is the minimum verified unary command context the +// gateway may forward to downstream business services. +type AuthenticatedCommand struct { + // ProtocolVersion is the authenticated transport protocol version accepted + // by the gateway. + ProtocolVersion string + + // UserID is the authenticated user identity resolved from SessionCache. + UserID string + + // DeviceSessionID is the authenticated device session that originated the + // command. + DeviceSessionID string + + // MessageType is the stable exact-match downstream routing key. + MessageType string + + // TimestampMS is the client-supplied request timestamp that already passed + // freshness verification. + TimestampMS int64 + + // RequestID is the transport correlation and anti-replay identifier. + RequestID string + + // TraceID is the optional client-supplied correlation identifier. + TraceID string + + // PayloadBytes carries the verified opaque business payload bytes. + PayloadBytes []byte +} + +// UnaryResult is the minimum downstream unary result the gateway needs in +// order to build a signed authenticated client response. +type UnaryResult struct { + // ResultCode is the stable opaque downstream result code returned to the + // client without business reinterpretation by the gateway. + ResultCode string + + // PayloadBytes carries the opaque downstream response payload bytes. + PayloadBytes []byte +} + +// Client executes a verified authenticated unary command against one concrete +// downstream service or adapter. +type Client interface { + // ExecuteCommand executes command and returns the downstream unary result. + ExecuteCommand(ctx context.Context, command AuthenticatedCommand) (UnaryResult, error) +} + +// Router resolves the downstream unary client for one exact authenticated +// message_type value. +type Router interface { + // Route returns the downstream client for messageType. Implementations must + // wrap ErrRouteNotFound when the route table does not contain messageType. + Route(messageType string) (Client, error) +} + +// StaticRouter resolves exact message_type literals from an immutable route +// map supplied at construction time. +type StaticRouter struct { + routes map[string]Client +} + +// NewStaticRouter constructs a StaticRouter with a defensive copy of routes. +func NewStaticRouter(routes map[string]Client) *StaticRouter { + clonedRoutes := make(map[string]Client, len(routes)) + for messageType, client := range routes { + if client == nil { + continue + } + clonedRoutes[messageType] = client + } + + return &StaticRouter{routes: clonedRoutes} +} + +// Route returns the exact-match client for messageType. +func (r *StaticRouter) Route(messageType string) (Client, error) { + if r == nil { + return nil, ErrRouteNotFound + } + + client, ok := r.routes[messageType] + if !ok || client == nil { + return nil, ErrRouteNotFound + } + + return client, nil +} diff --git a/gateway/internal/downstream/downstream_test.go b/gateway/internal/downstream/downstream_test.go new file mode 100644 index 0000000..fbda9d7 --- /dev/null +++ b/gateway/internal/downstream/downstream_test.go @@ -0,0 +1,39 @@ +package downstream + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStaticRouterRoutesExactMessageType(t *testing.T) { + t.Parallel() + + want := &stubClient{} + router := NewStaticRouter(map[string]Client{ + "fleet.move": want, + }) + + got, err := router.Route("fleet.move") + require.NoError(t, err) + assert.Same(t, want, got) +} + +func TestStaticRouterRejectsUnknownMessageType(t *testing.T) { + t.Parallel() + + router := NewStaticRouter(map[string]Client{ + "fleet.move": &stubClient{}, + }) + + _, err := router.Route("fleet.rename") + require.ErrorIs(t, err, ErrRouteNotFound) +} + +type stubClient struct{} + +func (*stubClient) ExecuteCommand(context.Context, AuthenticatedCommand) (UnaryResult, error) { + return UnaryResult{}, nil +} diff --git a/gateway/internal/events/client_subscriber.go b/gateway/internal/events/client_subscriber.go new file mode 100644 index 0000000..4ef4356 --- /dev/null +++ b/gateway/internal/events/client_subscriber.go @@ -0,0 +1,341 @@ +package events + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "strings" + "sync" + "time" + + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/push" + "galaxy/gateway/internal/telemetry" + + "github.com/redis/go-redis/v9" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" +) + +const clientEventReadCount int64 = 128 + +// ClientEventPublisher accepts decoded client-facing events from the internal +// event subscriber. +type ClientEventPublisher interface { + // Publish fans out event to the currently active push streams. + Publish(event push.Event) +} + +// RedisClientEventSubscriber consumes client-facing events from one Redis +// Stream and forwards them to the configured publisher. +type RedisClientEventSubscriber struct { + client *redis.Client + stream string + pingTimeout time.Duration + readBlockTimeout time.Duration + publisher ClientEventPublisher + logger *zap.Logger + metrics *telemetry.Runtime + + closeOnce sync.Once + startedOnce sync.Once + started chan struct{} +} + +// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that +// reuses the SessionCache Redis connection settings and forwards decoded +// client-facing events to publisher. +func NewRedisClientEventSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) { + return NewRedisClientEventSubscriberWithObservability(sessionCfg, eventsCfg, publisher, nil, nil) +} + +// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream +// subscriber that also records malformed or dropped internal events. +func NewRedisClientEventSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) { + if strings.TrimSpace(sessionCfg.Addr) == "" { + return nil, errors.New("new redis client event subscriber: redis addr must not be empty") + } + if sessionCfg.DB < 0 { + return nil, errors.New("new redis client event subscriber: redis db must not be negative") + } + if sessionCfg.LookupTimeout <= 0 { + return nil, errors.New("new redis client event subscriber: lookup timeout must be positive") + } + if strings.TrimSpace(eventsCfg.Stream) == "" { + return nil, errors.New("new redis client event subscriber: stream must not be empty") + } + if eventsCfg.ReadBlockTimeout <= 0 { + return nil, errors.New("new redis client event subscriber: read block timeout must be positive") + } + if publisher == nil { + return nil, errors.New("new redis client event subscriber: nil publisher") + } + + options := &redis.Options{ + Addr: sessionCfg.Addr, + Username: sessionCfg.Username, + Password: sessionCfg.Password, + DB: sessionCfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if sessionCfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + if logger == nil { + logger = zap.NewNop() + } + + return &RedisClientEventSubscriber{ + client: redis.NewClient(options), + stream: eventsCfg.Stream, + pingTimeout: sessionCfg.LookupTimeout, + readBlockTimeout: eventsCfg.ReadBlockTimeout, + publisher: publisher, + logger: logger.Named("client_event_subscriber"), + metrics: metrics, + started: make(chan struct{}), + }, nil +} + +// Ping verifies that the Redis backend used for client-facing event fan-out is +// reachable within the configured timeout budget. +func (s *RedisClientEventSubscriber) Ping(ctx context.Context) error { + if s == nil || s.client == nil { + return errors.New("ping redis client event subscriber: nil subscriber") + } + if ctx == nil { + return errors.New("ping redis client event subscriber: nil context") + } + + pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout) + defer cancel() + + if err := s.client.Ping(pingCtx).Err(); err != nil { + return fmt.Errorf("ping redis client event subscriber: %w", err) + } + + return nil +} + +// Run consumes client-facing events until ctx is canceled or Redis returns an +// unexpected error. +func (s *RedisClientEventSubscriber) Run(ctx context.Context) error { + if s == nil || s.client == nil { + return errors.New("run redis client event subscriber: nil subscriber") + } + if ctx == nil { + return errors.New("run redis client event subscriber: nil context") + } + if err := ctx.Err(); err != nil { + return err + } + + lastID, err := s.resolveStartID(ctx) + if err != nil { + return err + } + + s.signalStarted() + + for { + streams, err := s.client.XRead(ctx, &redis.XReadArgs{ + Streams: []string{s.stream, lastID}, + Count: clientEventReadCount, + Block: s.readBlockTimeout, + }).Result() + switch { + case err == nil: + for _, stream := range streams { + for _, message := range stream.Messages { + s.publishMessage(message) + lastID = message.ID + } + } + continue + case errors.Is(err, redis.Nil): + continue + case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)): + return ctx.Err() + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed): + return fmt.Errorf("run redis client event subscriber: %w", err) + default: + return fmt.Errorf("run redis client event subscriber: %w", err) + } + } +} + +func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) { + messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result() + switch { + case err == nil: + case errors.Is(err, redis.Nil): + return "0-0", nil + default: + return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err) + } + + if len(messages) == 0 { + return "0-0", nil + } + + return messages[0].ID, nil +} + +// Shutdown closes the Redis client so a blocking stream read can terminate +// promptly during gateway shutdown. +func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error { + if ctx == nil { + return errors.New("shutdown redis client event subscriber: nil context") + } + + return s.Close() +} + +// Close releases the underlying Redis client resources. +func (s *RedisClientEventSubscriber) Close() error { + if s == nil || s.client == nil { + return nil + } + + var err error + s.closeOnce.Do(func() { + err = s.client.Close() + }) + + return err +} + +func (s *RedisClientEventSubscriber) signalStarted() { + s.startedOnce.Do(func() { + close(s.started) + }) +} + +func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) { + event, err := decodeClientEvent(message.Values) + if err != nil { + s.logger.Warn("dropped malformed client event", + zap.String("stream", s.stream), + zap.String("message_id", message.ID), + zap.Error(err), + ) + s.metrics.RecordInternalEventDrop(context.Background(), + attribute.String("component", "client_event_subscriber"), + attribute.String("reason", "malformed_event"), + ) + return + } + + s.publisher.Publish(event) +} + +func decodeClientEvent(values map[string]any) (push.Event, error) { + requiredKeys := map[string]struct{}{ + "user_id": {}, + "event_type": {}, + "event_id": {}, + "payload_bytes": {}, + } + optionalKeys := map[string]struct{}{ + "device_session_id": {}, + "request_id": {}, + "trace_id": {}, + } + + for key := range values { + if _, ok := requiredKeys[key]; ok { + continue + } + if _, ok := optionalKeys[key]; ok { + continue + } + + return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key) + } + + userID, err := requiredStringField(values, "user_id") + if err != nil { + return push.Event{}, err + } + eventType, err := requiredStringField(values, "event_type") + if err != nil { + return push.Event{}, err + } + eventID, err := requiredStringField(values, "event_id") + if err != nil { + return push.Event{}, err + } + payloadBytes, err := requiredBytesField(values, "payload_bytes") + if err != nil { + return push.Event{}, err + } + + event := push.Event{ + UserID: userID, + EventType: eventType, + EventID: eventID, + PayloadBytes: payloadBytes, + } + + if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil { + return push.Event{}, err + } else if ok { + event.DeviceSessionID = strings.TrimSpace(deviceSessionID) + } + + if requestID, ok, err := optionalStringField(values, "request_id"); err != nil { + return push.Event{}, err + } else if ok { + event.RequestID = requestID + } + + if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil { + return push.Event{}, err + } else if ok { + event.TraceID = traceID + } + + return event, nil +} + +func requiredBytesField(values map[string]any, field string) ([]byte, error) { + value, ok := values[field] + if !ok { + return nil, fmt.Errorf("decode client event: missing %s", field) + } + + byteValue, err := coerceBytes(value) + if err != nil { + return nil, fmt.Errorf("decode client event: %s: %w", field, err) + } + + return byteValue, nil +} + +func optionalStringField(values map[string]any, field string) (string, bool, error) { + value, ok := values[field] + if !ok { + return "", false, nil + } + + stringValue, err := coerceString(value) + if err != nil { + return "", false, fmt.Errorf("decode client event: %s: %w", field, err) + } + + return stringValue, true, nil +} + +func coerceBytes(value any) ([]byte, error) { + switch typed := value.(type) { + case string: + return []byte(typed), nil + case []byte: + return bytes.Clone(typed), nil + default: + return nil, fmt.Errorf("unsupported type %T", value) + } +} diff --git a/gateway/internal/events/client_subscriber_test.go b/gateway/internal/events/client_subscriber_test.go new file mode 100644 index 0000000..2d2c347 --- /dev/null +++ b/gateway/internal/events/client_subscriber_test.go @@ -0,0 +1,294 @@ +package events + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/push" + "galaxy/gateway/internal/testutil" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRedisClientEventSubscriberPublishesValidEvent(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := &recordingClientEventPublisher{} + subscriber := newTestRedisClientEventSubscriber(t, server, publisher) + running := runTestClientEventSubscriber(t, subscriber) + defer running.stop(t) + + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "device_session_id": "device-session-123", + "event_type": "fleet.updated", + "event_id": "event-123", + "payload_bytes": []byte("payload-123"), + "request_id": "request-123", + "trace_id": "trace-123", + }) + + require.Eventually(t, func() bool { + return len(publisher.events()) == 1 + }, time.Second, 10*time.Millisecond) + + assert.Equal(t, []push.Event{{ + UserID: "user-123", + DeviceSessionID: "device-session-123", + EventType: "fleet.updated", + EventID: "event-123", + PayloadBytes: []byte("payload-123"), + RequestID: "request-123", + TraceID: "trace-123", + }}, publisher.events()) +} + +func TestRedisClientEventSubscriberSkipsMalformedEventAndContinues(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := &recordingClientEventPublisher{} + subscriber := newTestRedisClientEventSubscriber(t, server, publisher) + running := runTestClientEventSubscriber(t, subscriber) + defer running.stop(t) + + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "event_type": "fleet.updated", + "event_id": "event-bad", + "payload_bytes": []byte("payload-bad"), + "unexpected": "boom", + }) + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "event_type": "fleet.updated", + "event_id": "event-good", + "payload_bytes": []byte("payload-good"), + }) + + require.Eventually(t, func() bool { + events := publisher.events() + return len(events) == 1 && events[0].EventID == "event-good" + }, time.Second, 10*time.Millisecond) +} + +func TestRedisClientEventSubscriberStartsFromCurrentTail(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := &recordingClientEventPublisher{} + + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "event_type": "fleet.updated", + "event_id": "event-old", + "payload_bytes": []byte("payload-old"), + }) + + subscriber := newTestRedisClientEventSubscriber(t, server, publisher) + running := runTestClientEventSubscriber(t, subscriber) + defer running.stop(t) + + assert.Never(t, func() bool { + return len(publisher.events()) > 0 + }, 100*time.Millisecond, 10*time.Millisecond) + + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "event_type": "fleet.updated", + "event_id": "event-new", + "payload_bytes": []byte("payload-new"), + }) + + require.Eventually(t, func() bool { + events := publisher.events() + return len(events) == 1 && events[0].EventID == "event-new" + }, time.Second, 10*time.Millisecond) +} + +func TestRedisClientEventSubscriberShutdownInterruptsBlockingRead(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := &recordingClientEventPublisher{} + subscriber := newTestRedisClientEventSubscriber(t, server, publisher) + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- subscriber.Run(ctx) + }() + + select { + case <-subscriber.started: + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not start") + } + + cancel() + require.NoError(t, subscriber.Shutdown(context.Background())) + + select { + case err := <-resultCh: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not stop after shutdown") + } +} + +func TestRedisClientEventSubscriberLogsAndCountsMalformedEvents(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := &recordingClientEventPublisher{} + logger, logBuffer := testutil.NewObservedLogger(t) + telemetryRuntime := testutil.NewTelemetryRuntime(t, logger) + + subscriber, err := NewRedisClientEventSubscriberWithObservability( + config.SessionCacheRedisConfig{ + Addr: server.Addr(), + LookupTimeout: 250 * time.Millisecond, + }, + config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: 25 * time.Millisecond, + }, + publisher, + logger, + telemetryRuntime, + ) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, subscriber.Close()) + }) + + running := runTestClientEventSubscriber(t, subscriber) + defer running.stop(t) + + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "event_type": "fleet.updated", + "event_id": "event-bad", + "payload_bytes": []byte("payload-bad"), + "unexpected": "boom", + }) + + require.Eventually(t, func() bool { + return strings.Contains(logBuffer.String(), "dropped malformed client event") + }, time.Second, 10*time.Millisecond) + + metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler()) + assert.Contains(t, metricsText, `gateway_internal_event_drops_total`) + assert.Contains(t, metricsText, `component="client_event_subscriber"`) + assert.Contains(t, metricsText, `reason="malformed_event"`) +} + +func newTestRedisClientEventSubscriber(t *testing.T, server *miniredis.Miniredis, publisher ClientEventPublisher) *RedisClientEventSubscriber { + t.Helper() + + subscriber, err := NewRedisClientEventSubscriber( + config.SessionCacheRedisConfig{ + Addr: server.Addr(), + LookupTimeout: 250 * time.Millisecond, + }, + config.ClientEventsRedisConfig{ + Stream: "gateway:client_events", + ReadBlockTimeout: 25 * time.Millisecond, + }, + publisher, + ) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, subscriber.Close()) + }) + + return subscriber +} + +func addClientEvent(t *testing.T, server *miniredis.Miniredis, stream string, values map[string]any) { + t.Helper() + + client := redis.NewClient(&redis.Options{ + Addr: server.Addr(), + Protocol: 2, + DisableIdentity: true, + }) + defer func() { + assert.NoError(t, client.Close()) + }() + + err := client.XAdd(context.Background(), &redis.XAddArgs{ + Stream: stream, + Values: values, + }).Err() + require.NoError(t, err) +} + +type runningClientEventSubscriber struct { + cancel context.CancelFunc + resultCh chan error +} + +func runTestClientEventSubscriber(t *testing.T, subscriber *RedisClientEventSubscriber) runningClientEventSubscriber { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- subscriber.Run(ctx) + }() + + select { + case <-subscriber.started: + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not start") + } + + return runningClientEventSubscriber{ + cancel: cancel, + resultCh: resultCh, + } +} + +func (r runningClientEventSubscriber) stop(t *testing.T) { + t.Helper() + + r.cancel() + + select { + case err := <-r.resultCh: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not stop") + } +} + +type recordingClientEventPublisher struct { + mu sync.Mutex + records []push.Event +} + +func (p *recordingClientEventPublisher) Publish(event push.Event) { + p.mu.Lock() + defer p.mu.Unlock() + + p.records = append(p.records, event) +} + +func (p *recordingClientEventPublisher) events() []push.Event { + p.mu.Lock() + defer p.mu.Unlock() + + cloned := make([]push.Event, len(p.records)) + copy(cloned, p.records) + return cloned +} diff --git a/gateway/internal/events/grpc_integration_test.go b/gateway/internal/events/grpc_integration_test.go new file mode 100644 index 0000000..5dab310 --- /dev/null +++ b/gateway/internal/events/grpc_integration_test.go @@ -0,0 +1,385 @@ +package events + +import ( + "context" + "crypto/ed25519" + "crypto/sha256" + "encoding/base64" + "errors" + "net" + "sync" + "testing" + "time" + + "galaxy/gateway/internal/app" + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/clock" + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/downstream" + "galaxy/gateway/internal/grpcapi" + "galaxy/gateway/internal/replay" + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" +) + +var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) + +func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + local := session.NewMemoryCache() + fallback := &countingSessionCache{ + records: map[string]session.Record{ + "device-session-123": newActiveSessionRecord("user-123"), + }, + } + readThrough, err := session.NewReadThroughCache(local, fallback) + require.NoError(t, err) + + subscriber := newTestRedisSessionSubscriber(t, server, local) + downstreamClient := &recordingDownstreamClient{} + addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient) + defer running.stop(t) + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1")) + require.NoError(t, err) + assert.Equal(t, 1, fallback.lookupCalls()) + + _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2")) + require.NoError(t, err) + assert.Equal(t, 1, fallback.lookupCalls()) + assert.Len(t, downstreamClient.commands(), 2) +} + +func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + local := session.NewMemoryCache() + fallback := &countingSessionCache{ + records: map[string]session.Record{ + "device-session-123": newActiveSessionRecord("user-123"), + }, + } + readThrough, err := session.NewReadThroughCache(local, fallback) + require.NoError(t, err) + + subscriber := newTestRedisSessionSubscriber(t, server, local) + downstreamClient := &recordingDownstreamClient{} + addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient) + defer running.stop(t) + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1")) + require.NoError(t, err) + assert.Equal(t, 1, fallback.lookupCalls()) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-456", + "client_public_key": testClientPublicKeyBase64(), + "status": string(session.StatusActive), + }) + + require.Eventually(t, func() bool { + record, lookupErr := local.Lookup(context.Background(), "device-session-123") + return lookupErr == nil && record.UserID == "user-456" + }, time.Second, 10*time.Millisecond) + + _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2")) + require.NoError(t, err) + assert.Equal(t, 1, fallback.lookupCalls()) + + commands := downstreamClient.commands() + require.Len(t, commands, 2) + assert.Equal(t, "user-456", commands[1].UserID) +} + +func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + local := session.NewMemoryCache() + fallback := &countingSessionCache{ + records: map[string]session.Record{ + "device-session-123": newActiveSessionRecord("user-123"), + }, + } + readThrough, err := session.NewReadThroughCache(local, fallback) + require.NoError(t, err) + + subscriber := newTestRedisSessionSubscriber(t, server, local) + downstreamClient := &recordingDownstreamClient{} + addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient) + defer running.stop(t) + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1")) + require.NoError(t, err) + assert.Equal(t, 1, fallback.lookupCalls()) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": testClientPublicKeyBase64(), + "status": string(session.StatusRevoked), + "revoked_at_ms": "123456789", + }) + + require.Eventually(t, func() bool { + record, lookupErr := local.Lookup(context.Background(), "device-session-123") + return lookupErr == nil && record.Status == session.StatusRevoked + }, time.Second, 10*time.Millisecond) + + _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2")) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "device session is revoked", status.Convert(err).Message()) + assert.Equal(t, 1, fallback.lookupCalls()) +} + +type runningAuthenticatedGateway struct { + cancel context.CancelFunc + resultCh chan error +} + +func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) { + t.Helper() + + addr := unusedTCPAddr(t) + grpcCfg := config.DefaultAuthenticatedGRPCConfig() + grpcCfg.Addr = addr + grpcCfg.FreshnessWindow = 5 * time.Minute + + router := downstream.NewStaticRouter(map[string]downstream.Client{ + "fleet.move": downstreamClient, + }) + + gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{ + Router: router, + ResponseSigner: newTestResponseSigner(t), + SessionCache: sessionCache, + ReplayStore: staticReplayStore{}, + Clock: fixedClock{now: testNow}, + }) + + application := app.New( + config.Config{ + ShutdownTimeout: time.Second, + AuthenticatedGRPC: grpcCfg, + }, + gateway, + subscriber, + ) + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + + select { + case <-subscriber.started: + case <-time.After(time.Second): + require.FailNow(t, "session subscriber did not start") + } + + return addr, runningAuthenticatedGateway{ + cancel: cancel, + resultCh: resultCh, + } +} + +func (g runningAuthenticatedGateway) stop(t *testing.T) { + t.Helper() + + g.cancel() + + select { + case err := <-g.resultCh: + require.NoError(t, err) + case <-time.After(2 * time.Second): + require.FailNow(t, "gateway did not stop after cancellation") + } +} + +func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + conn, err := grpc.DialContext( + ctx, + addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + require.NoError(t, err) + + return conn +} + +func unusedTCPAddr(t *testing.T) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + addr := listener.Addr().String() + require.NoError(t, listener.Close()) + + return addr +} + +func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest { + payloadBytes := []byte("payload") + payloadHash := sha256.Sum256(payloadBytes) + + req := &gatewayv1.ExecuteCommandRequest{ + ProtocolVersion: "v1", + DeviceSessionId: "device-session-123", + MessageType: "fleet.move", + TimestampMs: testNow.UnixMilli(), + RequestId: requestID, + PayloadBytes: payloadBytes, + PayloadHash: payloadHash[:], + TraceId: "trace-123", + } + req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{ + ProtocolVersion: req.GetProtocolVersion(), + DeviceSessionID: req.GetDeviceSessionId(), + MessageType: req.GetMessageType(), + TimestampMS: req.GetTimestampMs(), + RequestID: req.GetRequestId(), + PayloadHash: req.GetPayloadHash(), + })) + + return req +} + +func newActiveSessionRecord(userID string) session.Record { + return session.Record{ + DeviceSessionID: "device-session-123", + UserID: userID, + ClientPublicKey: testClientPublicKeyBase64(), + Status: session.StatusActive, + } +} + +func testClientPrivateKey() ed25519.PrivateKey { + seed := sha256.Sum256([]byte("gateway-events-grpc-test-client")) + return ed25519.NewKeyFromSeed(seed[:]) +} + +func testClientPublicKeyBase64() string { + return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey)) +} + +func newTestResponseSigner(t *testing.T) authn.ResponseSigner { + t.Helper() + + seed := sha256.Sum256([]byte("gateway-events-grpc-test-response")) + signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:])) + require.NoError(t, err) + + return signer +} + +type fixedClock struct { + now time.Time +} + +func (c fixedClock) Now() time.Time { + return c.now +} + +var _ clock.Clock = fixedClock{} + +type staticReplayStore struct{} + +func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error { + return nil +} + +var _ replay.Store = staticReplayStore{} + +type countingSessionCache struct { + mu sync.Mutex + records map[string]session.Record + lookupCount int +} + +func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.lookupCount++ + + record, ok := c.records["device-session-123"] + if !ok { + return session.Record{}, errors.New("lookup session from counting cache: session cache record not found") + } + + return record, nil +} + +func (c *countingSessionCache) lookupCalls() int { + c.mu.Lock() + defer c.mu.Unlock() + + return c.lookupCount +} + +type recordingDownstreamClient struct { + mu sync.Mutex + captured []downstream.AuthenticatedCommand +} + +func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { + c.mu.Lock() + c.captured = append(c.captured, command) + c.mu.Unlock() + + return downstream.UnaryResult{ + ResultCode: "ok", + PayloadBytes: []byte("response"), + }, nil +} + +func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand { + c.mu.Lock() + defer c.mu.Unlock() + + cloned := make([]downstream.AuthenticatedCommand, len(c.captured)) + copy(cloned, c.captured) + return cloned +} diff --git a/gateway/internal/events/push_grpc_integration_test.go b/gateway/internal/events/push_grpc_integration_test.go new file mode 100644 index 0000000..7cad073 --- /dev/null +++ b/gateway/internal/events/push_grpc_integration_test.go @@ -0,0 +1,416 @@ +package events + +import ( + "context" + "crypto/ed25519" + "crypto/sha256" + "encoding/base64" + "testing" + "time" + + "galaxy/gateway/internal/app" + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/grpcapi" + "galaxy/gateway/internal/push" + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + sessionCache := session.NewMemoryCache() + require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) + require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123"))) + require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999"))) + + pushHub := push.NewHub(4) + clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) + addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber) + defer running.stop(t) + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + client := gatewayv1.NewEdgeGatewayClient(conn) + + targetOneCtx, cancelTargetOne := context.WithCancel(context.Background()) + defer cancelTargetOne() + targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1")) + require.NoError(t, err) + assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1") + + targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background()) + defer cancelTargetTwo() + targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2")) + require.NoError(t, err) + assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2") + + unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background()) + defer cancelUnrelated() + unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3")) + require.NoError(t, err) + assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3") + + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "event_type": "fleet.updated", + "event_id": "event-123", + "payload_bytes": []byte("payload-123"), + "request_id": "request-123", + "trace_id": "trace-123", + }) + + assertSignedPushEvent(t, recvPushEvent(t, targetOne), push.Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-123", + PayloadBytes: []byte("payload-123"), + RequestID: "request-123", + TraceID: "trace-123", + }) + assertSignedPushEvent(t, recvPushEvent(t, targetTwo), push.Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-123", + PayloadBytes: []byte("payload-123"), + RequestID: "request-123", + TraceID: "trace-123", + }) + assertNoPushEvent(t, unrelated, cancelUnrelated) +} + +func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + sessionCache := session.NewMemoryCache() + require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) + require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123"))) + + pushHub := push.NewHub(4) + clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) + addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber) + defer running.stop(t) + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + client := gatewayv1.NewEdgeGatewayClient(conn) + + otherCtx, cancelOther := context.WithCancel(context.Background()) + defer cancelOther() + otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1")) + require.NoError(t, err) + assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1") + + targetCtx, cancelTarget := context.WithCancel(context.Background()) + defer cancelTarget() + targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2")) + require.NoError(t, err) + assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2") + + addClientEvent(t, server, "gateway:client_events", map[string]any{ + "user_id": "user-123", + "device_session_id": "device-session-2", + "event_type": "fleet.updated", + "event_id": "event-456", + "payload_bytes": []byte("payload-456"), + }) + + assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{ + UserID: "user-123", + DeviceSessionID: "device-session-2", + EventType: "fleet.updated", + EventID: "event-456", + PayloadBytes: []byte("payload-456"), + }) + assertNoPushEvent(t, otherStream, cancelOther) +} + +func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + sessionCache := session.NewMemoryCache() + require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) + + pushHub := push.NewHub(4) + clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) + sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub) + addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber) + defer running.stop(t) + + select { + case <-sessionSubscriber.started: + case <-time.After(time.Second): + require.FailNow(t, "session subscriber did not start") + } + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + client := gatewayv1.NewEdgeGatewayClient(conn) + + streamCtx, cancelStream := context.WithCancel(context.Background()) + defer cancelStream() + + stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1")) + require.NoError(t, err) + assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1") + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-1", + "user_id": "user-123", + "client_public_key": pushClientPublicKeyBase64(), + "status": string(session.StatusRevoked), + "revoked_at_ms": "123456789", + }) + + require.Eventually(t, func() bool { + record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1") + return lookupErr == nil && record.Status == session.StatusRevoked + }, time.Second, 10*time.Millisecond) + + recvErrCh := make(chan error, 1) + go func() { + _, recvErr := stream.Recv() + recvErrCh <- recvErr + }() + + select { + case recvErr := <-recvErrCh: + require.Error(t, recvErr) + assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr)) + assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message()) + case <-time.After(time.Second): + require.FailNow(t, "stream did not close after revoke") + } + + reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2")) + if err == nil { + _, err = reopened.Recv() + } + + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "device session is revoked", status.Convert(err).Message()) +} + +func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + sessionCache := session.NewMemoryCache() + require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) + + pushHub := push.NewHub(4) + clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) + addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber) + defer running.stop(t) + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + client := gatewayv1.NewEdgeGatewayClient(conn) + + stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1")) + require.NoError(t, err) + assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1") + + recvErrCh := make(chan error, 1) + go func() { + _, recvErr := stream.Recv() + recvErrCh <- recvErr + }() + + running.cancel() + + select { + case recvErr := <-recvErrCh: + require.Error(t, recvErr) + assert.Equal(t, codes.Unavailable, status.Code(recvErr)) + assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message()) + case <-time.After(time.Second): + require.FailNow(t, "stream did not close after gateway shutdown") + } +} + +func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) { + t.Helper() + + addr := unusedTCPAddr(t) + grpcCfg := config.DefaultAuthenticatedGRPCConfig() + grpcCfg.Addr = addr + grpcCfg.FreshnessWindow = 5 * time.Minute + + responseSigner := newTestResponseSigner(t) + gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{ + Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()), + ResponseSigner: responseSigner, + SessionCache: sessionCache, + ReplayStore: staticReplayStore{}, + Clock: fixedClock{now: testNow}, + PushHub: pushHub, + }) + + components := []app.Component{gateway, clientSubscriber} + components = append(components, extraComponents...) + application := app.New( + config.Config{ + ShutdownTimeout: time.Second, + AuthenticatedGRPC: grpcCfg, + }, + components..., + ) + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + + select { + case <-clientSubscriber.started: + case <-time.After(time.Second): + require.FailNow(t, "client event subscriber did not start") + } + + return addr, runningAuthenticatedGateway{ + cancel: cancel, + resultCh: resultCh, + } +} + +func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record { + return session.Record{ + DeviceSessionID: deviceSessionID, + UserID: userID, + ClientPublicKey: pushClientPublicKeyBase64(), + Status: session.StatusActive, + } +} + +func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest { + payloadHash := sha256.Sum256(nil) + traceID := "trace-" + deviceSessionID + + req := &gatewayv1.SubscribeEventsRequest{ + ProtocolVersion: "v1", + DeviceSessionId: deviceSessionID, + MessageType: "gateway.subscribe", + TimestampMs: testNow.UnixMilli(), + RequestId: requestID, + PayloadHash: payloadHash[:], + TraceId: traceID, + } + req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{ + ProtocolVersion: req.GetProtocolVersion(), + DeviceSessionID: req.GetDeviceSessionId(), + MessageType: req.GetMessageType(), + TimestampMS: req.GetTimestampMs(), + RequestID: req.GetRequestId(), + PayloadHash: req.GetPayloadHash(), + })) + + return req +} + +func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent { + t.Helper() + + event, err := stream.Recv() + require.NoError(t, err) + return event +} + +func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) { + t.Helper() + + require.NotNil(t, event) + assert.Equal(t, "gateway.server_time", event.GetEventType()) + assert.Equal(t, wantRequestID, event.GetEventId()) + assert.Equal(t, wantRequestID, event.GetRequestId()) + assert.Equal(t, wantTraceID, event.GetTraceId()) + require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) + require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{ + EventType: event.GetEventType(), + EventID: event.GetEventId(), + TimestampMS: event.GetTimestampMs(), + RequestID: event.GetRequestId(), + TraceID: event.GetTraceId(), + PayloadHash: event.GetPayloadHash(), + })) +} + +func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) { + t.Helper() + + require.NotNil(t, event) + assert.Equal(t, want.EventType, event.GetEventType()) + assert.Equal(t, want.EventID, event.GetEventId()) + assert.Equal(t, want.RequestID, event.GetRequestId()) + assert.Equal(t, want.TraceID, event.GetTraceId()) + assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes()) + require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) + require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{ + EventType: event.GetEventType(), + EventID: event.GetEventId(), + TimestampMS: event.GetTimestampMs(), + RequestID: event.GetRequestId(), + TraceID: event.GetTraceId(), + PayloadHash: event.GetPayloadHash(), + })) +} + +func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) { + t.Helper() + + recvCh := make(chan *gatewayv1.GatewayEvent, 1) + errCh := make(chan error, 1) + go func() { + event, err := stream.Recv() + if err != nil { + errCh <- err + return + } + recvCh <- event + }() + + select { + case event := <-recvCh: + require.FailNowf(t, "unexpected push event delivered", "%+v", event) + case <-time.After(100 * time.Millisecond): + cancel() + case err := <-errCh: + require.FailNowf(t, "stream closed unexpectedly", "%v", err) + } +} + +func pushClientPrivateKey() ed25519.PrivateKey { + seed := sha256.Sum256([]byte("gateway-push-grpc-test-client")) + return ed25519.NewKeyFromSeed(seed[:]) +} + +func pushClientPublicKeyBase64() string { + return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey)) +} + +func pushResponseSignerPublicKey() ed25519.PublicKey { + seed := sha256.Sum256([]byte("gateway-events-grpc-test-response")) + return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey) +} diff --git a/gateway/internal/events/subscriber.go b/gateway/internal/events/subscriber.go new file mode 100644 index 0000000..605f80c --- /dev/null +++ b/gateway/internal/events/subscriber.go @@ -0,0 +1,389 @@ +// Package events subscribes to internal session lifecycle streams used to keep +// the gateway hot-path session cache synchronized without per-request upstream +// lookups. +package events + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/session" + "galaxy/gateway/internal/telemetry" + + "github.com/redis/go-redis/v9" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" +) + +const sessionEventReadCount int64 = 128 + +// SessionRevocationHandler reacts to a successfully applied revoked session +// snapshot and may tear down active resources bound to that session. +type SessionRevocationHandler interface { + // RevokeDeviceSession tears down active resources bound to deviceSessionID. + RevokeDeviceSession(deviceSessionID string) +} + +// RedisSessionSubscriber consumes full session snapshots from one Redis Stream +// and applies them to a process-local session snapshot store. +type RedisSessionSubscriber struct { + client *redis.Client + stream string + pingTimeout time.Duration + readBlockTimeout time.Duration + store session.SnapshotStore + revocationHandler SessionRevocationHandler + logger *zap.Logger + metrics *telemetry.Runtime + + closeOnce sync.Once + startedOnce sync.Once + started chan struct{} +} + +// NewRedisSessionSubscriber constructs a Redis Stream subscriber that reuses +// the SessionCache Redis connection settings and applies updates to store. +func NewRedisSessionSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) { + return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, nil, nil, nil) +} + +// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream +// subscriber that reuses the SessionCache Redis connection settings, applies +// updates to store, and optionally tears down active resources for revoked +// sessions. +func NewRedisSessionSubscriberWithRevocationHandler(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) { + return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, revocationHandler, nil, nil) +} + +// NewRedisSessionSubscriberWithObservability constructs a Redis Stream +// subscriber that also logs and counts malformed internal session events. +func NewRedisSessionSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) { + if strings.TrimSpace(sessionCfg.Addr) == "" { + return nil, errors.New("new redis session subscriber: redis addr must not be empty") + } + if sessionCfg.DB < 0 { + return nil, errors.New("new redis session subscriber: redis db must not be negative") + } + if sessionCfg.LookupTimeout <= 0 { + return nil, errors.New("new redis session subscriber: lookup timeout must be positive") + } + if strings.TrimSpace(eventsCfg.Stream) == "" { + return nil, errors.New("new redis session subscriber: stream must not be empty") + } + if eventsCfg.ReadBlockTimeout <= 0 { + return nil, errors.New("new redis session subscriber: read block timeout must be positive") + } + if store == nil { + return nil, errors.New("new redis session subscriber: nil session snapshot store") + } + + options := &redis.Options{ + Addr: sessionCfg.Addr, + Username: sessionCfg.Username, + Password: sessionCfg.Password, + DB: sessionCfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if sessionCfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + if logger == nil { + logger = zap.NewNop() + } + + return &RedisSessionSubscriber{ + client: redis.NewClient(options), + stream: eventsCfg.Stream, + pingTimeout: sessionCfg.LookupTimeout, + readBlockTimeout: eventsCfg.ReadBlockTimeout, + store: store, + revocationHandler: revocationHandler, + logger: logger.Named("session_subscriber"), + metrics: metrics, + started: make(chan struct{}), + }, nil +} + +// Ping verifies that the Redis backend used for session lifecycle events is +// reachable within the configured timeout budget. +func (s *RedisSessionSubscriber) Ping(ctx context.Context) error { + if s == nil || s.client == nil { + return errors.New("ping redis session subscriber: nil subscriber") + } + if ctx == nil { + return errors.New("ping redis session subscriber: nil context") + } + + pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout) + defer cancel() + + if err := s.client.Ping(pingCtx).Err(); err != nil { + return fmt.Errorf("ping redis session subscriber: %w", err) + } + + return nil +} + +// Run consumes session lifecycle events until ctx is canceled or Redis returns +// an unexpected error. +func (s *RedisSessionSubscriber) Run(ctx context.Context) error { + if s == nil || s.client == nil { + return errors.New("run redis session subscriber: nil subscriber") + } + if ctx == nil { + return errors.New("run redis session subscriber: nil context") + } + if err := ctx.Err(); err != nil { + return err + } + + lastID, err := s.resolveStartID(ctx) + if err != nil { + return err + } + + s.signalStarted() + + for { + streams, err := s.client.XRead(ctx, &redis.XReadArgs{ + Streams: []string{s.stream, lastID}, + Count: sessionEventReadCount, + Block: s.readBlockTimeout, + }).Result() + switch { + case err == nil: + for _, stream := range streams { + for _, message := range stream.Messages { + s.applyMessage(message) + lastID = message.ID + } + } + continue + case errors.Is(err, redis.Nil): + continue + case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)): + return ctx.Err() + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed): + return fmt.Errorf("run redis session subscriber: %w", err) + default: + return fmt.Errorf("run redis session subscriber: %w", err) + } + } +} + +func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) { + messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result() + switch { + case err == nil: + case errors.Is(err, redis.Nil): + return "0-0", nil + default: + return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err) + } + + if len(messages) == 0 { + return "0-0", nil + } + + return messages[0].ID, nil +} + +// Shutdown closes the Redis client so a blocking stream read can terminate +// promptly during gateway shutdown. +func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error { + if ctx == nil { + return errors.New("shutdown redis session subscriber: nil context") + } + + return s.Close() +} + +// Close releases the underlying Redis client resources. +func (s *RedisSessionSubscriber) Close() error { + if s == nil || s.client == nil { + return nil + } + + var err error + s.closeOnce.Do(func() { + err = s.client.Close() + }) + + return err +} + +func (s *RedisSessionSubscriber) signalStarted() { + s.startedOnce.Do(func() { + close(s.started) + }) +} + +func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) { + record, err := decodeSessionRecordSnapshot(message.Values) + if err != nil { + s.logger.Warn("dropped malformed session event", + zap.String("stream", s.stream), + zap.String("message_id", message.ID), + zap.Error(err), + ) + s.metrics.RecordInternalEventDrop(context.Background(), + attribute.String("component", "session_subscriber"), + attribute.String("reason", "malformed_event"), + ) + if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok { + s.store.Delete(deviceSessionID) + } + return + } + + if err := s.store.Upsert(record); err != nil { + s.logger.Warn("dropped session snapshot after store failure", + zap.String("stream", s.stream), + zap.String("message_id", message.ID), + zap.String("device_session_id", record.DeviceSessionID), + zap.Error(err), + ) + s.metrics.RecordInternalEventDrop(context.Background(), + attribute.String("component", "session_subscriber"), + attribute.String("reason", "store_failure"), + ) + s.store.Delete(record.DeviceSessionID) + return + } + + if record.Status == session.StatusRevoked && s.revocationHandler != nil { + s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID) + } +} + +func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) { + requiredKeys := map[string]struct{}{ + "device_session_id": {}, + "user_id": {}, + "client_public_key": {}, + "status": {}, + } + optionalKeys := map[string]struct{}{ + "revoked_at_ms": {}, + } + + for key := range values { + if _, ok := requiredKeys[key]; ok { + continue + } + if _, ok := optionalKeys[key]; ok { + continue + } + + return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key) + } + + deviceSessionID, err := requiredStringField(values, "device_session_id") + if err != nil { + return session.Record{}, err + } + userID, err := requiredStringField(values, "user_id") + if err != nil { + return session.Record{}, err + } + clientPublicKey, err := requiredStringField(values, "client_public_key") + if err != nil { + return session.Record{}, err + } + statusValue, err := requiredStringField(values, "status") + if err != nil { + return session.Record{}, err + } + + record := session.Record{ + DeviceSessionID: deviceSessionID, + UserID: userID, + ClientPublicKey: clientPublicKey, + Status: session.Status(statusValue), + } + + if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok { + revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms") + if err != nil { + return session.Record{}, err + } + record.RevokedAtMS = &revokedAtMS + } + + return record, nil +} + +func extractDeviceSessionID(values map[string]any) (string, bool) { + value, ok := values["device_session_id"] + if !ok { + return "", false + } + + deviceSessionID, err := coerceString(value) + if err != nil { + return "", false + } + if strings.TrimSpace(deviceSessionID) == "" { + return "", false + } + + return deviceSessionID, true +} + +func requiredStringField(values map[string]any, field string) (string, error) { + value, ok := values[field] + if !ok { + return "", fmt.Errorf("decode session event: missing %s", field) + } + + stringValue, err := coerceString(value) + if err != nil { + return "", fmt.Errorf("decode session event: %s: %w", field, err) + } + if strings.TrimSpace(stringValue) == "" { + return "", fmt.Errorf("decode session event: %s must not be empty", field) + } + + return stringValue, nil +} + +func parseInt64Field(value any, field string) (int64, error) { + stringValue, err := coerceString(value) + if err != nil { + return 0, fmt.Errorf("decode session event: %s: %w", field, err) + } + + parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64) + if err != nil { + return 0, fmt.Errorf("decode session event: %s: %w", field, err) + } + + return parsed, nil +} + +func coerceString(value any) (string, error) { + switch typed := value.(type) { + case string: + return typed, nil + case []byte: + return string(typed), nil + case fmt.Stringer: + return typed.String(), nil + case int: + return strconv.Itoa(typed), nil + case int64: + return strconv.FormatInt(typed, 10), nil + case uint64: + return strconv.FormatUint(typed, 10), nil + default: + return "", fmt.Errorf("unsupported value type %T", value) + } +} diff --git a/gateway/internal/events/subscriber_test.go b/gateway/internal/events/subscriber_test.go new file mode 100644 index 0000000..6193ae1 --- /dev/null +++ b/gateway/internal/events/subscriber_test.go @@ -0,0 +1,366 @@ +package events + +import ( + "context" + "sync" + "testing" + "time" + + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/session" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := session.NewMemoryCache() + subscriber := newTestRedisSessionSubscriber(t, server, store) + running := runTestSubscriber(t, subscriber) + defer running.stop(t) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "public-key-123", + "status": string(session.StatusActive), + }) + + require.Eventually(t, func() bool { + record, err := store.Lookup(context.Background(), "device-session-123") + if err != nil { + return false + } + + return record.UserID == "user-123" && record.Status == session.StatusActive + }, time.Second, 10*time.Millisecond) +} + +func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := session.NewMemoryCache() + require.NoError(t, store.Upsert(session.Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: session.StatusActive, + })) + + subscriber := newTestRedisSessionSubscriber(t, server, store) + running := runTestSubscriber(t, subscriber) + defer running.stop(t) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "public-key-123", + "status": string(session.StatusRevoked), + "revoked_at_ms": "123456789", + }) + + require.Eventually(t, func() bool { + record, err := store.Lookup(context.Background(), "device-session-123") + if err != nil || record.RevokedAtMS == nil { + return false + } + + return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789 + }, time.Second, 10*time.Millisecond) +} + +func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := session.NewMemoryCache() + handler := &recordingSessionRevocationHandler{} + subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler) + running := runTestSubscriber(t, subscriber) + defer running.stop(t) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "public-key-123", + "status": string(session.StatusRevoked), + "revoked_at_ms": "123456789", + }) + + require.Eventually(t, func() bool { + record, err := store.Lookup(context.Background(), "device-session-123") + if err != nil || record.Status != session.StatusRevoked { + return false + } + + return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations()) + }, time.Second, 10*time.Millisecond) +} + +func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := session.NewMemoryCache() + handler := &recordingSessionRevocationHandler{} + subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler) + running := runTestSubscriber(t, subscriber) + defer running.stop(t) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "public-key-123", + "status": string(session.StatusActive), + }) + + assert.Never(t, func() bool { + return len(handler.revocations()) != 0 + }, 100*time.Millisecond, 10*time.Millisecond) +} + +func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + handler := &recordingSessionRevocationHandler{} + subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler) + running := runTestSubscriber(t, subscriber) + defer running.stop(t) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "public-key-123", + "status": string(session.StatusRevoked), + "revoked_at_ms": "123456789", + }) + + assert.Never(t, func() bool { + return len(handler.revocations()) != 0 + }, 100*time.Millisecond, 10*time.Millisecond) +} + +func TestRedisSessionSubscriberLaterEventWins(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := session.NewMemoryCache() + subscriber := newTestRedisSessionSubscriber(t, server, store) + running := runTestSubscriber(t, subscriber) + defer running.stop(t) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "public-key-123", + "status": string(session.StatusActive), + }) + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-456", + "client_public_key": "public-key-456", + "status": string(session.StatusActive), + }) + + require.Eventually(t, func() bool { + record, err := store.Lookup(context.Background(), "device-session-123") + if err != nil { + return false + } + + return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456" + }, time.Second, 10*time.Millisecond) +} + +func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := session.NewMemoryCache() + require.NoError(t, store.Upsert(session.Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: session.StatusActive, + })) + + subscriber := newTestRedisSessionSubscriber(t, server, store) + running := runTestSubscriber(t, subscriber) + defer running.stop(t) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "public-key-123", + "status": "paused", + }) + + require.Eventually(t, func() bool { + _, err := store.Lookup(context.Background(), "device-session-123") + return err != nil + }, time.Second, 10*time.Millisecond) + + addSessionEvent(t, server, "gateway:session_events", map[string]string{ + "device_session_id": "device-session-123", + "user_id": "user-456", + "client_public_key": "public-key-456", + "status": string(session.StatusActive), + }) + + require.Eventually(t, func() bool { + record, err := store.Lookup(context.Background(), "device-session-123") + if err != nil { + return false + } + + return record.UserID == "user-456" && record.Status == session.StatusActive + }, time.Second, 10*time.Millisecond) +} + +func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := session.NewMemoryCache() + subscriber := newTestRedisSessionSubscriber(t, server, store) + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- subscriber.Run(ctx) + }() + + select { + case <-subscriber.started: + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not start") + } + + cancel() + require.NoError(t, subscriber.Shutdown(context.Background())) + + select { + case err := <-resultCh: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not stop after shutdown") + } +} + +func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber { + t.Helper() + + return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil) +} + +func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber { + t.Helper() + + subscriber, err := NewRedisSessionSubscriberWithRevocationHandler( + config.SessionCacheRedisConfig{ + Addr: server.Addr(), + LookupTimeout: 250 * time.Millisecond, + }, + config.SessionEventsRedisConfig{ + Stream: "gateway:session_events", + ReadBlockTimeout: 25 * time.Millisecond, + }, + store, + revocationHandler, + ) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, subscriber.Close()) + }) + + return subscriber +} + +type recordingSessionRevocationHandler struct { + mu sync.Mutex + revokedIDs []string +} + +func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) { + h.mu.Lock() + h.revokedIDs = append(h.revokedIDs, deviceSessionID) + h.mu.Unlock() +} + +func (h *recordingSessionRevocationHandler) revocations() []string { + h.mu.Lock() + defer h.mu.Unlock() + + return append([]string(nil), h.revokedIDs...) +} + +type failingSnapshotStore struct{} + +func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) { + return session.Record{}, session.ErrNotFound +} + +func (failingSnapshotStore) Upsert(session.Record) error { + return context.DeadlineExceeded +} + +func (failingSnapshotStore) Delete(string) {} + +func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) { + t.Helper() + + values := make([]string, 0, len(fields)*2) + for key, value := range fields { + values = append(values, key, value) + } + + _, err := server.XAdd(stream, "*", values) + require.NoError(t, err) +} + +type runningSubscriber struct { + cancel context.CancelFunc + resultCh chan error + stopOnce bool +} + +func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- subscriber.Run(ctx) + }() + + select { + case <-subscriber.started: + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not start") + } + + return runningSubscriber{ + cancel: cancel, + resultCh: resultCh, + } +} + +func (r runningSubscriber) stop(t *testing.T) { + t.Helper() + + r.cancel() + + select { + case err := <-r.resultCh: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + require.FailNow(t, "subscriber did not stop") + } +} diff --git a/gateway/internal/grpcapi/command_routing.go b/gateway/internal/grpcapi/command_routing.go new file mode 100644 index 0000000..88f025d --- /dev/null +++ b/gateway/internal/grpcapi/command_routing.go @@ -0,0 +1,145 @@ +package grpcapi + +import ( + "bytes" + "context" + "crypto/sha256" + "errors" + "strings" + "time" + + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/clock" + "galaxy/gateway/internal/downstream" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// commandRoutingService translates the verified authenticated request context +// into an internal downstream command and signs successful unary responses. +type commandRoutingService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + subscribeDelegate gatewayv1.EdgeGatewayServer + router downstream.Router + responseSigner authn.ResponseSigner + clock clock.Clock + downstreamTimeout time.Duration +} + +// ExecuteCommand builds a verified downstream command, routes it by exact +// message_type, executes it, and signs the resulting unary response. +func (s commandRoutingService) ExecuteCommand(ctx context.Context, _ *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + command, err := authenticatedCommandFromContext(ctx) + if err != nil { + return nil, err + } + + client, err := s.router.Route(command.MessageType) + switch { + case err == nil: + case errors.Is(err, downstream.ErrRouteNotFound): + return nil, status.Error(codes.Unimplemented, "message_type is not routed") + case errors.Is(err, downstream.ErrDownstreamUnavailable): + return nil, status.Error(codes.Unavailable, "downstream service is unavailable") + default: + return nil, status.Error(codes.Internal, "downstream route resolution failed") + } + + downstreamCtx, cancel := context.WithTimeout(ctx, s.downstreamTimeout) + defer cancel() + + result, err := client.ExecuteCommand(downstreamCtx, command) + switch { + case err == nil: + case errors.Is(err, downstream.ErrDownstreamUnavailable), + errors.Is(err, context.DeadlineExceeded), + errors.Is(err, context.Canceled): + return nil, status.Error(codes.Unavailable, "downstream service is unavailable") + default: + return nil, status.Error(codes.Internal, "downstream execution failed") + } + + if strings.TrimSpace(result.ResultCode) == "" { + return nil, status.Error(codes.Internal, "downstream response is invalid") + } + + responseTimestampMS := s.clock.Now().UTC().UnixMilli() + payloadHash := sha256.Sum256(result.PayloadBytes) + signature, err := s.responseSigner.SignResponse(authn.ResponseSigningFields{ + ProtocolVersion: command.ProtocolVersion, + RequestID: command.RequestID, + TimestampMS: responseTimestampMS, + ResultCode: result.ResultCode, + PayloadHash: payloadHash[:], + }) + if err != nil { + return nil, status.Error(codes.Unavailable, "response signer is unavailable") + } + + return &gatewayv1.ExecuteCommandResponse{ + ProtocolVersion: command.ProtocolVersion, + RequestId: command.RequestID, + TimestampMs: responseTimestampMS, + ResultCode: result.ResultCode, + PayloadBytes: bytes.Clone(result.PayloadBytes), + PayloadHash: bytes.Clone(payloadHash[:]), + Signature: signature, + }, nil +} + +// SubscribeEvents delegates to the authenticated streaming service +// implementation selected during server construction. +func (s commandRoutingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + return s.subscribeDelegate.SubscribeEvents(req, stream) +} + +// newCommandRoutingService constructs the final authenticated service that +// owns verified unary routing while preserving the delegated streaming path. +func newCommandRoutingService(subscribeDelegate gatewayv1.EdgeGatewayServer, router downstream.Router, responseSigner authn.ResponseSigner, clk clock.Clock, downstreamTimeout time.Duration) gatewayv1.EdgeGatewayServer { + return commandRoutingService{ + subscribeDelegate: subscribeDelegate, + router: router, + responseSigner: responseSigner, + clock: clk, + downstreamTimeout: downstreamTimeout, + } +} + +func authenticatedCommandFromContext(ctx context.Context) (downstream.AuthenticatedCommand, error) { + envelope, ok := parsedEnvelopeFromContext(ctx) + if !ok { + return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete") + } + + record, ok := resolvedSessionFromContext(ctx) + if !ok { + return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete") + } + + return downstream.AuthenticatedCommand{ + ProtocolVersion: envelope.ProtocolVersion, + UserID: record.UserID, + DeviceSessionID: record.DeviceSessionID, + MessageType: envelope.MessageType, + TimestampMS: envelope.TimestampMS, + RequestID: envelope.RequestID, + TraceID: envelope.TraceID, + PayloadBytes: bytes.Clone(envelope.PayloadBytes), + }, nil +} + +type unavailableResponseSigner struct{} + +func (unavailableResponseSigner) SignResponse(authn.ResponseSigningFields) ([]byte, error) { + return nil, errors.New("response signer is unavailable") +} + +func (unavailableResponseSigner) SignEvent(authn.EventSigningFields) ([]byte, error) { + return nil, errors.New("response signer is unavailable") +} + +var _ gatewayv1.EdgeGatewayServer = commandRoutingService{} diff --git a/gateway/internal/grpcapi/command_routing_integration_test.go b/gateway/internal/grpcapi/command_routing_integration_test.go new file mode 100644 index 0000000..da225d8 --- /dev/null +++ b/gateway/internal/grpcapi/command_routing_integration_test.go @@ -0,0 +1,296 @@ +package grpcapi + +import ( + "context" + "crypto/sha256" + "fmt" + "testing" + "time" + + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/downstream" + "galaxy/gateway/internal/testutil" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) { + t.Parallel() + + signer := newTestEd25519ResponseSigner() + moveClient := &recordingDownstreamClient{ + executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { + assert.Equal(t, downstream.AuthenticatedCommand{ + ProtocolVersion: "v1", + UserID: "user-123", + DeviceSessionID: "device-session-123", + MessageType: "fleet.move", + TimestampMS: testCurrentTime.UnixMilli(), + RequestID: "request-123", + TraceID: "trace-123", + PayloadBytes: []byte("payload"), + }, command) + + return downstream.UnaryResult{ + ResultCode: "accepted", + PayloadBytes: []byte("downstream-response"), + }, nil + }, + } + renameClient := &recordingDownstreamClient{} + + server, runGateway := newTestGateway(t, ServerDependencies{ + Router: downstream.NewStaticRouter(map[string]downstream.Client{ + "fleet.move": moveClient, + "fleet.rename": renameClient, + }), + ResponseSigner: signer, + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.NoError(t, err) + + assert.Equal(t, "v1", response.GetProtocolVersion()) + assert.Equal(t, "request-123", response.GetRequestId()) + assert.Equal(t, testCurrentTime.UnixMilli(), response.GetTimestampMs()) + assert.Equal(t, "accepted", response.GetResultCode()) + assert.Equal(t, []byte("downstream-response"), response.GetPayloadBytes()) + assert.Equal(t, 1, moveClient.executeCalls) + assert.Zero(t, renameClient.executeCalls) + + wantHash := sha256.Sum256([]byte("downstream-response")) + assert.Equal(t, wantHash[:], response.GetPayloadHash()) + require.NoError(t, authn.VerifyPayloadHash(response.GetPayloadBytes(), response.GetPayloadHash())) + require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.GetSignature(), authn.ResponseSigningFields{ + ProtocolVersion: response.GetProtocolVersion(), + RequestID: response.GetRequestId(), + TimestampMS: response.GetTimestampMs(), + ResultCode: response.GetResultCode(), + PayloadHash: response.GetPayloadHash(), + })) +} + +func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{ + Router: downstream.NewStaticRouter(nil), + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + ResponseSigner: newTestResponseSigner(), + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unimplemented, status.Code(err)) + assert.Equal(t, "message_type is not routed", status.Convert(err).Message()) +} + +func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) { + t.Parallel() + + failingClient := &recordingDownstreamClient{ + executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { + return downstream.UnaryResult{}, fmt.Errorf("rpc transport failed: %w", downstream.ErrDownstreamUnavailable) + }, + } + + server, runGateway := newTestGateway(t, ServerDependencies{ + Router: downstream.NewStaticRouter(map[string]downstream.Client{ + "fleet.move": failingClient, + }), + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + ResponseSigner: newTestResponseSigner(), + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message()) + assert.Equal(t, 1, failingClient.executeCalls) +} + +func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) { + t.Parallel() + + logger := zap.NewNop() + telemetryRuntime := testutil.NewTelemetryRuntime(t, logger) + + var ( + seenSpanContext trace.SpanContext + seenCommand downstream.AuthenticatedCommand + ) + + server, runGateway := newTestGateway(t, ServerDependencies{ + Router: downstream.NewStaticRouter(map[string]downstream.Client{ + "fleet.move": &recordingDownstreamClient{ + executeFunc: func(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { + seenSpanContext = trace.SpanContextFromContext(ctx) + seenCommand = command + + return downstream.UnaryResult{ + ResultCode: "accepted", + PayloadBytes: []byte("downstream-response"), + }, nil + }, + }, + }), + ResponseSigner: newTestResponseSigner(), + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + Logger: logger, + Telemetry: telemetryRuntime, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.NoError(t, err) + + assert.True(t, seenSpanContext.IsValid()) + assert.Equal(t, "trace-123", seenCommand.TraceID) +} + +func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) { + t.Parallel() + + started := make(chan struct{}) + release := make(chan struct{}) + + server, runGateway := newTestGateway(t, ServerDependencies{ + Router: downstream.NewStaticRouter(map[string]downstream.Client{ + "fleet.move": &recordingDownstreamClient{ + executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { + close(started) + <-release + + return downstream.UnaryResult{ + ResultCode: "accepted", + PayloadBytes: []byte("downstream-response"), + }, nil + }, + }, + }), + ResponseSigner: newTestResponseSigner(), + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + resultCh := make(chan error, 1) + go func() { + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + resultCh <- err + }() + + require.Eventually(t, func() bool { + select { + case <-started: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond, "downstream execution did not start") + + runGateway.cancel() + + require.Never(t, func() bool { + select { + case <-resultCh: + return true + default: + return false + } + }, 100*time.Millisecond, 10*time.Millisecond, "unary request returned before downstream release") + + close(release) + + var err error + require.Eventually(t, func() bool { + select { + case err = <-resultCh: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond, "unary request did not drain before shutdown timeout") + require.NoError(t, err) +} + +func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T) { + t.Parallel() + + logger, logBuffer := testutil.NewObservedLogger(t) + + server, runGateway := newTestGateway(t, ServerDependencies{ + Router: downstream.NewStaticRouter(map[string]downstream.Client{ + "fleet.move": &recordingDownstreamClient{}, + }), + ResponseSigner: newTestResponseSigner(), + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + Logger: logger, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.NoError(t, err) + + logOutput := logBuffer.String() + assert.NotContains(t, logOutput, "payload_hash") + assert.NotContains(t, logOutput, "signature") + assert.NotContains(t, logOutput, `"payload"`) +} diff --git a/gateway/internal/grpcapi/envelope.go b/gateway/internal/grpcapi/envelope.go new file mode 100644 index 0000000..885789c --- /dev/null +++ b/gateway/internal/grpcapi/envelope.go @@ -0,0 +1,214 @@ +package grpcapi + +import ( + "bytes" + "context" + "fmt" + + "galaxy/gateway/proto/galaxy/gateway/v1" + + "buf.build/go/protovalidate" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const supportedProtocolVersion = "v1" + +// parsedEnvelope captures the authenticated transport fields extracted from a +// request envelope after validation succeeds. Later wrappers may enrich this +// structure without changing the raw gRPC request types. +type parsedEnvelope struct { + ProtocolVersion string + DeviceSessionID string + MessageType string + TimestampMS int64 + RequestID string + TraceID string + PayloadBytes []byte + PayloadHash []byte + Signature []byte +} + +// parsedEnvelopeFromContext returns the parsed envelope previously attached to +// ctx by the envelope-validating gRPC service wrapper. +func parsedEnvelopeFromContext(ctx context.Context) (parsedEnvelope, bool) { + if ctx == nil { + return parsedEnvelope{}, false + } + + envelope, ok := ctx.Value(parsedEnvelopeContextKey{}).(parsedEnvelope) + if !ok { + return parsedEnvelope{}, false + } + + return envelope, true +} + +// envelopeValidatingService applies envelope parsing and the protocol gate +// before delegating to the configured service implementation. +type envelopeValidatingService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + delegate gatewayv1.EdgeGatewayServer +} + +// ExecuteCommand validates req and only then forwards it to the configured +// delegate with the parsed envelope attached to ctx. +func (s envelopeValidatingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + envelope, err := parseExecuteCommandRequest(req) + if err != nil { + return nil, err + } + + return s.delegate.ExecuteCommand(context.WithValue(ctx, parsedEnvelopeContextKey{}, envelope), req) +} + +// SubscribeEvents validates req and only then forwards it to the configured +// delegate with the parsed envelope attached to the stream context. +func (s envelopeValidatingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + envelope, err := parseSubscribeEventsRequest(req) + if err != nil { + return err + } + + return s.delegate.SubscribeEvents(req, envelopeContextStream{ + ServerStreamingServer: stream, + ctx: context.WithValue(stream.Context(), parsedEnvelopeContextKey{}, envelope), + }) +} + +// parseExecuteCommandRequest validates req according to the request-envelope +// rules and returns a cloned parsed envelope suitable for later auth steps. +func parseExecuteCommandRequest(req *gatewayv1.ExecuteCommandRequest) (parsedEnvelope, error) { + if req == nil { + return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil") + } + if err := protovalidate.Validate(req); err != nil { + return parsedEnvelope{}, canonicalExecuteCommandValidationError(req) + } + if req.GetProtocolVersion() != supportedProtocolVersion { + return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion()) + } + + return parsedEnvelope{ + ProtocolVersion: req.GetProtocolVersion(), + DeviceSessionID: req.GetDeviceSessionId(), + MessageType: req.GetMessageType(), + TimestampMS: req.GetTimestampMs(), + RequestID: req.GetRequestId(), + TraceID: req.GetTraceId(), + PayloadBytes: bytes.Clone(req.GetPayloadBytes()), + PayloadHash: bytes.Clone(req.GetPayloadHash()), + Signature: bytes.Clone(req.GetSignature()), + }, nil +} + +// parseSubscribeEventsRequest validates req according to the request-envelope +// rules and returns a cloned parsed envelope suitable for later auth steps. +func parseSubscribeEventsRequest(req *gatewayv1.SubscribeEventsRequest) (parsedEnvelope, error) { + if req == nil { + return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil") + } + if err := protovalidate.Validate(req); err != nil { + return parsedEnvelope{}, canonicalSubscribeEventsValidationError(req) + } + if req.GetProtocolVersion() != supportedProtocolVersion { + return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion()) + } + + return parsedEnvelope{ + ProtocolVersion: req.GetProtocolVersion(), + DeviceSessionID: req.GetDeviceSessionId(), + MessageType: req.GetMessageType(), + TimestampMS: req.GetTimestampMs(), + RequestID: req.GetRequestId(), + TraceID: req.GetTraceId(), + PayloadBytes: bytes.Clone(req.GetPayloadBytes()), + PayloadHash: bytes.Clone(req.GetPayloadHash()), + Signature: bytes.Clone(req.GetSignature()), + }, nil +} + +// newEnvelopeValidatingService wraps delegate with the envelope-validation +// gate. +func newEnvelopeValidatingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer { + return envelopeValidatingService{delegate: delegate} +} + +// canonicalExecuteCommandValidationError maps any ExecuteCommand validation +// failure into the stable canonical error chosen by field order. +func canonicalExecuteCommandValidationError(req *gatewayv1.ExecuteCommandRequest) error { + switch { + case req.GetProtocolVersion() == "": + return newMalformedEnvelopeError("protocol_version must not be empty") + case req.GetDeviceSessionId() == "": + return newMalformedEnvelopeError("device_session_id must not be empty") + case req.GetMessageType() == "": + return newMalformedEnvelopeError("message_type must not be empty") + case req.GetTimestampMs() <= 0: + return newMalformedEnvelopeError("timestamp_ms must be greater than zero") + case req.GetRequestId() == "": + return newMalformedEnvelopeError("request_id must not be empty") + case len(req.GetPayloadBytes()) == 0: + return newMalformedEnvelopeError("payload_bytes must not be empty") + case len(req.GetPayloadHash()) == 0: + return newMalformedEnvelopeError("payload_hash must not be empty") + case len(req.GetSignature()) == 0: + return newMalformedEnvelopeError("signature must not be empty") + default: + return newMalformedEnvelopeError("request envelope is invalid") + } +} + +// canonicalSubscribeEventsValidationError maps any SubscribeEvents validation +// failure into the stable canonical error chosen by field order. +func canonicalSubscribeEventsValidationError(req *gatewayv1.SubscribeEventsRequest) error { + switch { + case req.GetProtocolVersion() == "": + return newMalformedEnvelopeError("protocol_version must not be empty") + case req.GetDeviceSessionId() == "": + return newMalformedEnvelopeError("device_session_id must not be empty") + case req.GetMessageType() == "": + return newMalformedEnvelopeError("message_type must not be empty") + case req.GetTimestampMs() <= 0: + return newMalformedEnvelopeError("timestamp_ms must be greater than zero") + case req.GetRequestId() == "": + return newMalformedEnvelopeError("request_id must not be empty") + case len(req.GetPayloadHash()) == 0: + return newMalformedEnvelopeError("payload_hash must not be empty") + case len(req.GetSignature()) == 0: + return newMalformedEnvelopeError("signature must not be empty") + default: + return newMalformedEnvelopeError("request envelope is invalid") + } +} + +// newMalformedEnvelopeError returns the stable malformed-envelope reject used +// before the gateway performs any auth or routing work. +func newMalformedEnvelopeError(message string) error { + return status.Error(codes.InvalidArgument, message) +} + +// newUnsupportedProtocolVersionError returns the stable reject for a non-empty +// but unsupported protocol_version literal. +func newUnsupportedProtocolVersionError(version string) error { + return status.Error(codes.FailedPrecondition, fmt.Sprintf("unsupported protocol_version %q", version)) +} + +type parsedEnvelopeContextKey struct{} + +type envelopeContextStream struct { + grpc.ServerStreamingServer[gatewayv1.GatewayEvent] + ctx context.Context +} + +func (s envelopeContextStream) Context() context.Context { + if s.ctx == nil { + return context.Background() + } + + return s.ctx +} + +var _ gatewayv1.EdgeGatewayServer = envelopeValidatingService{} diff --git a/gateway/internal/grpcapi/envelope_test.go b/gateway/internal/grpcapi/envelope_test.go new file mode 100644 index 0000000..880fb26 --- /dev/null +++ b/gateway/internal/grpcapi/envelope_test.go @@ -0,0 +1,420 @@ +package grpcapi + +import ( + "context" + "testing" + + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func TestParseExecuteCommandRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*gatewayv1.ExecuteCommandRequest) + wantCode codes.Code + wantMessage string + assertValid func(*testing.T, *gatewayv1.ExecuteCommandRequest, parsedEnvelope) + }{ + { + name: "nil request", + wantCode: codes.InvalidArgument, + wantMessage: "request envelope must not be nil", + }, + { + name: "empty protocol version", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.ProtocolVersion = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "protocol_version must not be empty", + }, + { + name: "empty device session id", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.DeviceSessionId = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "device_session_id must not be empty", + }, + { + name: "empty message type", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.MessageType = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "message_type must not be empty", + }, + { + name: "zero timestamp", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.TimestampMs = 0 + }, + wantCode: codes.InvalidArgument, + wantMessage: "timestamp_ms must be greater than zero", + }, + { + name: "empty request id", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.RequestId = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "request_id must not be empty", + }, + { + name: "empty payload bytes", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.PayloadBytes = nil + }, + wantCode: codes.InvalidArgument, + wantMessage: "payload_bytes must not be empty", + }, + { + name: "empty payload hash", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.PayloadHash = nil + }, + wantCode: codes.InvalidArgument, + wantMessage: "payload_hash must not be empty", + }, + { + name: "empty signature", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.Signature = nil + }, + wantCode: codes.InvalidArgument, + wantMessage: "signature must not be empty", + }, + { + name: "unsupported protocol version", + mutate: func(req *gatewayv1.ExecuteCommandRequest) { + req.ProtocolVersion = "v2" + }, + wantCode: codes.FailedPrecondition, + wantMessage: `unsupported protocol_version "v2"`, + }, + { + name: "valid request", + wantCode: codes.OK, + assertValid: func(t *testing.T, req *gatewayv1.ExecuteCommandRequest, envelope parsedEnvelope) { + t.Helper() + + assert.Equal(t, supportedProtocolVersion, envelope.ProtocolVersion) + assert.Equal(t, req.GetDeviceSessionId(), envelope.DeviceSessionID) + assert.Equal(t, req.GetMessageType(), envelope.MessageType) + assert.Equal(t, req.GetTimestampMs(), envelope.TimestampMS) + assert.Equal(t, req.GetRequestId(), envelope.RequestID) + assert.Equal(t, req.GetTraceId(), envelope.TraceID) + assert.Equal(t, req.GetPayloadBytes(), envelope.PayloadBytes) + assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash) + assert.Equal(t, req.GetSignature(), envelope.Signature) + + originalPayloadBytes := append([]byte(nil), req.GetPayloadBytes()...) + originalPayloadHash := append([]byte(nil), req.GetPayloadHash()...) + originalSignature := append([]byte(nil), req.GetSignature()...) + + envelope.PayloadBytes[0] = 'X' + envelope.PayloadHash[0] = 'Y' + envelope.Signature[0] = 'Z' + + assert.Equal(t, originalPayloadBytes, req.GetPayloadBytes()) + assert.Equal(t, originalPayloadHash, req.GetPayloadHash()) + assert.Equal(t, originalSignature, req.GetSignature()) + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var req *gatewayv1.ExecuteCommandRequest + if tt.name != "nil request" { + req = newValidExecuteCommandRequest() + if tt.mutate != nil { + tt.mutate(req) + } + } + + envelope, err := parseExecuteCommandRequest(req) + if tt.wantCode != codes.OK { + require.Error(t, err) + assert.Equal(t, tt.wantCode, status.Code(err)) + assert.Equal(t, tt.wantMessage, status.Convert(err).Message()) + return + } + + require.NoError(t, err) + require.NotNil(t, tt.assertValid) + tt.assertValid(t, req, envelope) + }) + } +} + +func TestParseSubscribeEventsRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*gatewayv1.SubscribeEventsRequest) + wantCode codes.Code + wantMessage string + assertValid func(*testing.T, *gatewayv1.SubscribeEventsRequest, parsedEnvelope) + }{ + { + name: "nil request", + wantCode: codes.InvalidArgument, + wantMessage: "request envelope must not be nil", + }, + { + name: "empty protocol version", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.ProtocolVersion = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "protocol_version must not be empty", + }, + { + name: "empty device session id", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.DeviceSessionId = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "device_session_id must not be empty", + }, + { + name: "empty message type", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.MessageType = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "message_type must not be empty", + }, + { + name: "zero timestamp", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.TimestampMs = 0 + }, + wantCode: codes.InvalidArgument, + wantMessage: "timestamp_ms must be greater than zero", + }, + { + name: "empty request id", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.RequestId = "" + }, + wantCode: codes.InvalidArgument, + wantMessage: "request_id must not be empty", + }, + { + name: "empty payload hash", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.PayloadHash = nil + }, + wantCode: codes.InvalidArgument, + wantMessage: "payload_hash must not be empty", + }, + { + name: "empty signature", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.Signature = nil + }, + wantCode: codes.InvalidArgument, + wantMessage: "signature must not be empty", + }, + { + name: "unsupported protocol version", + mutate: func(req *gatewayv1.SubscribeEventsRequest) { + req.ProtocolVersion = "v2" + }, + wantCode: codes.FailedPrecondition, + wantMessage: `unsupported protocol_version "v2"`, + }, + { + name: "valid request with empty payload bytes", + wantCode: codes.OK, + assertValid: func(t *testing.T, req *gatewayv1.SubscribeEventsRequest, envelope parsedEnvelope) { + t.Helper() + + assert.Empty(t, req.GetPayloadBytes()) + assert.Empty(t, envelope.PayloadBytes) + assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash) + assert.Equal(t, req.GetSignature(), envelope.Signature) + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var req *gatewayv1.SubscribeEventsRequest + if tt.name != "nil request" { + req = newValidSubscribeEventsRequest() + if tt.mutate != nil { + tt.mutate(req) + } + } + + envelope, err := parseSubscribeEventsRequest(req) + if tt.wantCode != codes.OK { + require.Error(t, err) + assert.Equal(t, tt.wantCode, status.Code(err)) + assert.Equal(t, tt.wantMessage, status.Convert(err).Message()) + return + } + + require.NoError(t, err) + require.NotNil(t, tt.assertValid) + tt.assertValid(t, req, envelope) + }) + } +} + +func TestEnvelopeValidatingServiceExecuteCommandRejectsInvalidRequestBeforeDelegate(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + service := newEnvelopeValidatingService(delegate) + + _, err := service.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{}) + require.Error(t, err) + + assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assert.Zero(t, delegate.executeCalls) +} + +func TestEnvelopeValidatingServiceSubscribeEventsRejectsInvalidRequestBeforeDelegate(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + service := newEnvelopeValidatingService(delegate) + + err := service.SubscribeEvents(&gatewayv1.SubscribeEventsRequest{}, stubGatewayEventStream{}) + require.Error(t, err) + + assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestEnvelopeValidatingServiceExecuteCommandAttachesParsedEnvelope(t *testing.T) { + t.Parallel() + + want := newValidExecuteCommandRequest() + delegate := &recordingEdgeGatewayService{ + executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + envelope, ok := parsedEnvelopeFromContext(ctx) + require.True(t, ok) + assert.Equal(t, want.GetRequestId(), envelope.RequestID) + assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID) + assert.Equal(t, want.GetMessageType(), envelope.MessageType) + assert.Equal(t, want.GetPayloadBytes(), envelope.PayloadBytes) + return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil + }, + } + service := newEnvelopeValidatingService(delegate) + + response, err := service.ExecuteCommand(context.Background(), want) + require.NoError(t, err) + + assert.Equal(t, want.GetRequestId(), response.GetRequestId()) + assert.Equal(t, 1, delegate.executeCalls) +} + +func TestEnvelopeValidatingServiceSubscribeEventsAttachesParsedEnvelope(t *testing.T) { + t.Parallel() + + want := newValidSubscribeEventsRequest() + delegate := &recordingEdgeGatewayService{ + subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + envelope, ok := parsedEnvelopeFromContext(stream.Context()) + require.True(t, ok) + assert.Equal(t, want.GetRequestId(), envelope.RequestID) + assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID) + assert.Equal(t, want.GetMessageType(), envelope.MessageType) + assert.Equal(t, want.GetPayloadHash(), envelope.PayloadHash) + assert.Equal(t, want.GetSignature(), envelope.Signature) + return nil + }, + } + service := newEnvelopeValidatingService(delegate) + + err := service.SubscribeEvents(want, stubGatewayEventStream{}) + require.NoError(t, err) + + assert.Equal(t, 1, delegate.subscribeCalls) +} + +type recordingEdgeGatewayService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + executeCalls int + subscribeCalls int + executeCommandFunc func(context.Context, *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) + subscribeEventsFunc func(*gatewayv1.SubscribeEventsRequest, grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error +} + +func (s *recordingEdgeGatewayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + s.executeCalls++ + if s.executeCommandFunc != nil { + return s.executeCommandFunc(ctx, req) + } + + return &gatewayv1.ExecuteCommandResponse{}, nil +} + +func (s *recordingEdgeGatewayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + s.subscribeCalls++ + if s.subscribeEventsFunc != nil { + return s.subscribeEventsFunc(req, stream) + } + + return nil +} + +type stubGatewayEventStream struct { + grpc.ServerStream + ctx context.Context +} + +func (s stubGatewayEventStream) Send(*gatewayv1.GatewayEvent) error { + return nil +} + +func (s stubGatewayEventStream) SetHeader(metadata.MD) error { + return nil +} + +func (s stubGatewayEventStream) SendHeader(metadata.MD) error { + return nil +} + +func (s stubGatewayEventStream) SetTrailer(metadata.MD) {} + +func (s stubGatewayEventStream) Context() context.Context { + if s.ctx == nil { + return context.Background() + } + + return s.ctx +} + +func (s stubGatewayEventStream) SendMsg(any) error { + return nil +} + +func (s stubGatewayEventStream) RecvMsg(any) error { + return nil +} diff --git a/gateway/internal/grpcapi/freshness_replay.go b/gateway/internal/grpcapi/freshness_replay.go new file mode 100644 index 0000000..905795b --- /dev/null +++ b/gateway/internal/grpcapi/freshness_replay.go @@ -0,0 +1,95 @@ +package grpcapi + +import ( + "context" + "errors" + "time" + + "galaxy/gateway/internal/clock" + "galaxy/gateway/internal/replay" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const minimumReplayReservationTTL = time.Millisecond + +// freshnessAndReplayService applies freshness and anti-replay checks after +// client-signature verification and before later policy or routing steps run. +type freshnessAndReplayService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + delegate gatewayv1.EdgeGatewayServer + clock clock.Clock + replayStore replay.Store + freshnessWindow time.Duration +} + +// ExecuteCommand verifies request freshness and replay protection before +// delegating to the configured service implementation. +func (s freshnessAndReplayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + if err := s.verifyFreshnessAndReplay(ctx); err != nil { + return nil, err + } + + return s.delegate.ExecuteCommand(ctx, req) +} + +// SubscribeEvents verifies request freshness and replay protection before +// delegating to the configured service implementation. +func (s freshnessAndReplayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + if err := s.verifyFreshnessAndReplay(stream.Context()); err != nil { + return err + } + + return s.delegate.SubscribeEvents(req, stream) +} + +// newFreshnessAndReplayService wraps delegate with the freshness and replay +// gate. +func newFreshnessAndReplayService(delegate gatewayv1.EdgeGatewayServer, clk clock.Clock, replayStore replay.Store, freshnessWindow time.Duration) gatewayv1.EdgeGatewayServer { + return freshnessAndReplayService{ + delegate: delegate, + clock: clk, + replayStore: replayStore, + freshnessWindow: freshnessWindow, + } +} + +func (s freshnessAndReplayService) verifyFreshnessAndReplay(ctx context.Context) error { + envelope, ok := parsedEnvelopeFromContext(ctx) + if !ok { + return status.Error(codes.Internal, "authenticated request context is incomplete") + } + + now := s.clock.Now().UTC() + requestTime := time.UnixMilli(envelope.TimestampMS).UTC() + if requestTime.Before(now.Add(-s.freshnessWindow)) || requestTime.After(now.Add(s.freshnessWindow)) { + return status.Error(codes.FailedPrecondition, "request timestamp is outside the freshness window") + } + + ttl := requestTime.Add(s.freshnessWindow).Sub(now) + if ttl < minimumReplayReservationTTL { + ttl = minimumReplayReservationTTL + } + + err := s.replayStore.Reserve(ctx, envelope.DeviceSessionID, envelope.RequestID, ttl) + switch { + case err == nil: + return nil + case errors.Is(err, replay.ErrDuplicate): + return status.Error(codes.FailedPrecondition, "request replay detected") + default: + return status.Error(codes.Unavailable, "replay store is unavailable") + } +} + +type unavailableReplayStore struct{} + +func (unavailableReplayStore) Reserve(context.Context, string, string, time.Duration) error { + return errors.New("replay store is unavailable") +} + +var _ gatewayv1.EdgeGatewayServer = freshnessAndReplayService{} diff --git a/gateway/internal/grpcapi/freshness_replay_integration_test.go b/gateway/internal/grpcapi/freshness_replay_integration_test.go new file mode 100644 index 0000000..d0da946 --- /dev/null +++ b/gateway/internal/grpcapi/freshness_replay_integration_test.go @@ -0,0 +1,509 @@ +package grpcapi + +import ( + "context" + "errors" + "io" + "sync" + "testing" + "time" + + "galaxy/gateway/internal/replay" + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestExecuteCommandRejectsStaleTimestamp(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + timestampMS int64 + }{ + { + name: "past window", + timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(), + }, + { + name: "future window", + timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(), + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS)) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) + }) + } +} + +func TestSubscribeEventsRejectsStaleTimestamp(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + timestampMS int64 + }{ + { + name: "past window", + timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(), + }, + { + name: "future window", + timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(), + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS)) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) + }) + } +} + +func TestExecuteCommandRejectsReplay(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: replayDuplicateBySessionAndRequest(), + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + req := newValidExecuteCommandRequest() + + _, err := client.ExecuteCommand(context.Background(), req) + require.NoError(t, err) + + _, err = client.ExecuteCommand(context.Background(), req) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "request replay detected", status.Convert(err).Message()) + assert.Equal(t, 1, delegate.executeCalls) +} + +func TestSubscribeEventsRejectsReplay(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: replayDuplicateBySessionAndRequest(), + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + req := newValidSubscribeEventsRequest() + + stream, err := client.SubscribeEvents(context.Background(), req) + require.NoError(t, err) + event := recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) + + err = subscribeEventsError(t, context.Background(), client, req) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "request replay detected", status.Convert(err).Message()) + assert.Equal(t, 1, delegate.subscribeCalls) +} + +func TestExecuteCommandAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil + }, + } + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) { + return newActiveSessionRecordWithSessionID(deviceSessionID), nil + }, + }, + ReplayStore: staticReplayStore{ + reserveFunc: replayDuplicateBySessionAndRequest(), + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-shared")) + require.NoError(t, err) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-456", "request-shared")) + require.NoError(t, err) + + assert.Equal(t, 2, delegate.executeCalls) +} + +func TestSubscribeEventsAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + return nil + }, + } + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) { + return newActiveSessionRecordWithSessionID(deviceSessionID), nil + }, + }, + ReplayStore: staticReplayStore{ + reserveFunc: replayDuplicateBySessionAndRequest(), + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-shared")) + require.NoError(t, err) + event := recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli()) + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) + + stream, err = client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-456", "request-shared")) + require.NoError(t, err) + event = recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli()) + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) + + assert.Equal(t, 2, delegate.subscribeCalls) +} + +func TestExecuteCommandRejectsReplayStoreUnavailable(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: func(context.Context, string, string, time.Duration) error { + return errors.New("redis down") + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "replay store is unavailable", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestSubscribeEventsRejectsReplayStoreUnavailable(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: func(context.Context, string, string, time.Duration) error { + return errors.New("redis down") + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "replay store is unavailable", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestExecuteCommandFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil + }, + } + + var reservedDeviceSessionID string + var reservedRequestID string + var reservedTTL time.Duration + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { + reservedDeviceSessionID = deviceSessionID + reservedRequestID = requestID + reservedTTL = ttl + return nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.NoError(t, err) + assert.Equal(t, "request-123", response.GetRequestId()) + assert.Equal(t, "device-session-123", reservedDeviceSessionID) + assert.Equal(t, "request-123", reservedRequestID) + assert.Equal(t, testFreshnessWindow, reservedTTL) + assert.Equal(t, 1, delegate.executeCalls) +} + +func TestSubscribeEventsFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + return nil + }, + } + + var reservedTTL time.Duration + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { + assert.Equal(t, "device-session-123", deviceSessionID) + assert.Equal(t, "request-123", requestID) + reservedTTL = ttl + return nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest()) + require.NoError(t, err) + event := recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) + assert.Equal(t, testFreshnessWindow, reservedTTL) + assert.Equal(t, 1, delegate.subscribeCalls) +} + +func TestExecuteCommandFutureSkewUsesExtendedReplayTTL(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil + }, + } + + var reservedTTL time.Duration + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { + reservedTTL = ttl + return nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand( + context.Background(), + newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(2*time.Minute).UnixMilli()), + ) + require.NoError(t, err) + assert.Equal(t, 7*time.Minute, reservedTTL) + assert.Equal(t, 1, delegate.executeCalls) +} + +func TestExecuteCommandBoundaryFreshnessUsesMinimumReplayTTL(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil + }, + } + + var reservedTTL time.Duration + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{ + reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { + reservedTTL = ttl + return nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand( + context.Background(), + newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(-testFreshnessWindow).UnixMilli()), + ) + require.NoError(t, err) + assert.Equal(t, minimumReplayReservationTTL, reservedTTL) + assert.Equal(t, 1, delegate.executeCalls) +} + +func replayDuplicateBySessionAndRequest() func(context.Context, string, string, time.Duration) error { + var ( + mu sync.Mutex + seen = make(map[string]struct{}) + ) + + return func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { + mu.Lock() + defer mu.Unlock() + + key := deviceSessionID + "\x00" + requestID + if _, ok := seen[key]; ok { + return replay.ErrDuplicate + } + + seen[key] = struct{}{} + return nil + } +} diff --git a/gateway/internal/grpcapi/observability.go b/gateway/internal/grpcapi/observability.go new file mode 100644 index 0000000..0c1463d --- /dev/null +++ b/gateway/internal/grpcapi/observability.go @@ -0,0 +1,147 @@ +package grpcapi + +import ( + "context" + "errors" + "path" + "time" + + "galaxy/gateway/internal/logging" + "galaxy/gateway/internal/telemetry" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func observabilityUnaryInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.UnaryServerInterceptor { + if logger == nil { + logger = zap.NewNop() + } + + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + start := time.Now() + resp, err := handler(ctx, req) + + recordGRPCRequest(logger, metrics, ctx, info.FullMethod, req, resp, err, time.Since(start), "unary") + return resp, err + } +} + +func observabilityStreamInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.StreamServerInterceptor { + if logger == nil { + logger = zap.NewNop() + } + + return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + start := time.Now() + wrapped := &observabilityServerStream{ServerStream: stream} + err := handler(srv, wrapped) + + recordGRPCRequest(logger, metrics, stream.Context(), info.FullMethod, wrapped.request, nil, err, time.Since(start), "stream") + return err + } +} + +type observabilityServerStream struct { + grpc.ServerStream + request any +} + +func (s *observabilityServerStream) RecvMsg(m any) error { + err := s.ServerStream.RecvMsg(m) + if err == nil && s.request == nil { + s.request = m + } + + return err +} + +func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx context.Context, fullMethod string, req any, resp any, err error, duration time.Duration, streamKind string) { + rpcMethod := path.Base(fullMethod) + messageType, requestID, traceID := grpcEnvelopeFields(req) + resultCode := grpcResultCode(resp) + grpcCode, grpcMessage, outcome := grpcOutcome(err) + rejectReason := telemetry.RejectReason(outcome) + + attrs := []attribute.KeyValue{ + attribute.String("rpc_method", rpcMethod), + attribute.String("message_type", messageType), + attribute.String("edge_outcome", string(outcome)), + } + if resultCode != "" { + attrs = append(attrs, attribute.String("result_code", resultCode)) + } + if rejectReason != "" { + attrs = append(attrs, attribute.String("reject_reason", rejectReason)) + } + metrics.RecordAuthenticatedGRPC(ctx, attrs, duration) + + fields := []zap.Field{ + zap.String("component", "authenticated_grpc"), + zap.String("transport", "grpc"), + zap.String("stream_kind", streamKind), + zap.String("rpc_method", rpcMethod), + zap.String("message_type", messageType), + zap.String("grpc_code", grpcCode.String()), + zap.Float64("duration_ms", float64(duration.Microseconds())/1000), + zap.String("request_id", requestID), + zap.String("trace_id", traceID), + zap.String("peer_ip", peerIPFromContext(ctx)), + zap.String("edge_outcome", string(outcome)), + } + if resultCode != "" { + fields = append(fields, zap.String("result_code", resultCode)) + } + if rejectReason != "" { + fields = append(fields, zap.String("reject_reason", rejectReason)) + } + if grpcMessage != "" { + fields = append(fields, zap.String("grpc_message", grpcMessage)) + } + fields = append(fields, logging.TraceFieldsFromContext(ctx)...) + + switch outcome { + case telemetry.EdgeOutcomeSuccess: + logger.Info("authenticated gRPC request completed", fields...) + case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeDownstreamUnavailable, telemetry.EdgeOutcomeInternalError: + logger.Error("authenticated gRPC request failed", fields...) + default: + logger.Warn("authenticated gRPC request rejected", fields...) + } +} + +func grpcEnvelopeFields(req any) (messageType string, requestID string, traceID string) { + switch typed := req.(type) { + case *gatewayv1.ExecuteCommandRequest: + return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId() + case *gatewayv1.SubscribeEventsRequest: + return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId() + default: + return "", "", "" + } +} + +func grpcResultCode(resp any) string { + typed, ok := resp.(*gatewayv1.ExecuteCommandResponse) + if !ok { + return "" + } + + return typed.GetResultCode() +} + +func grpcOutcome(err error) (codes.Code, string, telemetry.EdgeOutcome) { + switch { + case err == nil: + return codes.OK, "", telemetry.EdgeOutcomeSuccess + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return codes.Canceled, err.Error(), telemetry.EdgeOutcomeSuccess + default: + grpcStatus := status.Convert(err) + return grpcStatus.Code(), grpcStatus.Message(), telemetry.OutcomeFromGRPCStatus(grpcStatus.Code(), grpcStatus.Message()) + } +} diff --git a/gateway/internal/grpcapi/payload_hash.go b/gateway/internal/grpcapi/payload_hash.go new file mode 100644 index 0000000..b48d817 --- /dev/null +++ b/gateway/internal/grpcapi/payload_hash.go @@ -0,0 +1,66 @@ +package grpcapi + +import ( + "context" + "errors" + + "galaxy/gateway/internal/authn" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// payloadHashVerifyingService applies payload-hash verification after session +// lookup and before any later auth or routing step runs. +type payloadHashVerifyingService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + delegate gatewayv1.EdgeGatewayServer +} + +// ExecuteCommand verifies req payload integrity before delegating to the +// configured service implementation. +func (s payloadHashVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + if err := verifyPayloadHash(ctx); err != nil { + return nil, err + } + + return s.delegate.ExecuteCommand(ctx, req) +} + +// SubscribeEvents verifies req payload integrity before delegating to the +// configured service implementation. +func (s payloadHashVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + if err := verifyPayloadHash(stream.Context()); err != nil { + return err + } + + return s.delegate.SubscribeEvents(req, stream) +} + +// newPayloadHashVerifyingService wraps delegate with the payload-hash +// verification gate. +func newPayloadHashVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer { + return payloadHashVerifyingService{delegate: delegate} +} + +func verifyPayloadHash(ctx context.Context) error { + envelope, ok := parsedEnvelopeFromContext(ctx) + if !ok { + return status.Error(codes.Internal, "authenticated request context is incomplete") + } + + err := authn.VerifyPayloadHash(envelope.PayloadBytes, envelope.PayloadHash) + switch { + case err == nil: + return nil + case errors.Is(err, authn.ErrInvalidPayloadHash), errors.Is(err, authn.ErrPayloadHashMismatch): + return status.Error(codes.InvalidArgument, err.Error()) + default: + return status.Error(codes.Internal, "payload hash verification failed") + } +} + +var _ gatewayv1.EdgeGatewayServer = payloadHashVerifyingService{} diff --git a/gateway/internal/grpcapi/payload_hash_integration_test.go b/gateway/internal/grpcapi/payload_hash_integration_test.go new file mode 100644 index 0000000..84b1c20 --- /dev/null +++ b/gateway/internal/grpcapi/payload_hash_integration_test.go @@ -0,0 +1,125 @@ +package grpcapi + +import ( + "context" + "crypto/sha256" + "testing" + + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestExecuteCommandRejectsPayloadHashWithInvalidLength(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + req := newValidExecuteCommandRequest() + req.PayloadHash = []byte("short") + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), req) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestExecuteCommandRejectsPayloadHashMismatch(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + req := newValidExecuteCommandRequest() + sum := sha256.Sum256([]byte("other")) + req.PayloadHash = sum[:] + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), req) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestSubscribeEventsRejectsPayloadHashWithInvalidLength(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + req := newValidSubscribeEventsRequest() + req.PayloadHash = []byte("short") + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, req) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestSubscribeEventsRejectsPayloadHashMismatch(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + req := newValidSubscribeEventsRequest() + sum := sha256.Sum256([]byte("other")) + req.PayloadHash = sum[:] + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, req) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} diff --git a/gateway/internal/grpcapi/push_fanout.go b/gateway/internal/grpcapi/push_fanout.go new file mode 100644 index 0000000..de3a290 --- /dev/null +++ b/gateway/internal/grpcapi/push_fanout.go @@ -0,0 +1,172 @@ +package grpcapi + +import ( + "bytes" + "context" + "crypto/sha256" + "errors" + + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/clock" + "galaxy/gateway/internal/logging" + "galaxy/gateway/internal/push" + "galaxy/gateway/internal/telemetry" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// NewFanOutPushStreamService constructs the authenticated SubscribeEvents tail +// service that registers active streams in hub and forwards client-facing +// events after the bootstrap event has been sent. +func NewFanOutPushStreamService(hub *push.Hub, responseSigner authn.ResponseSigner, clk clock.Clock, logger *zap.Logger) gatewayv1.EdgeGatewayServer { + if responseSigner == nil { + responseSigner = unavailableResponseSigner{} + } + if clk == nil { + clk = clock.System{} + } + if logger == nil { + logger = zap.NewNop() + } + + return fanOutPushStreamService{ + hub: hub, + responseSigner: responseSigner, + clock: clk, + logger: logger.Named("push_stream"), + } +} + +// fanOutPushStreamService owns the post-bootstrap authenticated push-stream +// lifecycle backed by the in-memory push hub. +type fanOutPushStreamService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + hub *push.Hub + responseSigner authn.ResponseSigner + clock clock.Clock + logger *zap.Logger +} + +// SubscribeEvents registers the verified stream in the push hub and forwards +// matching client-facing events until the stream ends. +func (s fanOutPushStreamService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + binding, ok := authenticatedStreamBindingFromContext(stream.Context()) + if !ok { + return status.Error(codes.Internal, "authenticated request context is incomplete") + } + if s.hub == nil { + return status.Error(codes.Internal, "push hub is unavailable") + } + + subscription, err := s.hub.Register(push.StreamBinding{ + UserID: binding.UserID, + DeviceSessionID: binding.DeviceSessionID, + }) + if err != nil { + return status.Error(codes.Internal, "push stream registration failed") + } + defer subscription.Close() + + openFields := []zap.Field{ + zap.String("component", "authenticated_grpc"), + zap.String("transport", "grpc"), + zap.String("rpc_method", authenticatedRPCSubscribeEvents), + zap.String("message_type", binding.MessageType), + zap.String("request_id", binding.RequestID), + zap.String("trace_id", binding.TraceID), + zap.String("device_session_id", binding.DeviceSessionID), + zap.String("user_id", binding.UserID), + } + openFields = append(openFields, logging.TraceFieldsFromContext(stream.Context())...) + s.logger.Info("push stream opened", openFields...) + + for { + select { + case <-stream.Context().Done(): + s.logger.Info("push stream closed", append(openFields, zap.String("edge_outcome", string(mapSubscriptionOutcome(stream.Context().Err()))))...) + return stream.Context().Err() + case <-subscription.Done(): + subscriptionErr := subscription.Err() + s.logger.Warn("push stream closed", append(openFields, + zap.String("edge_outcome", string(mapSubscriptionOutcome(subscriptionErr))), + zap.String("reject_reason", string(mapSubscriptionOutcome(subscriptionErr))), + )...) + return mapSubscriptionError(subscriptionErr) + case event := <-subscription.Events(): + signedEvent, err := s.buildGatewayEvent(event) + if err != nil { + return err + } + if err := stream.Send(signedEvent); err != nil { + return err + } + } + } +} + +func (s fanOutPushStreamService) buildGatewayEvent(event push.Event) (*gatewayv1.GatewayEvent, error) { + timestampMS := s.clock.Now().UTC().UnixMilli() + payloadHash := sha256.Sum256(event.PayloadBytes) + + signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{ + EventType: event.EventType, + EventID: event.EventID, + TimestampMS: timestampMS, + RequestID: event.RequestID, + TraceID: event.TraceID, + PayloadHash: payloadHash[:], + }) + if err != nil { + return nil, status.Error(codes.Unavailable, "response signer is unavailable") + } + + return &gatewayv1.GatewayEvent{ + EventType: event.EventType, + EventId: event.EventID, + TimestampMs: timestampMS, + PayloadBytes: bytes.Clone(event.PayloadBytes), + PayloadHash: bytes.Clone(payloadHash[:]), + Signature: signature, + RequestId: event.RequestID, + TraceId: event.TraceID, + }, nil +} + +func mapSubscriptionError(err error) error { + switch { + case err == nil: + return nil + case errors.Is(err, push.ErrSubscriptionRevoked): + return status.Error(codes.FailedPrecondition, "device session is revoked") + case errors.Is(err, push.ErrSubscriptionOverflow): + return status.Error(codes.ResourceExhausted, "push stream overflowed") + case errors.Is(err, push.ErrHubShuttingDown): + return status.Error(codes.Unavailable, "gateway is shutting down") + default: + return status.Error(codes.Internal, "push stream closed unexpectedly") + } +} + +func mapSubscriptionOutcome(err error) telemetry.EdgeOutcome { + switch { + case err == nil: + return telemetry.EdgeOutcomeSuccess + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return telemetry.EdgeOutcomeSuccess + case errors.Is(err, push.ErrSubscriptionRevoked): + return telemetry.EdgeOutcomeRevokedSession + case errors.Is(err, push.ErrSubscriptionOverflow): + return telemetry.EdgeOutcomeRateLimited + case errors.Is(err, push.ErrHubShuttingDown): + return telemetry.EdgeOutcomeGatewayShuttingDown + default: + return telemetry.EdgeOutcomeInternalError + } +} + +var _ gatewayv1.EdgeGatewayServer = fanOutPushStreamService{} diff --git a/gateway/internal/grpcapi/push_stream.go b/gateway/internal/grpcapi/push_stream.go new file mode 100644 index 0000000..404afe6 --- /dev/null +++ b/gateway/internal/grpcapi/push_stream.go @@ -0,0 +1,164 @@ +package grpcapi + +import ( + "bytes" + "context" + "crypto/sha256" + + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/clock" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + gatewayfbs "galaxy/schema/fbs/gateway" + + flatbuffers "github.com/google/flatbuffers/go" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const serverTimeEventType = "gateway.server_time" + +// authenticatedStreamBinding captures the verified identity bound to one +// authenticated SubscribeEvents stream after the full ingress pipeline +// succeeds. +type authenticatedStreamBinding struct { + UserID string + DeviceSessionID string + MessageType string + RequestID string + TraceID string +} + +// authenticatedStreamBindingFromContext returns the verified stream binding +// previously attached to ctx by the authenticated push-stream service. +func authenticatedStreamBindingFromContext(ctx context.Context) (authenticatedStreamBinding, bool) { + if ctx == nil { + return authenticatedStreamBinding{}, false + } + + binding, ok := ctx.Value(authenticatedStreamBindingContextKey{}).(authenticatedStreamBinding) + if !ok { + return authenticatedStreamBinding{}, false + } + + return binding, true +} + +// authenticatedPushStreamService owns SubscribeEvents bootstrap behavior: +// bind the authenticated stream, send the initial signed server-time event, +// and then hand the stream lifecycle to the configured tail delegate. +type authenticatedPushStreamService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + tailDelegate gatewayv1.EdgeGatewayServer + responseSigner authn.ResponseSigner + clock clock.Clock +} + +// SubscribeEvents binds the verified stream identity, sends the initial signed +// server-time event, and then delegates the remaining lifecycle. +func (s authenticatedPushStreamService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + envelope, ok := parsedEnvelopeFromContext(stream.Context()) + if !ok { + return status.Error(codes.Internal, "authenticated request context is incomplete") + } + + record, ok := resolvedSessionFromContext(stream.Context()) + if !ok { + return status.Error(codes.Internal, "authenticated request context is incomplete") + } + + binding := authenticatedStreamBinding{ + UserID: record.UserID, + DeviceSessionID: record.DeviceSessionID, + MessageType: envelope.MessageType, + RequestID: envelope.RequestID, + TraceID: envelope.TraceID, + } + boundStream := authenticatedStreamContextStream{ + ServerStreamingServer: stream, + ctx: context.WithValue( + stream.Context(), + authenticatedStreamBindingContextKey{}, + binding, + ), + } + + serverTimeMS := s.clock.Now().UTC().UnixMilli() + payloadBytes := buildServerTimeEventPayload(serverTimeMS) + payloadHash := sha256.Sum256(payloadBytes) + signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{ + EventType: serverTimeEventType, + EventID: envelope.RequestID, + TimestampMS: serverTimeMS, + RequestID: envelope.RequestID, + TraceID: envelope.TraceID, + PayloadHash: payloadHash[:], + }) + if err != nil { + return status.Error(codes.Unavailable, "response signer is unavailable") + } + + if err := boundStream.Send(&gatewayv1.GatewayEvent{ + EventType: serverTimeEventType, + EventId: envelope.RequestID, + TimestampMs: serverTimeMS, + PayloadBytes: bytes.Clone(payloadBytes), + PayloadHash: bytes.Clone(payloadHash[:]), + Signature: signature, + RequestId: envelope.RequestID, + TraceId: envelope.TraceID, + }); err != nil { + return err + } + + return s.tailDelegate.SubscribeEvents(req, boundStream) +} + +func newAuthenticatedPushStreamService(tailDelegate gatewayv1.EdgeGatewayServer, responseSigner authn.ResponseSigner, clk clock.Clock) gatewayv1.EdgeGatewayServer { + if tailDelegate == nil { + tailDelegate = holdOpenSubscribeEventsService{} + } + + return authenticatedPushStreamService{ + tailDelegate: tailDelegate, + responseSigner: responseSigner, + clock: clk, + } +} + +func buildServerTimeEventPayload(serverTimeMS int64) []byte { + builder := flatbuffers.NewBuilder(32) + gatewayfbs.ServerTimeEventStart(builder) + gatewayfbs.ServerTimeEventAddServerTimeMs(builder, serverTimeMS) + eventOffset := gatewayfbs.ServerTimeEventEnd(builder) + gatewayfbs.FinishServerTimeEventBuffer(builder, eventOffset) + + return bytes.Clone(builder.FinishedBytes()) +} + +type authenticatedStreamBindingContextKey struct{} + +type authenticatedStreamContextStream struct { + grpc.ServerStreamingServer[gatewayv1.GatewayEvent] + ctx context.Context +} + +func (s authenticatedStreamContextStream) Context() context.Context { + if s.ctx == nil { + return context.Background() + } + + return s.ctx +} + +type holdOpenSubscribeEventsService struct { + gatewayv1.UnimplementedEdgeGatewayServer +} + +func (holdOpenSubscribeEventsService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + <-stream.Context().Done() + return stream.Context().Err() +} + +var _ gatewayv1.EdgeGatewayServer = authenticatedPushStreamService{} diff --git a/gateway/internal/grpcapi/rate_limit.go b/gateway/internal/grpcapi/rate_limit.go new file mode 100644 index 0000000..87bad40 --- /dev/null +++ b/gateway/internal/grpcapi/rate_limit.go @@ -0,0 +1,286 @@ +package grpcapi + +import ( + "context" + "errors" + "net" + "strings" + + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/ratelimit" + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +const ( + authenticatedGRPCBaseBucketKeyPrefix = "authenticated_grpc/" + + authenticatedGRPCIPBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "ip=" + authenticatedGRPCSessionBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "session=" + authenticatedGRPCUserBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "user=" + authenticatedGRPCMessageClassBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "message_class=" + + unknownAuthenticatedPeerIP = "unknown" + + authenticatedRPCExecuteCommand = "ExecuteCommand" + authenticatedRPCSubscribeEvents = "SubscribeEvents" +) + +var ( + // ErrAuthenticatedPolicyDenied reports that the authenticated request was + // rejected by later edge policy after transport authenticity succeeded. + ErrAuthenticatedPolicyDenied = errors.New("authenticated request rejected by edge policy") + + // ErrAuthenticatedPolicyUnavailable reports that authenticated policy could + // not be evaluated because its backing dependency is unavailable. + ErrAuthenticatedPolicyUnavailable = errors.New("authenticated request policy is unavailable") +) + +// AuthenticatedRequestLimiter applies authenticated gRPC rate-limit policy to +// one concrete bucket key. +type AuthenticatedRequestLimiter interface { + // Reserve evaluates key under policy and reports whether the request may + // proceed immediately. + Reserve(key string, policy ratelimit.Policy) ratelimit.Decision +} + +// AuthenticatedRequest describes the authenticated request metadata exposed to +// the edge-policy hook. +type AuthenticatedRequest struct { + // RPCMethod identifies the public gRPC method being processed. + RPCMethod string + + // PeerIP is the transport peer IP derived from the gRPC connection. + PeerIP string + + // MessageClass is the stable rate-limit and policy class. The gateway uses + // the full message_type literal because the v1 transport does not yet define + // a coarser authenticated class taxonomy. + MessageClass string + + // Envelope contains the verified transport envelope fields used by later + // edge policy. + Envelope AuthenticatedRequestEnvelope + + // Session contains the authenticated identity resolved from SessionCache. + Session session.Record +} + +// AuthenticatedRequestEnvelope describes the verified request envelope fields +// exposed to the edge-policy hook. +type AuthenticatedRequestEnvelope struct { + // ProtocolVersion is the supported transport protocol version literal. + ProtocolVersion string + + // DeviceSessionID is the authenticated device-session identifier. + DeviceSessionID string + + // MessageType is the verified downstream routing key supplied by the client. + MessageType string + + // TimestampMS is the client timestamp that already passed freshness checks. + TimestampMS int64 + + // RequestID is the authenticated transport request identifier. + RequestID string + + // TraceID is the optional client-supplied correlation identifier. + TraceID string +} + +// AuthenticatedRequestPolicy evaluates later authenticated edge policy after +// transport authenticity and rate-limit checks succeed. +type AuthenticatedRequestPolicy interface { + // Evaluate returns nil when the authenticated request may proceed. It should + // wrap ErrAuthenticatedPolicyDenied for stable reject mapping and + // ErrAuthenticatedPolicyUnavailable when its backing dependency is + // temporarily unavailable. + Evaluate(ctx context.Context, request AuthenticatedRequest) error +} + +type authenticatedRateLimitService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + delegate gatewayv1.EdgeGatewayServer + limiter AuthenticatedRequestLimiter + policy AuthenticatedRequestPolicy + cfg config.AuthenticatedGRPCAntiAbuseConfig +} + +// ExecuteCommand applies authenticated rate limits and edge policy before +// delegating to the configured service implementation. +func (s authenticatedRateLimitService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + if err := s.applyRateLimitsAndPolicy(ctx, authenticatedRPCExecuteCommand); err != nil { + return nil, err + } + + return s.delegate.ExecuteCommand(ctx, req) +} + +// SubscribeEvents applies authenticated rate limits and edge policy before +// delegating to the configured service implementation. +func (s authenticatedRateLimitService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + if err := s.applyRateLimitsAndPolicy(stream.Context(), authenticatedRPCSubscribeEvents); err != nil { + return err + } + + return s.delegate.SubscribeEvents(req, stream) +} + +// newAuthenticatedRateLimitService wraps delegate with the authenticated +// rate-limit and edge-policy gate. +func newAuthenticatedRateLimitService(delegate gatewayv1.EdgeGatewayServer, limiter AuthenticatedRequestLimiter, policy AuthenticatedRequestPolicy, cfg config.AuthenticatedGRPCAntiAbuseConfig) gatewayv1.EdgeGatewayServer { + return authenticatedRateLimitService{ + delegate: delegate, + limiter: limiter, + policy: policy, + cfg: cfg, + } +} + +func (s authenticatedRateLimitService) applyRateLimitsAndPolicy(ctx context.Context, rpcMethod string) error { + request, err := authenticatedRequestFromContext(ctx, rpcMethod) + if err != nil { + return err + } + + if err := s.applyRateLimits(request); err != nil { + return err + } + + if err := s.applyPolicy(ctx, request); err != nil { + return err + } + + return nil +} + +func (s authenticatedRateLimitService) applyRateLimits(request AuthenticatedRequest) error { + checks := []struct { + key string + policy config.AuthenticatedRateLimitConfig + }{ + { + key: authenticatedGRPCIPBucketKey(request.PeerIP), + policy: s.cfg.IP, + }, + { + key: authenticatedGRPCSessionBucketKey(request.Envelope.DeviceSessionID), + policy: s.cfg.Session, + }, + { + key: authenticatedGRPCUserBucketKey(request.Session.UserID), + policy: s.cfg.User, + }, + { + key: authenticatedGRPCMessageClassBucketKey(request.MessageClass), + policy: s.cfg.MessageClass, + }, + } + + for _, check := range checks { + decision := s.limiter.Reserve(check.key, ratelimit.Policy{ + Requests: check.policy.Requests, + Window: check.policy.Window, + Burst: check.policy.Burst, + }) + if !decision.Allowed { + return status.Error(codes.ResourceExhausted, "authenticated request rate limit exceeded") + } + } + + return nil +} + +func (s authenticatedRateLimitService) applyPolicy(ctx context.Context, request AuthenticatedRequest) error { + err := s.policy.Evaluate(ctx, request) + switch { + case err == nil: + return nil + case errors.Is(err, ErrAuthenticatedPolicyDenied): + return status.Error(codes.PermissionDenied, "authenticated request rejected by edge policy") + case errors.Is(err, ErrAuthenticatedPolicyUnavailable): + return status.Error(codes.Unavailable, "authenticated request policy is unavailable") + default: + return status.Error(codes.Internal, "authenticated request policy evaluation failed") + } +} + +func authenticatedRequestFromContext(ctx context.Context, rpcMethod string) (AuthenticatedRequest, error) { + envelope, ok := parsedEnvelopeFromContext(ctx) + if !ok { + return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete") + } + + record, ok := resolvedSessionFromContext(ctx) + if !ok { + return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete") + } + + return AuthenticatedRequest{ + RPCMethod: rpcMethod, + PeerIP: peerIPFromContext(ctx), + MessageClass: authenticatedMessageClass(envelope.MessageType), + Envelope: AuthenticatedRequestEnvelope{ + ProtocolVersion: envelope.ProtocolVersion, + DeviceSessionID: envelope.DeviceSessionID, + MessageType: envelope.MessageType, + TimestampMS: envelope.TimestampMS, + RequestID: envelope.RequestID, + TraceID: envelope.TraceID, + }, + Session: record, + }, nil +} + +func authenticatedGRPCIPBucketKey(peerIP string) string { + return authenticatedGRPCIPBucketKeySegment + peerIP +} + +func authenticatedGRPCSessionBucketKey(deviceSessionID string) string { + return authenticatedGRPCSessionBucketKeySegment + deviceSessionID +} + +func authenticatedGRPCUserBucketKey(userID string) string { + return authenticatedGRPCUserBucketKeySegment + userID +} + +func authenticatedGRPCMessageClassBucketKey(messageClass string) string { + return authenticatedGRPCMessageClassBucketKeySegment + messageClass +} + +func authenticatedMessageClass(messageType string) string { + return messageType +} + +func peerIPFromContext(ctx context.Context) string { + peerInfo, ok := peer.FromContext(ctx) + if !ok || peerInfo.Addr == nil { + return unknownAuthenticatedPeerIP + } + + value := strings.TrimSpace(peerInfo.Addr.String()) + if value == "" { + return unknownAuthenticatedPeerIP + } + + host, _, err := net.SplitHostPort(value) + if err == nil && host != "" { + return host + } + + return value +} + +type noopAuthenticatedRequestPolicy struct{} + +func (noopAuthenticatedRequestPolicy) Evaluate(context.Context, AuthenticatedRequest) error { + return nil +} + +var _ gatewayv1.EdgeGatewayServer = authenticatedRateLimitService{} diff --git a/gateway/internal/grpcapi/rate_limit_integration_test.go b/gateway/internal/grpcapi/rate_limit_integration_test.go new file mode 100644 index 0000000..8d515e2 --- /dev/null +++ b/gateway/internal/grpcapi/rate_limit_integration_test.go @@ -0,0 +1,497 @@ +package grpcapi + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" + "time" + + "galaxy/gateway/internal/app" + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/ratelimit" + "galaxy/gateway/internal/restapi" + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestExecuteCommandRateLimitsByIP(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { + cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + }), ServerDependencies{ + Service: delegate, + SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}), + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")) + require.NoError(t, err) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2")) + require.Error(t, err) + assert.Equal(t, codes.ResourceExhausted, status.Code(err)) + assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message()) + assert.Equal(t, 1, delegate.executeCalls) +} + +func TestExecuteCommandRateLimitsBySession(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { + cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + }), ServerDependencies{ + Service: delegate, + SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}), + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")) + require.NoError(t, err) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2")) + require.Error(t, err) + assert.Equal(t, codes.ResourceExhausted, status.Code(err)) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3")) + require.NoError(t, err) + + assert.Equal(t, 2, delegate.executeCalls) +} + +func TestExecuteCommandRateLimitsByUser(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { + cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + }), ServerDependencies{ + Service: delegate, + SessionCache: userMappedSessionCache(map[string]string{ + "device-session-1": "user-shared", + "device-session-2": "user-shared", + "device-session-3": "user-other", + }), + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")) + require.NoError(t, err) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2")) + require.Error(t, err) + assert.Equal(t, codes.ResourceExhausted, status.Code(err)) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3")) + require.NoError(t, err) + + assert.Equal(t, 2, delegate.executeCalls) +} + +func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { + cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + }), ServerDependencies{ + Service: delegate, + SessionCache: userMappedSessionCache(map[string]string{ + "device-session-1": "user-1", + "device-session-2": "user-2", + }), + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move")) + require.NoError(t, err) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move")) + require.Error(t, err) + assert.Equal(t, codes.ResourceExhausted, status.Code(err)) + + _, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename")) + require.NoError(t, err) + + assert.Equal(t, 2, delegate.executeCalls) +} + +func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) { + t.Parallel() + + policy := &recordingAuthenticatedRequestPolicy{} + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + Policy: policy, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.NoError(t, err) + + require.Len(t, policy.requests, 1) + assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod) + assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP) + assert.Equal(t, "fleet.move", policy.requests[0].MessageClass) + assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID) + assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID) + assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID) + assert.Equal(t, "user-123", policy.requests[0].Session.UserID) + assert.Equal(t, 1, delegate.executeCalls) +} + +func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error { + return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied) + }), + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.PermissionDenied, status.Code(err)) + assert.Equal(t, "authenticated request rejected by edge policy", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { + cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + }), ServerDependencies{ + Service: delegate, + SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}), + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + + stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1")) + require.NoError(t, err) + event := recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli()) + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) + + err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2")) + require.Error(t, err) + assert.Equal(t, codes.ResourceExhausted, status.Code(err)) + assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message()) + assert.Equal(t, 1, delegate.subscribeCalls) +} + +func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) { + t.Parallel() + + sharedLimiter := ratelimit.NewInMemory() + + publicCfg := config.DefaultPublicHTTPConfig() + publicCfg.Addr = unusedTCPAddr(t) + publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + + grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { + cfg.Addr = unusedTCPAddr(t) + cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + }) + + restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{ + AuthService: staticAuthServiceClient{}, + Limiter: publicLimiterAdapter{limiter: sharedLimiter}, + }) + delegate := &recordingEdgeGatewayService{} + grpcServer := NewServer(grpcCfg, ServerDependencies{ + Service: delegate, + Router: executeCommandAdapterRouter{service: delegate}, + ResponseSigner: newTestResponseSigner(), + SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), + ReplayStore: staticReplayStore{}, + Limiter: sharedLimiter, + Clock: fixedClock{now: testCurrentTime}, + }) + + application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer) + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + runGateway := runningGateway{cancel: cancel, resultCh: resultCh} + defer runGateway.stop(t) + + waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz") + addr := waitForListenAddr(t, grpcServer) + + firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code") + secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code") + + assert.Equal(t, http.StatusOK, firstPublic.StatusCode) + assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode) + require.NoError(t, firstPublic.Body.Close()) + require.NoError(t, secondPublic.Body.Close()) + + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.NoError(t, err) +} + +func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig { + cfg := config.DefaultAuthenticatedGRPCConfig() + cfg.Addr = "127.0.0.1:0" + cfg.FreshnessWindow = testFreshnessWindow + cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + + if mutate != nil { + mutate(&cfg) + } + + return cfg +} + +func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest { + req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID) + req.MessageType = messageType + req.Signature = signRequest( + req.GetProtocolVersion(), + req.GetDeviceSessionId(), + req.GetMessageType(), + req.GetTimestampMs(), + req.GetRequestId(), + req.GetPayloadHash(), + ) + + return req +} + +func userMappedSessionCache(users map[string]string) staticSessionCache { + return staticSessionCache{ + lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) { + userID, ok := users[deviceSessionID] + if !ok { + return session.Record{}, session.ErrNotFound + } + + record := newActiveSessionRecordWithSessionID(deviceSessionID) + record.UserID = userID + return record, nil + }, + } +} + +type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error + +func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error { + return f(ctx, request) +} + +type recordingAuthenticatedRequestPolicy struct { + requests []AuthenticatedRequest +} + +func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error { + p.requests = append(p.requests, request) + return nil +} + +type publicLimiterAdapter struct { + limiter ratelimit.Limiter +} + +func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision { + decision := a.limiter.Reserve(key, ratelimit.Policy{ + Requests: policy.Requests, + Window: policy.Window, + Burst: policy.Burst, + }) + + return restapi.PublicRateLimitDecision{ + Allowed: decision.Allowed, + RetryAfter: decision.RetryAfter, + } +} + +type staticAuthServiceClient struct{} + +func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) { + return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil +} + +func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) { + return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil +} + +func waitForHTTPHealthz(t *testing.T, url string) { + t.Helper() + + client := &http.Client{Timeout: 200 * time.Millisecond} + require.Eventually(t, func() bool { + response, err := client.Get(url) + if err != nil { + return false + } + require.NoError(t, response.Body.Close()) + + return response.StatusCode == http.StatusOK + }, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url) +} + +func sendPublicAuthRequest(t *testing.T, url string) *http.Response { + t.Helper() + + request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`)) + require.NoError(t, err) + request.Header.Set("Content-Type", "application/json") + + response, err := (&http.Client{Timeout: time.Second}).Do(request) + require.NoError(t, err) + + return response +} + +func unusedTCPAddr(t *testing.T) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + addr := listener.Addr().String() + require.NoError(t, listener.Close()) + + return addr +} diff --git a/gateway/internal/grpcapi/server.go b/gateway/internal/grpcapi/server.go new file mode 100644 index 0000000..4f7922a --- /dev/null +++ b/gateway/internal/grpcapi/server.go @@ -0,0 +1,260 @@ +// Package grpcapi exposes the authenticated gRPC surface of the gateway. +package grpcapi + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/clock" + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/downstream" + "galaxy/gateway/internal/push" + "galaxy/gateway/internal/ratelimit" + "galaxy/gateway/internal/replay" + "galaxy/gateway/internal/session" + "galaxy/gateway/internal/telemetry" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +// ServerDependencies describes the optional collaborators used by the +// authenticated gRPC server. The zero value is valid and keeps the process +// runnable with the built-in unimplemented service stub. +type ServerDependencies struct { + // Service optionally handles the post-bootstrap SubscribeEvents lifecycle + // after the initial authenticated service event has been sent. When nil, the + // gateway keeps authenticated SubscribeEvents streams open until the client + // cancels them, the server shuts down, or a later stream send fails. + Service gatewayv1.EdgeGatewayServer + + // Router resolves the exact downstream unary client for the verified + // message_type value. When nil, the authenticated unary surface uses an + // empty exact-match router and returns UNIMPLEMENTED for unrouted commands. + Router downstream.Router + + // ResponseSigner signs authenticated unary responses after downstream + // execution succeeds. When nil, the unary surface fails closed once it needs + // to sign a routed response. + ResponseSigner authn.ResponseSigner + + // SessionCache resolves authenticated device sessions after the envelope + // gate succeeds. When nil, the authenticated gRPC surface remains runnable + // but valid envelopes fail closed as session-cache unavailable. + SessionCache session.Cache + + // Clock provides current server time for freshness checks. When nil, the + // authenticated gRPC surface uses the system clock. + Clock clock.Clock + + // ReplayStore reserves authenticated request identifiers after signature + // verification. When nil, valid requests fail closed as replay-store + // unavailable. + ReplayStore replay.Store + + // Limiter applies authenticated rate limits after the request passes the + // transport authenticity checks. When nil, the authenticated gRPC surface + // uses a process-local in-memory limiter. + Limiter AuthenticatedRequestLimiter + + // Policy evaluates later authenticated edge policy after rate limits pass. + // When nil, the authenticated gRPC surface applies a no-op allow policy. + Policy AuthenticatedRequestPolicy + + // Logger writes structured logs for authenticated gRPC traffic. + Logger *zap.Logger + + // Telemetry records low-cardinality gRPC metrics. + Telemetry *telemetry.Runtime + + // PushHub is the active authenticated push-stream hub. When present, the + // server closes active streams before GracefulStop during shutdown. + PushHub *push.Hub +} + +// Server owns the authenticated gRPC listener exposed by the gateway. +type Server struct { + cfg config.AuthenticatedGRPCConfig + service gatewayv1.EdgeGatewayServer + logger *zap.Logger + pushHub *push.Hub + metrics *telemetry.Runtime + + stateMu sync.RWMutex + server *grpc.Server + listener net.Listener +} + +// NewServer constructs an authenticated gRPC server for the supplied listener +// configuration and dependency bundle. Nil dependencies are replaced with safe +// defaults so the gateway can expose the documented transport surface with the +// full auth pipeline wired from built-in fallbacks. +func NewServer(cfg config.AuthenticatedGRPCConfig, deps ServerDependencies) *Server { + deps = normalizeServerDependencies(deps) + + finalService := newCommandRoutingService( + newAuthenticatedPushStreamService(deps.Service, deps.ResponseSigner, deps.Clock), + deps.Router, + deps.ResponseSigner, + deps.Clock, + cfg.DownstreamTimeout, + ) + + return &Server{ + cfg: cfg, + service: newEnvelopeValidatingService( + newSessionLookupService( + newPayloadHashVerifyingService( + newSignatureVerifyingService( + newFreshnessAndReplayService( + newAuthenticatedRateLimitService( + finalService, + deps.Limiter, + deps.Policy, + cfg.AntiAbuse, + ), + deps.Clock, + deps.ReplayStore, + cfg.FreshnessWindow, + ), + ), + ), + deps.SessionCache, + ), + ), + logger: deps.Logger.Named("authenticated_grpc"), + pushHub: deps.PushHub, + metrics: deps.Telemetry, + } +} + +// Run binds the configured listener and serves the authenticated gRPC surface +// until Shutdown closes the server. +func (s *Server) Run(ctx context.Context) error { + if ctx == nil { + return errors.New("run authenticated gRPC server: nil context") + } + if err := ctx.Err(); err != nil { + return err + } + + listener, err := net.Listen("tcp", s.cfg.Addr) + if err != nil { + return fmt.Errorf("run authenticated gRPC server: listen on %q: %w", s.cfg.Addr, err) + } + + grpcServer := grpc.NewServer( + grpc.ConnectionTimeout(s.cfg.ConnectionTimeout), + grpc.StatsHandler(otelgrpc.NewServerHandler()), + grpc.ChainUnaryInterceptor(observabilityUnaryInterceptor(s.logger, s.metrics)), + grpc.ChainStreamInterceptor(observabilityStreamInterceptor(s.logger, s.metrics)), + ) + gatewayv1.RegisterEdgeGatewayServer(grpcServer, s.service) + + s.stateMu.Lock() + s.server = grpcServer + s.listener = listener + s.stateMu.Unlock() + + s.logger.Info("authenticated gRPC server started", zap.String("addr", listener.Addr().String())) + + defer func() { + s.stateMu.Lock() + s.server = nil + s.listener = nil + s.stateMu.Unlock() + }() + + err = grpcServer.Serve(listener) + switch { + case err == nil: + return nil + case errors.Is(err, grpc.ErrServerStopped): + s.logger.Info("authenticated gRPC server stopped") + return nil + default: + return fmt.Errorf("run authenticated gRPC server: serve on %q: %w", s.cfg.Addr, err) + } +} + +// Shutdown gracefully stops the authenticated gRPC server within ctx. When the +// graceful stop exceeds ctx, the server is force-stopped before returning the +// timeout to the caller. +func (s *Server) Shutdown(ctx context.Context) error { + if ctx == nil { + return errors.New("shutdown authenticated gRPC server: nil context") + } + + s.stateMu.RLock() + server := s.server + s.stateMu.RUnlock() + + if server == nil { + return nil + } + + if s.pushHub != nil { + s.pushHub.Shutdown() + } + + stopped := make(chan struct{}) + go func() { + server.GracefulStop() + close(stopped) + }() + + select { + case <-stopped: + return nil + case <-ctx.Done(): + server.Stop() + <-stopped + return fmt.Errorf("shutdown authenticated gRPC server: %w", ctx.Err()) + } +} + +func (s *Server) listenAddr() string { + s.stateMu.RLock() + defer s.stateMu.RUnlock() + + if s.listener == nil { + return "" + } + + return s.listener.Addr().String() +} + +func normalizeServerDependencies(deps ServerDependencies) ServerDependencies { + if deps.Router == nil { + deps.Router = downstream.NewStaticRouter(nil) + } + if deps.ResponseSigner == nil { + deps.ResponseSigner = unavailableResponseSigner{} + } + if deps.SessionCache == nil { + deps.SessionCache = unavailableSessionCache{} + } + if deps.Clock == nil { + deps.Clock = clock.System{} + } + if deps.ReplayStore == nil { + deps.ReplayStore = unavailableReplayStore{} + } + if deps.Limiter == nil { + deps.Limiter = ratelimit.NewInMemory() + } + if deps.Policy == nil { + deps.Policy = noopAuthenticatedRequestPolicy{} + } + if deps.Logger == nil { + deps.Logger = zap.NewNop() + } + + return deps +} diff --git a/gateway/internal/grpcapi/server_test.go b/gateway/internal/grpcapi/server_test.go new file mode 100644 index 0000000..49ad4c5 --- /dev/null +++ b/gateway/internal/grpcapi/server_test.go @@ -0,0 +1,332 @@ +package grpcapi + +import ( + "context" + "testing" + "time" + + "galaxy/gateway/internal/app" + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" +) + +func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{}) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{}) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) +} + +func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{}) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{}) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) +} + +func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{}) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{ + ProtocolVersion: "v2", + DeviceSessionId: "device-session-123", + MessageType: "fleet.move", + TimestampMs: 123456789, + RequestId: "request-123", + PayloadBytes: []byte("payload"), + PayloadHash: []byte("hash"), + Signature: []byte("signature"), + }) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message()) +} + +func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{ + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return newActiveSessionRecord(), nil + }, + }, + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unimplemented, status.Code(err)) +} + +func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{ + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return newActiveSessionRecord(), nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "replay store is unavailable", status.Convert(err).Message()) +} + +func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + server, runGateway := newTestGateway(t, ServerDependencies{ + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return newActiveSessionRecord(), nil + }, + }, + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest()) + require.NoError(t, err) + + event := recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) + + recvResult := make(chan error, 1) + go func() { + _, recvErr := stream.Recv() + recvResult <- recvErr + }() + + require.Never(t, func() bool { + select { + case <-recvResult: + return true + default: + return false + } + }, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation") + + cancel() + + var recvErr error + require.Eventually(t, func() bool { + select { + case recvErr = <-recvResult: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation") + require.Error(t, recvErr) + assert.Equal(t, codes.Canceled, status.Code(recvErr)) +} + +func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{ + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return newActiveSessionRecord(), nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "replay store is unavailable", status.Convert(err).Message()) +} + +func TestServerLifecycle(t *testing.T) { + t.Parallel() + + server, runGateway := newTestGateway(t, ServerDependencies{}) + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + require.NoError(t, conn.Close()) + + runGateway.stop(t) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := grpc.DialContext( + ctx, + addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + require.Error(t, err) +} + +type runningGateway struct { + cancel context.CancelFunc + resultCh chan error +} + +func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) { + t.Helper() + + grpcCfg := config.DefaultAuthenticatedGRPCConfig() + grpcCfg.Addr = "127.0.0.1:0" + grpcCfg.FreshnessWindow = testFreshnessWindow + + return newTestGatewayWithGRPCConfig(t, grpcCfg, deps) +} + +func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) { + t.Helper() + + cfg := config.Config{ + ShutdownTimeout: time.Second, + AuthenticatedGRPC: grpcCfg, + } + + if deps.Clock == nil { + deps.Clock = fixedClock{now: testCurrentTime} + } + if deps.ResponseSigner == nil { + deps.ResponseSigner = newTestResponseSigner() + } + if deps.Router == nil && deps.Service != nil { + deps.Router = executeCommandAdapterRouter{service: deps.Service} + } + + server := NewServer(cfg.AuthenticatedGRPC, deps) + application := app.New(cfg, server) + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + + return server, runningGateway{ + cancel: cancel, + resultCh: resultCh, + } +} + +func (g runningGateway) stop(t *testing.T) { + t.Helper() + + g.cancel() + + var err error + require.Eventually(t, func() bool { + select { + case err = <-g.resultCh: + return true + default: + return false + } + }, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation") + require.NoError(t, err) +} + +func waitForListenAddr(t *testing.T, server *Server) string { + t.Helper() + + var addr string + require.Eventually(t, func() bool { + addr = server.listenAddr() + return addr != "" + }, time.Second, 10*time.Millisecond, "server did not start listening") + return addr +} + +func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + conn, err := grpc.DialContext( + ctx, + addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + require.NoError(t, err) + + return conn +} diff --git a/gateway/internal/grpcapi/session_lookup.go b/gateway/internal/grpcapi/session_lookup.go new file mode 100644 index 0000000..a987619 --- /dev/null +++ b/gateway/internal/grpcapi/session_lookup.go @@ -0,0 +1,126 @@ +package grpcapi + +import ( + "context" + "errors" + + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// resolvedSessionFromContext returns the session record previously attached to +// ctx by the session-lookup gateway wrapper. +func resolvedSessionFromContext(ctx context.Context) (session.Record, bool) { + if ctx == nil { + return session.Record{}, false + } + + record, ok := ctx.Value(resolvedSessionContextKey{}).(session.Record) + if !ok { + return session.Record{}, false + } + + return cloneSessionRecord(record), true +} + +// sessionLookupService resolves the authenticated session from SessionCache +// after envelope parsing succeeds and before later auth steps run. +type sessionLookupService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + delegate gatewayv1.EdgeGatewayServer + cache session.Cache +} + +// ExecuteCommand resolves the cached session for req and only then forwards it +// to the configured delegate with the resolved session attached to ctx. +func (s sessionLookupService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + record, err := s.lookupSession(ctx) + if err != nil { + return nil, err + } + + return s.delegate.ExecuteCommand(context.WithValue(ctx, resolvedSessionContextKey{}, cloneSessionRecord(record)), req) +} + +// SubscribeEvents resolves the cached session for req and only then forwards it +// to the configured delegate with the resolved session attached to the stream +// context. +func (s sessionLookupService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + record, err := s.lookupSession(stream.Context()) + if err != nil { + return err + } + + return s.delegate.SubscribeEvents(req, resolvedSessionContextStream{ + ServerStreamingServer: stream, + ctx: context.WithValue(stream.Context(), resolvedSessionContextKey{}, cloneSessionRecord(record)), + }) +} + +// newSessionLookupService wraps delegate with the session-cache lookup gate. +func newSessionLookupService(delegate gatewayv1.EdgeGatewayServer, cache session.Cache) gatewayv1.EdgeGatewayServer { + return sessionLookupService{ + delegate: delegate, + cache: cache, + } +} + +func (s sessionLookupService) lookupSession(ctx context.Context) (session.Record, error) { + envelope, ok := parsedEnvelopeFromContext(ctx) + if !ok { + return session.Record{}, status.Error(codes.Internal, "authenticated request context is incomplete") + } + + record, err := s.cache.Lookup(ctx, envelope.DeviceSessionID) + switch { + case err == nil: + case errors.Is(err, session.ErrNotFound): + return session.Record{}, status.Error(codes.Unauthenticated, "unknown device session") + default: + return session.Record{}, status.Error(codes.Unavailable, "session cache is unavailable") + } + + if record.Status == session.StatusRevoked { + return session.Record{}, status.Error(codes.FailedPrecondition, "device session is revoked") + } + + return cloneSessionRecord(record), nil +} + +func cloneSessionRecord(record session.Record) session.Record { + cloned := record + if record.RevokedAtMS != nil { + value := *record.RevokedAtMS + cloned.RevokedAtMS = &value + } + + return cloned +} + +type resolvedSessionContextKey struct{} + +type resolvedSessionContextStream struct { + grpc.ServerStreamingServer[gatewayv1.GatewayEvent] + ctx context.Context +} + +func (s resolvedSessionContextStream) Context() context.Context { + if s.ctx == nil { + return context.Background() + } + + return s.ctx +} + +type unavailableSessionCache struct{} + +func (unavailableSessionCache) Lookup(context.Context, string) (session.Record, error) { + return session.Record{}, errors.New("session cache is unavailable") +} + +var _ gatewayv1.EdgeGatewayServer = sessionLookupService{} diff --git a/gateway/internal/grpcapi/session_lookup_integration_test.go b/gateway/internal/grpcapi/session_lookup_integration_test.go new file mode 100644 index 0000000..08b144a --- /dev/null +++ b/gateway/internal/grpcapi/session_lookup_integration_test.go @@ -0,0 +1,294 @@ +package grpcapi + +import ( + "context" + "errors" + "io" + "testing" + + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestExecuteCommandRejectsUnknownSession(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return session.Record{}, session.ErrNotFound + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + assert.Equal(t, "unknown device session", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestSubscribeEventsRejectsUnknownSession(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return session.Record{}, session.ErrNotFound + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + assert.Equal(t, "unknown device session", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestExecuteCommandRejectsRevokedSession(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "device session is revoked", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestSubscribeEventsRejectsRevokedSession(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) + require.Error(t, err) + assert.Equal(t, codes.FailedPrecondition, status.Code(err)) + assert.Equal(t, "device session is revoked", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestExecuteCommandRejectsSessionCacheUnavailable(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return session.Record{}, errors.New("redis down") + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "session cache is unavailable", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestSubscribeEventsRejectsSessionCacheUnavailable(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + return session.Record{}, errors.New("redis down") + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "session cache is unavailable", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestExecuteCommandAttachesResolvedSession(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + record, ok := resolvedSessionFromContext(ctx) + require.True(t, ok) + assert.Equal(t, newActiveSessionRecord(), record) + return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil + }, + } + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.NoError(t, err) + assert.Equal(t, "request-123", response.GetRequestId()) +} + +func TestSubscribeEventsAttachesResolvedSession(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + record, ok := resolvedSessionFromContext(stream.Context()) + require.True(t, ok) + assert.Equal(t, newActiveSessionRecord(), record) + return nil + }, + } + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest()) + require.NoError(t, err) + + event := recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) + + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) +} + +func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{ + subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + binding, ok := authenticatedStreamBindingFromContext(stream.Context()) + require.True(t, ok) + assert.Equal(t, authenticatedStreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-123", + MessageType: "gateway.subscribe", + RequestID: "request-123", + TraceID: "trace-123", + }, binding) + return nil + }, + } + + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + ReplayStore: staticReplayStore{}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest()) + require.NoError(t, err) + + event := recvBootstrapEvent(t, stream) + assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) + + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) +} + +type staticSessionCache struct { + lookupFunc func(context.Context, string) (session.Record, error) +} + +func (c staticSessionCache) Lookup(ctx context.Context, deviceSessionID string) (session.Record, error) { + return c.lookupFunc(ctx, deviceSessionID) +} diff --git a/gateway/internal/grpcapi/signature.go b/gateway/internal/grpcapi/signature.go new file mode 100644 index 0000000..31c5f6a --- /dev/null +++ b/gateway/internal/grpcapi/signature.go @@ -0,0 +1,80 @@ +package grpcapi + +import ( + "context" + "errors" + + "galaxy/gateway/internal/authn" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// signatureVerifyingService applies client-signature verification after +// payload integrity checks and before later auth or routing steps run. +type signatureVerifyingService struct { + gatewayv1.UnimplementedEdgeGatewayServer + + delegate gatewayv1.EdgeGatewayServer +} + +// ExecuteCommand verifies req client signature before delegating to the +// configured service implementation. +func (s signatureVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { + if err := verifyRequestSignature(ctx); err != nil { + return nil, err + } + + return s.delegate.ExecuteCommand(ctx, req) +} + +// SubscribeEvents verifies req client signature before delegating to the +// configured service implementation. +func (s signatureVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { + if err := verifyRequestSignature(stream.Context()); err != nil { + return err + } + + return s.delegate.SubscribeEvents(req, stream) +} + +// newSignatureVerifyingService wraps delegate with the client-signature +// verification gate. +func newSignatureVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer { + return signatureVerifyingService{delegate: delegate} +} + +func verifyRequestSignature(ctx context.Context) error { + envelope, ok := parsedEnvelopeFromContext(ctx) + if !ok { + return status.Error(codes.Internal, "authenticated request context is incomplete") + } + + record, ok := resolvedSessionFromContext(ctx) + if !ok { + return status.Error(codes.Internal, "authenticated request context is incomplete") + } + + err := authn.VerifyRequestSignature(record.ClientPublicKey, envelope.Signature, authn.RequestSigningFields{ + ProtocolVersion: envelope.ProtocolVersion, + DeviceSessionID: envelope.DeviceSessionID, + MessageType: envelope.MessageType, + TimestampMS: envelope.TimestampMS, + RequestID: envelope.RequestID, + PayloadHash: envelope.PayloadHash, + }) + switch { + case err == nil: + return nil + case errors.Is(err, authn.ErrInvalidClientPublicKey): + return status.Error(codes.Unavailable, "session cache is unavailable") + case errors.Is(err, authn.ErrInvalidRequestSignature): + return status.Error(codes.Unauthenticated, "invalid request signature") + default: + return status.Error(codes.Internal, "request signature verification failed") + } +} + +var _ gatewayv1.EdgeGatewayServer = signatureVerifyingService{} diff --git a/gateway/internal/grpcapi/signature_integration_test.go b/gateway/internal/grpcapi/signature_integration_test.go new file mode 100644 index 0000000..3b36911 --- /dev/null +++ b/gateway/internal/grpcapi/signature_integration_test.go @@ -0,0 +1,188 @@ +package grpcapi + +import ( + "context" + "testing" + + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestExecuteCommandRejectsInvalidSignature(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + req := newValidExecuteCommandRequest() + req.Signature[0] ^= 0xff + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), req) + require.Error(t, err) + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + assert.Equal(t, "invalid request signature", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestExecuteCommandRejectsWrongKey(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + record := newActiveSessionRecord() + record.ClientPublicKey = alternateTestClientPublicKeyBase64() + return record, nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + assert.Equal(t, "invalid request signature", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestExecuteCommandRejectsInvalidCachedPublicKey(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + record := newActiveSessionRecord() + record.ClientPublicKey = "%%%not-base64%%%" + return record, nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "session cache is unavailable", status.Convert(err).Message()) + assert.Zero(t, delegate.executeCalls) +} + +func TestSubscribeEventsRejectsInvalidSignature(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + req := newValidSubscribeEventsRequest() + req.Signature[0] ^= 0xff + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, req) + require.Error(t, err) + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + assert.Equal(t, "invalid request signature", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestSubscribeEventsRejectsWrongKey(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + record := newActiveSessionRecord() + record.ClientPublicKey = alternateTestClientPublicKeyBase64() + return record, nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + assert.Equal(t, "invalid request signature", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} + +func TestSubscribeEventsRejectsInvalidCachedPublicKey(t *testing.T) { + t.Parallel() + + delegate := &recordingEdgeGatewayService{} + server, runGateway := newTestGateway(t, ServerDependencies{ + Service: delegate, + SessionCache: staticSessionCache{ + lookupFunc: func(context.Context, string) (session.Record, error) { + record := newActiveSessionRecord() + record.ClientPublicKey = "%%%not-base64%%%" + return record, nil + }, + }, + }) + defer runGateway.stop(t) + + addr := waitForListenAddr(t, server) + conn := dialGatewayClient(t, addr) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := gatewayv1.NewEdgeGatewayClient(conn) + err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, "session cache is unavailable", status.Convert(err).Message()) + assert.Zero(t, delegate.subscribeCalls) +} diff --git a/gateway/internal/grpcapi/test_fixtures_test.go b/gateway/internal/grpcapi/test_fixtures_test.go new file mode 100644 index 0000000..04be95b --- /dev/null +++ b/gateway/internal/grpcapi/test_fixtures_test.go @@ -0,0 +1,298 @@ +package grpcapi + +import ( + "context" + "crypto/ed25519" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "time" + + "galaxy/gateway/internal/authn" + "galaxy/gateway/internal/downstream" + "galaxy/gateway/internal/session" + gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" + + gatewayfbs "galaxy/schema/fbs/gateway" + + flatbuffers "github.com/google/flatbuffers/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +var ( + testCurrentTime = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) + testFreshnessWindow = 5 * time.Minute +) + +func newValidExecuteCommandRequest() *gatewayv1.ExecuteCommandRequest { + return newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-123") +} + +func newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.ExecuteCommandRequest { + return newValidExecuteCommandRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli()) +} + +func newValidExecuteCommandRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.ExecuteCommandRequest { + payloadBytes := []byte("payload") + payloadHash := sha256.Sum256(payloadBytes) + + req := &gatewayv1.ExecuteCommandRequest{ + ProtocolVersion: supportedProtocolVersion, + DeviceSessionId: deviceSessionID, + MessageType: "fleet.move", + TimestampMs: timestampMS, + RequestId: requestID, + PayloadBytes: payloadBytes, + PayloadHash: payloadHash[:], + TraceId: "trace-123", + } + req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash()) + + return req +} + +func newValidSubscribeEventsRequest() *gatewayv1.SubscribeEventsRequest { + return newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-123") +} + +func newValidSubscribeEventsRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest { + return newValidSubscribeEventsRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli()) +} + +func newValidSubscribeEventsRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.SubscribeEventsRequest { + payloadHash := sha256.Sum256(nil) + + req := &gatewayv1.SubscribeEventsRequest{ + ProtocolVersion: supportedProtocolVersion, + DeviceSessionId: deviceSessionID, + MessageType: "gateway.subscribe", + TimestampMs: timestampMS, + RequestId: requestID, + PayloadHash: payloadHash[:], + TraceId: "trace-123", + } + req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash()) + + return req +} + +func newActiveSessionRecord() session.Record { + return newActiveSessionRecordWithSessionID("device-session-123") +} + +func newActiveSessionRecordWithSessionID(deviceSessionID string) session.Record { + return session.Record{ + DeviceSessionID: deviceSessionID, + UserID: "user-123", + ClientPublicKey: testClientPublicKeyBase64(), + Status: session.StatusActive, + } +} + +func newRevokedSessionRecord() session.Record { + revokedAtMS := int64(123456789) + + return session.Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: testClientPublicKeyBase64(), + Status: session.StatusRevoked, + RevokedAtMS: &revokedAtMS, + } +} + +func alternateTestClientPublicKeyBase64() string { + return base64.StdEncoding.EncodeToString(newTestPrivateKey("alternate").Public().(ed25519.PublicKey)) +} + +func testClientPublicKeyBase64() string { + return base64.StdEncoding.EncodeToString(newTestPrivateKey("primary").Public().(ed25519.PublicKey)) +} + +func signRequest(protocolVersion, deviceSessionID, messageType string, timestampMS int64, requestID string, payloadHash []byte) []byte { + return ed25519.Sign(newTestPrivateKey("primary"), authn.BuildRequestSigningInput(authn.RequestSigningFields{ + ProtocolVersion: protocolVersion, + DeviceSessionID: deviceSessionID, + MessageType: messageType, + TimestampMS: timestampMS, + RequestID: requestID, + PayloadHash: payloadHash, + })) +} + +func newTestPrivateKey(label string) ed25519.PrivateKey { + seed := sha256.Sum256([]byte("gateway-grpcapi-signature-test-" + label)) + return ed25519.NewKeyFromSeed(seed[:]) +} + +func newTestEd25519ResponseSigner() *authn.Ed25519ResponseSigner { + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: mustMarshalPKCS8PrivateKey(newTestPrivateKey("response-signer")), + }) + + signer, err := authn.ParseEd25519ResponseSignerPEM(pemBytes) + if err != nil { + panic(err) + } + + return signer +} + +func newTestResponseSigner() authn.ResponseSigner { + return newTestEd25519ResponseSigner() +} + +func newTestResponseSignerPublicKey() ed25519.PublicKey { + return newTestEd25519ResponseSigner().PublicKey() +} + +func mustMarshalPKCS8PrivateKey(privateKey ed25519.PrivateKey) []byte { + encoded, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + panic(err) + } + + return encoded +} + +type fixedClock struct { + now time.Time +} + +func (c fixedClock) Now() time.Time { + return c.now +} + +func recvBootstrapEvent(t interface { + require.TestingT + Helper() +}, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent { + t.Helper() + + event, err := stream.Recv() + require.NoError(t, err) + + return event +} + +func subscribeEventsError(t interface { + require.TestingT + Helper() +}, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error { + t.Helper() + + stream, err := client.SubscribeEvents(ctx, req) + if err != nil { + return err + } + + _, err = stream.Recv() + return err +} + +func assertServerTimeBootstrapEvent(t interface { + require.TestingT + Helper() +}, event *gatewayv1.GatewayEvent, publicKey ed25519.PublicKey, wantRequestID string, wantTraceID string, wantTimestampMS int64) { + t.Helper() + + require.NotNil(t, event) + assert.Equal(t, serverTimeEventType, event.GetEventType()) + assert.Equal(t, wantRequestID, event.GetEventId()) + assert.Equal(t, wantRequestID, event.GetRequestId()) + assert.Equal(t, wantTraceID, event.GetTraceId()) + assert.Equal(t, wantTimestampMS, event.GetTimestampMs()) + require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) + require.NoError(t, authn.VerifyEventSignature(publicKey, event.GetSignature(), authn.EventSigningFields{ + EventType: event.GetEventType(), + EventID: event.GetEventId(), + TimestampMS: event.GetTimestampMs(), + RequestID: event.GetRequestId(), + TraceID: event.GetTraceId(), + PayloadHash: event.GetPayloadHash(), + })) + + payload := gatewayfbs.GetRootAsServerTimeEvent(event.GetPayloadBytes(), flatbuffers.UOffsetT(0)) + assert.Equal(t, wantTimestampMS, payload.ServerTimeMs()) +} + +type staticReplayStore struct { + reserveFunc func(context.Context, string, string, time.Duration) error +} + +func (s staticReplayStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { + if s.reserveFunc != nil { + return s.reserveFunc(ctx, deviceSessionID, requestID, ttl) + } + + return nil +} + +type executeCommandAdapterRouter struct { + service gatewayv1.EdgeGatewayServer +} + +func (r executeCommandAdapterRouter) Route(string) (downstream.Client, error) { + return executeCommandAdapterClient{service: r.service}, nil +} + +type executeCommandAdapterClient struct { + service gatewayv1.EdgeGatewayServer +} + +func (c executeCommandAdapterClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { + response, err := c.service.ExecuteCommand(ctx, &gatewayv1.ExecuteCommandRequest{ + ProtocolVersion: command.ProtocolVersion, + DeviceSessionId: command.DeviceSessionID, + MessageType: command.MessageType, + TimestampMs: command.TimestampMS, + RequestId: command.RequestID, + PayloadBytes: command.PayloadBytes, + TraceId: command.TraceID, + }) + if err != nil { + return downstream.UnaryResult{}, err + } + + resultCode := response.GetResultCode() + if resultCode == "" { + resultCode = "ok" + } + + return downstream.UnaryResult{ + ResultCode: resultCode, + PayloadBytes: response.GetPayloadBytes(), + }, nil +} + +type recordingDownstreamClient struct { + executeCalls int + commands []downstream.AuthenticatedCommand + executeFunc func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) +} + +func (c *recordingDownstreamClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { + c.executeCalls++ + c.commands = append(c.commands, downstream.AuthenticatedCommand{ + ProtocolVersion: command.ProtocolVersion, + UserID: command.UserID, + DeviceSessionID: command.DeviceSessionID, + MessageType: command.MessageType, + TimestampMS: command.TimestampMS, + RequestID: command.RequestID, + TraceID: command.TraceID, + PayloadBytes: append([]byte(nil), command.PayloadBytes...), + }) + if c.executeFunc != nil { + return c.executeFunc(ctx, command) + } + + return downstream.UnaryResult{ + ResultCode: "ok", + PayloadBytes: []byte("response"), + }, nil +} diff --git a/gateway/internal/logging/logger.go b/gateway/internal/logging/logger.go new file mode 100644 index 0000000..8e964c3 --- /dev/null +++ b/gateway/internal/logging/logger.go @@ -0,0 +1,84 @@ +// Package logging configures the gateway structured logger and provides +// context-aware helpers for attaching OpenTelemetry trace identifiers. +package logging + +import ( + "context" + "strings" + + "galaxy/gateway/internal/config" + + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// New constructs the process-wide JSON logger from cfg. +func New(cfg config.LoggingConfig) (*zap.Logger, error) { + level := zap.NewAtomicLevel() + if err := level.UnmarshalText([]byte(strings.TrimSpace(cfg.Level))); err != nil { + return nil, err + } + + zapCfg := zap.NewProductionConfig() + zapCfg.Level = level + zapCfg.Sampling = nil + zapCfg.Encoding = "json" + zapCfg.EncoderConfig.TimeKey = "timestamp" + zapCfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + zapCfg.OutputPaths = []string{"stdout"} + zapCfg.ErrorOutputPaths = []string{"stderr"} + + return zapCfg.Build() +} + +// TraceFieldsFromContext returns zap fields for the active OpenTelemetry span +// when ctx carries a valid span context. +func TraceFieldsFromContext(ctx context.Context) []zap.Field { + if ctx == nil { + return nil + } + + spanContext := trace.SpanContextFromContext(ctx) + if !spanContext.IsValid() { + return nil + } + + return []zap.Field{ + zap.String("otel_trace_id", spanContext.TraceID().String()), + zap.String("otel_span_id", spanContext.SpanID().String()), + } +} + +// Sync flushes logger and ignores the benign stdout or stderr sync errors +// commonly returned by containerized or redirected process outputs. +func Sync(logger *zap.Logger) error { + if logger == nil { + return nil + } + + err := logger.Sync() + if err == nil || isIgnorableSyncError(err) { + return nil + } + + return err +} + +func isIgnorableSyncError(err error) bool { + if err == nil { + return false + } + + message := strings.ToLower(err.Error()) + switch { + case strings.Contains(message, "invalid argument"): + return true + case strings.Contains(message, "bad file descriptor"): + return true + case strings.Contains(message, "inappropriate ioctl for device"): + return true + default: + return false + } +} diff --git a/gateway/internal/push/hub.go b/gateway/internal/push/hub.go new file mode 100644 index 0000000..9a4a64c --- /dev/null +++ b/gateway/internal/push/hub.go @@ -0,0 +1,385 @@ +// Package push provides the in-memory hub used to fan out internal +// client-facing events to active authenticated push streams. +package push + +import ( + "bytes" + "errors" + "strings" + "sync" +) + +const defaultSubscriptionQueueCapacity = 64 + +var ( + // ErrSubscriptionOverflow reports that one push stream stopped consuming + // events quickly enough and its bounded queue overflowed. + ErrSubscriptionOverflow = errors.New("push stream overflowed") + + // ErrSubscriptionRevoked reports that the authenticated device session bound + // to the push stream was revoked and the stream must terminate. + ErrSubscriptionRevoked = errors.New("device session is revoked") + + // ErrHubShuttingDown reports that the gateway is shutting down and all + // active push streams must terminate promptly. + ErrHubShuttingDown = errors.New("gateway is shutting down") +) + +// StreamBinding identifies one authenticated push stream tracked by Hub. +type StreamBinding struct { + // UserID is the verified authenticated user bound to the stream. + UserID string + + // DeviceSessionID is the verified authenticated device session bound to the + // stream. + DeviceSessionID string +} + +// Event is the internal client-facing event delivered from internal pub/sub to +// active push streams. +type Event struct { + // UserID identifies the authenticated user that should receive the event. + UserID string + + // DeviceSessionID optionally narrows delivery to one device session. + DeviceSessionID string + + // EventType identifies the stable client-facing event category. + EventType string + + // EventID is the stable event correlation identifier. + EventID string + + // PayloadBytes carries the opaque event payload bytes. + PayloadBytes []byte + + // RequestID optionally correlates the event to an earlier client request. + RequestID string + + // TraceID optionally carries tracing correlation. + TraceID string +} + +// Subscription represents one active push stream registered in Hub. +type Subscription struct { + hub *Hub + id uint64 + binding StreamBinding + events chan Event + done chan struct{} + + closeOnce sync.Once + stateMu sync.RWMutex + err error +} + +// Observer receives push stream lifecycle notifications suitable for metrics +// bookkeeping. +type Observer interface { + // Registered reports one active push stream binding. + Registered(binding StreamBinding) + + // Unregistered reports that binding stopped with err. A nil err means the + // stream ended without a hub-enforced terminal reason. + Unregistered(binding StreamBinding, err error) +} + +// Events returns the ordered event queue for the subscription. +func (s *Subscription) Events() <-chan Event { + if s == nil { + return nil + } + + return s.events +} + +// Done closes when the subscription has been removed from the hub. +func (s *Subscription) Done() <-chan struct{} { + if s == nil { + return nil + } + + return s.done +} + +// Err returns the terminal subscription error, if any. +func (s *Subscription) Err() error { + if s == nil { + return nil + } + + s.stateMu.RLock() + defer s.stateMu.RUnlock() + + return s.err +} + +// Close unregisters the subscription from its hub. +func (s *Subscription) Close() { + if s == nil || s.hub == nil { + return + } + + s.hub.unregister(s.id, nil) +} + +func (s *Subscription) enqueue(event Event) bool { + if s == nil { + return true + } + + cloned := cloneEvent(event) + + select { + case <-s.done: + return true + default: + } + + select { + case s.events <- cloned: + return true + case <-s.done: + return true + default: + return false + } +} + +func (s *Subscription) closeWithError(err error) { + if s == nil { + return + } + + s.closeOnce.Do(func() { + s.stateMu.Lock() + s.err = err + s.stateMu.Unlock() + close(s.done) + }) +} + +// Hub tracks active authenticated push streams and fans out client-facing +// events to the matching subscriptions. +type Hub struct { + mu sync.RWMutex + nextID uint64 + queueCapacity int + observer Observer + byID map[uint64]*Subscription + byUser map[string]map[uint64]*Subscription + bySession map[string]map[uint64]*Subscription +} + +// NewHub constructs a push hub with one bounded in-memory queue per +// subscription. Non-positive queueCapacity falls back to the package default. +func NewHub(queueCapacity int) *Hub { + return NewHubWithObserver(queueCapacity, nil) +} + +// NewHubWithObserver constructs a push hub that also reports stream lifecycle +// changes to observer. +func NewHubWithObserver(queueCapacity int, observer Observer) *Hub { + if queueCapacity <= 0 { + queueCapacity = defaultSubscriptionQueueCapacity + } + + return &Hub{ + queueCapacity: queueCapacity, + observer: observer, + byID: make(map[uint64]*Subscription), + byUser: make(map[string]map[uint64]*Subscription), + bySession: make(map[string]map[uint64]*Subscription), + } +} + +// Register adds one authenticated push stream to the hub and returns its +// subscription handle. +func (h *Hub) Register(binding StreamBinding) (*Subscription, error) { + if h == nil { + return nil, errors.New("register push subscription: nil hub") + } + + userID := strings.TrimSpace(binding.UserID) + if userID == "" { + return nil, errors.New("register push subscription: user id must not be empty") + } + + deviceSessionID := strings.TrimSpace(binding.DeviceSessionID) + if deviceSessionID == "" { + return nil, errors.New("register push subscription: device session id must not be empty") + } + + h.mu.Lock() + + h.nextID++ + subscription := &Subscription{ + hub: h, + id: h.nextID, + binding: StreamBinding{ + UserID: userID, + DeviceSessionID: deviceSessionID, + }, + events: make(chan Event, h.queueCapacity), + done: make(chan struct{}), + } + h.byID[subscription.id] = subscription + addIndexedSubscription(h.byUser, userID, subscription) + addIndexedSubscription(h.bySession, deviceSessionID, subscription) + h.mu.Unlock() + + if h.observer != nil { + h.observer.Registered(subscription.binding) + } + + return subscription, nil +} + +// Publish fans out event to the matching active subscriptions. When one +// subscription queue overflows, only that subscription is closed. +func (h *Hub) Publish(event Event) { + if h == nil { + return + } + + targets := h.targets(event) + for _, target := range targets { + if target.enqueue(event) { + continue + } + + h.unregister(target.id, ErrSubscriptionOverflow) + } +} + +// RevokeDeviceSession closes all active subscriptions bound to the exact +// authenticated device session identifier. +func (h *Hub) RevokeDeviceSession(deviceSessionID string) { + if h == nil { + return + } + + deviceSessionID = strings.TrimSpace(deviceSessionID) + if deviceSessionID == "" { + return + } + + h.mu.RLock() + targets := cloneSubscriptions(h.bySession[deviceSessionID]) + h.mu.RUnlock() + + for _, target := range targets { + h.unregister(target.id, ErrSubscriptionRevoked) + } +} + +// Shutdown closes every active subscription because the gateway is shutting +// down. +func (h *Hub) Shutdown() { + if h == nil { + return + } + + h.mu.RLock() + targets := cloneSubscriptions(h.byID) + h.mu.RUnlock() + + for _, target := range targets { + h.unregister(target.id, ErrHubShuttingDown) + } +} + +func (h *Hub) targets(event Event) []*Subscription { + userID := strings.TrimSpace(event.UserID) + eventType := strings.TrimSpace(event.EventType) + eventID := strings.TrimSpace(event.EventID) + if h == nil || userID == "" || eventType == "" || eventID == "" { + return nil + } + + deviceSessionID := strings.TrimSpace(event.DeviceSessionID) + + h.mu.RLock() + defer h.mu.RUnlock() + + if deviceSessionID == "" { + return cloneSubscriptions(h.byUser[userID]) + } + + sessionMatches := cloneSubscriptions(h.bySession[deviceSessionID]) + filtered := sessionMatches[:0] + for _, subscription := range sessionMatches { + if subscription.binding.UserID == userID { + filtered = append(filtered, subscription) + } + } + + return filtered +} + +func (h *Hub) unregister(id uint64, err error) { + if h == nil || id == 0 { + return + } + + h.mu.Lock() + subscription, ok := h.byID[id] + if !ok { + h.mu.Unlock() + return + } + + delete(h.byID, id) + removeIndexedSubscription(h.byUser, subscription.binding.UserID, id) + removeIndexedSubscription(h.bySession, subscription.binding.DeviceSessionID, id) + h.mu.Unlock() + + subscription.closeWithError(err) + if h.observer != nil { + h.observer.Unregistered(subscription.binding, err) + } +} + +func addIndexedSubscription(index map[string]map[uint64]*Subscription, key string, subscription *Subscription) { + if _, ok := index[key]; !ok { + index[key] = make(map[uint64]*Subscription) + } + index[key][subscription.id] = subscription +} + +func removeIndexedSubscription(index map[string]map[uint64]*Subscription, key string, id uint64) { + bucket, ok := index[key] + if !ok { + return + } + + delete(bucket, id) + if len(bucket) == 0 { + delete(index, key) + } +} + +func cloneSubscriptions(bucket map[uint64]*Subscription) []*Subscription { + if len(bucket) == 0 { + return nil + } + + cloned := make([]*Subscription, 0, len(bucket)) + for _, subscription := range bucket { + cloned = append(cloned, subscription) + } + + return cloned +} + +func cloneEvent(event Event) Event { + return Event{ + UserID: event.UserID, + DeviceSessionID: event.DeviceSessionID, + EventType: event.EventType, + EventID: event.EventID, + PayloadBytes: bytes.Clone(event.PayloadBytes), + RequestID: event.RequestID, + TraceID: event.TraceID, + } +} diff --git a/gateway/internal/push/hub_observability_test.go b/gateway/internal/push/hub_observability_test.go new file mode 100644 index 0000000..1ee7e27 --- /dev/null +++ b/gateway/internal/push/hub_observability_test.go @@ -0,0 +1,77 @@ +package push_test + +import ( + "testing" + "time" + + "galaxy/gateway/internal/push" + "galaxy/gateway/internal/telemetry" + "galaxy/gateway/internal/testutil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHubObserverClassifiesClosureReasons(t *testing.T) { + t.Parallel() + + logger, _ := testutil.NewObservedLogger(t) + telemetryRuntime := testutil.NewTelemetryRuntime(t, logger) + hub := push.NewHubWithObserver(1, telemetry.NewPushObserver(telemetryRuntime)) + + overflow, err := hub.Register(push.StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-overflow", + }) + require.NoError(t, err) + revoked, err := hub.Register(push.StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-revoked", + }) + require.NoError(t, err) + shutdown, err := hub.Register(push.StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-shutdown", + }) + require.NoError(t, err) + + hub.Publish(push.Event{ + UserID: "user-123", + DeviceSessionID: "device-session-overflow", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) + hub.Publish(push.Event{ + UserID: "user-123", + DeviceSessionID: "device-session-overflow", + EventType: "fleet.updated", + EventID: "event-2", + PayloadBytes: []byte("payload-2"), + }) + hub.RevokeDeviceSession("device-session-revoked") + hub.Shutdown() + + select { + case <-overflow.Done(): + case <-time.After(time.Second): + require.FailNow(t, "overflow subscription did not close") + } + select { + case <-revoked.Done(): + case <-time.After(time.Second): + require.FailNow(t, "revoked subscription did not close") + } + select { + case <-shutdown.Done(): + case <-time.After(time.Second): + require.FailNow(t, "shutdown subscription did not close") + } + + metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler()) + assert.Contains(t, metricsText, `gateway_push_stream_closures_total`) + assert.Contains(t, metricsText, `reason="overflow"`) + assert.Contains(t, metricsText, `reason="revoked"`) + assert.Contains(t, metricsText, `reason="shutdown"`) + assert.Contains(t, metricsText, `gateway_push_active_streams`) +} diff --git a/gateway/internal/push/hub_test.go b/gateway/internal/push/hub_test.go new file mode 100644 index 0000000..cd4a93b --- /dev/null +++ b/gateway/internal/push/hub_test.go @@ -0,0 +1,270 @@ +package push + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHubDeliversSessionTargetedEvent(t *testing.T) { + t.Parallel() + + hub := NewHub(4) + target, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + }) + require.NoError(t, err) + otherSession, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-2", + }) + require.NoError(t, err) + unrelatedUser, err := hub.Register(StreamBinding{ + UserID: "user-999", + DeviceSessionID: "device-session-3", + }) + require.NoError(t, err) + + hub.Publish(Event{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) + + assertEvent(t, target.Events(), Event{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) + assertNoEvent(t, otherSession.Events()) + assertNoEvent(t, unrelatedUser.Events()) +} + +func TestHubDeliversUserTargetedEventToAllUserSessions(t *testing.T) { + t.Parallel() + + hub := NewHub(4) + first, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + }) + require.NoError(t, err) + second, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-2", + }) + require.NoError(t, err) + unrelated, err := hub.Register(StreamBinding{ + UserID: "user-999", + DeviceSessionID: "device-session-3", + }) + require.NoError(t, err) + + hub.Publish(Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + RequestID: "request-1", + TraceID: "trace-1", + }) + + want := Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + RequestID: "request-1", + TraceID: "trace-1", + } + assertEvent(t, first.Events(), want) + assertEvent(t, second.Events(), want) + assertNoEvent(t, unrelated.Events()) +} + +func TestSubscriptionCloseUnregistersStream(t *testing.T) { + t.Parallel() + + hub := NewHub(4) + subscription, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + }) + require.NoError(t, err) + + subscription.Close() + + select { + case <-subscription.Done(): + case <-time.After(time.Second): + require.FailNow(t, "subscription did not close") + } + + hub.Publish(Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) + + assertNoEvent(t, subscription.Events()) + assert.NoError(t, subscription.Err()) +} + +func TestHubOverflowClosesOnlySlowSubscription(t *testing.T) { + t.Parallel() + + hub := NewHub(1) + slow, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + }) + require.NoError(t, err) + fast, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-2", + }) + require.NoError(t, err) + + hub.Publish(Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) + assertEvent(t, fast.Events(), Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) + + hub.Publish(Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-2", + PayloadBytes: []byte("payload-2"), + }) + + select { + case <-slow.Done(): + case <-time.After(time.Second): + require.FailNow(t, "slow subscription did not close after overflow") + } + + assert.ErrorIs(t, slow.Err(), ErrSubscriptionOverflow) + assertEvent(t, fast.Events(), Event{ + UserID: "user-123", + EventType: "fleet.updated", + EventID: "event-2", + PayloadBytes: []byte("payload-2"), + }) +} + +func TestHubRevokeDeviceSessionClosesOnlyMatchingSubscriptions(t *testing.T) { + t.Parallel() + + hub := NewHub(4) + targetOne, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + }) + require.NoError(t, err) + targetTwo, err := hub.Register(StreamBinding{ + UserID: "user-456", + DeviceSessionID: "device-session-1", + }) + require.NoError(t, err) + otherSession, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-2", + }) + require.NoError(t, err) + + hub.RevokeDeviceSession("device-session-1") + + select { + case <-targetOne.Done(): + case <-time.After(time.Second): + require.FailNow(t, "first matching subscription did not close after revoke") + } + + select { + case <-targetTwo.Done(): + case <-time.After(time.Second): + require.FailNow(t, "second matching subscription did not close after revoke") + } + + assert.ErrorIs(t, targetOne.Err(), ErrSubscriptionRevoked) + assert.ErrorIs(t, targetTwo.Err(), ErrSubscriptionRevoked) + + select { + case <-otherSession.Done(): + require.FailNow(t, "unrelated session subscription closed after revoke") + case <-time.After(50 * time.Millisecond): + } + + hub.Publish(Event{ + UserID: "user-123", + DeviceSessionID: "device-session-2", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) + + assertEvent(t, otherSession.Events(), Event{ + UserID: "user-123", + DeviceSessionID: "device-session-2", + EventType: "fleet.updated", + EventID: "event-1", + PayloadBytes: []byte("payload-1"), + }) +} + +func TestHubRevokeDeviceSessionIgnoresUnknownOrEmptySession(t *testing.T) { + t.Parallel() + + hub := NewHub(4) + subscription, err := hub.Register(StreamBinding{ + UserID: "user-123", + DeviceSessionID: "device-session-1", + }) + require.NoError(t, err) + + hub.RevokeDeviceSession("") + hub.RevokeDeviceSession("missing-session") + + select { + case <-subscription.Done(): + require.FailNow(t, "subscription closed for empty or unknown session revoke") + case <-time.After(50 * time.Millisecond): + } +} + +func assertEvent(t *testing.T, eventCh <-chan Event, want Event) { + t.Helper() + + select { + case got := <-eventCh: + assert.Equal(t, want, got) + case <-time.After(time.Second): + require.FailNow(t, "event was not delivered") + } +} + +func assertNoEvent(t *testing.T, eventCh <-chan Event) { + t.Helper() + + select { + case got := <-eventCh: + require.FailNowf(t, "unexpected event delivered", "%+v", got) + case <-time.After(50 * time.Millisecond): + } +} diff --git a/gateway/internal/ratelimit/inmemory.go b/gateway/internal/ratelimit/inmemory.go new file mode 100644 index 0000000..52a21b3 --- /dev/null +++ b/gateway/internal/ratelimit/inmemory.go @@ -0,0 +1,136 @@ +// Package ratelimit provides small process-local rate-limit primitives used by +// the gateway edge policy layers. +package ratelimit + +import ( + "sync" + "time" + + "golang.org/x/time/rate" +) + +// Policy describes one token-bucket budget enforced for a concrete key. +type Policy struct { + // Requests is the number of accepted requests replenished per Window. + Requests int + + // Window is the interval over which Requests are replenished. + Window time.Duration + + // Burst is the maximum number of immediately available tokens. + Burst int +} + +// Decision describes the result of one limiter reservation attempt. +type Decision struct { + // Allowed reports whether the request may proceed immediately. + Allowed bool + + // RetryAfter is the minimum delay the caller should wait before retrying + // when Allowed is false. + RetryAfter time.Duration +} + +// Limiter applies a policy to one concrete key. +type Limiter interface { + // Reserve evaluates key under policy and reports whether the request may + // proceed immediately. + Reserve(key string, policy Policy) Decision +} + +// InMemory is a process-local Limiter backed by x/time/rate token buckets. +type InMemory struct { + now func() time.Time + cleanupInterval time.Duration + + mu sync.Mutex + entries map[string]*entry + nextCleanup time.Time +} + +type entry struct { + limiter *rate.Limiter + limit rate.Limit + burst int + expiresAt time.Time +} + +// NewInMemory constructs a process-local limiter suitable for one gateway +// process instance. +func NewInMemory() *InMemory { + return &InMemory{ + now: time.Now, + cleanupInterval: time.Minute, + entries: make(map[string]*entry), + } +} + +// Reserve evaluates key against policy and reports whether the request may +// proceed immediately. +func (l *InMemory) Reserve(key string, policy Policy) Decision { + if policy.Requests <= 0 || policy.Window <= 0 || policy.Burst <= 0 { + return Decision{} + } + + now := l.now() + limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds()) + + l.mu.Lock() + defer l.mu.Unlock() + + l.cleanupExpiredBucketsLocked(now) + + current, ok := l.entries[key] + if !ok || current.limit != limit || current.burst != policy.Burst { + current = &entry{ + limiter: rate.NewLimiter(limit, policy.Burst), + limit: limit, + burst: policy.Burst, + } + l.entries[key] = current + } + + current.expiresAt = now.Add(entryTTL(policy.Window)) + + reservation := current.limiter.ReserveN(now, 1) + if !reservation.OK() { + return Decision{ + Allowed: false, + RetryAfter: policy.Window, + } + } + + retryAfter := reservation.DelayFrom(now) + if retryAfter > 0 { + return Decision{ + Allowed: false, + RetryAfter: retryAfter, + } + } + + return Decision{Allowed: true} +} + +func (l *InMemory) cleanupExpiredBucketsLocked(now time.Time) { + if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) { + return + } + + for key, current := range l.entries { + if !current.expiresAt.After(now) { + delete(l.entries, key) + } + } + + l.nextCleanup = now.Add(l.cleanupInterval) +} + +func entryTTL(window time.Duration) time.Duration { + if window < time.Minute { + return time.Minute + } + + return 2 * window +} + +var _ Limiter = (*InMemory)(nil) diff --git a/gateway/internal/ratelimit/inmemory_test.go b/gateway/internal/ratelimit/inmemory_test.go new file mode 100644 index 0000000..091ad18 --- /dev/null +++ b/gateway/internal/ratelimit/inmemory_test.go @@ -0,0 +1,49 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestInMemoryReserve(t *testing.T) { + t.Parallel() + + limiter := NewInMemory() + policy := Policy{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + + first := limiter.Reserve("bucket-1", policy) + second := limiter.Reserve("bucket-1", policy) + otherBucket := limiter.Reserve("bucket-2", policy) + + assert.True(t, first.Allowed) + assert.False(t, second.Allowed) + assert.Positive(t, second.RetryAfter) + assert.True(t, otherBucket.Allowed) +} + +func TestInMemoryReserveResetsOnPolicyChange(t *testing.T) { + t.Parallel() + + limiter := NewInMemory() + + initialPolicy := Policy{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + updatedPolicy := Policy{ + Requests: 2, + Window: time.Hour, + Burst: 2, + } + + assert.True(t, limiter.Reserve("bucket-1", initialPolicy).Allowed) + assert.False(t, limiter.Reserve("bucket-1", initialPolicy).Allowed) + assert.True(t, limiter.Reserve("bucket-1", updatedPolicy).Allowed) +} diff --git a/gateway/internal/replay/redis.go b/gateway/internal/replay/redis.go new file mode 100644 index 0000000..0823f1c --- /dev/null +++ b/gateway/internal/replay/redis.go @@ -0,0 +1,131 @@ +package replay + +import ( + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "strings" + "time" + + "galaxy/gateway/internal/config" + + "github.com/redis/go-redis/v9" +) + +// RedisStore implements Store with Redis SETNX reservations over a dedicated +// key namespace. +type RedisStore struct { + client *redis.Client + keyPrefix string + reserveTimeout time.Duration +} + +// NewRedisStore constructs a Redis-backed replay store that reuses the +// SessionCache Redis deployment settings and applies the replay-specific key +// namespace and timeout controls from replayCfg. +func NewRedisStore(sessionCfg config.SessionCacheRedisConfig, replayCfg config.ReplayRedisConfig) (*RedisStore, error) { + if strings.TrimSpace(sessionCfg.Addr) == "" { + return nil, errors.New("new redis replay store: redis addr must not be empty") + } + if sessionCfg.DB < 0 { + return nil, errors.New("new redis replay store: redis db must not be negative") + } + if strings.TrimSpace(replayCfg.KeyPrefix) == "" { + return nil, errors.New("new redis replay store: replay key prefix must not be empty") + } + if replayCfg.ReserveTimeout <= 0 { + return nil, errors.New("new redis replay store: reserve timeout must be positive") + } + + options := &redis.Options{ + Addr: sessionCfg.Addr, + Username: sessionCfg.Username, + Password: sessionCfg.Password, + DB: sessionCfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if sessionCfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + + return &RedisStore{ + client: redis.NewClient(options), + keyPrefix: replayCfg.KeyPrefix, + reserveTimeout: replayCfg.ReserveTimeout, + }, nil +} + +// Close releases the underlying Redis client resources. +func (s *RedisStore) Close() error { + if s == nil || s.client == nil { + return nil + } + + return s.client.Close() +} + +// Ping verifies that the configured Redis backend is reachable within the +// replay reserve timeout budget. +func (s *RedisStore) Ping(ctx context.Context) error { + if s == nil || s.client == nil { + return errors.New("ping redis replay store: nil store") + } + if ctx == nil { + return errors.New("ping redis replay store: nil context") + } + + pingCtx, cancel := context.WithTimeout(ctx, s.reserveTimeout) + defer cancel() + + if err := s.client.Ping(pingCtx).Err(); err != nil { + return fmt.Errorf("ping redis replay store: %w", err) + } + + return nil +} + +// Reserve records the authenticated deviceSessionID and requestID pair for +// ttl. It rejects duplicates while the reservation remains active. +func (s *RedisStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { + if s == nil || s.client == nil { + return errors.New("reserve replay request in redis: nil store") + } + if ctx == nil { + return errors.New("reserve replay request in redis: nil context") + } + if strings.TrimSpace(deviceSessionID) == "" { + return errors.New("reserve replay request in redis: empty device session id") + } + if strings.TrimSpace(requestID) == "" { + return errors.New("reserve replay request in redis: empty request id") + } + if ttl <= 0 { + return errors.New("reserve replay request in redis: ttl must be positive") + } + + reserveCtx, cancel := context.WithTimeout(ctx, s.reserveTimeout) + defer cancel() + + reserved, err := s.client.SetNX(reserveCtx, s.reservationKey(deviceSessionID, requestID), "1", ttl).Result() + if err != nil { + return fmt.Errorf("reserve replay request in redis: %w", err) + } + if !reserved { + return fmt.Errorf("reserve replay request in redis: %w", ErrDuplicate) + } + + return nil +} + +func (s *RedisStore) reservationKey(deviceSessionID string, requestID string) string { + return s.keyPrefix + encodeKeyComponent(deviceSessionID) + ":" + encodeKeyComponent(requestID) +} + +func encodeKeyComponent(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +var _ Store = (*RedisStore)(nil) diff --git a/gateway/internal/replay/redis_test.go b/gateway/internal/replay/redis_test.go new file mode 100644 index 0000000..857449d --- /dev/null +++ b/gateway/internal/replay/redis_test.go @@ -0,0 +1,254 @@ +package replay + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "galaxy/gateway/internal/config" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRedisStore(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + tests := []struct { + name string + sessionCfg config.SessionCacheRedisConfig + replayCfg config.ReplayRedisConfig + wantErr string + }{ + { + name: "valid config", + sessionCfg: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + DB: 2, + }, + replayCfg: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + }, + { + name: "empty redis addr", + replayCfg: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "negative redis db", + sessionCfg: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + DB: -1, + }, + replayCfg: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 250 * time.Millisecond, + }, + wantErr: "redis db must not be negative", + }, + { + name: "empty replay key prefix", + sessionCfg: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + }, + replayCfg: config.ReplayRedisConfig{ + ReserveTimeout: 250 * time.Millisecond, + }, + wantErr: "replay key prefix must not be empty", + }, + { + name: "non-positive reserve timeout", + sessionCfg: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + }, + replayCfg: config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + }, + wantErr: "reserve timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store, err := NewRedisStore(tt.sessionCfg, tt.replayCfg) + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + }) + } +} + +func TestRedisStorePing(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestRedisStore(t, server, config.SessionCacheRedisConfig{}, config.ReplayRedisConfig{}) + + require.NoError(t, store.Ping(context.Background())) +} + +func TestRedisStoreReserve(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sessionCfg config.SessionCacheRedisConfig + replayCfg config.ReplayRedisConfig + deviceSessionID string + requestID string + ttl time.Duration + secondReserve func(*testing.T, Store) + wantErrIs error + wantErrText string + }{ + { + name: "first reservation succeeds", + deviceSessionID: "device-session-123", + requestID: "request-123", + ttl: 5 * time.Second, + }, + { + name: "duplicate reservation is rejected", + deviceSessionID: "device-session-123", + requestID: "request-123", + ttl: 5 * time.Second, + secondReserve: func(t *testing.T, store Store) { + t.Helper() + err := store.Reserve(context.Background(), "device-session-123", "request-123", 5*time.Second) + require.ErrorIs(t, err, ErrDuplicate) + }, + }, + { + name: "same request id in distinct sessions does not collide", + deviceSessionID: "device-session-123", + requestID: "request-123", + ttl: 5 * time.Second, + secondReserve: func(t *testing.T, store Store) { + t.Helper() + require.NoError(t, store.Reserve(context.Background(), "device-session-456", "request-123", 5*time.Second)) + }, + }, + { + name: "empty device session id", + requestID: "request-123", + ttl: 5 * time.Second, + wantErrText: "empty device session id", + }, + { + name: "empty request id", + deviceSessionID: "device-session-123", + ttl: 5 * time.Second, + wantErrText: "empty request id", + }, + { + name: "non-positive ttl", + deviceSessionID: "device-session-123", + requestID: "request-123", + wantErrText: "ttl must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestRedisStore(t, server, tt.sessionCfg, tt.replayCfg) + + err := store.Reserve(context.Background(), tt.deviceSessionID, tt.requestID, tt.ttl) + if tt.wantErrIs != nil || tt.wantErrText != "" { + require.Error(t, err) + if tt.wantErrIs != nil { + require.ErrorIs(t, err, tt.wantErrIs) + } + if tt.wantErrText != "" { + require.ErrorContains(t, err, tt.wantErrText) + } + return + } + + require.NoError(t, err) + if tt.secondReserve != nil { + tt.secondReserve(t, store) + } + }) + } +} + +func TestRedisStoreReserveReturnsBackendError(t *testing.T) { + t.Parallel() + + store, err := NewRedisStore( + config.SessionCacheRedisConfig{Addr: unusedTCPAddr(t)}, + config.ReplayRedisConfig{ + KeyPrefix: "gateway:replay:", + ReserveTimeout: 100 * time.Millisecond, + }, + ) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + + err = store.Reserve(context.Background(), "device-session-123", "request-123", 5*time.Second) + require.Error(t, err) + assert.False(t, errors.Is(err, ErrDuplicate)) + assert.ErrorContains(t, err, "reserve replay request in redis") +} + +func newTestRedisStore(t *testing.T, server *miniredis.Miniredis, sessionCfg config.SessionCacheRedisConfig, replayCfg config.ReplayRedisConfig) *RedisStore { + t.Helper() + + if sessionCfg.Addr == "" { + sessionCfg.Addr = server.Addr() + } + if replayCfg.KeyPrefix == "" { + replayCfg.KeyPrefix = "gateway:replay:" + } + if replayCfg.ReserveTimeout == 0 { + replayCfg.ReserveTimeout = 250 * time.Millisecond + } + + store, err := NewRedisStore(sessionCfg, replayCfg) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + + return store +} + +func unusedTCPAddr(t *testing.T) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + addr := listener.Addr().String() + require.NoError(t, listener.Close()) + + return addr +} diff --git a/gateway/internal/replay/replay.go b/gateway/internal/replay/replay.go new file mode 100644 index 0000000..9f2ef0c --- /dev/null +++ b/gateway/internal/replay/replay.go @@ -0,0 +1,24 @@ +// Package replay defines the authenticated replay-reservation contract used by +// the gateway transport pipeline. +package replay + +import ( + "context" + "errors" + "time" +) + +var ( + // ErrDuplicate reports that the request identifier has already been + // reserved for the same device session within the active replay window. + ErrDuplicate = errors.New("replay reservation already exists") +) + +// Store reserves authenticated transport request identifiers for a bounded +// replay window. +type Store interface { + // Reserve marks the deviceSessionID and requestID pair as seen for ttl. + // Implementations must wrap ErrDuplicate when the same pair is reserved + // again before ttl expires. + Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error +} diff --git a/gateway/internal/restapi/observability.go b/gateway/internal/restapi/observability.go new file mode 100644 index 0000000..06448cc --- /dev/null +++ b/gateway/internal/restapi/observability.go @@ -0,0 +1,76 @@ +package restapi + +import ( + "time" + + "galaxy/gateway/internal/logging" + "galaxy/gateway/internal/telemetry" + + "github.com/gin-gonic/gin" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" +) + +func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc { + if logger == nil { + logger = zap.NewNop() + } + + return func(c *gin.Context) { + start := time.Now() + c.Next() + + statusCode := c.Writer.Status() + route := c.FullPath() + if route == "" { + route = c.Request.URL.Path + } + + class, ok := PublicRouteClassFromContext(c.Request.Context()) + if !ok { + class = PublicRouteClassPublicMisc + } + + errorCode, _ := c.Get(publicErrorCodeContextKey) + errorCodeValue, _ := errorCode.(string) + + outcome := telemetry.OutcomeFromPublicErrorCode(statusCode, errorCodeValue) + rejectReason := telemetry.RejectReason(outcome) + duration := time.Since(start) + + attrs := []attribute.KeyValue{ + attribute.String("route_class", string(class)), + attribute.String("route", route), + attribute.String("method", c.Request.Method), + attribute.String("edge_outcome", string(outcome)), + } + if rejectReason != "" { + attrs = append(attrs, attribute.String("reject_reason", rejectReason)) + } + metrics.RecordPublicRequest(c.Request.Context(), attrs, duration) + + fields := []zap.Field{ + zap.String("component", "public_http"), + zap.String("transport", "http"), + zap.String("route", route), + zap.String("route_class", string(class)), + zap.String("method", c.Request.Method), + zap.Int("status_code", statusCode), + zap.Float64("duration_ms", float64(duration.Microseconds())/1000), + zap.String("edge_outcome", string(outcome)), + } + if rejectReason != "" { + fields = append(fields, zap.String("reject_reason", rejectReason)) + } + fields = append(fields, logging.TraceFieldsFromContext(c.Request.Context())...) + + switch outcome { + case telemetry.EdgeOutcomeSuccess: + logger.Info("public request completed", fields...) + case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeInternalError: + logger.Error("public request failed", fields...) + default: + logger.Warn("public request rejected", fields...) + } + } +} diff --git a/gateway/internal/restapi/openapi_test.go b/gateway/internal/restapi/openapi_test.go new file mode 100644 index 0000000..1553436 --- /dev/null +++ b/gateway/internal/restapi/openapi_test.go @@ -0,0 +1,30 @@ +package restapi + +import ( + "context" + "path/filepath" + "runtime" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +func TestPublicOpenAPISpecValidates(t *testing.T) { + t.Parallel() + + _, thisFile, _, ok := runtime.Caller(0) + require.True(t, ok) + + specPath := filepath.Join(filepath.Dir(thisFile), "..", "..", "openapi.yaml") + ctx := context.Background() + + loader := openapi3.NewLoader() + doc, err := loader.LoadFromFile(specPath) + require.NoError(t, err) + require.NotNil(t, doc) + require.NotNil(t, doc.Info) + require.Equal(t, "v1", doc.Info.Version) + + require.NoError(t, doc.Validate(ctx)) +} diff --git a/gateway/internal/restapi/public_anti_abuse.go b/gateway/internal/restapi/public_anti_abuse.go new file mode 100644 index 0000000..c390fe6 --- /dev/null +++ b/gateway/internal/restapi/public_anti_abuse.go @@ -0,0 +1,378 @@ +package restapi + +import ( + "bytes" + "errors" + "io" + "math" + "net" + "net/http" + "path" + "strconv" + "strings" + "sync" + "time" + + "galaxy/gateway/internal/config" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +const ( + errorCodeRequestTooLarge = "request_too_large" + errorCodeRateLimited = "rate_limited" + + publicRESTIPBucketKeySegment = "/ip=" +) + +var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit") + +// PublicMalformedRequestReason identifies the stable malformed-request counter +// dimension recorded by the public REST anti-abuse middleware. +type PublicMalformedRequestReason string + +const ( + // PublicMalformedRequestReasonEmptyBody records a missing request body. + PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body" + + // PublicMalformedRequestReasonMalformedJSON records syntactically malformed + // JSON. + PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json" + + // PublicMalformedRequestReasonInvalidJSONValue records JSON values whose + // types do not match the expected request schema. + PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value" + + // PublicMalformedRequestReasonUnknownField records JSON objects with fields + // outside the documented schema. + PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field" + + // PublicMalformedRequestReasonMultipleJSONObjects records requests that + // contain more than one JSON object. + PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects" + + // PublicMalformedRequestReasonOversizedBody records requests whose bodies + // exceed the configured class limit. + PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body" +) + +// PublicRateLimitDecision describes the outcome returned by a public REST +// limiter for one request bucket reservation attempt. +type PublicRateLimitDecision struct { + // Allowed reports whether the request may proceed immediately. + Allowed bool + + // RetryAfter is the minimum delay the client should wait before retrying + // when Allowed is false. + RetryAfter time.Duration +} + +// PublicRequestLimiter applies public REST rate-limit policy to a concrete +// bucket key. +type PublicRequestLimiter interface { + // Reserve evaluates key under policy and returns whether the request may + // proceed immediately. + Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision +} + +// PublicRequestObserver captures low-cardinality public REST anti-abuse +// telemetry. +type PublicRequestObserver interface { + // RecordMalformedRequest records one malformed request in class for reason. + RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason) +} + +type noopPublicRequestObserver struct{} + +func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) { +} + +type inMemoryPublicRequestLimiter struct { + now func() time.Time + cleanupInterval time.Duration + + mu sync.Mutex + entries map[string]*publicRateLimiterEntry + nextCleanup time.Time +} + +type publicRateLimiterEntry struct { + limiter *rate.Limiter + limit rate.Limit + burst int + expiresAt time.Time +} + +func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter { + return &inMemoryPublicRequestLimiter{ + now: time.Now, + cleanupInterval: time.Minute, + entries: make(map[string]*publicRateLimiterEntry), + } +} + +func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision { + now := l.now() + limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds()) + + l.mu.Lock() + defer l.mu.Unlock() + + l.cleanupExpiredBucketsLocked(now) + + entry, ok := l.entries[key] + if !ok || entry.limit != limit || entry.burst != policy.Burst { + entry = &publicRateLimiterEntry{ + limiter: rate.NewLimiter(limit, policy.Burst), + limit: limit, + burst: policy.Burst, + } + l.entries[key] = entry + } + + entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window)) + + reservation := entry.limiter.ReserveN(now, 1) + if !reservation.OK() { + return PublicRateLimitDecision{ + Allowed: false, + RetryAfter: policy.Window, + } + } + + retryAfter := reservation.DelayFrom(now) + if retryAfter > 0 { + return PublicRateLimitDecision{ + Allowed: false, + RetryAfter: retryAfter, + } + } + + return PublicRateLimitDecision{Allowed: true} +} + +func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) { + if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) { + return + } + + for key, entry := range l.entries { + if !entry.expiresAt.After(now) { + delete(l.entries, key) + } + } + + l.nextCleanup = now.Add(l.cleanupInterval) +} + +func publicRateLimiterEntryTTL(window time.Duration) time.Duration { + if window < time.Minute { + return time.Minute + } + + return 2 * window +} + +func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc { + return func(c *gin.Context) { + class, ok := PublicRouteClassFromContext(c.Request.Context()) + if !ok { + class = PublicRouteClassPublicMisc + } + + allowedMethods := allowedMethodsForRequestShape(c.Request) + if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) { + c.Header("Allow", strings.Join(allowedMethods, ", ")) + abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route") + return + } + + classPolicy := publicRoutePolicyForClass(policy, class) + bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes) + if err != nil { + switch { + case errors.Is(err, errRequestBodyTooLarge): + observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody) + abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit") + default: + abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error") + } + return + } + + clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr) + if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed { + abortRateLimited(c, decision.RetryAfter) + return + } + + identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes) + switch { + case err == nil: + identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy) + if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed { + abortRateLimited(c, decision.RetryAfter) + return + } + case errors.Is(err, errPublicAuthIdentityNotApplicable): + default: + if reason, malformed := malformedRequestReasonFromError(err); malformed { + observer.RecordMalformedRequest(class, reason) + } + } + + c.Next() + } +} + +func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig { + switch class.Normalized() { + case PublicRouteClassPublicAuth: + return policy.PublicAuth + case PublicRouteClassBrowserBootstrap: + return policy.BrowserBootstrap + case PublicRouteClassBrowserAsset: + return policy.BrowserAsset + default: + return policy.PublicMisc + } +} + +func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig { + switch requestPath { + case "/api/v1/public/auth/send-email-code": + return policy.SendEmailCodeIdentity + case "/api/v1/public/auth/confirm-email-code": + return policy.ConfirmEmailCodeIdentity + default: + return config.PublicAuthIdentityPolicyConfig{} + } +} + +func allowedMethodsForRequestShape(r *http.Request) []string { + switch { + case isPublicAuthPath(r.URL.Path): + return []string{http.MethodPost} + case isProbePath(r.URL.Path): + return []string{http.MethodGet} + case matchesBrowserAssetRequestShape(r): + return []string{http.MethodGet, http.MethodHead} + case matchesBrowserBootstrapRequestShape(r): + return []string{http.MethodGet, http.MethodHead} + default: + return nil + } +} + +func isAllowedMethod(method string, allowedMethods []string) bool { + for _, allowedMethod := range allowedMethods { + if method == allowedMethod { + return true + } + } + + return false +} + +func isPublicAuthPath(requestPath string) bool { + switch requestPath { + case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code": + return true + default: + return false + } +} + +func isProbePath(requestPath string) bool { + switch requestPath { + case "/healthz", "/readyz": + return true + default: + return false + } +} + +func matchesBrowserBootstrapRequestShape(r *http.Request) bool { + if r.URL.Path == "/" { + return true + } + + return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html") +} + +func matchesBrowserAssetRequestShape(r *http.Request) bool { + if strings.HasPrefix(r.URL.Path, "/assets/") { + return true + } + + switch strings.ToLower(path.Ext(r.URL.Path)) { + case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest": + return true + default: + return false + } +} + +func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) { + if r == nil { + return nil, nil + } + + if r.Body == nil { + r.Body = io.NopCloser(bytes.NewReader(nil)) + return nil, nil + } + + bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1)) + closeErr := r.Body.Close() + if err != nil { + return nil, err + } + if closeErr != nil { + return nil, closeErr + } + if int64(len(bodyBytes)) > maxBodyBytes { + return nil, errRequestBodyTooLarge + } + + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + return bodyBytes, nil +} + +func abortRateLimited(c *gin.Context, retryAfter time.Duration) { + c.Header("Retry-After", retryAfterHeaderValue(retryAfter)) + abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded") +} + +func retryAfterHeaderValue(delay time.Duration) string { + seconds := int64(math.Ceil(delay.Seconds())) + if seconds < 1 { + seconds = 1 + } + + return strconv.FormatInt(seconds, 10) +} + +func clientIPFromRemoteAddr(remoteAddr string) string { + host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr)) + if err == nil { + return host + } + + remoteAddr = strings.TrimSpace(remoteAddr) + if remoteAddr == "" { + return "unknown" + } + + return remoteAddr +} + +func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string { + return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP +} + +func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string { + return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue +} diff --git a/gateway/internal/restapi/public_anti_abuse_test.go b/gateway/internal/restapi/public_anti_abuse_test.go new file mode 100644 index 0000000..74f7fef --- /dev/null +++ b/gateway/internal/restapi/public_anti_abuse_test.go @@ -0,0 +1,455 @@ +package restapi + +import ( + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "galaxy/gateway/internal/config" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPublicAntiAbuseRejectsOversizedBodies(t *testing.T) { + t.Parallel() + + oversizedJSONBody := `{"email":"` + strings.Repeat("a", 8200) + `@example.com"}` + oversizedConfirmJSONBody := `{"challenge_id":"` + strings.Repeat("c", 8300) + `","code":"123456","client_public_key":"key"}` + + tests := []struct { + name string + method string + target string + body string + wantClass PublicRouteClass + }{ + { + name: "send email", + method: http.MethodPost, + target: "/api/v1/public/auth/send-email-code", + body: oversizedJSONBody, + wantClass: PublicRouteClassPublicAuth, + }, + { + name: "confirm email", + method: http.MethodPost, + target: "/api/v1/public/auth/confirm-email-code", + body: oversizedConfirmJSONBody, + wantClass: PublicRouteClassPublicAuth, + }, + { + name: "healthz body", + method: http.MethodGet, + target: "/healthz", + body: `x`, + wantClass: PublicRouteClassPublicMisc, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + observer := &recordingPublicRequestObserver{} + authService := &recordingAuthServiceClient{ + sendEmailCodeResult: SendEmailCodeResult{ChallengeID: "challenge-123"}, + confirmEmailCodeResult: ConfirmEmailCodeResult{ + DeviceSessionID: "device-session-123", + }, + } + handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{ + AuthService: authService, + Observer: observer, + }) + + req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Equal(t, `{"error":{"code":"request_too_large","message":"request body exceeds the configured limit"}}`, recorder.Body.String()) + assert.Equal(t, 0, authService.sendEmailCodeCalls) + assert.Equal(t, 0, authService.confirmEmailCodeCalls) + assert.Equal(t, []malformedObservation{{ + class: tt.wantClass, + reason: PublicMalformedRequestReasonOversizedBody, + }}, observer.snapshot()) + }) + } +} + +func TestPublicAntiAbuseRejectsInvalidMethodsForBrowserShapes(t *testing.T) { + t.Parallel() + + handler := newPublicHandler(ServerDependencies{}) + + tests := []struct { + name string + method string + target string + accept string + wantAllow string + }{ + { + name: "asset path", + method: http.MethodPost, + target: "/assets/app.js", + wantAllow: "GET, HEAD", + }, + { + name: "bootstrap request", + method: http.MethodPost, + target: "/", + accept: "text/html", + wantAllow: "GET, HEAD", + }, + { + name: "head probe rejected", + method: http.MethodHead, + target: "/healthz", + wantAllow: http.MethodGet, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(tt.method, tt.target, nil) + if tt.accept != "" { + req.Header.Set("Accept", tt.accept) + } + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusMethodNotAllowed, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow")) + assert.Equal(t, `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, recorder.Body.String()) + }) + } +} + +func TestPublicAntiAbuseBrowserClassBucketsStayIsolatedFromPublicAuth(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + burstRequest *http.Request + }{ + { + name: "browser asset", + burstRequest: httptest.NewRequest(http.MethodGet, "/assets/app.js", nil), + }, + { + name: "browser bootstrap", + burstRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Accept", "text/html") + return req + }(), + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cfg := config.DefaultPublicHTTPConfig() + cfg.AntiAbuse.BrowserAsset.RateLimit = config.PublicRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + cfg.AntiAbuse.BrowserBootstrap.RateLimit = config.PublicRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + authService := &recordingAuthServiceClient{ + sendEmailCodeResult: SendEmailCodeResult{ + ChallengeID: "challenge-123", + }, + } + handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService}) + + tt.burstRequest.RemoteAddr = "192.0.2.10:1234" + + firstBurst := httptest.NewRecorder() + handler.ServeHTTP(firstBurst, tt.burstRequest.Clone(tt.burstRequest.Context())) + + secondBurst := httptest.NewRecorder() + handler.ServeHTTP(secondBurst, tt.burstRequest.Clone(tt.burstRequest.Context())) + + authReq := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(`{"email":"pilot@example.com"}`)) + authReq.Header.Set("Content-Type", "application/json") + authReq.RemoteAddr = "192.0.2.10:1234" + authResp := httptest.NewRecorder() + + handler.ServeHTTP(authResp, authReq) + + assert.Equal(t, http.StatusNotFound, firstBurst.Code) + assert.Equal(t, http.StatusTooManyRequests, secondBurst.Code) + assert.Equal(t, http.StatusOK, authResp.Code) + assert.Equal(t, `{"challenge_id":"challenge-123"}`, authResp.Body.String()) + assert.Equal(t, 1, authService.sendEmailCodeCalls) + }) + } +} + +func TestPublicAntiAbuseSendEmailIdentityThrottle(t *testing.T) { + t.Parallel() + + cfg := config.DefaultPublicHTTPConfig() + cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + cfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + + authService := &recordingAuthServiceClient{ + sendEmailCodeResult: SendEmailCodeResult{ + ChallengeID: "challenge-123", + }, + } + handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService}) + + first := sendEmailCodeRequest(`{"email":"pilot@example.com"}`) + second := sendEmailCodeRequest(`{"email":"pilot@example.com"}`) + third := sendEmailCodeRequest(`{"email":"other@example.com"}`) + + firstResp := httptest.NewRecorder() + handler.ServeHTTP(firstResp, first) + + secondResp := httptest.NewRecorder() + handler.ServeHTTP(secondResp, second) + + thirdResp := httptest.NewRecorder() + handler.ServeHTTP(thirdResp, third) + + assert.Equal(t, http.StatusOK, firstResp.Code) + assert.Equal(t, http.StatusTooManyRequests, secondResp.Code) + assert.Equal(t, "3600", secondResp.Header().Get("Retry-After")) + assert.Equal(t, http.StatusOK, thirdResp.Code) + assert.Equal(t, 2, authService.sendEmailCodeCalls) + thirdInput := authService.sendEmailCodeInput + assert.Equal(t, "other@example.com", thirdInput.Email) +} + +func TestPublicAntiAbuseConfirmEmailIdentityThrottle(t *testing.T) { + t.Parallel() + + cfg := config.DefaultPublicHTTPConfig() + cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{ + Requests: 100, + Window: time.Hour, + Burst: 100, + } + cfg.AntiAbuse.ConfirmEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{ + Requests: 1, + Window: time.Hour, + Burst: 1, + } + + authService := &recordingAuthServiceClient{ + confirmEmailCodeResult: ConfirmEmailCodeResult{ + DeviceSessionID: "device-session-123", + }, + } + handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService}) + + first := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`) + second := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`) + third := confirmEmailCodeRequest(`{"challenge_id":"challenge-456","code":"123456","client_public_key":"public-key-material"}`) + + firstResp := httptest.NewRecorder() + handler.ServeHTTP(firstResp, first) + + secondResp := httptest.NewRecorder() + handler.ServeHTTP(secondResp, second) + + thirdResp := httptest.NewRecorder() + handler.ServeHTTP(thirdResp, third) + + assert.Equal(t, http.StatusOK, firstResp.Code) + assert.Equal(t, http.StatusTooManyRequests, secondResp.Code) + assert.Equal(t, "3600", secondResp.Header().Get("Retry-After")) + assert.Equal(t, http.StatusOK, thirdResp.Code) + assert.Equal(t, 2, authService.confirmEmailCodeCalls) + assert.Equal(t, "challenge-456", authService.confirmEmailCodeInput.ChallengeID) +} + +func TestPublicAntiAbuseMalformedTelemetry(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body string + wantReason PublicMalformedRequestReason + wantRecords int + }{ + { + name: "empty body", + body: ``, + wantReason: PublicMalformedRequestReasonEmptyBody, + wantRecords: 1, + }, + { + name: "malformed json", + body: `{"email":`, + wantReason: PublicMalformedRequestReasonMalformedJSON, + wantRecords: 1, + }, + { + name: "invalid json value", + body: `{"email":123}`, + wantReason: PublicMalformedRequestReasonInvalidJSONValue, + wantRecords: 1, + }, + { + name: "unknown field", + body: `{"email":"pilot@example.com","extra":"x"}`, + wantReason: PublicMalformedRequestReasonUnknownField, + wantRecords: 1, + }, + { + name: "multiple objects", + body: `{"email":"pilot@example.com"}{"email":"pilot@example.com"}`, + wantReason: PublicMalformedRequestReasonMultipleJSONObjects, + wantRecords: 1, + }, + { + name: "validation error does not count as malformed", + body: `{"email":"not-an-email"}`, + wantRecords: 0, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + observer := &recordingPublicRequestObserver{} + authService := &recordingAuthServiceClient{} + handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{ + AuthService: authService, + Observer: observer, + }) + + req := sendEmailCodeRequest(tt.body) + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Equal(t, tt.wantRecords, len(observer.snapshot())) + assert.Equal(t, 0, authService.sendEmailCodeCalls) + if tt.wantRecords == 1 { + assert.Equal(t, malformedObservation{ + class: PublicRouteClassPublicAuth, + reason: tt.wantReason, + }, observer.snapshot()[0]) + } + }) + } +} + +func TestInMemoryPublicRequestLimiterCleansExpiredBuckets(t *testing.T) { + t.Parallel() + + now := time.Unix(1000, 0) + limiter := newInMemoryPublicRequestLimiter() + limiter.now = func() time.Time { + return now + } + limiter.cleanupInterval = time.Second + + policy := config.PublicRateLimitConfig{ + Requests: 1, + Window: time.Minute, + Burst: 1, + } + + firstDecision := limiter.Reserve("bucket-1", policy) + secondDecision := limiter.Reserve("bucket-2", policy) + require.True(t, firstDecision.Allowed) + require.True(t, secondDecision.Allowed) + require.Len(t, limiter.entries, 2) + + now = now.Add(3 * time.Minute) + + thirdDecision := limiter.Reserve("bucket-3", policy) + require.True(t, thirdDecision.Allowed) + assert.Len(t, limiter.entries, 1) + _, exists := limiter.entries["bucket-3"] + assert.True(t, exists) +} + +func sendEmailCodeRequest(body string) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "192.0.2.10:1234" + + return req +} + +func confirmEmailCodeRequest(body string) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/confirm-email-code", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "192.0.2.10:1234" + + return req +} + +type malformedObservation struct { + class PublicRouteClass + reason PublicMalformedRequestReason +} + +type recordingPublicRequestObserver struct { + mu sync.Mutex + observations []malformedObservation +} + +func (o *recordingPublicRequestObserver) RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason) { + o.mu.Lock() + defer o.mu.Unlock() + + o.observations = append(o.observations, malformedObservation{ + class: class, + reason: reason, + }) +} + +func (o *recordingPublicRequestObserver) snapshot() []malformedObservation { + o.mu.Lock() + defer o.mu.Unlock() + + return append([]malformedObservation(nil), o.observations...) +} diff --git a/gateway/internal/restapi/public_auth.go b/gateway/internal/restapi/public_auth.go new file mode 100644 index 0000000..ef30632 --- /dev/null +++ b/gateway/internal/restapi/public_auth.go @@ -0,0 +1,446 @@ +package restapi + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/mail" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +var errPublicAuthIdentityNotApplicable = errors.New("public auth identity does not apply to this route") + +type malformedJSONRequestError struct { + message string + reason PublicMalformedRequestReason +} + +func (e *malformedJSONRequestError) Error() string { + if e == nil { + return "" + } + + return e.message +} + +type publicAuthIdentity struct { + kind string + value string +} + +// AuthServiceClient defines the consumer-side contract used by public auth +// REST handlers to delegate unauthenticated authentication commands to the +// Auth / Session Service. +type AuthServiceClient interface { + // SendEmailCode starts a login challenge for input.Email and returns the + // challenge identifier that the client must later confirm. + SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) + + // ConfirmEmailCode completes a previously issued challenge, registers + // input.ClientPublicKey for the new device session, and returns the created + // device session identifier. + ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) +} + +// SendEmailCodeInput describes the public REST and adapter payload used to +// request a login code for a single e-mail address. +type SendEmailCodeInput struct { + // Email is the single client e-mail address that should receive the login + // code challenge. + Email string `json:"email"` +} + +// SendEmailCodeResult describes the public REST and adapter payload returned +// after the Auth / Session Service creates a login challenge. +type SendEmailCodeResult struct { + // ChallengeID identifies the issued challenge that must be confirmed by the + // client in the next public auth step. + ChallengeID string `json:"challenge_id"` +} + +// ConfirmEmailCodeInput describes the public REST and adapter payload used to +// complete a previously issued login challenge. +type ConfirmEmailCodeInput struct { + // ChallengeID identifies the challenge previously returned by + // SendEmailCode. + ChallengeID string `json:"challenge_id"` + + // Code is the verification code delivered to the client by the Auth / + // Session Service. + Code string `json:"code"` + + // ClientPublicKey is the standard base64-encoded raw 32-byte Ed25519 public + // key that should be registered for the created device session. + ClientPublicKey string `json:"client_public_key"` +} + +// ConfirmEmailCodeResult describes the public REST and adapter payload +// returned after the Auth / Session Service creates a device session. +type ConfirmEmailCodeResult struct { + // DeviceSessionID is the stable identifier of the created device session. + DeviceSessionID string `json:"device_session_id"` +} + +// AuthServiceError allows an auth adapter to project a stable public REST +// error without teaching the gateway transport layer about upstream business +// rules. +type AuthServiceError struct { + // StatusCode is the HTTP status that the public REST handler should expose. + StatusCode int + + // Code is the stable edge-level error code written into the JSON envelope. + Code string + + // Message is the human-readable client-safe error description. + Message string +} + +// Error returns a readable representation of the projected auth service error. +func (e *AuthServiceError) Error() string { + if e == nil { + return "" + } + + switch { + case strings.TrimSpace(e.Code) == "" && strings.TrimSpace(e.Message) == "": + return http.StatusText(e.normalizedStatusCode()) + case strings.TrimSpace(e.Code) == "": + return e.Message + case strings.TrimSpace(e.Message) == "": + return e.Code + default: + return e.Code + ": " + e.Message + } +} + +func (e *AuthServiceError) normalizedStatusCode() int { + if e == nil || e.StatusCode < 400 || e.StatusCode > 599 { + return http.StatusInternalServerError + } + + return e.StatusCode +} + +func (e *AuthServiceError) normalizedCode() string { + if e == nil { + return errorCodeInternalError + } + + code := strings.TrimSpace(e.Code) + if code == "" { + switch e.normalizedStatusCode() { + case http.StatusServiceUnavailable: + return errorCodeServiceUnavailable + case http.StatusBadRequest: + return errorCodeInvalidRequest + default: + return errorCodeInternalError + } + } + + return code +} + +func (e *AuthServiceError) normalizedMessage() string { + if e == nil { + return "internal server error" + } + + message := strings.TrimSpace(e.Message) + if message == "" { + switch e.normalizedStatusCode() { + case http.StatusServiceUnavailable: + return "auth service is unavailable" + case http.StatusBadRequest: + return "request is invalid" + default: + return "internal server error" + } + } + + return message +} + +// unavailableAuthServiceClient keeps the public auth surface mounted until a +// concrete upstream adapter is wired into the gateway process. +type unavailableAuthServiceClient struct{} + +func (unavailableAuthServiceClient) SendEmailCode(context.Context, SendEmailCodeInput) (SendEmailCodeResult, error) { + return SendEmailCodeResult{}, &AuthServiceError{ + StatusCode: http.StatusServiceUnavailable, + Code: errorCodeServiceUnavailable, + Message: "auth service is unavailable", + } +} + +func (unavailableAuthServiceClient) ConfirmEmailCode(context.Context, ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) { + return ConfirmEmailCodeResult{}, &AuthServiceError{ + StatusCode: http.StatusServiceUnavailable, + Code: errorCodeServiceUnavailable, + Message: "auth service is unavailable", + } +} + +func handleSendEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + var input SendEmailCodeInput + if err := decodeJSONRequest(c.Request, &input); err != nil { + abortInvalidRequest(c, err.Error()) + return + } + if err := validateSendEmailCodeInput(&input); err != nil { + abortInvalidRequest(c, err.Error()) + return + } + + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := authService.SendEmailCode(callCtx, input) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable") + return + } + abortWithAuthServiceError(c, err) + return + } + if err := validateSendEmailCodeResult(&result); err != nil { + abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error") + return + } + + c.JSON(http.StatusOK, result) + } +} + +func handleConfirmEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + var input ConfirmEmailCodeInput + if err := decodeJSONRequest(c.Request, &input); err != nil { + abortInvalidRequest(c, err.Error()) + return + } + if err := validateConfirmEmailCodeInput(&input); err != nil { + abortInvalidRequest(c, err.Error()) + return + } + + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := authService.ConfirmEmailCode(callCtx, input) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable") + return + } + abortWithAuthServiceError(c, err) + return + } + if err := validateConfirmEmailCodeResult(&result); err != nil { + abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error") + return + } + + c.JSON(http.StatusOK, result) + } +} + +func abortInvalidRequest(c *gin.Context, message string) { + abortWithError(c, http.StatusBadRequest, errorCodeInvalidRequest, message) +} + +func abortWithAuthServiceError(c *gin.Context, err error) { + var authErr *AuthServiceError + if errors.As(err, &authErr) { + abortWithError(c, authErr.normalizedStatusCode(), authErr.normalizedCode(), authErr.normalizedMessage()) + return + } + + abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error") +} + +func decodeJSONRequest(r *http.Request, target any) error { + if r == nil || r.Body == nil { + return &malformedJSONRequestError{ + message: "request body must not be empty", + reason: PublicMalformedRequestReasonEmptyBody, + } + } + + return decodeJSONReader(r.Body, target) +} + +func decodeJSONBytes(bodyBytes []byte, target any) error { + return decodeJSONReader(bytes.NewReader(bodyBytes), target) +} + +func decodeJSONReader(reader io.Reader, target any) error { + decoder := json.NewDecoder(reader) + decoder.DisallowUnknownFields() + + if err := decoder.Decode(target); err != nil { + return describeJSONDecodeError(err) + } + + if err := decoder.Decode(&struct{}{}); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return &malformedJSONRequestError{ + message: "request body must contain a single JSON object", + reason: PublicMalformedRequestReasonMultipleJSONObjects, + } + } + + return &malformedJSONRequestError{ + message: "request body must contain a single JSON object", + reason: PublicMalformedRequestReasonMultipleJSONObjects, + } +} + +func describeJSONDecodeError(err error) error { + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + + switch { + case errors.Is(err, io.EOF): + return &malformedJSONRequestError{ + message: "request body must not be empty", + reason: PublicMalformedRequestReasonEmptyBody, + } + case errors.As(err, &syntaxErr): + return &malformedJSONRequestError{ + message: "request body contains malformed JSON", + reason: PublicMalformedRequestReasonMalformedJSON, + } + case errors.Is(err, io.ErrUnexpectedEOF): + return &malformedJSONRequestError{ + message: "request body contains malformed JSON", + reason: PublicMalformedRequestReasonMalformedJSON, + } + case errors.As(err, &typeErr): + if strings.TrimSpace(typeErr.Field) != "" { + return &malformedJSONRequestError{ + message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field), + reason: PublicMalformedRequestReasonInvalidJSONValue, + } + } + return &malformedJSONRequestError{ + message: "request body contains an invalid JSON value", + reason: PublicMalformedRequestReasonInvalidJSONValue, + } + case strings.HasPrefix(err.Error(), "json: unknown field "): + return &malformedJSONRequestError{ + message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")), + reason: PublicMalformedRequestReasonUnknownField, + } + default: + return &malformedJSONRequestError{ + message: "request body contains invalid JSON", + reason: PublicMalformedRequestReasonMalformedJSON, + } + } +} + +func validateSendEmailCodeInput(input *SendEmailCodeInput) error { + input.Email = strings.TrimSpace(input.Email) + if input.Email == "" { + return errors.New("email must not be empty") + } + + parsedAddress, err := mail.ParseAddress(input.Email) + if err != nil || parsedAddress.Name != "" || parsedAddress.Address != input.Email { + return errors.New("email must be a single valid email address") + } + + return nil +} + +func validateSendEmailCodeResult(result *SendEmailCodeResult) error { + result.ChallengeID = strings.TrimSpace(result.ChallengeID) + if result.ChallengeID == "" { + return errors.New("auth service returned an empty challenge_id") + } + + return nil +} + +func validateConfirmEmailCodeInput(input *ConfirmEmailCodeInput) error { + input.ChallengeID = strings.TrimSpace(input.ChallengeID) + if input.ChallengeID == "" { + return errors.New("challenge_id must not be empty") + } + + input.Code = strings.TrimSpace(input.Code) + if input.Code == "" { + return errors.New("code must not be empty") + } + + input.ClientPublicKey = strings.TrimSpace(input.ClientPublicKey) + if input.ClientPublicKey == "" { + return errors.New("client_public_key must not be empty") + } + + return nil +} + +func validateConfirmEmailCodeResult(result *ConfirmEmailCodeResult) error { + result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID) + if result.DeviceSessionID == "" { + return errors.New("auth service returned an empty device_session_id") + } + + return nil +} + +func malformedRequestReasonFromError(err error) (PublicMalformedRequestReason, bool) { + var malformedErr *malformedJSONRequestError + if !errors.As(err, &malformedErr) { + return "", false + } + + return malformedErr.reason, true +} + +func extractPublicAuthIdentity(requestPath string, bodyBytes []byte) (publicAuthIdentity, error) { + switch requestPath { + case "/api/v1/public/auth/send-email-code": + var input SendEmailCodeInput + if err := decodeJSONBytes(bodyBytes, &input); err != nil { + return publicAuthIdentity{}, err + } + if err := validateSendEmailCodeInput(&input); err != nil { + return publicAuthIdentity{}, err + } + + return publicAuthIdentity{ + kind: "email", + value: input.Email, + }, nil + case "/api/v1/public/auth/confirm-email-code": + var input ConfirmEmailCodeInput + if err := decodeJSONBytes(bodyBytes, &input); err != nil { + return publicAuthIdentity{}, err + } + if err := validateConfirmEmailCodeInput(&input); err != nil { + return publicAuthIdentity{}, err + } + + return publicAuthIdentity{ + kind: "challenge", + value: input.ChallengeID, + }, nil + default: + return publicAuthIdentity{}, errPublicAuthIdentityNotApplicable + } +} diff --git a/gateway/internal/restapi/public_auth_test.go b/gateway/internal/restapi/public_auth_test.go new file mode 100644 index 0000000..ae68a43 --- /dev/null +++ b/gateway/internal/restapi/public_auth_test.go @@ -0,0 +1,377 @@ +package restapi + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/testutil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSendEmailCodeHandlerSuccess(t *testing.T) { + t.Parallel() + + authService := &recordingAuthServiceClient{ + sendEmailCodeResult: SendEmailCodeResult{ + ChallengeID: "challenge-123", + }, + } + handler := newPublicHandler(ServerDependencies{AuthService: authService}) + + req := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/send-email-code", + strings.NewReader(`{"email":" pilot@example.com "}`), + ) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Equal(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String()) + assert.Equal(t, 1, authService.sendEmailCodeCalls) + assert.Equal(t, 0, authService.confirmEmailCodeCalls) + assert.Equal(t, SendEmailCodeInput{Email: "pilot@example.com"}, authService.sendEmailCodeInput) + assert.True(t, authService.sendEmailCodeRouteClassOK) + assert.Equal(t, PublicRouteClassPublicAuth, authService.sendEmailCodeRouteClass) +} + +func TestConfirmEmailCodeHandlerSuccess(t *testing.T) { + t.Parallel() + + authService := &recordingAuthServiceClient{ + confirmEmailCodeResult: ConfirmEmailCodeResult{ + DeviceSessionID: "device-session-123", + }, + } + handler := newPublicHandler(ServerDependencies{AuthService: authService}) + + req := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/confirm-email-code", + strings.NewReader(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`), + ) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Equal(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String()) + assert.Equal(t, 0, authService.sendEmailCodeCalls) + assert.Equal(t, 1, authService.confirmEmailCodeCalls) + assert.Equal(t, ConfirmEmailCodeInput{ + ChallengeID: "challenge-123", + Code: "123456", + ClientPublicKey: "public-key-material", + }, authService.confirmEmailCodeInput) + assert.True(t, authService.confirmEmailCodeRouteClassOK) + assert.Equal(t, PublicRouteClassPublicAuth, authService.confirmEmailCodeRouteClass) +} + +func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target string + body string + wantStatus int + wantBody string + wantSendCalls int + wantConfirmCalls int + }{ + { + name: "send email malformed json", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`, + wantSendCalls: 0, + wantConfirmCalls: 0, + }, + { + name: "send email validation error", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"not-an-email"}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`, + wantSendCalls: 0, + wantConfirmCalls: 0, + }, + { + name: "confirm email empty code", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`, + wantSendCalls: 0, + wantConfirmCalls: 0, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + authService := &recordingAuthServiceClient{} + handler := newPublicHandler(ServerDependencies{AuthService: authService}) + + req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Equal(t, tt.wantBody, recorder.Body.String()) + assert.Equal(t, tt.wantSendCalls, authService.sendEmailCodeCalls) + assert.Equal(t, tt.wantConfirmCalls, authService.confirmEmailCodeCalls) + }) + } +} + +func TestPublicAuthHandlersMapAdapterErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target string + body string + authClient *recordingAuthServiceClient + wantStatus int + wantBody string + }{ + { + name: "auth service projected bad request", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + authClient: &recordingAuthServiceClient{ + confirmEmailCodeErr: &AuthServiceError{ + StatusCode: http.StatusBadRequest, + Code: errorCodeInvalidRequest, + Message: "confirmation code is invalid", + }, + }, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"confirmation code is invalid"}}`, + }, + { + name: "auth service projected custom too many requests", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"pilot@example.com"}`, + authClient: &recordingAuthServiceClient{ + sendEmailCodeErr: &AuthServiceError{ + StatusCode: http.StatusTooManyRequests, + Code: "upstream_rate_limited", + Message: "too many attempts for this email", + }, + }, + wantStatus: http.StatusTooManyRequests, + wantBody: `{"error":{"code":"upstream_rate_limited","message":"too many attempts for this email"}}`, + }, + { + name: "auth service projected gateway normalizes blank gateway error fields", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + authClient: &recordingAuthServiceClient{ + confirmEmailCodeErr: &AuthServiceError{ + StatusCode: http.StatusBadGateway, + }, + }, + wantStatus: http.StatusBadGateway, + wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`, + }, + { + name: "unexpected auth service error", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"pilot@example.com"}`, + authClient: &recordingAuthServiceClient{ + sendEmailCodeErr: errors.New("boom"), + }, + wantStatus: http.StatusInternalServerError, + wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := newPublicHandler(ServerDependencies{AuthService: tt.authClient}) + req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Equal(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestDefaultAuthServiceReturnsServiceUnavailable(t *testing.T) { + t.Parallel() + + handler := newPublicHandler(ServerDependencies{}) + + tests := []struct { + name string + method string + target string + body string + wantStatus int + wantBody string + }{ + { + name: "send email code", + method: http.MethodPost, + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"pilot@example.com"}`, + wantStatus: http.StatusServiceUnavailable, + wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`, + }, + { + name: "confirm email code", + method: http.MethodPost, + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + wantStatus: http.StatusServiceUnavailable, + wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`, + }, + { + name: "healthz remains available", + method: http.MethodGet, + target: "/healthz", + wantStatus: http.StatusOK, + wantBody: `{"status":"ok"}`, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body)) + if tt.body != "" { + req.Header.Set("Content-Type", "application/json") + } + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Equal(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) { + t.Parallel() + + authService := &recordingAuthServiceClient{ + sendEmailCodeErr: context.DeadlineExceeded, + } + cfg := config.DefaultPublicHTTPConfig() + cfg.AuthUpstreamTimeout = 5 * time.Millisecond + handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService}) + + req := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/send-email-code", + strings.NewReader(`{"email":"pilot@example.com"}`), + ) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) + assert.Equal(t, `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`, recorder.Body.String()) +} + +func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) { + t.Parallel() + + logger, buffer := testutil.NewObservedLogger(t) + handler := newPublicHandler(ServerDependencies{ + Logger: logger, + AuthService: &recordingAuthServiceClient{ + confirmEmailCodeResult: ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, + }, + }) + + req := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/confirm-email-code", + strings.NewReader(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`), + ) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) + + logOutput := buffer.String() + assert.NotContains(t, logOutput, "challenge-123") + assert.NotContains(t, logOutput, "123456") + assert.NotContains(t, logOutput, "public-key-material") + assert.NotContains(t, logOutput, "pilot@example.com") +} + +// recordingAuthServiceClient captures handler inputs and route classification +// so tests can assert the exact adapter delegation contract. +type recordingAuthServiceClient struct { + sendEmailCodeResult SendEmailCodeResult + sendEmailCodeErr error + sendEmailCodeInput SendEmailCodeInput + sendEmailCodeRouteClass PublicRouteClass + sendEmailCodeRouteClassOK bool + sendEmailCodeCalls int + + confirmEmailCodeResult ConfirmEmailCodeResult + confirmEmailCodeErr error + confirmEmailCodeInput ConfirmEmailCodeInput + confirmEmailCodeRouteClass PublicRouteClass + confirmEmailCodeRouteClassOK bool + confirmEmailCodeCalls int +} + +func (c *recordingAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) { + c.sendEmailCodeCalls++ + c.sendEmailCodeInput = input + + c.sendEmailCodeRouteClass, c.sendEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx) + + return c.sendEmailCodeResult, c.sendEmailCodeErr +} + +func (c *recordingAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) { + c.confirmEmailCodeCalls++ + c.confirmEmailCodeInput = input + + c.confirmEmailCodeRouteClass, c.confirmEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx) + + return c.confirmEmailCodeResult, c.confirmEmailCodeErr +} diff --git a/gateway/internal/restapi/server.go b/gateway/internal/restapi/server.go new file mode 100644 index 0000000..28b0701 --- /dev/null +++ b/gateway/internal/restapi/server.go @@ -0,0 +1,388 @@ +// Package restapi exposes the unauthenticated public REST surface of the +// gateway. +package restapi + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + + "galaxy/gateway/internal/config" + "galaxy/gateway/internal/telemetry" + + "github.com/gin-gonic/gin" + "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" + "go.uber.org/zap" +) + +const ( + jsonContentType = "application/json; charset=utf-8" + + errorCodeInvalidRequest = "invalid_request" + errorCodeNotFound = "not_found" + errorCodeMethodNotAllowed = "method_not_allowed" + errorCodeInternalError = "internal_error" + errorCodeServiceUnavailable = "service_unavailable" + + publicRESTBaseBucketKeyPrefix = "public_rest/class=" +) + +// PublicRouteClass identifies the public traffic class assigned to an incoming +// REST request before route handling and edge policy evaluation. +type PublicRouteClass string + +const ( + // PublicRouteClassPublicAuth identifies public authentication commands. + PublicRouteClassPublicAuth PublicRouteClass = "public_auth" + + // PublicRouteClassBrowserBootstrap identifies browser bootstrap traffic such + // as the main document request. + PublicRouteClassBrowserBootstrap PublicRouteClass = "browser_bootstrap" + + // PublicRouteClassBrowserAsset identifies browser asset requests. + PublicRouteClassBrowserAsset PublicRouteClass = "browser_asset" + + // PublicRouteClassPublicMisc identifies public traffic that does not match a + // more specific class. + PublicRouteClassPublicMisc PublicRouteClass = "public_misc" +) + +var configureGinModeOnce sync.Once + +// Normalized returns c when it belongs to the stable public route class set. +// Unknown or empty values collapse to PublicRouteClassPublicMisc so edge policy +// code can rely on a fixed anti-abuse namespace. +func (c PublicRouteClass) Normalized() PublicRouteClass { + switch c { + case PublicRouteClassPublicAuth, + PublicRouteClassBrowserBootstrap, + PublicRouteClassBrowserAsset, + PublicRouteClassPublicMisc: + return c + default: + return PublicRouteClassPublicMisc + } +} + +// BaseBucketKey returns the canonical base rate-limit namespace for c. The key +// stays scoped only by the normalized public route class; callers may append +// subject dimensions such as IP or identity without redefining the class +// namespace. +func (c PublicRouteClass) BaseBucketKey() string { + return publicRESTBaseBucketKeyPrefix + string(c.Normalized()) +} + +// PublicTrafficClassifier maps public REST requests to the public anti-abuse +// class used by the gateway edge. The server normalizes classifier outputs to +// the stable class set before storing them in request context. +type PublicTrafficClassifier interface { + Classify(*http.Request) PublicRouteClass +} + +// ServerDependencies describes the optional collaborators used by the public +// REST server. The zero value is valid and keeps the process runnable with the +// built-in defaults. +type ServerDependencies struct { + // Classifier assigns the public anti-abuse class before route handling. + // When nil, the gateway default classifier is used. + Classifier PublicTrafficClassifier + + // AuthService delegates public auth commands to the Auth / Session Service. + // When nil, public auth routes remain mounted and return a stable + // service-unavailable response. + AuthService AuthServiceClient + + // Limiter applies the public REST rate-limit policy. When nil, a default + // process-local in-memory limiter is used. + Limiter PublicRequestLimiter + + // Observer records malformed-request telemetry for the public REST layer. + // When nil, a no-op observer is used. + Observer PublicRequestObserver + + // Logger writes structured transport logs for public REST traffic. When nil, + // a no-op logger is used. + Logger *zap.Logger + + // Telemetry records low-cardinality edge metrics. When nil, metrics are + // disabled. + Telemetry *telemetry.Runtime +} + +// Server owns the public unauthenticated REST listener exposed by the gateway. +type Server struct { + cfg config.PublicHTTPConfig + + handler http.Handler + logger *zap.Logger + + stateMu sync.RWMutex + server *http.Server + listener net.Listener +} + +// NewServer constructs a public REST server for the supplied listener +// configuration and dependency bundle. Nil dependencies are replaced with safe +// defaults so the gateway can still expose the documented public surface. +func NewServer(cfg config.PublicHTTPConfig, deps ServerDependencies) *Server { + deps = normalizeServerDependencies(deps) + + return &Server{ + cfg: cfg, + handler: newPublicHandlerWithConfig(cfg, deps), + logger: deps.Logger.Named("public_http"), + } +} + +// Run binds the configured listener and serves the public REST surface until +// Shutdown closes the server. +func (s *Server) Run(ctx context.Context) error { + if ctx == nil { + return errors.New("run public REST server: nil context") + } + if err := ctx.Err(); err != nil { + return err + } + + listener, err := net.Listen("tcp", s.cfg.Addr) + if err != nil { + return fmt.Errorf("run public REST server: listen on %q: %w", s.cfg.Addr, err) + } + + server := &http.Server{ + Handler: s.handler, + ReadHeaderTimeout: s.cfg.ReadHeaderTimeout, + ReadTimeout: s.cfg.ReadTimeout, + IdleTimeout: s.cfg.IdleTimeout, + } + + s.stateMu.Lock() + s.server = server + s.listener = listener + s.stateMu.Unlock() + + s.logger.Info("public REST server started", zap.String("addr", listener.Addr().String())) + + defer func() { + s.stateMu.Lock() + s.server = nil + s.listener = nil + s.stateMu.Unlock() + }() + + err = server.Serve(listener) + switch { + case err == nil: + return nil + case errors.Is(err, http.ErrServerClosed): + s.logger.Info("public REST server stopped") + return nil + default: + return fmt.Errorf("run public REST server: serve on %q: %w", s.cfg.Addr, err) + } +} + +// Shutdown gracefully stops the public REST server within ctx. +func (s *Server) Shutdown(ctx context.Context) error { + if ctx == nil { + return errors.New("shutdown public REST server: nil context") + } + + s.stateMu.RLock() + server := s.server + s.stateMu.RUnlock() + + if server == nil { + return nil + } + + if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("shutdown public REST server: %w", err) + } + + return nil +} + +// PublicRouteClassFromContext returns the previously classified normalized +// public route class stored in ctx. +func PublicRouteClassFromContext(ctx context.Context) (PublicRouteClass, bool) { + if ctx == nil { + return "", false + } + + class, ok := ctx.Value(publicRouteClassContextKey{}).(PublicRouteClass) + if !ok { + return "", false + } + + return class.Normalized(), true +} + +type publicRouteClassContextKey struct{} + +type defaultPublicTrafficClassifier struct{} + +// Classify maps the incoming request into a stable public route class that can +// later drive anti-abuse policy and rate limiting. +func (defaultPublicTrafficClassifier) Classify(r *http.Request) PublicRouteClass { + switch { + case isPublicAuthRequest(r): + return PublicRouteClassPublicAuth + case isBrowserBootstrapRequest(r): + return PublicRouteClassBrowserBootstrap + case isBrowserAssetRequest(r): + return PublicRouteClassBrowserAsset + default: + return PublicRouteClassPublicMisc + } +} + +func normalizeServerDependencies(deps ServerDependencies) ServerDependencies { + if deps.Classifier == nil { + deps.Classifier = defaultPublicTrafficClassifier{} + } + if deps.AuthService == nil { + deps.AuthService = unavailableAuthServiceClient{} + } + if deps.Limiter == nil { + deps.Limiter = newInMemoryPublicRequestLimiter() + } + if deps.Observer == nil { + deps.Observer = noopPublicRequestObserver{} + } + if deps.Logger == nil { + deps.Logger = zap.NewNop() + } + + return deps +} + +func newPublicHandler(deps ServerDependencies) http.Handler { + return newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), deps) +} + +func newPublicHandlerWithConfig(cfg config.PublicHTTPConfig, deps ServerDependencies) http.Handler { + configureGinModeOnce.Do(func() { + gin.SetMode(gin.ReleaseMode) + }) + + deps = normalizeServerDependencies(deps) + + router := gin.New() + router.HandleMethodNotAllowed = true + router.Use(gin.CustomRecovery(func(c *gin.Context, _ any) { + abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error") + })) + router.Use(otelgin.Middleware("galaxy-edge-gateway-public")) + router.Use(withPublicObservability(deps.Logger.Named("public_http"), deps.Telemetry)) + router.Use(withPublicRouteClass(deps.Classifier)) + router.Use(withPublicAntiAbuse(cfg.AntiAbuse, deps.Limiter, deps.Observer)) + + router.GET("/healthz", handleHealthz) + router.GET("/readyz", handleReadyz) + router.POST("/api/v1/public/auth/send-email-code", handleSendEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout)) + router.POST("/api/v1/public/auth/confirm-email-code", handleConfirmEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout)) + + router.NoMethod(func(c *gin.Context) { + allowMethods := allowedMethodsForPath(c.Request.URL.Path) + if allowMethods != "" { + c.Header("Allow", allowMethods) + } + + abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route") + }) + router.NoRoute(func(c *gin.Context) { + abortWithError(c, http.StatusNotFound, errorCodeNotFound, "resource was not found") + }) + + return router +} + +func handleHealthz(c *gin.Context) { + c.JSON(http.StatusOK, statusResponse{Status: "ok"}) +} + +func handleReadyz(c *gin.Context) { + c.JSON(http.StatusOK, statusResponse{Status: "ready"}) +} + +func withPublicRouteClass(classifier PublicTrafficClassifier) gin.HandlerFunc { + return func(c *gin.Context) { + class := classifier.Classify(c.Request).Normalized() + ctx := context.WithValue(c.Request.Context(), publicRouteClassContextKey{}, class) + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} + +func isPublicAuthRequest(r *http.Request) bool { + return r.Method == http.MethodPost && isPublicAuthPath(r.URL.Path) +} + +func isBrowserBootstrapRequest(r *http.Request) bool { + if r.Method == http.MethodGet && r.URL.Path == "/" { + return true + } + + return matchesBrowserBootstrapRequestShape(r) +} + +func isBrowserAssetRequest(r *http.Request) bool { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + return false + } + + return matchesBrowserAssetRequestShape(r) +} + +type statusResponse struct { + Status string `json:"status"` +} + +type errorResponse struct { + Error errorBody `json:"error"` +} + +type errorBody struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func abortWithError(c *gin.Context, statusCode int, code string, message string) { + if c != nil { + c.Set(publicErrorCodeContextKey, code) + } + c.AbortWithStatusJSON(statusCode, errorResponse{ + Error: errorBody{ + Code: code, + Message: message, + }, + }) +} + +const publicErrorCodeContextKey = "public_error_code" + +func allowedMethodsForPath(requestPath string) string { + switch requestPath { + case "/healthz", "/readyz": + return http.MethodGet + case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code": + return http.MethodPost + default: + return "" + } +} + +func (s *Server) listenAddr() string { + s.stateMu.RLock() + defer s.stateMu.RUnlock() + + if s.listener == nil { + return "" + } + + return s.listener.Addr().String() +} diff --git a/gateway/internal/restapi/server_test.go b/gateway/internal/restapi/server_test.go new file mode 100644 index 0000000..e21fd3f --- /dev/null +++ b/gateway/internal/restapi/server_test.go @@ -0,0 +1,459 @@ +package restapi + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "galaxy/gateway/internal/app" + "galaxy/gateway/internal/config" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPublicHandlerHealthEndpoints(t *testing.T) { + t.Parallel() + + handler := newPublicHandler(ServerDependencies{}) + + tests := []struct { + name string + method string + target string + wantStatus int + wantType string + wantBody string + wantAllow string + }{ + { + name: "healthz", + method: http.MethodGet, + target: "/healthz", + wantStatus: http.StatusOK, + wantType: jsonContentType, + wantBody: `{"status":"ok"}`, + }, + { + name: "readyz", + method: http.MethodGet, + target: "/readyz", + wantStatus: http.StatusOK, + wantType: jsonContentType, + wantBody: `{"status":"ready"}`, + }, + { + name: "wrong method on known route", + method: http.MethodPost, + target: "/healthz", + wantStatus: http.StatusMethodNotAllowed, + wantType: jsonContentType, + wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, + wantAllow: http.MethodGet, + }, + { + name: "unknown route", + method: http.MethodGet, + target: "/unknown", + wantStatus: http.StatusNotFound, + wantType: jsonContentType, + wantBody: `{"error":{"code":"not_found","message":"resource was not found"}}`, + }, + { + name: "wrong method on public auth route", + method: http.MethodGet, + target: "/api/v1/public/auth/send-email-code", + wantStatus: http.StatusMethodNotAllowed, + wantType: jsonContentType, + wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, + wantAllow: http.MethodPost, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(tt.method, tt.target, nil) + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, tt.wantType, recorder.Header().Get("Content-Type")) + assert.Equal(t, tt.wantBody, recorder.Body.String()) + assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow")) + }) + } +} + +func TestDefaultPublicTrafficClassifier(t *testing.T) { + t.Parallel() + + classifier := defaultPublicTrafficClassifier{} + + tests := []struct { + name string + method string + target string + accept string + wantClass PublicRouteClass + }{ + { + name: "public auth route", + method: http.MethodPost, + target: "/api/v1/public/auth/send-email-code", + wantClass: PublicRouteClassPublicAuth, + }, + { + name: "public auth confirm route", + method: http.MethodPost, + target: "/api/v1/public/auth/confirm-email-code", + wantClass: PublicRouteClassPublicAuth, + }, + { + name: "browser bootstrap route", + method: http.MethodGet, + target: "/", + wantClass: PublicRouteClassBrowserBootstrap, + }, + { + name: "browser asset route", + method: http.MethodGet, + target: "/assets/app.js", + wantClass: PublicRouteClassBrowserAsset, + }, + { + name: "browser asset head request", + method: http.MethodHead, + target: "/assets/app.js", + wantClass: PublicRouteClassBrowserAsset, + }, + { + name: "browser asset extension request", + method: http.MethodGet, + target: "/manifest.webmanifest", + wantClass: PublicRouteClassBrowserAsset, + }, + { + name: "public misc route", + method: http.MethodPost, + target: "/api/v1/public/unknown", + wantClass: PublicRouteClassPublicMisc, + }, + { + name: "html accept bootstrap route", + method: http.MethodGet, + target: "/app", + accept: "application/json, text/html;q=0.9", + wantClass: PublicRouteClassBrowserBootstrap, + }, + { + name: "public auth wins over browser accept header", + method: http.MethodPost, + target: "/api/v1/public/auth/confirm-email-code", + accept: "text/html", + wantClass: PublicRouteClassPublicAuth, + }, + { + name: "probe with html accept is bootstrap", + method: http.MethodGet, + target: "/healthz", + accept: "text/html", + wantClass: PublicRouteClassBrowserBootstrap, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(tt.method, tt.target, nil) + if tt.accept != "" { + req.Header.Set("Accept", tt.accept) + } + + assert.Equal(t, tt.wantClass, classifier.Classify(req)) + }) + } +} + +func TestPublicRouteClassNormalized(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input PublicRouteClass + want PublicRouteClass + }{ + { + name: "public auth", + input: PublicRouteClassPublicAuth, + want: PublicRouteClassPublicAuth, + }, + { + name: "browser bootstrap", + input: PublicRouteClassBrowserBootstrap, + want: PublicRouteClassBrowserBootstrap, + }, + { + name: "browser asset", + input: PublicRouteClassBrowserAsset, + want: PublicRouteClassBrowserAsset, + }, + { + name: "public misc", + input: PublicRouteClassPublicMisc, + want: PublicRouteClassPublicMisc, + }, + { + name: "unknown collapses to misc", + input: PublicRouteClass("unexpected"), + want: PublicRouteClassPublicMisc, + }, + { + name: "empty collapses to misc", + input: PublicRouteClass(""), + want: PublicRouteClassPublicMisc, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, tt.input.Normalized()) + }) + } +} + +func TestPublicRouteClassBaseBucketKeyIsolation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + class PublicRouteClass + wantKey string + }{ + { + name: "public auth", + class: PublicRouteClassPublicAuth, + wantKey: "public_rest/class=public_auth", + }, + { + name: "browser bootstrap", + class: PublicRouteClassBrowserBootstrap, + wantKey: "public_rest/class=browser_bootstrap", + }, + { + name: "browser asset", + class: PublicRouteClassBrowserAsset, + wantKey: "public_rest/class=browser_asset", + }, + { + name: "public misc", + class: PublicRouteClassPublicMisc, + wantKey: "public_rest/class=public_misc", + }, + { + name: "unknown collapses to misc namespace", + class: PublicRouteClass("unexpected"), + wantKey: "public_rest/class=public_misc", + }, + } + + seenKeys := make(map[string]PublicRouteClass, len(tests)) + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.wantKey, tt.class.BaseBucketKey()) + }) + + normalizedClass := tt.class.Normalized() + if normalizedClass == PublicRouteClassPublicMisc && tt.class != PublicRouteClassPublicMisc { + continue + } + + if previousClass, exists := seenKeys[tt.wantKey]; exists { + require.FailNowf(t, "bucket key collision", "class %q collides with %q on key %q", tt.class, previousClass, tt.wantKey) + } + + seenKeys[tt.wantKey] = tt.class + } + + assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserBootstrap.BaseBucketKey()) + assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserAsset.BaseBucketKey()) +} + +func TestWithPublicRouteClassStoresClassInContext(t *testing.T) { + t.Parallel() + + router := gin.New() + router.Use(withPublicRouteClass(staticClassifier{class: PublicRouteClassBrowserAsset})) + router.GET("/assets/app.js", func(c *gin.Context) { + class, ok := PublicRouteClassFromContext(c.Request.Context()) + require.True(t, ok) + assert.Equal(t, PublicRouteClassBrowserAsset, class) + + c.JSON(http.StatusOK, statusResponse{Status: "ok"}) + }) + + req := httptest.NewRequest(http.MethodGet, "/assets/app.js", nil) + recorder := httptest.NewRecorder() + + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, `{"status":"ok"}`, recorder.Body.String()) +} + +func TestWithPublicRouteClassNormalizesUnsupportedClassToPublicMisc(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + class PublicRouteClass + }{ + { + name: "unknown class", + class: PublicRouteClass("unexpected"), + }, + { + name: "empty class", + class: PublicRouteClass(""), + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + router := gin.New() + router.Use(withPublicRouteClass(staticClassifier{class: tt.class})) + router.GET("/", func(c *gin.Context) { + class, ok := PublicRouteClassFromContext(c.Request.Context()) + require.True(t, ok) + assert.Equal(t, PublicRouteClassPublicMisc, class) + + c.JSON(http.StatusOK, statusResponse{Status: "ok"}) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + recorder := httptest.NewRecorder() + + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, `{"status":"ok"}`, recorder.Body.String()) + }) + } +} + +func TestServerLifecycle(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + ShutdownTimeout: time.Second, + PublicHTTP: func() config.PublicHTTPConfig { + publicHTTPCfg := config.DefaultPublicHTTPConfig() + publicHTTPCfg.Addr = "127.0.0.1:0" + publicHTTPCfg.AntiAbuse.PublicMisc.RateLimit = config.PublicRateLimitConfig{ + Requests: 1000, + Window: time.Minute, + Burst: 1000, + } + return publicHTTPCfg + }(), + } + + server := NewServer(cfg.PublicHTTP, ServerDependencies{}) + application := app.New(cfg, server) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan error, 1) + go func() { + resultCh <- application.Run(ctx) + }() + + addr := waitForListenAddr(t, server) + waitForHealthResponse(t, addr) + + cancel() + + select { + case err := <-resultCh: + require.NoError(t, err) + case <-time.After(2 * time.Second): + require.FailNow(t, "Run() did not return after cancellation") + } +} + +type staticClassifier struct { + class PublicRouteClass +} + +func (c staticClassifier) Classify(*http.Request) PublicRouteClass { + return c.class +} + +func waitForListenAddr(t *testing.T, server *Server) string { + t.Helper() + + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if addr := server.listenAddr(); addr != "" { + return addr + } + + time.Sleep(10 * time.Millisecond) + } + + require.FailNow(t, "server did not start listening") + return "" +} + +func waitForHealthResponse(t *testing.T, addr string) { + t.Helper() + + client := &http.Client{Timeout: 100 * time.Millisecond} + url := "http://" + addr + "/healthz" + deadline := time.Now().Add(time.Second) + + for time.Now().Before(deadline) { + resp, err := client.Get(url) + if err != nil { + time.Sleep(10 * time.Millisecond) + continue + } + + body, readErr := io.ReadAll(resp.Body) + closeErr := resp.Body.Close() + require.NoError(t, readErr) + require.NoError(t, closeErr) + require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, `{"status":"ok"}`, strings.TrimSpace(string(body))) + + return + } + + require.FailNowf(t, "health check timed out", "url=%s", url) +} diff --git a/gateway/internal/session/memory.go b/gateway/internal/session/memory.go new file mode 100644 index 0000000..7de963c --- /dev/null +++ b/gateway/internal/session/memory.go @@ -0,0 +1,88 @@ +package session + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" +) + +// MemoryCache stores session record snapshots in process-local memory. It is +// intended for the authenticated gateway hot path and deliberately keeps no +// TTL or size-based eviction policy. +type MemoryCache struct { + mu sync.RWMutex + records map[string]Record +} + +// NewMemoryCache constructs an empty process-local session snapshot store. +func NewMemoryCache() *MemoryCache { + return &MemoryCache{ + records: make(map[string]Record), + } +} + +// Lookup resolves deviceSessionID from the process-local snapshot map. +func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) { + if c == nil { + return Record{}, errors.New("lookup session from in-memory cache: nil cache") + } + if ctx == nil || fmt.Sprint(ctx) == "context.TODO" { + return Record{}, errors.New("lookup session from in-memory cache: nil context") + } + if strings.TrimSpace(deviceSessionID) == "" { + return Record{}, errors.New("lookup session from in-memory cache: empty device session id") + } + + c.mu.RLock() + record, ok := c.records[deviceSessionID] + c.mu.RUnlock() + if !ok { + return Record{}, fmt.Errorf("lookup session from in-memory cache: %w", ErrNotFound) + } + + return cloneRecord(record), nil +} + +// Upsert stores record in the process-local snapshot map after validating the +// same session invariants expected from the Redis-backed cache. +func (c *MemoryCache) Upsert(record Record) error { + if c == nil { + return errors.New("upsert session into in-memory cache: nil cache") + } + if err := validateRecord(record.DeviceSessionID, record); err != nil { + return fmt.Errorf("upsert session into in-memory cache: %w", err) + } + + cloned := cloneRecord(record) + + c.mu.Lock() + c.records[record.DeviceSessionID] = cloned + c.mu.Unlock() + + return nil +} + +// Delete removes the local snapshot for deviceSessionID when one exists. +func (c *MemoryCache) Delete(deviceSessionID string) { + if c == nil || strings.TrimSpace(deviceSessionID) == "" { + return + } + + c.mu.Lock() + delete(c.records, deviceSessionID) + c.mu.Unlock() +} + +func cloneRecord(record Record) Record { + cloned := record + if record.RevokedAtMS != nil { + value := *record.RevokedAtMS + cloned.RevokedAtMS = &value + } + + return cloned +} + +var _ SnapshotStore = (*MemoryCache)(nil) diff --git a/gateway/internal/session/readthrough.go b/gateway/internal/session/readthrough.go new file mode 100644 index 0000000..570eb7c --- /dev/null +++ b/gateway/internal/session/readthrough.go @@ -0,0 +1,68 @@ +package session + +import ( + "context" + "errors" + "fmt" +) + +// ReadThroughCache resolves authenticated sessions from a process-local +// SnapshotStore first and falls back to another Cache only on a local miss. +type ReadThroughCache struct { + local SnapshotStore + fallback Cache +} + +// NewReadThroughCache constructs a hot-path cache that seeds local snapshots +// from fallback on demand. +func NewReadThroughCache(local SnapshotStore, fallback Cache) (*ReadThroughCache, error) { + if local == nil { + return nil, errors.New("new read-through session cache: nil local cache") + } + if fallback == nil { + return nil, errors.New("new read-through session cache: nil fallback cache") + } + + return &ReadThroughCache{ + local: local, + fallback: fallback, + }, nil +} + +// Lookup resolves deviceSessionID from local first, then performs one fallback +// lookup on a local miss and seeds the local cache with the returned snapshot. +func (c *ReadThroughCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) { + if c == nil { + return Record{}, errors.New("lookup session from read-through cache: nil cache") + } + + record, err := c.local.Lookup(ctx, deviceSessionID) + switch { + case err == nil: + return record, nil + case !errors.Is(err, ErrNotFound): + return Record{}, fmt.Errorf("lookup session from read-through cache: %w", err) + } + + record, err = c.fallback.Lookup(ctx, deviceSessionID) + if err != nil { + return Record{}, err + } + + if err := c.local.Upsert(record); err != nil { + return Record{}, fmt.Errorf("lookup session from read-through cache: seed local cache: %w", err) + } + + return cloneRecord(record), nil +} + +// Local returns the mutable process-local snapshot store used by c. +func (c *ReadThroughCache) Local() SnapshotStore { + if c == nil { + return nil + } + + return c.local +} + +var _ Cache = (*ReadThroughCache)(nil) diff --git a/gateway/internal/session/readthrough_test.go b/gateway/internal/session/readthrough_test.go new file mode 100644 index 0000000..e4339a2 --- /dev/null +++ b/gateway/internal/session/readthrough_test.go @@ -0,0 +1,176 @@ +package session + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) { + t.Parallel() + + cache := NewMemoryCache() + revokedAtMS := int64(123456789) + + require.NoError(t, cache.Upsert(Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusRevoked, + RevokedAtMS: &revokedAtMS, + })) + + record, err := cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + require.NotNil(t, record.RevokedAtMS) + + *record.RevokedAtMS = 1 + + stored, err := cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + require.NotNil(t, stored.RevokedAtMS) + assert.Equal(t, revokedAtMS, *stored.RevokedAtMS) +} + +func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) { + t.Parallel() + + local := NewMemoryCache() + require.NoError(t, local.Upsert(Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusActive, + })) + + fallback := &recordingCache{ + lookupFunc: func(context.Context, string) (Record, error) { + return Record{}, errors.New("fallback should not be called") + }, + } + + cache, err := NewReadThroughCache(local, fallback) + require.NoError(t, err) + + record, err := cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + assert.Equal(t, Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusActive, + }, record) + assert.Equal(t, 0, fallback.lookupCalls) +} + +func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) { + t.Parallel() + + local := NewMemoryCache() + fallback := &recordingCache{ + lookupFunc: func(context.Context, string) (Record, error) { + return Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusActive, + }, nil + }, + } + + cache, err := NewReadThroughCache(local, fallback) + require.NoError(t, err) + + record, err := cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + assert.Equal(t, 1, fallback.lookupCalls) + assert.Equal(t, "user-123", record.UserID) + + record, err = cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + assert.Equal(t, 1, fallback.lookupCalls) + assert.Equal(t, "user-123", record.UserID) +} + +func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) { + t.Parallel() + + revokedAtMS := int64(123456789) + local := NewMemoryCache() + fallback := &recordingCache{ + lookupFunc: func(context.Context, string) (Record, error) { + return Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusRevoked, + RevokedAtMS: &revokedAtMS, + }, nil + }, + } + + cache, err := NewReadThroughCache(local, fallback) + require.NoError(t, err) + + record, err := cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + require.NotNil(t, record.RevokedAtMS) + assert.Equal(t, StatusRevoked, record.Status) + assert.Equal(t, 1, fallback.lookupCalls) + + record, err = cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + require.NotNil(t, record.RevokedAtMS) + assert.Equal(t, StatusRevoked, record.Status) + assert.Equal(t, revokedAtMS, *record.RevokedAtMS) + assert.Equal(t, 1, fallback.lookupCalls) +} + +func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) { + t.Parallel() + + revokedAtMS := int64(123456789) + local := NewMemoryCache() + fallback := &recordingCache{ + lookupFunc: func(context.Context, string) (Record, error) { + return Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusRevoked, + RevokedAtMS: &revokedAtMS, + }, nil + }, + } + + cache, err := NewReadThroughCache(local, fallback) + require.NoError(t, err) + + record, err := cache.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + require.NotNil(t, record.RevokedAtMS) + + *record.RevokedAtMS = 1 + + stored, err := local.Lookup(context.Background(), "device-session-123") + require.NoError(t, err) + require.NotNil(t, stored.RevokedAtMS) + assert.Equal(t, revokedAtMS, *stored.RevokedAtMS) +} + +type recordingCache struct { + lookupCalls int + lookupFunc func(context.Context, string) (Record, error) +} + +func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) { + c.lookupCalls++ + if c.lookupFunc != nil { + return c.lookupFunc(ctx, deviceSessionID) + } + + return Record{}, errors.New("lookup is not implemented") +} diff --git a/gateway/internal/session/redis.go b/gateway/internal/session/redis.go new file mode 100644 index 0000000..73df8dd --- /dev/null +++ b/gateway/internal/session/redis.go @@ -0,0 +1,192 @@ +package session + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + "time" + + "galaxy/gateway/internal/config" + + "github.com/redis/go-redis/v9" +) + +// RedisCache implements Cache with Redis GET lookups over strict JSON session +// records. +type RedisCache struct { + client *redis.Client + keyPrefix string + lookupTimeout time.Duration +} + +type redisRecord struct { + DeviceSessionID string `json:"device_session_id"` + UserID string `json:"user_id"` + ClientPublicKey string `json:"client_public_key"` + Status Status `json:"status"` + RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"` +} + +// NewRedisCache constructs a Redis-backed SessionCache from cfg. The returned +// cache is read-only from the gateway perspective and does not write or mutate +// Redis state. +func NewRedisCache(cfg config.SessionCacheRedisConfig) (*RedisCache, error) { + if strings.TrimSpace(cfg.Addr) == "" { + return nil, errors.New("new redis session cache: redis addr must not be empty") + } + if cfg.DB < 0 { + return nil, errors.New("new redis session cache: redis db must not be negative") + } + if cfg.LookupTimeout <= 0 { + return nil, errors.New("new redis session cache: lookup timeout must be positive") + } + + options := &redis.Options{ + Addr: cfg.Addr, + Username: cfg.Username, + Password: cfg.Password, + DB: cfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if cfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + + return &RedisCache{ + client: redis.NewClient(options), + keyPrefix: cfg.KeyPrefix, + lookupTimeout: cfg.LookupTimeout, + }, nil +} + +// Close releases the underlying Redis client resources. +func (c *RedisCache) Close() error { + if c == nil || c.client == nil { + return nil + } + + return c.client.Close() +} + +// Ping verifies that the configured Redis backend is reachable within the +// cache lookup timeout budget. +func (c *RedisCache) Ping(ctx context.Context) error { + if c == nil || c.client == nil { + return errors.New("ping redis session cache: nil cache") + } + if ctx == nil { + return errors.New("ping redis session cache: nil context") + } + + pingCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout) + defer cancel() + + if err := c.client.Ping(pingCtx).Err(); err != nil { + return fmt.Errorf("ping redis session cache: %w", err) + } + + return nil +} + +// Lookup resolves deviceSessionID from Redis, validates the cached JSON +// payload strictly, and returns the decoded session record. +func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) { + if c == nil || c.client == nil { + return Record{}, errors.New("lookup session from redis: nil cache") + } + if ctx == nil || fmt.Sprint(ctx) == "context.TODO" { + return Record{}, errors.New("lookup session from redis: nil context") + } + if strings.TrimSpace(deviceSessionID) == "" { + return Record{}, errors.New("lookup session from redis: empty device session id") + } + + lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout) + defer cancel() + + payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes() + switch { + case errors.Is(err, redis.Nil): + return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound) + case err != nil: + return Record{}, fmt.Errorf("lookup session from redis: %w", err) + } + + record, err := decodeRedisRecord(deviceSessionID, payload) + if err != nil { + return Record{}, fmt.Errorf("lookup session from redis: %w", err) + } + + return record, nil +} + +func (c *RedisCache) lookupKey(deviceSessionID string) string { + return c.keyPrefix + deviceSessionID +} + +func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) { + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.DisallowUnknownFields() + + var stored redisRecord + if err := decoder.Decode(&stored); err != nil { + return Record{}, fmt.Errorf("decode redis session record: %w", err) + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + if err == nil { + return Record{}, errors.New("decode redis session record: unexpected trailing JSON input") + } + return Record{}, fmt.Errorf("decode redis session record: %w", err) + } + + record := Record{ + DeviceSessionID: stored.DeviceSessionID, + UserID: stored.UserID, + ClientPublicKey: stored.ClientPublicKey, + Status: stored.Status, + RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS), + } + + if err := validateRecord(expectedDeviceSessionID, record); err != nil { + return Record{}, err + } + + return record, nil +} + +func validateRecord(expectedDeviceSessionID string, record Record) error { + if record.DeviceSessionID == "" { + return errors.New("session record device_session_id must not be empty") + } + if record.DeviceSessionID != expectedDeviceSessionID { + return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID) + } + if record.UserID == "" { + return errors.New("session record user_id must not be empty") + } + if record.ClientPublicKey == "" { + return errors.New("session record client_public_key must not be empty") + } + if !record.Status.IsKnown() { + return fmt.Errorf("session record status %q is unsupported", record.Status) + } + + return nil +} + +func cloneOptionalInt64(value *int64) *int64 { + if value == nil { + return nil + } + + cloned := *value + return &cloned +} + +var _ Cache = (*RedisCache)(nil) diff --git a/gateway/internal/session/redis_test.go b/gateway/internal/session/redis_test.go new file mode 100644 index 0000000..a0ca24c --- /dev/null +++ b/gateway/internal/session/redis_test.go @@ -0,0 +1,331 @@ +package session + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "galaxy/gateway/internal/config" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRedisCache(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + tests := []struct { + name string + cfg config.SessionCacheRedisConfig + wantErr string + }{ + { + name: "valid config", + cfg: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + DB: 2, + KeyPrefix: "gateway:session:", + LookupTimeout: 250 * time.Millisecond, + }, + }, + { + name: "empty addr", + cfg: config.SessionCacheRedisConfig{ + LookupTimeout: 250 * time.Millisecond, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "negative db", + cfg: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + DB: -1, + LookupTimeout: 250 * time.Millisecond, + }, + wantErr: "redis db must not be negative", + }, + { + name: "non-positive lookup timeout", + cfg: config.SessionCacheRedisConfig{ + Addr: server.Addr(), + }, + wantErr: "lookup timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cache, err := NewRedisCache(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cache.Close()) + }) + }) + } +} + +func TestRedisCachePing(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{}) + + require.NoError(t, cache.Ping(context.Background())) +} + +func TestRedisCacheLookup(t *testing.T) { + t.Parallel() + + revokedAtMS := int64(123456789) + + tests := []struct { + name string + cfg config.SessionCacheRedisConfig + requestID string + seed func(*testing.T, *miniredis.Miniredis, config.SessionCacheRedisConfig) + want Record + wantErrIs error + wantErrText string + assertErrText string + }{ + { + name: "active cache hit", + requestID: "device-session-123", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "gateway:session:", + }, + seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) { + t.Helper() + setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-123", redisRecord{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusActive, + }) + }, + want: Record{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-123", + Status: StatusActive, + }, + }, + { + name: "missing session", + requestID: "device-session-404", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "gateway:session:", + }, + wantErrIs: ErrNotFound, + assertErrText: "session cache record not found", + }, + { + name: "revoked session", + requestID: "device-session-revoked", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "gateway:session:", + }, + seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) { + t.Helper() + setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-revoked", redisRecord{ + DeviceSessionID: "device-session-revoked", + UserID: "user-777", + ClientPublicKey: "public-key-777", + Status: StatusRevoked, + RevokedAtMS: &revokedAtMS, + }) + }, + want: Record{ + DeviceSessionID: "device-session-revoked", + UserID: "user-777", + ClientPublicKey: "public-key-777", + Status: StatusRevoked, + RevokedAtMS: &revokedAtMS, + }, + }, + { + name: "malformed json", + requestID: "device-session-bad-json", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "gateway:session:", + }, + seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) { + t.Helper() + server.Set(cfg.KeyPrefix+"device-session-bad-json", "{") + }, + wantErrText: "decode redis session record", + }, + { + name: "unknown status", + requestID: "device-session-unknown-status", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "gateway:session:", + }, + seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) { + t.Helper() + setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-unknown-status", redisRecord{ + DeviceSessionID: "device-session-unknown-status", + UserID: "user-1", + ClientPublicKey: "public-key-1", + Status: Status("paused"), + }) + }, + wantErrText: `status "paused" is unsupported`, + }, + { + name: "missing required field", + requestID: "device-session-missing-user", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "gateway:session:", + }, + seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) { + t.Helper() + setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-missing-user", redisRecord{ + DeviceSessionID: "device-session-missing-user", + ClientPublicKey: "public-key-1", + Status: StatusActive, + }) + }, + wantErrText: "user_id must not be empty", + }, + { + name: "device session id mismatch", + requestID: "device-session-requested", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "gateway:session:", + }, + seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) { + t.Helper() + setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-requested", redisRecord{ + DeviceSessionID: "device-session-other", + UserID: "user-1", + ClientPublicKey: "public-key-1", + Status: StatusActive, + }) + }, + wantErrText: `does not match requested "device-session-requested"`, + }, + { + name: "key prefix is honored", + requestID: "device-session-prefixed", + cfg: config.SessionCacheRedisConfig{ + KeyPrefix: "custom:session:", + }, + seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) { + t.Helper() + setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-prefixed", redisRecord{ + DeviceSessionID: "device-session-prefixed", + UserID: "user-prefixed", + ClientPublicKey: "public-key-prefixed", + Status: StatusActive, + }) + setRedisSessionRecord(t, server, "gateway:session:device-session-prefixed", redisRecord{ + DeviceSessionID: "device-session-prefixed", + UserID: "wrong-user", + ClientPublicKey: "wrong-key", + Status: StatusRevoked, + }) + }, + want: Record{ + DeviceSessionID: "device-session-prefixed", + UserID: "user-prefixed", + ClientPublicKey: "public-key-prefixed", + Status: StatusActive, + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + cfg := tt.cfg + cfg.Addr = server.Addr() + cfg.DB = 0 + cfg.LookupTimeout = 250 * time.Millisecond + + if tt.seed != nil { + tt.seed(t, server, cfg) + } + + cache := newTestRedisCache(t, server, cfg) + record, err := cache.Lookup(context.Background(), tt.requestID) + if tt.wantErrIs != nil || tt.wantErrText != "" { + require.Error(t, err) + if tt.wantErrIs != nil { + assert.ErrorIs(t, err, tt.wantErrIs) + } + if tt.wantErrText != "" { + assert.ErrorContains(t, err, tt.wantErrText) + } + if tt.assertErrText != "" { + assert.ErrorContains(t, err, tt.assertErrText) + } + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, record) + }) + } +} + +func newTestRedisCache(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) *RedisCache { + t.Helper() + + if cfg.Addr == "" { + cfg.Addr = server.Addr() + } + if cfg.LookupTimeout == 0 { + cfg.LookupTimeout = 250 * time.Millisecond + } + + cache, err := NewRedisCache(cfg) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, cache.Close()) + }) + + return cache +} + +func setRedisSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record redisRecord) { + t.Helper() + + payload, err := json.Marshal(record) + require.NoError(t, err) + + server.Set(key, string(payload)) +} + +func TestRedisCacheLookupNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{}) + + _, err := cache.Lookup(context.TODO(), "device-session-123") + require.Error(t, err) + assert.False(t, errors.Is(err, ErrNotFound)) + assert.ErrorContains(t, err, "nil context") +} diff --git a/gateway/internal/session/session.go b/gateway/internal/session/session.go new file mode 100644 index 0000000..766823f --- /dev/null +++ b/gateway/internal/session/session.go @@ -0,0 +1,80 @@ +// Package session defines the authenticated session-cache contract used by the +// gateway hot path. +package session + +import ( + "context" + "errors" +) + +var ( + // ErrNotFound reports that SessionCache does not currently contain the + // requested device session identifier. + ErrNotFound = errors.New("session cache record not found") +) + +// Cache resolves authenticated device-session state from the gateway hot-path +// cache. +type Cache interface { + // Lookup returns the cached record for deviceSessionID. Implementations must + // wrap ErrNotFound when the cache does not contain the requested record. + Lookup(ctx context.Context, deviceSessionID string) (Record, error) +} + +// SnapshotStore stores mutable session record snapshots inside one gateway +// process and exposes the same read contract as Cache for the hot path. +type SnapshotStore interface { + Cache + + // Upsert stores record under record.DeviceSessionID, replacing any previous + // snapshot for that session. + Upsert(record Record) error + + // Delete removes the local snapshot for deviceSessionID when it exists. + Delete(deviceSessionID string) +} + +// Status identifies the cached lifecycle state of a device session. +type Status string + +const ( + // StatusActive reports that the cached device session may continue through + // later authenticated gateway checks. + StatusActive Status = "active" + + // StatusRevoked reports that the cached device session has been revoked and + // must be rejected before later auth steps run. + StatusRevoked Status = "revoked" +) + +// Record is the minimum authenticated session state required by the gateway +// before signature verification begins. +type Record struct { + // DeviceSessionID is the stable device-session identifier resolved from the + // hot-path cache. + DeviceSessionID string + + // UserID is the authenticated user identity bound to DeviceSessionID. + UserID string + + // ClientPublicKey is the standard base64-encoded raw Ed25519 public key + // material used for request-signature verification. + ClientPublicKey string + + // Status reports whether the cached session is active or revoked. + Status Status + + // RevokedAtMS optionally records when the device session was revoked. + RevokedAtMS *int64 +} + +// IsKnown reports whether s is one of the session states supported by the +// gateway. +func (s Status) IsKnown() bool { + switch s { + case StatusActive, StatusRevoked: + return true + default: + return false + } +} diff --git a/gateway/internal/telemetry/outcome.go b/gateway/internal/telemetry/outcome.go new file mode 100644 index 0000000..1251cb3 --- /dev/null +++ b/gateway/internal/telemetry/outcome.go @@ -0,0 +1,102 @@ +// Package telemetry provides shared edge observability helpers used by the +// gateway transports and internal event consumers. +package telemetry + +import ( + "net/http" + "strings" + + "google.golang.org/grpc/codes" +) + +// EdgeOutcome is the stable low-cardinality outcome vocabulary shared by REST, +// gRPC, push shutdown, and observability backends. +type EdgeOutcome string + +const ( + EdgeOutcomeSuccess EdgeOutcome = "success" + EdgeOutcomeMalformedRequest EdgeOutcome = "malformed_request" + EdgeOutcomeRequestTooLarge EdgeOutcome = "request_too_large" + EdgeOutcomeUnsupportedProtocol EdgeOutcome = "unsupported_protocol" + EdgeOutcomeUnknownSession EdgeOutcome = "unknown_session" + EdgeOutcomeRevokedSession EdgeOutcome = "revoked_session" + EdgeOutcomeInvalidSignature EdgeOutcome = "invalid_signature" + EdgeOutcomeStaleRequest EdgeOutcome = "stale_request" + EdgeOutcomeReplayDetected EdgeOutcome = "replay_detected" + EdgeOutcomeRateLimited EdgeOutcome = "rate_limited" + EdgeOutcomePolicyDenied EdgeOutcome = "policy_denied" + EdgeOutcomeDownstreamUnavailable EdgeOutcome = "downstream_unavailable" + EdgeOutcomeBackendUnavailable EdgeOutcome = "backend_unavailable" + EdgeOutcomeInternalError EdgeOutcome = "internal_error" + EdgeOutcomeGatewayShuttingDown EdgeOutcome = "gateway_shutting_down" +) + +// RejectReason returns the stable reject reason for outcome. Success does not +// produce a reject reason. +func RejectReason(outcome EdgeOutcome) string { + if outcome == EdgeOutcomeSuccess { + return "" + } + + return string(outcome) +} + +// OutcomeFromPublicErrorCode maps the stable public REST error envelope into +// the shared edge-outcome vocabulary. +func OutcomeFromPublicErrorCode(statusCode int, code string) EdgeOutcome { + switch strings.TrimSpace(code) { + case "": + if statusCode < http.StatusBadRequest { + return EdgeOutcomeSuccess + } + return EdgeOutcomeInternalError + case "invalid_request", "method_not_allowed", "not_found": + return EdgeOutcomeMalformedRequest + case "request_too_large": + return EdgeOutcomeRequestTooLarge + case "rate_limited": + return EdgeOutcomeRateLimited + case "service_unavailable": + return EdgeOutcomeBackendUnavailable + default: + if statusCode >= http.StatusInternalServerError { + return EdgeOutcomeInternalError + } + return EdgeOutcomeMalformedRequest + } +} + +// OutcomeFromGRPCStatus maps the stable authenticated gRPC reject contract +// into the shared edge-outcome vocabulary. +func OutcomeFromGRPCStatus(code codes.Code, message string) EdgeOutcome { + switch { + case code == codes.OK: + return EdgeOutcomeSuccess + case code == codes.InvalidArgument: + return EdgeOutcomeMalformedRequest + case code == codes.FailedPrecondition && strings.Contains(message, "unsupported protocol_version"): + return EdgeOutcomeUnsupportedProtocol + case code == codes.Unauthenticated && message == "unknown device session": + return EdgeOutcomeUnknownSession + case code == codes.FailedPrecondition && message == "device session is revoked": + return EdgeOutcomeRevokedSession + case code == codes.Unauthenticated && message == "invalid request signature": + return EdgeOutcomeInvalidSignature + case code == codes.FailedPrecondition && message == "request timestamp is outside the freshness window": + return EdgeOutcomeStaleRequest + case code == codes.FailedPrecondition && message == "request replay detected": + return EdgeOutcomeReplayDetected + case code == codes.ResourceExhausted && message == "authenticated request rate limit exceeded": + return EdgeOutcomeRateLimited + case code == codes.PermissionDenied && message == "authenticated request rejected by edge policy": + return EdgeOutcomePolicyDenied + case code == codes.Unavailable && message == "downstream service is unavailable": + return EdgeOutcomeDownstreamUnavailable + case code == codes.Unavailable && message == "gateway is shutting down": + return EdgeOutcomeGatewayShuttingDown + case code == codes.Unavailable: + return EdgeOutcomeBackendUnavailable + default: + return EdgeOutcomeInternalError + } +} diff --git a/gateway/internal/telemetry/push.go b/gateway/internal/telemetry/push.go new file mode 100644 index 0000000..d646efd --- /dev/null +++ b/gateway/internal/telemetry/push.go @@ -0,0 +1,48 @@ +package telemetry + +import ( + "context" + "errors" + + "galaxy/gateway/internal/push" + + "go.opentelemetry.io/otel/attribute" +) + +// PushObserver adapts Runtime to the push.Observer interface. +type PushObserver struct { + runtime *Runtime +} + +// NewPushObserver constructs a push stream observer backed by runtime. +func NewPushObserver(runtime *Runtime) *PushObserver { + return &PushObserver{runtime: runtime} +} + +// Registered records one active push stream. +func (o *PushObserver) Registered(_ push.StreamBinding) { + if o == nil || o.runtime == nil { + return + } + + o.runtime.AddActivePushStream(context.Background(), 1) +} + +// Unregistered records one active-stream decrement and one closure reason for +// hub-enforced shutdown, overflow, or revocation. +func (o *PushObserver) Unregistered(_ push.StreamBinding, err error) { + if o == nil || o.runtime == nil { + return + } + + o.runtime.AddActivePushStream(context.Background(), -1) + + switch { + case errors.Is(err, push.ErrSubscriptionOverflow): + o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "overflow")) + case errors.Is(err, push.ErrSubscriptionRevoked): + o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "revoked")) + case errors.Is(err, push.ErrHubShuttingDown): + o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "shutdown")) + } +} diff --git a/gateway/internal/telemetry/runtime.go b/gateway/internal/telemetry/runtime.go new file mode 100644 index 0000000..549fbb6 --- /dev/null +++ b/gateway/internal/telemetry/runtime.go @@ -0,0 +1,254 @@ +package telemetry + +import ( + "context" + "errors" + "net/http" + "os" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + otelprom "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/propagation" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.uber.org/zap" +) + +const defaultServiceName = "galaxy-edge-gateway" + +// Runtime owns the shared OpenTelemetry providers, the Prometheus metrics +// handler, and the custom low-cardinality edge instruments. +type Runtime struct { + logger *zap.Logger + + tracerProvider *sdktrace.TracerProvider + meterProvider *sdkmetric.MeterProvider + promHandler http.Handler + + // Public REST instruments. + publicRequests metric.Int64Counter + publicDuration metric.Float64Histogram + + // Authenticated gRPC instruments. + grpcRequests metric.Int64Counter + grpcDuration metric.Float64Histogram + + // Push instruments. + pushActiveStreams metric.Int64UpDownCounter + pushStreamClosers metric.Int64Counter + + // Internal event consumer instruments. + internalEventDrops metric.Int64Counter +} + +// New constructs the gateway telemetry runtime, registers global providers, +// and returns the Prometheus handler used by the admin listener. +func New(ctx context.Context, logger *zap.Logger) (*Runtime, error) { + if logger == nil { + logger = zap.NewNop() + } + + serviceName := strings.TrimSpace(os.Getenv("OTEL_SERVICE_NAME")) + if serviceName == "" { + serviceName = defaultServiceName + } + + res, err := resource.New( + ctx, + resource.WithAttributes(attribute.String("service.name", serviceName)), + ) + if err != nil { + return nil, err + } + + tracerProvider, err := newTracerProvider(ctx, res) + if err != nil { + return nil, err + } + + registry := prometheus.NewRegistry() + exporter, err := otelprom.New(otelprom.WithRegisterer(registry)) + if err != nil { + return nil, err + } + + meterProvider := sdkmetric.NewMeterProvider( + sdkmetric.WithResource(res), + sdkmetric.WithReader(exporter), + ) + + otel.SetTracerProvider(tracerProvider) + otel.SetMeterProvider(meterProvider) + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + )) + + meter := meterProvider.Meter("galaxy/gateway") + + publicRequests, err := meter.Int64Counter("gateway.public_http.requests") + if err != nil { + return nil, err + } + publicDuration, err := meter.Float64Histogram("gateway.public_http.duration", metric.WithUnit("ms")) + if err != nil { + return nil, err + } + grpcRequests, err := meter.Int64Counter("gateway.authenticated_grpc.requests") + if err != nil { + return nil, err + } + grpcDuration, err := meter.Float64Histogram("gateway.authenticated_grpc.duration", metric.WithUnit("ms")) + if err != nil { + return nil, err + } + pushActiveStreams, err := meter.Int64UpDownCounter("gateway.push.active_streams") + if err != nil { + return nil, err + } + pushStreamClosers, err := meter.Int64Counter("gateway.push.stream_closures") + if err != nil { + return nil, err + } + internalEventDrops, err := meter.Int64Counter("gateway.internal_event_drops") + if err != nil { + return nil, err + } + + return &Runtime{ + logger: logger, + tracerProvider: tracerProvider, + meterProvider: meterProvider, + promHandler: promhttp.HandlerFor(registry, promhttp.HandlerOpts{}), + publicRequests: publicRequests, + publicDuration: publicDuration, + grpcRequests: grpcRequests, + grpcDuration: grpcDuration, + pushActiveStreams: pushActiveStreams, + pushStreamClosers: pushStreamClosers, + internalEventDrops: internalEventDrops, + }, nil +} + +// Handler returns the Prometheus handler that should be mounted on the admin +// listener. +func (r *Runtime) Handler() http.Handler { + if r == nil || r.promHandler == nil { + return http.NotFoundHandler() + } + + return r.promHandler +} + +// Shutdown flushes the configured telemetry providers. +func (r *Runtime) Shutdown(ctx context.Context) error { + if r == nil { + return nil + } + + var shutdownErr error + if r.meterProvider != nil { + shutdownErr = errors.Join(shutdownErr, r.meterProvider.Shutdown(ctx)) + } + if r.tracerProvider != nil { + shutdownErr = errors.Join(shutdownErr, r.tracerProvider.Shutdown(ctx)) + } + + return shutdownErr +} + +// RecordPublicRequest records one public REST request outcome. +func (r *Runtime) RecordPublicRequest(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) { + if r == nil { + return + } + + options := metric.WithAttributes(attrs...) + r.publicRequests.Add(ctx, 1, options) + r.publicDuration.Record(ctx, duration.Seconds()*1000, options) +} + +// RecordAuthenticatedGRPC records one authenticated gRPC request or stream +// outcome. +func (r *Runtime) RecordAuthenticatedGRPC(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) { + if r == nil { + return + } + + options := metric.WithAttributes(attrs...) + r.grpcRequests.Add(ctx, 1, options) + r.grpcDuration.Record(ctx, duration.Seconds()*1000, options) +} + +// AddActivePushStream records one active-stream delta. +func (r *Runtime) AddActivePushStream(ctx context.Context, delta int64, attrs ...attribute.KeyValue) { + if r == nil { + return + } + + r.pushActiveStreams.Add(ctx, delta, metric.WithAttributes(attrs...)) +} + +// RecordPushStreamClosure records one push-stream closure reason. +func (r *Runtime) RecordPushStreamClosure(ctx context.Context, attrs ...attribute.KeyValue) { + if r == nil { + return + } + + r.pushStreamClosers.Add(ctx, 1, metric.WithAttributes(attrs...)) +} + +// RecordInternalEventDrop records one malformed or rejected internal event. +func (r *Runtime) RecordInternalEventDrop(ctx context.Context, attrs ...attribute.KeyValue) { + if r == nil { + return + } + + r.internalEventDrops.Add(ctx, 1, metric.WithAttributes(attrs...)) +} + +func newTracerProvider(ctx context.Context, res *resource.Resource) (*sdktrace.TracerProvider, error) { + exporterName := strings.TrimSpace(os.Getenv("OTEL_TRACES_EXPORTER")) + if exporterName == "" || exporterName == "none" { + return sdktrace.NewTracerProvider(sdktrace.WithResource(res)), nil + } + + if exporterName != "otlp" { + return nil, errors.New("unsupported OTEL_TRACES_EXPORTER value") + } + + protocol := strings.TrimSpace(os.Getenv("OTEL_EXPORTER_OTLP_TRACES_PROTOCOL")) + if protocol == "" { + protocol = strings.TrimSpace(os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL")) + } + + var ( + exporter sdktrace.SpanExporter + err error + ) + switch protocol { + case "", "http/protobuf": + exporter, err = otlptracehttp.New(ctx) + case "grpc": + exporter, err = otlptracegrpc.New(ctx) + default: + return nil, errors.New("unsupported OTEL exporter protocol") + } + if err != nil { + return nil, err + } + + return sdktrace.NewTracerProvider( + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(res), + ), nil +} diff --git a/gateway/internal/testutil/observability.go b/gateway/internal/testutil/observability.go new file mode 100644 index 0000000..ebf2cf1 --- /dev/null +++ b/gateway/internal/testutil/observability.go @@ -0,0 +1,94 @@ +package testutil + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "galaxy/gateway/internal/logging" + "galaxy/gateway/internal/telemetry" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// LogBuffer is a concurrency-safe in-memory buffer used by observability +// tests. +type LogBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +// Write appends p to the buffer. +func (b *LogBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + + return b.buf.Write(p) +} + +// String returns the current buffer contents. +func (b *LogBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + + return b.buf.String() +} + +// NewObservedLogger constructs a JSON zap logger that writes into an in-memory +// buffer suitable for log assertions. +func NewObservedLogger(t *testing.T) (*zap.Logger, *LogBuffer) { + t.Helper() + + buffer := &LogBuffer{} + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "timestamp" + encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(encoderConfig), + zapcore.Lock(zapcore.AddSync(buffer)), + zap.DebugLevel, + ) + + logger := zap.New(core) + t.Cleanup(func() { + require.NoError(t, logging.Sync(logger)) + }) + + return logger, buffer +} + +// NewTelemetryRuntime constructs a telemetry runtime for tests and shuts it +// down automatically. +func NewTelemetryRuntime(t *testing.T, logger *zap.Logger) *telemetry.Runtime { + t.Helper() + + runtime, err := telemetry.New(context.Background(), logger) + require.NoError(t, err) + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, runtime.Shutdown(ctx)) + }) + + return runtime +} + +// ScrapeMetrics returns the Prometheus exposition produced by handler. +func ScrapeMetrics(t *testing.T, handler http.Handler) string { + t.Helper() + + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) + + return recorder.Body.String() +} diff --git a/gateway/openapi.yaml b/gateway/openapi.yaml new file mode 100644 index 0000000..3a542f8 --- /dev/null +++ b/gateway/openapi.yaml @@ -0,0 +1,462 @@ +openapi: 3.0.3 +info: + title: Galaxy Edge Gateway Public REST API + version: v1 + description: | + This specification documents the implemented `galaxy/gateway` v1 public + REST surface. + + Implemented endpoints: + - `GET /healthz` + - `GET /readyz` + - `POST /api/v1/public/auth/send-email-code` + - `POST /api/v1/public/auth/confirm-email-code` + + This specification intentionally excludes the private operational admin + listener and its `GET /metrics` endpoint. That endpoint is documented in + `README.md` because it is not part of the public REST contract. + + Common runtime behavior: + - requests are unauthenticated; + - unknown routes return `404` with the JSON error envelope; + - unsupported methods on implemented routes and browser-shaped public paths + return `405` with the same JSON error envelope and an `Allow` header; + - request classification happens before route handling and depends on the + incoming method, path, and selected headers; + - the only stable public route classes are `public_auth`, + `browser_bootstrap`, `browser_asset`, and `public_misc`; + - any unsupported or empty classifier result is normalized to + `public_misc`; + - public REST policy derives its base bucket namespace from the normalized + class as `public_rest/class=`; + - per-IP public REST rate limits use only `RemoteAddr`; `X-Forwarded-For` + and `Forwarded` are intentionally ignored; + - `public_auth` additionally applies normalized identity buckets by + `email` for `send-email-code` and by `challenge_id` for + `confirm-email-code`; + - oversized request bodies are rejected with `413 request_too_large`; + - public REST rate limits reject with `429 rate_limited` and a + `Retry-After` header; + - public auth routes delegate through `AuthServiceClient`; + - the default `cmd/gateway` wiring keeps the auth routes mounted and + returns `503 service_unavailable` until a concrete upstream auth adapter + is configured; + - injected public auth adapters may also project client-safe `4xx/5xx` + `AuthServiceError` envelopes, which the gateway preserves after + normalizing blank or invalid fields. +servers: + - url: http://localhost:8080 + description: | + Example local public REST listener. The actual address is configured by + `GATEWAY_PUBLIC_HTTP_ADDR`. +tags: + - name: Probes + description: Unauthenticated public probe endpoints served by the gateway. + - name: PublicAuth + description: | + Unauthenticated public auth endpoints delegated to the Auth / Session + Service through `AuthServiceClient`. +paths: + /healthz: + get: + tags: + - Probes + operationId: getHealthz + summary: Public liveness probe + description: | + Returns a deterministic JSON payload confirming that the public REST + listener is alive and able to answer requests. + security: [] + x-public-route-classification-note: | + Typical probe requests are classified as `public_misc`. + Requests that match browser bootstrap rules, for example because they + advertise `Accept: text/html`, are classified as `browser_bootstrap` + before the route handler runs. + responses: + "200": + description: Public REST listener is alive. + content: + application/json: + schema: + $ref: "#/components/schemas/HealthzResponse" + examples: + ok: + value: + status: ok + "413": + $ref: "#/components/responses/RequestTooLargeError" + "429": + $ref: "#/components/responses/RateLimitedError" + "500": + $ref: "#/components/responses/InternalError" + /readyz: + get: + tags: + - Probes + operationId: getReadyz + summary: Public readiness probe + description: | + Returns a deterministic JSON payload confirming that the process is + ready to accept public REST traffic. Readiness is local-process only + and does not reflect downstream dependencies. + security: [] + x-public-route-classification-note: | + Typical probe requests are classified as `public_misc`. + Requests that match browser bootstrap rules, for example because they + advertise `Accept: text/html`, are classified as `browser_bootstrap` + before the route handler runs. + responses: + "200": + description: Public REST listener is ready to accept traffic. + content: + application/json: + schema: + $ref: "#/components/schemas/ReadyzResponse" + examples: + ready: + value: + status: ready + "413": + $ref: "#/components/responses/RequestTooLargeError" + "429": + $ref: "#/components/responses/RateLimitedError" + "500": + $ref: "#/components/responses/InternalError" + /api/v1/public/auth/send-email-code: + post: + tags: + - PublicAuth + operationId: sendEmailCode + summary: Start a public e-mail login challenge + description: | + Accepts a single client e-mail address and delegates the command to the + Auth / Session Service. The response returns an opaque `challenge_id` + that must later be confirmed through + `POST /api/v1/public/auth/confirm-email-code`. + + This route is unauthenticated and classified as `public_auth`. + Public REST anti-abuse applies a per-IP bucket derived from + `RemoteAddr` and an additional normalized identity bucket derived from + `email`. + + In the default `cmd/gateway` process wiring the upstream auth adapter + is intentionally absent, so this route returns `503 + service_unavailable` until a concrete `AuthServiceClient` is injected. + When an injected adapter returns a client-safe `AuthServiceError`, the + gateway preserves that projected `4xx/5xx` status and serialized error + envelope after normalizing blank or invalid fields. + security: [] + x-public-route-classification-note: | + This route is always classified as `public_auth`. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/SendEmailCodeRequest" + examples: + default: + value: + email: pilot@example.com + responses: + "200": + description: The login challenge was accepted by the Auth / Session Service. + content: + application/json: + schema: + $ref: "#/components/schemas/SendEmailCodeResponse" + examples: + accepted: + value: + challenge_id: challenge-123 + "400": + $ref: "#/components/responses/InvalidRequestError" + "413": + $ref: "#/components/responses/RequestTooLargeError" + "405": + $ref: "#/components/responses/MethodNotAllowedError" + "429": + $ref: "#/components/responses/RateLimitedError" + "500": + $ref: "#/components/responses/InternalError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" + default: + $ref: "#/components/responses/ProjectedAuthServiceError" + /api/v1/public/auth/confirm-email-code: + post: + tags: + - PublicAuth + operationId: confirmEmailCode + summary: Confirm a public e-mail login challenge + description: | + Completes a previously issued `challenge_id`, sends the verification + `code`, and registers the standard base64-encoded raw 32-byte Ed25519 + `client_public_key` for the new device session. The response returns + the created `device_session_id`. + + This route is unauthenticated and classified as `public_auth`. + Public REST anti-abuse applies a per-IP bucket derived from + `RemoteAddr` and an additional normalized identity bucket derived from + `challenge_id`. + + In the default `cmd/gateway` process wiring the upstream auth adapter + is intentionally absent, so this route returns `503 + service_unavailable` until a concrete `AuthServiceClient` is injected. + When an injected adapter returns a client-safe `AuthServiceError`, the + gateway preserves that projected `4xx/5xx` status and serialized error + envelope after normalizing blank or invalid fields. + security: [] + x-public-route-classification-note: | + This route is always classified as `public_auth`. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ConfirmEmailCodeRequest" + examples: + default: + value: + challenge_id: challenge-123 + code: "123456" + client_public_key: base64-encoded-raw-ed25519-public-key + responses: + "200": + description: The device session was created by the Auth / Session Service. + content: + application/json: + schema: + $ref: "#/components/schemas/ConfirmEmailCodeResponse" + examples: + accepted: + value: + device_session_id: device-session-123 + "400": + $ref: "#/components/responses/InvalidRequestError" + "413": + $ref: "#/components/responses/RequestTooLargeError" + "405": + $ref: "#/components/responses/MethodNotAllowedError" + "429": + $ref: "#/components/responses/RateLimitedError" + "500": + $ref: "#/components/responses/InternalError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" + default: + $ref: "#/components/responses/ProjectedAuthServiceError" +components: + schemas: + HealthzResponse: + type: object + additionalProperties: false + required: + - status + properties: + status: + type: string + description: Deterministic liveness marker. + enum: + - ok + ReadyzResponse: + type: object + additionalProperties: false + required: + - status + properties: + status: + type: string + description: Deterministic readiness marker. + enum: + - ready + SendEmailCodeRequest: + type: object + additionalProperties: false + required: + - email + properties: + email: + type: string + description: Single client e-mail address that should receive the login code. + format: email + SendEmailCodeResponse: + type: object + additionalProperties: false + required: + - challenge_id + properties: + challenge_id: + type: string + description: Opaque challenge identifier returned by the Auth / Session Service. + ConfirmEmailCodeRequest: + type: object + additionalProperties: false + required: + - challenge_id + - code + - client_public_key + properties: + challenge_id: + type: string + description: Opaque challenge identifier previously returned by send-email-code. + code: + type: string + description: Verification code delivered to the client. + client_public_key: + type: string + description: Standard base64-encoded raw 32-byte Ed25519 public key registered for the new device session. + ConfirmEmailCodeResponse: + type: object + additionalProperties: false + required: + - device_session_id + properties: + device_session_id: + type: string + description: Stable identifier of the created device session. + ErrorResponse: + type: object + additionalProperties: false + required: + - error + properties: + error: + $ref: "#/components/schemas/ErrorBody" + ErrorBody: + type: object + additionalProperties: false + required: + - code + - message + properties: + code: + type: string + description: | + Stable gateway-generated or client-safe auth-adapter-projected + error code. Gateway-generated values include `invalid_request`, + `not_found`, `method_not_allowed`, `request_too_large`, + `rate_limited`, `internal_error`, and `service_unavailable`. + message: + type: string + description: Human-readable client-safe error description. + headers: + Allow: + description: Comma-separated list of allowed methods for the target route. + schema: + type: string + example: GET + Retry-After: + description: Seconds until the client should retry a rejected rate-limited request. + schema: + type: string + example: "3600" + responses: + InvalidRequestError: + description: Request body or field values are invalid for the target public auth route. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + invalidRequest: + value: + error: + code: invalid_request + message: email must be a single valid email address + NotFoundError: + description: Request path is not implemented on the current public REST surface. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + notFound: + value: + error: + code: not_found + message: resource was not found + MethodNotAllowedError: + description: Request method is not allowed for an implemented route. + headers: + Allow: + $ref: "#/components/headers/Allow" + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + methodNotAllowed: + value: + error: + code: method_not_allowed + message: request method is not allowed for this route + RequestTooLargeError: + description: Request body exceeds the configured public REST body limit for the route class. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + requestTooLarge: + value: + error: + code: request_too_large + message: request body exceeds the configured limit + RateLimitedError: + description: Request is rejected by the public REST anti-abuse rate limiter. + headers: + Retry-After: + $ref: "#/components/headers/Retry-After" + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + rateLimited: + value: + error: + code: rate_limited + message: request rate limit exceeded + InternalError: + description: Internal gateway error while processing the request. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + internalError: + value: + error: + code: internal_error + message: internal server error + ServiceUnavailableError: + description: | + The public route is mounted, but the configured or default auth adapter + cannot currently serve the request. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + unavailable: + value: + error: + code: service_unavailable + message: auth service is unavailable + ProjectedAuthServiceError: + description: | + Client-safe `4xx/5xx` error envelope projected by an injected public + auth adapter through `AuthServiceError`. The gateway preserves the + projected status and serialized envelope after normalizing blank or + invalid fields. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + projectedRateLimit: + value: + error: + code: upstream_rate_limited + message: too many attempts for this email diff --git a/gateway/proto/galaxy/gateway/v1/edge_gateway.pb.go b/gateway/proto/galaxy/gateway/v1/edge_gateway.pb.go new file mode 100644 index 0000000..a4861f9 --- /dev/null +++ b/gateway/proto/galaxy/gateway/v1/edge_gateway.pb.go @@ -0,0 +1,545 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: galaxy/gateway/v1/edge_gateway.proto + +package gatewayv1 + +import ( + _ "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ExecuteCommandRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // protocol_version identifies the request envelope version. The gateway + // accepts only the literal "v1" after required-field validation succeeds. + ProtocolVersion string `protobuf:"bytes,1,opt,name=protocol_version,json=protocolVersion,proto3" json:"protocol_version,omitempty"` + DeviceSessionId string `protobuf:"bytes,2,opt,name=device_session_id,json=deviceSessionId,proto3" json:"device_session_id,omitempty"` + MessageType string `protobuf:"bytes,3,opt,name=message_type,json=messageType,proto3" json:"message_type,omitempty"` + TimestampMs int64 `protobuf:"varint,4,opt,name=timestamp_ms,json=timestampMs,proto3" json:"timestamp_ms,omitempty"` + RequestId string `protobuf:"bytes,5,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + PayloadBytes []byte `protobuf:"bytes,6,opt,name=payload_bytes,json=payloadBytes,proto3" json:"payload_bytes,omitempty"` + // payload_hash is the raw 32-byte SHA-256 digest of payload_bytes. + PayloadHash []byte `protobuf:"bytes,7,opt,name=payload_hash,json=payloadHash,proto3" json:"payload_hash,omitempty"` + Signature []byte `protobuf:"bytes,8,opt,name=signature,proto3" json:"signature,omitempty"` + TraceId string `protobuf:"bytes,9,opt,name=trace_id,json=traceId,proto3" json:"trace_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecuteCommandRequest) Reset() { + *x = ExecuteCommandRequest{} + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecuteCommandRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecuteCommandRequest) ProtoMessage() {} + +func (x *ExecuteCommandRequest) ProtoReflect() protoreflect.Message { + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecuteCommandRequest.ProtoReflect.Descriptor instead. +func (*ExecuteCommandRequest) Descriptor() ([]byte, []int) { + return file_galaxy_gateway_v1_edge_gateway_proto_rawDescGZIP(), []int{0} +} + +func (x *ExecuteCommandRequest) GetProtocolVersion() string { + if x != nil { + return x.ProtocolVersion + } + return "" +} + +func (x *ExecuteCommandRequest) GetDeviceSessionId() string { + if x != nil { + return x.DeviceSessionId + } + return "" +} + +func (x *ExecuteCommandRequest) GetMessageType() string { + if x != nil { + return x.MessageType + } + return "" +} + +func (x *ExecuteCommandRequest) GetTimestampMs() int64 { + if x != nil { + return x.TimestampMs + } + return 0 +} + +func (x *ExecuteCommandRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *ExecuteCommandRequest) GetPayloadBytes() []byte { + if x != nil { + return x.PayloadBytes + } + return nil +} + +func (x *ExecuteCommandRequest) GetPayloadHash() []byte { + if x != nil { + return x.PayloadHash + } + return nil +} + +func (x *ExecuteCommandRequest) GetSignature() []byte { + if x != nil { + return x.Signature + } + return nil +} + +func (x *ExecuteCommandRequest) GetTraceId() string { + if x != nil { + return x.TraceId + } + return "" +} + +type ExecuteCommandResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + ProtocolVersion string `protobuf:"bytes,1,opt,name=protocol_version,json=protocolVersion,proto3" json:"protocol_version,omitempty"` + RequestId string `protobuf:"bytes,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + TimestampMs int64 `protobuf:"varint,3,opt,name=timestamp_ms,json=timestampMs,proto3" json:"timestamp_ms,omitempty"` + ResultCode string `protobuf:"bytes,4,opt,name=result_code,json=resultCode,proto3" json:"result_code,omitempty"` + PayloadBytes []byte `protobuf:"bytes,5,opt,name=payload_bytes,json=payloadBytes,proto3" json:"payload_bytes,omitempty"` + PayloadHash []byte `protobuf:"bytes,6,opt,name=payload_hash,json=payloadHash,proto3" json:"payload_hash,omitempty"` + Signature []byte `protobuf:"bytes,7,opt,name=signature,proto3" json:"signature,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecuteCommandResponse) Reset() { + *x = ExecuteCommandResponse{} + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecuteCommandResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecuteCommandResponse) ProtoMessage() {} + +func (x *ExecuteCommandResponse) ProtoReflect() protoreflect.Message { + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecuteCommandResponse.ProtoReflect.Descriptor instead. +func (*ExecuteCommandResponse) Descriptor() ([]byte, []int) { + return file_galaxy_gateway_v1_edge_gateway_proto_rawDescGZIP(), []int{1} +} + +func (x *ExecuteCommandResponse) GetProtocolVersion() string { + if x != nil { + return x.ProtocolVersion + } + return "" +} + +func (x *ExecuteCommandResponse) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *ExecuteCommandResponse) GetTimestampMs() int64 { + if x != nil { + return x.TimestampMs + } + return 0 +} + +func (x *ExecuteCommandResponse) GetResultCode() string { + if x != nil { + return x.ResultCode + } + return "" +} + +func (x *ExecuteCommandResponse) GetPayloadBytes() []byte { + if x != nil { + return x.PayloadBytes + } + return nil +} + +func (x *ExecuteCommandResponse) GetPayloadHash() []byte { + if x != nil { + return x.PayloadHash + } + return nil +} + +func (x *ExecuteCommandResponse) GetSignature() []byte { + if x != nil { + return x.Signature + } + return nil +} + +type SubscribeEventsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // protocol_version identifies the request envelope version. The gateway + // accepts only the literal "v1" after required-field validation succeeds. + ProtocolVersion string `protobuf:"bytes,1,opt,name=protocol_version,json=protocolVersion,proto3" json:"protocol_version,omitempty"` + DeviceSessionId string `protobuf:"bytes,2,opt,name=device_session_id,json=deviceSessionId,proto3" json:"device_session_id,omitempty"` + MessageType string `protobuf:"bytes,3,opt,name=message_type,json=messageType,proto3" json:"message_type,omitempty"` + TimestampMs int64 `protobuf:"varint,4,opt,name=timestamp_ms,json=timestampMs,proto3" json:"timestamp_ms,omitempty"` + RequestId string `protobuf:"bytes,5,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // payload_hash is the raw 32-byte SHA-256 digest of payload_bytes. Empty + // payloads must use the SHA-256 digest of the empty byte slice. + PayloadHash []byte `protobuf:"bytes,6,opt,name=payload_hash,json=payloadHash,proto3" json:"payload_hash,omitempty"` + Signature []byte `protobuf:"bytes,7,opt,name=signature,proto3" json:"signature,omitempty"` + PayloadBytes []byte `protobuf:"bytes,8,opt,name=payload_bytes,json=payloadBytes,proto3" json:"payload_bytes,omitempty"` + TraceId string `protobuf:"bytes,9,opt,name=trace_id,json=traceId,proto3" json:"trace_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SubscribeEventsRequest) Reset() { + *x = SubscribeEventsRequest{} + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SubscribeEventsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SubscribeEventsRequest) ProtoMessage() {} + +func (x *SubscribeEventsRequest) ProtoReflect() protoreflect.Message { + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SubscribeEventsRequest.ProtoReflect.Descriptor instead. +func (*SubscribeEventsRequest) Descriptor() ([]byte, []int) { + return file_galaxy_gateway_v1_edge_gateway_proto_rawDescGZIP(), []int{2} +} + +func (x *SubscribeEventsRequest) GetProtocolVersion() string { + if x != nil { + return x.ProtocolVersion + } + return "" +} + +func (x *SubscribeEventsRequest) GetDeviceSessionId() string { + if x != nil { + return x.DeviceSessionId + } + return "" +} + +func (x *SubscribeEventsRequest) GetMessageType() string { + if x != nil { + return x.MessageType + } + return "" +} + +func (x *SubscribeEventsRequest) GetTimestampMs() int64 { + if x != nil { + return x.TimestampMs + } + return 0 +} + +func (x *SubscribeEventsRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *SubscribeEventsRequest) GetPayloadHash() []byte { + if x != nil { + return x.PayloadHash + } + return nil +} + +func (x *SubscribeEventsRequest) GetSignature() []byte { + if x != nil { + return x.Signature + } + return nil +} + +func (x *SubscribeEventsRequest) GetPayloadBytes() []byte { + if x != nil { + return x.PayloadBytes + } + return nil +} + +func (x *SubscribeEventsRequest) GetTraceId() string { + if x != nil { + return x.TraceId + } + return "" +} + +type GatewayEvent struct { + state protoimpl.MessageState `protogen:"open.v1"` + EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"` + EventId string `protobuf:"bytes,2,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"` + TimestampMs int64 `protobuf:"varint,3,opt,name=timestamp_ms,json=timestampMs,proto3" json:"timestamp_ms,omitempty"` + PayloadBytes []byte `protobuf:"bytes,4,opt,name=payload_bytes,json=payloadBytes,proto3" json:"payload_bytes,omitempty"` + PayloadHash []byte `protobuf:"bytes,5,opt,name=payload_hash,json=payloadHash,proto3" json:"payload_hash,omitempty"` + Signature []byte `protobuf:"bytes,6,opt,name=signature,proto3" json:"signature,omitempty"` + RequestId string `protobuf:"bytes,7,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + TraceId string `protobuf:"bytes,8,opt,name=trace_id,json=traceId,proto3" json:"trace_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GatewayEvent) Reset() { + *x = GatewayEvent{} + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GatewayEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GatewayEvent) ProtoMessage() {} + +func (x *GatewayEvent) ProtoReflect() protoreflect.Message { + mi := &file_galaxy_gateway_v1_edge_gateway_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GatewayEvent.ProtoReflect.Descriptor instead. +func (*GatewayEvent) Descriptor() ([]byte, []int) { + return file_galaxy_gateway_v1_edge_gateway_proto_rawDescGZIP(), []int{3} +} + +func (x *GatewayEvent) GetEventType() string { + if x != nil { + return x.EventType + } + return "" +} + +func (x *GatewayEvent) GetEventId() string { + if x != nil { + return x.EventId + } + return "" +} + +func (x *GatewayEvent) GetTimestampMs() int64 { + if x != nil { + return x.TimestampMs + } + return 0 +} + +func (x *GatewayEvent) GetPayloadBytes() []byte { + if x != nil { + return x.PayloadBytes + } + return nil +} + +func (x *GatewayEvent) GetPayloadHash() []byte { + if x != nil { + return x.PayloadHash + } + return nil +} + +func (x *GatewayEvent) GetSignature() []byte { + if x != nil { + return x.Signature + } + return nil +} + +func (x *GatewayEvent) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *GatewayEvent) GetTraceId() string { + if x != nil { + return x.TraceId + } + return "" +} + +var File_galaxy_gateway_v1_edge_gateway_proto protoreflect.FileDescriptor + +const file_galaxy_gateway_v1_edge_gateway_proto_rawDesc = "" + + "\n" + + "$galaxy/gateway/v1/edge_gateway.proto\x12\x11galaxy.gateway.v1\x1a\x1bbuf/validate/validate.proto\"\x9c\x03\n" + + "\x15ExecuteCommandRequest\x122\n" + + "\x10protocol_version\x18\x01 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\x0fprotocolVersion\x123\n" + + "\x11device_session_id\x18\x02 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\x0fdeviceSessionId\x12*\n" + + "\fmessage_type\x18\x03 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\vmessageType\x12*\n" + + "\ftimestamp_ms\x18\x04 \x01(\x03B\a\xbaH\x04\"\x02 \x00R\vtimestampMs\x12&\n" + + "\n" + + "request_id\x18\x05 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\trequestId\x12,\n" + + "\rpayload_bytes\x18\x06 \x01(\fB\a\xbaH\x04z\x02\x10\x01R\fpayloadBytes\x12*\n" + + "\fpayload_hash\x18\a \x01(\fB\a\xbaH\x04z\x02\x10\x01R\vpayloadHash\x12%\n" + + "\tsignature\x18\b \x01(\fB\a\xbaH\x04z\x02\x10\x01R\tsignature\x12\x19\n" + + "\btrace_id\x18\t \x01(\tR\atraceId\"\x8c\x02\n" + + "\x16ExecuteCommandResponse\x12)\n" + + "\x10protocol_version\x18\x01 \x01(\tR\x0fprotocolVersion\x12\x1d\n" + + "\n" + + "request_id\x18\x02 \x01(\tR\trequestId\x12!\n" + + "\ftimestamp_ms\x18\x03 \x01(\x03R\vtimestampMs\x12\x1f\n" + + "\vresult_code\x18\x04 \x01(\tR\n" + + "resultCode\x12#\n" + + "\rpayload_bytes\x18\x05 \x01(\fR\fpayloadBytes\x12!\n" + + "\fpayload_hash\x18\x06 \x01(\fR\vpayloadHash\x12\x1c\n" + + "\tsignature\x18\a \x01(\fR\tsignature\"\x94\x03\n" + + "\x16SubscribeEventsRequest\x122\n" + + "\x10protocol_version\x18\x01 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\x0fprotocolVersion\x123\n" + + "\x11device_session_id\x18\x02 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\x0fdeviceSessionId\x12*\n" + + "\fmessage_type\x18\x03 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\vmessageType\x12*\n" + + "\ftimestamp_ms\x18\x04 \x01(\x03B\a\xbaH\x04\"\x02 \x00R\vtimestampMs\x12&\n" + + "\n" + + "request_id\x18\x05 \x01(\tB\a\xbaH\x04r\x02\x10\x01R\trequestId\x12*\n" + + "\fpayload_hash\x18\x06 \x01(\fB\a\xbaH\x04z\x02\x10\x01R\vpayloadHash\x12%\n" + + "\tsignature\x18\a \x01(\fB\a\xbaH\x04z\x02\x10\x01R\tsignature\x12#\n" + + "\rpayload_bytes\x18\b \x01(\fR\fpayloadBytes\x12\x19\n" + + "\btrace_id\x18\t \x01(\tR\atraceId\"\x8b\x02\n" + + "\fGatewayEvent\x12\x1d\n" + + "\n" + + "event_type\x18\x01 \x01(\tR\teventType\x12\x19\n" + + "\bevent_id\x18\x02 \x01(\tR\aeventId\x12!\n" + + "\ftimestamp_ms\x18\x03 \x01(\x03R\vtimestampMs\x12#\n" + + "\rpayload_bytes\x18\x04 \x01(\fR\fpayloadBytes\x12!\n" + + "\fpayload_hash\x18\x05 \x01(\fR\vpayloadHash\x12\x1c\n" + + "\tsignature\x18\x06 \x01(\fR\tsignature\x12\x1d\n" + + "\n" + + "request_id\x18\a \x01(\tR\trequestId\x12\x19\n" + + "\btrace_id\x18\b \x01(\tR\atraceId2\xd5\x01\n" + + "\vEdgeGateway\x12e\n" + + "\x0eExecuteCommand\x12(.galaxy.gateway.v1.ExecuteCommandRequest\x1a).galaxy.gateway.v1.ExecuteCommandResponse\x12_\n" + + "\x0fSubscribeEvents\x12).galaxy.gateway.v1.SubscribeEventsRequest\x1a\x1f.galaxy.gateway.v1.GatewayEvent0\x01B2Z0galaxy/gateway/proto/galaxy/gateway/v1;gatewayv1b\x06proto3" + +var ( + file_galaxy_gateway_v1_edge_gateway_proto_rawDescOnce sync.Once + file_galaxy_gateway_v1_edge_gateway_proto_rawDescData []byte +) + +func file_galaxy_gateway_v1_edge_gateway_proto_rawDescGZIP() []byte { + file_galaxy_gateway_v1_edge_gateway_proto_rawDescOnce.Do(func() { + file_galaxy_gateway_v1_edge_gateway_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_galaxy_gateway_v1_edge_gateway_proto_rawDesc), len(file_galaxy_gateway_v1_edge_gateway_proto_rawDesc))) + }) + return file_galaxy_gateway_v1_edge_gateway_proto_rawDescData +} + +var file_galaxy_gateway_v1_edge_gateway_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_galaxy_gateway_v1_edge_gateway_proto_goTypes = []any{ + (*ExecuteCommandRequest)(nil), // 0: galaxy.gateway.v1.ExecuteCommandRequest + (*ExecuteCommandResponse)(nil), // 1: galaxy.gateway.v1.ExecuteCommandResponse + (*SubscribeEventsRequest)(nil), // 2: galaxy.gateway.v1.SubscribeEventsRequest + (*GatewayEvent)(nil), // 3: galaxy.gateway.v1.GatewayEvent +} +var file_galaxy_gateway_v1_edge_gateway_proto_depIdxs = []int32{ + 0, // 0: galaxy.gateway.v1.EdgeGateway.ExecuteCommand:input_type -> galaxy.gateway.v1.ExecuteCommandRequest + 2, // 1: galaxy.gateway.v1.EdgeGateway.SubscribeEvents:input_type -> galaxy.gateway.v1.SubscribeEventsRequest + 1, // 2: galaxy.gateway.v1.EdgeGateway.ExecuteCommand:output_type -> galaxy.gateway.v1.ExecuteCommandResponse + 3, // 3: galaxy.gateway.v1.EdgeGateway.SubscribeEvents:output_type -> galaxy.gateway.v1.GatewayEvent + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_galaxy_gateway_v1_edge_gateway_proto_init() } +func file_galaxy_gateway_v1_edge_gateway_proto_init() { + if File_galaxy_gateway_v1_edge_gateway_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_galaxy_gateway_v1_edge_gateway_proto_rawDesc), len(file_galaxy_gateway_v1_edge_gateway_proto_rawDesc)), + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_galaxy_gateway_v1_edge_gateway_proto_goTypes, + DependencyIndexes: file_galaxy_gateway_v1_edge_gateway_proto_depIdxs, + MessageInfos: file_galaxy_gateway_v1_edge_gateway_proto_msgTypes, + }.Build() + File_galaxy_gateway_v1_edge_gateway_proto = out.File + file_galaxy_gateway_v1_edge_gateway_proto_goTypes = nil + file_galaxy_gateway_v1_edge_gateway_proto_depIdxs = nil +} diff --git a/gateway/proto/galaxy/gateway/v1/edge_gateway.proto b/gateway/proto/galaxy/gateway/v1/edge_gateway.proto new file mode 100644 index 0000000..56ecd9e --- /dev/null +++ b/gateway/proto/galaxy/gateway/v1/edge_gateway.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; + +package galaxy.gateway.v1; + +option go_package = "galaxy/gateway/proto/galaxy/gateway/v1;gatewayv1"; + +import "buf/validate/validate.proto"; + +service EdgeGateway { + rpc ExecuteCommand(ExecuteCommandRequest) returns (ExecuteCommandResponse); + rpc SubscribeEvents(SubscribeEventsRequest) returns (stream GatewayEvent); +} + +message ExecuteCommandRequest { + // protocol_version identifies the request envelope version. The gateway + // accepts only the literal "v1" after required-field validation succeeds. + string protocol_version = 1 [(buf.validate.field).string.min_len = 1]; + string device_session_id = 2 [(buf.validate.field).string.min_len = 1]; + string message_type = 3 [(buf.validate.field).string.min_len = 1]; + int64 timestamp_ms = 4 [(buf.validate.field).int64.gt = 0]; + string request_id = 5 [(buf.validate.field).string.min_len = 1]; + bytes payload_bytes = 6 [(buf.validate.field).bytes.min_len = 1]; + // payload_hash is the raw 32-byte SHA-256 digest of payload_bytes. + bytes payload_hash = 7 [(buf.validate.field).bytes.min_len = 1]; + bytes signature = 8 [(buf.validate.field).bytes.min_len = 1]; + string trace_id = 9; +} + +message ExecuteCommandResponse { + string protocol_version = 1; + string request_id = 2; + int64 timestamp_ms = 3; + string result_code = 4; + bytes payload_bytes = 5; + bytes payload_hash = 6; + bytes signature = 7; +} + +message SubscribeEventsRequest { + // protocol_version identifies the request envelope version. The gateway + // accepts only the literal "v1" after required-field validation succeeds. + string protocol_version = 1 [(buf.validate.field).string.min_len = 1]; + string device_session_id = 2 [(buf.validate.field).string.min_len = 1]; + string message_type = 3 [(buf.validate.field).string.min_len = 1]; + int64 timestamp_ms = 4 [(buf.validate.field).int64.gt = 0]; + string request_id = 5 [(buf.validate.field).string.min_len = 1]; + // payload_hash is the raw 32-byte SHA-256 digest of payload_bytes. Empty + // payloads must use the SHA-256 digest of the empty byte slice. + bytes payload_hash = 6 [(buf.validate.field).bytes.min_len = 1]; + bytes signature = 7 [(buf.validate.field).bytes.min_len = 1]; + bytes payload_bytes = 8; + string trace_id = 9; +} + +message GatewayEvent { + string event_type = 1; + string event_id = 2; + int64 timestamp_ms = 3; + bytes payload_bytes = 4; + bytes payload_hash = 5; + bytes signature = 6; + string request_id = 7; + string trace_id = 8; +} diff --git a/gateway/proto/galaxy/gateway/v1/edge_gateway_grpc.pb.go b/gateway/proto/galaxy/gateway/v1/edge_gateway_grpc.pb.go new file mode 100644 index 0000000..efe4b2b --- /dev/null +++ b/gateway/proto/galaxy/gateway/v1/edge_gateway_grpc.pb.go @@ -0,0 +1,163 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc (unknown) +// source: galaxy/gateway/v1/edge_gateway.proto + +package gatewayv1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + EdgeGateway_ExecuteCommand_FullMethodName = "/galaxy.gateway.v1.EdgeGateway/ExecuteCommand" + EdgeGateway_SubscribeEvents_FullMethodName = "/galaxy.gateway.v1.EdgeGateway/SubscribeEvents" +) + +// EdgeGatewayClient is the client API for EdgeGateway service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type EdgeGatewayClient interface { + ExecuteCommand(ctx context.Context, in *ExecuteCommandRequest, opts ...grpc.CallOption) (*ExecuteCommandResponse, error) + SubscribeEvents(ctx context.Context, in *SubscribeEventsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GatewayEvent], error) +} + +type edgeGatewayClient struct { + cc grpc.ClientConnInterface +} + +func NewEdgeGatewayClient(cc grpc.ClientConnInterface) EdgeGatewayClient { + return &edgeGatewayClient{cc} +} + +func (c *edgeGatewayClient) ExecuteCommand(ctx context.Context, in *ExecuteCommandRequest, opts ...grpc.CallOption) (*ExecuteCommandResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ExecuteCommandResponse) + err := c.cc.Invoke(ctx, EdgeGateway_ExecuteCommand_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *edgeGatewayClient) SubscribeEvents(ctx context.Context, in *SubscribeEventsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GatewayEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &EdgeGateway_ServiceDesc.Streams[0], EdgeGateway_SubscribeEvents_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[SubscribeEventsRequest, GatewayEvent]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type EdgeGateway_SubscribeEventsClient = grpc.ServerStreamingClient[GatewayEvent] + +// EdgeGatewayServer is the server API for EdgeGateway service. +// All implementations must embed UnimplementedEdgeGatewayServer +// for forward compatibility. +type EdgeGatewayServer interface { + ExecuteCommand(context.Context, *ExecuteCommandRequest) (*ExecuteCommandResponse, error) + SubscribeEvents(*SubscribeEventsRequest, grpc.ServerStreamingServer[GatewayEvent]) error + mustEmbedUnimplementedEdgeGatewayServer() +} + +// UnimplementedEdgeGatewayServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedEdgeGatewayServer struct{} + +func (UnimplementedEdgeGatewayServer) ExecuteCommand(context.Context, *ExecuteCommandRequest) (*ExecuteCommandResponse, error) { + return nil, status.Error(codes.Unimplemented, "method ExecuteCommand not implemented") +} +func (UnimplementedEdgeGatewayServer) SubscribeEvents(*SubscribeEventsRequest, grpc.ServerStreamingServer[GatewayEvent]) error { + return status.Error(codes.Unimplemented, "method SubscribeEvents not implemented") +} +func (UnimplementedEdgeGatewayServer) mustEmbedUnimplementedEdgeGatewayServer() {} +func (UnimplementedEdgeGatewayServer) testEmbeddedByValue() {} + +// UnsafeEdgeGatewayServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to EdgeGatewayServer will +// result in compilation errors. +type UnsafeEdgeGatewayServer interface { + mustEmbedUnimplementedEdgeGatewayServer() +} + +func RegisterEdgeGatewayServer(s grpc.ServiceRegistrar, srv EdgeGatewayServer) { + // If the following call panics, it indicates UnimplementedEdgeGatewayServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&EdgeGateway_ServiceDesc, srv) +} + +func _EdgeGateway_ExecuteCommand_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ExecuteCommandRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EdgeGatewayServer).ExecuteCommand(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EdgeGateway_ExecuteCommand_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EdgeGatewayServer).ExecuteCommand(ctx, req.(*ExecuteCommandRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EdgeGateway_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(SubscribeEventsRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(EdgeGatewayServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeEventsRequest, GatewayEvent]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type EdgeGateway_SubscribeEventsServer = grpc.ServerStreamingServer[GatewayEvent] + +// EdgeGateway_ServiceDesc is the grpc.ServiceDesc for EdgeGateway service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var EdgeGateway_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "galaxy.gateway.v1.EdgeGateway", + HandlerType: (*EdgeGatewayServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ExecuteCommand", + Handler: _EdgeGateway_ExecuteCommand_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "SubscribeEvents", + Handler: _EdgeGateway_SubscribeEvents_Handler, + ServerStreams: true, + }, + }, + Metadata: "galaxy/gateway/v1/edge_gateway.proto", +} diff --git a/go.work.sum b/go.work.sum index b0b6e05..105382c 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,43 +1,165 @@ +buf.build/go/hyperpb v0.1.3/go.mod h1:IHXAM5qnS0/Fsnd7/HGDghFNvUET646WoHmq1FDZXIE= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= github.com/akavel/rsrc v0.10.2/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jackmordaunt/icns/v2 v2.2.6/go.mod h1:DqlVnR5iafSphrId7aSD06r3jg0KRC9V6lEBBp504ZQ= github.com/jordanlewis/gcassert v0.0.0-20250430164644-389ef753e22e/go.mod h1:ZybsQk6DWyN5t7An1MuPm1gtSZ1xDaTXS9ZjIOxvQrk= github.com/josephspurrier/goversioninfo v1.4.0/go.mod h1:JWzv5rKQr+MmW+LvM412ToT/IkYDZjaclF2pKDss8IY= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lucor/goinfo v0.9.0/go.mod h1:L6m6tN5Rlova5Z83h1ZaKsMP1iiaoZ9vGTNzu5QKOD4= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mcuadros/go-version v0.0.0-20190830083331-035f6764e8d2/go.mod h1:76rfSfYPWj01Z85hUf/ituArm797mNKcvINh1OlsZKo= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c h1:7ACFcSaQsrWtrH4WHHfUqE1C+f8r2uv8KGaW0jTNjus= +github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c/go.mod h1:JKox4Gszkxt57kj27u7rvi7IFoIULvCZHUsBTUmQM/s= +github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b h1:vivRhVUAa9t1q0Db4ZmezBP8pWQWnXHFokZj0AOea2g= +github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= +github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= +github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/timandy/routine v1.1.6/go.mod h1:kXslgIosdY8LW0byTyPnenDgn4/azt2euufAq9rK51w= github.com/urfave/cli/v2 v2.4.0/go.mod h1:NX9W0zmTvedE5oDoOMs2RTC8RvdK98NTYZE5LbaEYPg= +github.com/woodsbury/decimal128 v1.3.0 h1:8pffMNWIlC0O5vbyHWFZAt5yWvWcrHA+3ovIIjVWss0= +github.com/woodsbury/decimal128 v1.3.0/go.mod h1:C5UTmyTjW3JftjUFzOVhC20BEQa2a4ZKOB5I6Zjb+ds= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0/go.mod h1:t/OGqzHBa5v6RHZwrDBJ2OirWc+4q/w2fTbLZwAKjTk= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8= golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4/go.mod h1:g5NllXBEermZrmR51cJDQxmJUHUOfRAaNyWBM+R+548= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= +golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/tools/go/vcs v0.1.0-deprecated/go.mod h1:zUrvATBAvEI9535oC0yWYsLsHIV4Z7g63sNPVMtuBy8= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/pkg/schema/fbs/gateway.fbs b/pkg/schema/fbs/gateway.fbs new file mode 100644 index 0000000..78b07f5 --- /dev/null +++ b/pkg/schema/fbs/gateway.fbs @@ -0,0 +1,9 @@ +// gateway contains shared FlatBuffers payloads used by the gateway edge +// transport. +namespace gateway; + +table ServerTimeEvent { + server_time_ms:int64; +} + +root_type ServerTimeEvent; diff --git a/pkg/schema/fbs/gateway/ServerTimeEvent.go b/pkg/schema/fbs/gateway/ServerTimeEvent.go new file mode 100644 index 0000000..26b6eac --- /dev/null +++ b/pkg/schema/fbs/gateway/ServerTimeEvent.go @@ -0,0 +1,64 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package gateway + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type ServerTimeEvent struct { + _tab flatbuffers.Table +} + +func GetRootAsServerTimeEvent(buf []byte, offset flatbuffers.UOffsetT) *ServerTimeEvent { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &ServerTimeEvent{} + x.Init(buf, n+offset) + return x +} + +func FinishServerTimeEventBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.Finish(offset) +} + +func GetSizePrefixedRootAsServerTimeEvent(buf []byte, offset flatbuffers.UOffsetT) *ServerTimeEvent { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &ServerTimeEvent{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func FinishSizePrefixedServerTimeEventBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.FinishSizePrefixed(offset) +} + +func (rcv *ServerTimeEvent) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *ServerTimeEvent) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *ServerTimeEvent) ServerTimeMs() int64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.GetInt64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *ServerTimeEvent) MutateServerTimeMs(n int64) bool { + return rcv._tab.MutateInt64Slot(4, n) +} + +func ServerTimeEventStart(builder *flatbuffers.Builder) { + builder.StartObject(1) +} +func ServerTimeEventAddServerTimeMs(builder *flatbuffers.Builder, serverTimeMs int64) { + builder.PrependInt64Slot(0, serverTimeMs, 0) +} +func ServerTimeEventEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +}