feat: authsession service
This commit is contained in:
+2
-2
@@ -37,7 +37,7 @@ flowchart LR
|
||||
|
||||
## Main Components
|
||||
|
||||
### 1. Edge Gateway
|
||||
### 1. [Edge Gateway](./gateway/README.md)
|
||||
|
||||
The gateway is the only public entry point for client traffic.
|
||||
|
||||
@@ -58,7 +58,7 @@ Responsibilities:
|
||||
|
||||
The gateway must not implement domain-specific business logic.
|
||||
|
||||
### 2. Auth / Session Service
|
||||
### 2. [Auth / Session Service](./authsession/README.md)
|
||||
|
||||
This service owns authentication and device session lifecycle.
|
||||
|
||||
|
||||
+1152
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,459 @@
|
||||
# Auth / Session Service
|
||||
|
||||
## Run and Dependencies
|
||||
|
||||
`cmd/authsession` starts two HTTP listeners:
|
||||
|
||||
- public REST on `AUTHSESSION_PUBLIC_HTTP_ADDR` with default `:8080`
|
||||
- trusted internal REST on `AUTHSESSION_INTERNAL_HTTP_ADDR` with default `:8081`
|
||||
|
||||
Startup requires:
|
||||
|
||||
- one reachable Redis deployment configured by `AUTHSESSION_REDIS_ADDR`
|
||||
|
||||
That Redis deployment is used for:
|
||||
|
||||
- source-of-truth challenges
|
||||
- source-of-truth device sessions
|
||||
- dynamic active-session limit config
|
||||
- gateway session projection cache and stream updates
|
||||
- send-email-code resend throttling
|
||||
|
||||
Optional integrations:
|
||||
|
||||
- `AUTHSESSION_USER_SERVICE_MODE=stub|rest`
|
||||
- `AUTHSESSION_MAIL_SERVICE_MODE=stub|rest`
|
||||
- OTLP telemetry through standard `OTEL_*` variables
|
||||
- stdout telemetry through
|
||||
`AUTHSESSION_OTEL_STDOUT_TRACES_ENABLED` and
|
||||
`AUTHSESSION_OTEL_STDOUT_METRICS_ENABLED`
|
||||
|
||||
Operational caveats:
|
||||
|
||||
- the service exposes no `/healthz`, `/readyz`, or `/metrics` endpoints
|
||||
- user-service and mail-service default to in-process stub adapters until
|
||||
`rest` mode is configured
|
||||
- startup performs bounded Redis `PING` checks for every Redis-backed adapter
|
||||
and fails fast if Redis or runtime config is invalid
|
||||
|
||||
Additional module docs:
|
||||
|
||||
- [Public REST contract](api/public-openapi.yaml)
|
||||
- [Internal REST contract](api/internal-openapi.yaml)
|
||||
- [Documentation index](docs/README.md)
|
||||
- [Edge Gateway README](../gateway/README.md)
|
||||
|
||||
## Purpose
|
||||
|
||||
`Auth / Session Service` owns e-mail-code authentication and the lifecycle of
|
||||
device sessions.
|
||||
|
||||
It is the source of truth for:
|
||||
|
||||
- authentication challenges
|
||||
- device sessions
|
||||
- revoke and block state
|
||||
- publication of session lifecycle updates consumed by
|
||||
[`Edge Gateway`](../gateway/README.md)
|
||||
|
||||
The service is intentionally not on the hot path for every authenticated
|
||||
request. Gateway authenticates the steady-state request path from its own cache
|
||||
and session-lifecycle updates rather than by synchronous round-trips back to
|
||||
auth for each command.
|
||||
|
||||
## Responsibilities
|
||||
|
||||
The service is responsible for:
|
||||
|
||||
- public auth commands:
|
||||
- `send-email-code`
|
||||
- `confirm-email-code`
|
||||
- creating device sessions after successful confirmation
|
||||
- registering the client public key for a newly created session
|
||||
- revoking one device session
|
||||
- revoking all sessions of one user
|
||||
- blocking a user or e-mail subject for future auth flows
|
||||
- persisting source-of-truth session state
|
||||
- projecting session state into gateway-consumable Redis data
|
||||
- exposing a trusted internal REST API for read, revoke, and block operations
|
||||
|
||||
The service is not responsible for:
|
||||
|
||||
- verifying authenticated transport signatures on every business request
|
||||
- gateway anti-replay for authenticated command traffic
|
||||
- downstream business authorization
|
||||
- direct push delivery to clients
|
||||
- long-lived hot-path session caching inside gateway
|
||||
- mail-service implementation details beyond the mail-delivery contract
|
||||
|
||||
## Position in the System
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Client["Client"]
|
||||
Gateway["Edge Gateway"]
|
||||
Auth["Auth / Session Service"]
|
||||
User["User Service"]
|
||||
Mail["Mail Service"]
|
||||
Redis["Redis"]
|
||||
Business["Business Services"]
|
||||
|
||||
Client --> Gateway
|
||||
Gateway --> Auth
|
||||
Gateway --> Business
|
||||
Auth --> User
|
||||
Auth --> Mail
|
||||
Auth --> Redis
|
||||
Redis --> Gateway
|
||||
```
|
||||
|
||||
## Main Principles
|
||||
|
||||
- public auth stays synchronous
|
||||
- `send-email-code` returns `challenge_id`
|
||||
- `confirm-email-code` returns a ready `device_session_id`
|
||||
- no pending async session-provisioning stage exists
|
||||
- source-of-truth session state and gateway-facing projection remain separate
|
||||
- Redis is the initial backend, but the domain and service layers stay storage
|
||||
agnostic behind ports
|
||||
- `send-email-code` stays success-shaped for existing, new, blocked, and
|
||||
throttled e-mail flows
|
||||
- `confirm-email-code` supports short-window idempotent retry for the same
|
||||
confirmed challenge and the same `client_public_key`
|
||||
- active-session limits are configuration driven:
|
||||
- absent limit means disabled
|
||||
- limit overflow rejects new session creation explicitly
|
||||
- the service does not evict existing sessions to make room
|
||||
|
||||
## Gateway-Facing Public Contract
|
||||
|
||||
Gateway already exposes the public REST auth surface and delegates it to this
|
||||
service:
|
||||
|
||||
- `POST /api/v1/public/auth/send-email-code`
|
||||
- `POST /api/v1/public/auth/confirm-email-code`
|
||||
|
||||
The effective DTO contract is:
|
||||
|
||||
| Operation | Request | Success response |
|
||||
| --- | --- | --- |
|
||||
| `POST /api/v1/public/auth/send-email-code` | `{ "email": string }` | `{ "challenge_id": string }` |
|
||||
| `POST /api/v1/public/auth/confirm-email-code` | `{ "challenge_id": string, "code": string, "client_public_key": string }` | `{ "device_session_id": string }` |
|
||||
|
||||
`client_public_key` is the standard base64-encoded raw 32-byte Ed25519 public
|
||||
key registered for the created device session.
|
||||
|
||||
Public boundary rules:
|
||||
|
||||
- requests and responses are JSON only
|
||||
- request DTOs reject unknown fields
|
||||
- empty bodies, malformed JSON, trailing JSON input, and unknown fields return
|
||||
`400 invalid_request`
|
||||
- surrounding ASCII and Unicode whitespace is trimmed from input string fields
|
||||
before validation
|
||||
- `send-email-code` remains success-shaped for existing, new, blocked, and
|
||||
throttled e-mail paths
|
||||
- `confirm-email-code` returns a ready `device_session_id` synchronously on
|
||||
success
|
||||
|
||||
Stable public business-error contract:
|
||||
|
||||
| HTTP status | `error.code` | Stable `error.message` |
|
||||
| --- | --- | --- |
|
||||
| `400` | `invalid_request` | field-specific validation detail |
|
||||
| `400` | `invalid_code` | `confirmation code is invalid` |
|
||||
| `400` | `invalid_client_public_key` | `client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key` |
|
||||
| `403` | `blocked_by_policy` | `authentication is blocked by policy` |
|
||||
| `404` | `challenge_not_found` | `challenge not found` |
|
||||
| `409` | `session_limit_exceeded` | `active session limit would be exceeded` |
|
||||
| `410` | `challenge_expired` | `challenge expired` |
|
||||
| `503` | `service_unavailable` | `service is unavailable` |
|
||||
|
||||
The public error envelope is always:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"code": "string",
|
||||
"message": "string"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Trusted Internal API
|
||||
|
||||
The trusted internal REST surface lives under `/api/v1/internal` and is
|
||||
documented in [`api/internal-openapi.yaml`](api/internal-openapi.yaml).
|
||||
|
||||
Implemented endpoints:
|
||||
|
||||
- `GET /api/v1/internal/sessions/{device_session_id}`
|
||||
- `GET /api/v1/internal/users/{user_id}/sessions`
|
||||
- `POST /api/v1/internal/sessions/{device_session_id}/revoke`
|
||||
- `POST /api/v1/internal/users/{user_id}/sessions/revoke-all`
|
||||
- `POST /api/v1/internal/user-blocks`
|
||||
|
||||
Key internal API properties:
|
||||
|
||||
- all bodies are JSON only
|
||||
- `ListUserSessions` is newest-first and unpaginated in v1
|
||||
- revoke and block mutations require audit metadata as `reason_code` and
|
||||
`actor`
|
||||
- `BlockUser` accepts exactly one of `user_id` or `email`
|
||||
- mutating operations are idempotent and return explicit acknowledgement
|
||||
payloads rather than empty `204` responses
|
||||
|
||||
Stable internal error surface:
|
||||
|
||||
| HTTP status | `error.code` | Stable `error.message` |
|
||||
| --- | --- | --- |
|
||||
| `400` | `invalid_request` | field-specific validation detail |
|
||||
| `404` | `session_not_found` | `session not found` |
|
||||
| `404` | `subject_not_found` | `subject not found` |
|
||||
| `500` | `internal_error` | `internal server error` |
|
||||
| `503` | `service_unavailable` | `service is unavailable` |
|
||||
|
||||
## Challenge Model
|
||||
|
||||
A challenge represents one short-lived public e-mail-code flow.
|
||||
|
||||
Core fields:
|
||||
|
||||
- `challenge_id`
|
||||
- normalized e-mail
|
||||
- hashed confirmation code
|
||||
- `status`
|
||||
- `delivery_state`
|
||||
- creation and expiration timestamps
|
||||
- send and confirm attempt counters
|
||||
- minimal abuse metadata
|
||||
- optional confirmation metadata used for idempotent retry
|
||||
|
||||
### Challenge States
|
||||
|
||||
Supported `challenge.Status` values:
|
||||
|
||||
- `pending_send`
|
||||
- `sent`
|
||||
- `delivery_suppressed`
|
||||
- `delivery_throttled`
|
||||
- `confirmed_pending_expire`
|
||||
- `expired`
|
||||
- `failed`
|
||||
- `cancelled`
|
||||
|
||||
Supported `challenge.DeliveryState` values:
|
||||
|
||||
- `pending`
|
||||
- `sent`
|
||||
- `suppressed`
|
||||
- `throttled`
|
||||
- `failed`
|
||||
|
||||
Policy rules:
|
||||
|
||||
- initial challenge TTL is `5m`
|
||||
- confirmed-challenge retention for idempotent retry is `5m`
|
||||
- max invalid confirm attempts is `5`
|
||||
- every `send-email-code` call creates a fresh challenge
|
||||
- resend throttling is e-mail scoped with a fixed `1m` cooldown
|
||||
- a throttled send still creates a fresh challenge in
|
||||
`status=delivery_throttled` and `delivery_state=throttled`
|
||||
- throttled sends do not call `UserDirectory` and do not call `MailSender`
|
||||
- blocked sends outside the throttle path become `delivery_suppressed`
|
||||
|
||||
Fresh confirm semantics:
|
||||
|
||||
- only `sent` and `delivery_suppressed` accept a first successful confirm
|
||||
- `pending_send`, `delivery_throttled`, `failed`, and `cancelled` return
|
||||
`invalid_code`
|
||||
- expired challenges return `challenge_expired` while the Redis grace window
|
||||
keeps the record present, then `challenge_not_found` after cleanup removes
|
||||
the key
|
||||
|
||||
Idempotent retry semantics:
|
||||
|
||||
- a repeated confirm with the same `challenge_id`, valid `code`, and identical
|
||||
`client_public_key` on `confirmed_pending_expire` returns the same
|
||||
`device_session_id`
|
||||
- the same confirmed challenge with a different `client_public_key` fails as
|
||||
`invalid_code`
|
||||
- idempotent retry republishes the stored gateway session view
|
||||
|
||||
## Device Session And Revoke Model
|
||||
|
||||
A device session is created only after successful confirmation.
|
||||
|
||||
Core fields:
|
||||
|
||||
- `device_session_id`
|
||||
- `user_id`
|
||||
- parsed client public key
|
||||
- `status`
|
||||
- `created_at`
|
||||
- optional revocation metadata
|
||||
|
||||
Supported session states:
|
||||
|
||||
- `active`
|
||||
- `revoked`
|
||||
|
||||
Built-in revoke reason codes:
|
||||
|
||||
- `device_logout`
|
||||
- `logout_all`
|
||||
- `admin_revoke`
|
||||
- `user_blocked`
|
||||
- `confirm_race_repair` for best-effort cleanup of superseded sessions created
|
||||
during a confirm race
|
||||
|
||||
Revoke behavior is intentionally separated by use case:
|
||||
|
||||
- revoke one device session
|
||||
- revoke all sessions of one user
|
||||
- block a subject and revoke active sessions implied by that subject
|
||||
|
||||
Internal mutation responses report only sessions changed by the current call,
|
||||
so repeated idempotent operations may return:
|
||||
|
||||
- `already_revoked` with `affected_session_count=0`
|
||||
- `no_active_sessions` with `affected_session_count=0`
|
||||
- `already_blocked` with `affected_session_count=0`
|
||||
|
||||
## User Resolution And Session Limits
|
||||
|
||||
`Auth / Session Service` does not own durable user records. It delegates to
|
||||
`UserDirectory` for:
|
||||
|
||||
- resolve-by-email without mutation
|
||||
- ensure existing-or-created user during confirm
|
||||
- existence checks for stable `user_id`
|
||||
- block-by-user-id and block-by-email operations
|
||||
|
||||
Supported user-resolution outcomes:
|
||||
|
||||
- `existing`
|
||||
- `creatable`
|
||||
- `blocked`
|
||||
|
||||
Supported ensure-user outcomes:
|
||||
|
||||
- `existing`
|
||||
- `created`
|
||||
- `blocked`
|
||||
|
||||
Session-limit rules:
|
||||
|
||||
- the value is loaded from a shared config provider
|
||||
- absent value means the limit is disabled
|
||||
- active sessions are counted before creating a new one
|
||||
- limit overflow returns `session_limit_exceeded`
|
||||
- the service never silently revokes an existing session to satisfy the limit
|
||||
|
||||
## Gateway Projection Model
|
||||
|
||||
Gateway-facing session projection is separate from source-of-truth
|
||||
`devicesession.Session`.
|
||||
|
||||
Each successful projection publish writes:
|
||||
|
||||
- one Redis KV snapshot under
|
||||
`<gateway_session_cache_key_prefix><device_session_id>`
|
||||
- one full-snapshot Redis Stream event under the session-events stream
|
||||
|
||||
The default gateway-facing namespaces are:
|
||||
|
||||
- cache key prefix: `gateway:session:`
|
||||
- session-events stream: `gateway:session_events`
|
||||
|
||||
Projected fields are intentionally limited to what gateway consumes:
|
||||
|
||||
- `device_session_id`
|
||||
- `user_id`
|
||||
- `client_public_key`
|
||||
- `status`
|
||||
- optional `revoked_at_ms`
|
||||
|
||||
Revoke reason and actor metadata stay in authsession source of truth and are
|
||||
not projected to gateway.
|
||||
|
||||
## Consistency Model
|
||||
|
||||
Source of truth is written first. Gateway projection is published only after
|
||||
the source-of-truth write succeeds.
|
||||
|
||||
Caller-visible rules:
|
||||
|
||||
- if projection publication does not reach its required success threshold, the
|
||||
public or internal call returns `service_unavailable`
|
||||
- already-written source-of-truth state is intentionally preserved
|
||||
- the documented repair path is to repeat the same confirm or revoke command
|
||||
|
||||
Projection publish rules:
|
||||
|
||||
- request-path projection publish uses a bounded retry loop with `3` total
|
||||
attempts
|
||||
- repeated publishes are safe because the cache snapshot is overwritten and
|
||||
duplicate full-snapshot stream events remain valid under gateway's
|
||||
later-event-wins model
|
||||
- `confirm-email-code` rereads the stored session after the challenge CAS
|
||||
succeeds and republishes that current view so a concurrent revoke or block
|
||||
cannot overwrite source of truth with a stale active projection
|
||||
- idempotent confirm retry also republishes the stored session view
|
||||
- best-effort cleanup of superseded confirm-race sessions uses the same
|
||||
publish helper but is not part of the caller-visible success contract
|
||||
|
||||
## Runtime Summary
|
||||
|
||||
Runtime wiring is implemented in [`internal/app`](internal/app) and
|
||||
[`cmd/authsession`](cmd/authsession/main.go).
|
||||
|
||||
Process-local collaborators:
|
||||
|
||||
- system UTC clock
|
||||
- crypto-random `challenge_id` and `device_session_id` generators
|
||||
- crypto-random 6-digit confirmation-code generator
|
||||
- bcrypt-backed code hashing
|
||||
- structured logging through `zap`
|
||||
- process telemetry through OpenTelemetry
|
||||
|
||||
Redis-backed adapters:
|
||||
|
||||
- challenge store
|
||||
- session store
|
||||
- session-limit config provider
|
||||
- gateway projection publisher
|
||||
- send-email-code abuse protector
|
||||
|
||||
External service adapters:
|
||||
|
||||
- user-service:
|
||||
- default `stub`
|
||||
- optional REST adapter with one retry for read-style methods on transport
|
||||
errors and HTTP `502`, `503`, or `504`
|
||||
- mutation methods do not auto-retry
|
||||
- mail-service:
|
||||
- default `stub`
|
||||
- optional REST adapter with no automatic retry on transport or upstream
|
||||
failure, to avoid duplicate deliveries
|
||||
|
||||
Listener defaults:
|
||||
|
||||
- public HTTP: `:8080`
|
||||
- internal HTTP: `:8081`
|
||||
- read-header timeout: `2s`
|
||||
- read timeout: `10s`
|
||||
- idle timeout: `1m`
|
||||
- per-request use-case timeout: `3s`
|
||||
|
||||
For detailed runtime behavior, configuration groups, operational notes, and
|
||||
examples, see [`docs/README.md`](docs/README.md).
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- making authsession a hot synchronous dependency for every authenticated
|
||||
gateway command
|
||||
- moving business authorization into authsession
|
||||
- exposing revoke or read operations as public unauthenticated routes
|
||||
- introducing short-lived access-token or refresh-token flows
|
||||
- adding pending async session provisioning after confirm
|
||||
@@ -0,0 +1,456 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: Galaxy Auth / Session Service Internal API
|
||||
version: v1
|
||||
description: |
|
||||
This specification documents the implemented `galaxy/authsession` v1
|
||||
trusted internal REST contract.
|
||||
|
||||
Contract rules:
|
||||
- the internal surface lives under `/api/v1/internal`;
|
||||
- all request and response bodies are JSON only;
|
||||
- read operations return canonical session DTO wrappers;
|
||||
- mutating operations return explicit `200` JSON acknowledgements;
|
||||
- mutation requests carry audit metadata as `reason_code` and `actor`;
|
||||
- `BlockUser` accepts exactly one of `user_id` or `email`;
|
||||
- `ListUserSessions` is newest-first and unpaginated in v1.
|
||||
tags:
|
||||
- name: InternalAuthSession
|
||||
description: Trusted internal session read, revoke, and block operations.
|
||||
paths:
|
||||
/api/v1/internal/sessions/{device_session_id}:
|
||||
get:
|
||||
tags:
|
||||
- InternalAuthSession
|
||||
operationId: getSession
|
||||
summary: Read one device session
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/DeviceSessionID"
|
||||
responses:
|
||||
"200":
|
||||
description: The requested device session.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/GetSessionResponse"
|
||||
"404":
|
||||
$ref: "#/components/responses/SessionNotFoundError"
|
||||
"500":
|
||||
$ref: "#/components/responses/InternalError"
|
||||
"503":
|
||||
$ref: "#/components/responses/ServiceUnavailableError"
|
||||
/api/v1/internal/users/{user_id}/sessions:
|
||||
get:
|
||||
tags:
|
||||
- InternalAuthSession
|
||||
operationId: listUserSessions
|
||||
summary: List all active and revoked sessions of one user
|
||||
description: |
|
||||
Returns the full v1 session list for one user. Results are ordered from
|
||||
newest to oldest and are intentionally unpaginated in v1.
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/UserID"
|
||||
responses:
|
||||
"200":
|
||||
description: |
|
||||
Sessions belonging to the requested user. Returns an empty array
|
||||
when the user has no stored sessions, including unknown `user_id`
|
||||
values.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ListUserSessionsResponse"
|
||||
"500":
|
||||
$ref: "#/components/responses/InternalError"
|
||||
"503":
|
||||
$ref: "#/components/responses/ServiceUnavailableError"
|
||||
/api/v1/internal/sessions/{device_session_id}/revoke:
|
||||
post:
|
||||
tags:
|
||||
- InternalAuthSession
|
||||
operationId: revokeDeviceSession
|
||||
summary: Revoke one device session
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/DeviceSessionID"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/RevokeDeviceSessionRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: Explicit idempotent acknowledgement of the revoke result.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/RevokeDeviceSessionResponse"
|
||||
"400":
|
||||
$ref: "#/components/responses/InvalidRequestError"
|
||||
"404":
|
||||
$ref: "#/components/responses/SessionNotFoundError"
|
||||
"500":
|
||||
$ref: "#/components/responses/InternalError"
|
||||
"503":
|
||||
$ref: "#/components/responses/ServiceUnavailableError"
|
||||
/api/v1/internal/users/{user_id}/sessions/revoke-all:
|
||||
post:
|
||||
tags:
|
||||
- InternalAuthSession
|
||||
operationId: revokeAllUserSessions
|
||||
summary: Revoke all sessions of one user
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/UserID"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/RevokeAllUserSessionsRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: Explicit idempotent acknowledgement of the bulk revoke result.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/RevokeAllUserSessionsResponse"
|
||||
"400":
|
||||
$ref: "#/components/responses/InvalidRequestError"
|
||||
"404":
|
||||
$ref: "#/components/responses/SubjectNotFoundError"
|
||||
"500":
|
||||
$ref: "#/components/responses/InternalError"
|
||||
"503":
|
||||
$ref: "#/components/responses/ServiceUnavailableError"
|
||||
/api/v1/internal/user-blocks:
|
||||
post:
|
||||
tags:
|
||||
- InternalAuthSession
|
||||
operationId: blockUser
|
||||
summary: Block future auth flow for one subject and revoke active sessions
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/BlockUserRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: Explicit idempotent acknowledgement of the block result.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/BlockUserResponse"
|
||||
"400":
|
||||
$ref: "#/components/responses/InvalidRequestError"
|
||||
"404":
|
||||
$ref: "#/components/responses/SubjectNotFoundError"
|
||||
"500":
|
||||
$ref: "#/components/responses/InternalError"
|
||||
"503":
|
||||
$ref: "#/components/responses/ServiceUnavailableError"
|
||||
components:
|
||||
parameters:
|
||||
DeviceSessionID:
|
||||
name: device_session_id
|
||||
in: path
|
||||
required: true
|
||||
description: Stable identifier of one device session.
|
||||
schema:
|
||||
type: string
|
||||
UserID:
|
||||
name: user_id
|
||||
in: path
|
||||
required: true
|
||||
description: Stable identifier of one user.
|
||||
schema:
|
||||
type: string
|
||||
schemas:
|
||||
Actor:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: Machine-readable actor type such as `system`, `service`, or `admin`.
|
||||
id:
|
||||
type: string
|
||||
description: Optional stable identifier of the initiating actor.
|
||||
ErrorResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- error
|
||||
properties:
|
||||
error:
|
||||
$ref: "#/components/schemas/ErrorBody"
|
||||
ErrorBody:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- code
|
||||
- message
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
description: Stable internal API error code.
|
||||
message:
|
||||
type: string
|
||||
description: Human-readable error description safe for trusted internal callers.
|
||||
Session:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- device_session_id
|
||||
- user_id
|
||||
- client_public_key
|
||||
- status
|
||||
- created_at
|
||||
properties:
|
||||
device_session_id:
|
||||
type: string
|
||||
user_id:
|
||||
type: string
|
||||
client_public_key:
|
||||
type: string
|
||||
description: Standard base64-encoded raw 32-byte Ed25519 public key of the device session.
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- active
|
||||
- revoked
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: RFC3339 UTC timestamp when the session was created.
|
||||
revoked_at:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
description: RFC3339 UTC timestamp when the session was revoked.
|
||||
revoke_reason_code:
|
||||
type: string
|
||||
nullable: true
|
||||
description: Machine-readable revoke reason code when the session is revoked.
|
||||
revoke_actor_type:
|
||||
type: string
|
||||
nullable: true
|
||||
description: Actor type that initiated the revoke.
|
||||
revoke_actor_id:
|
||||
type: string
|
||||
nullable: true
|
||||
description: Optional stable actor identifier that initiated the revoke.
|
||||
GetSessionResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- session
|
||||
properties:
|
||||
session:
|
||||
$ref: "#/components/schemas/Session"
|
||||
ListUserSessionsResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- sessions
|
||||
properties:
|
||||
sessions:
|
||||
type: array
|
||||
description: Full newest-first session list for the requested user.
|
||||
items:
|
||||
$ref: "#/components/schemas/Session"
|
||||
RevokeDeviceSessionRequest:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- reason_code
|
||||
- actor
|
||||
properties:
|
||||
reason_code:
|
||||
type: string
|
||||
description: Machine-readable revoke reason code.
|
||||
actor:
|
||||
$ref: "#/components/schemas/Actor"
|
||||
RevokeDeviceSessionResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- outcome
|
||||
- device_session_id
|
||||
- affected_session_count
|
||||
properties:
|
||||
outcome:
|
||||
type: string
|
||||
enum:
|
||||
- revoked
|
||||
- already_revoked
|
||||
device_session_id:
|
||||
type: string
|
||||
affected_session_count:
|
||||
type: integer
|
||||
format: int64
|
||||
minimum: 0
|
||||
RevokeAllUserSessionsRequest:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- reason_code
|
||||
- actor
|
||||
properties:
|
||||
reason_code:
|
||||
type: string
|
||||
description: Machine-readable bulk revoke reason code.
|
||||
actor:
|
||||
$ref: "#/components/schemas/Actor"
|
||||
RevokeAllUserSessionsResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- outcome
|
||||
- user_id
|
||||
- affected_session_count
|
||||
- affected_device_session_ids
|
||||
properties:
|
||||
outcome:
|
||||
type: string
|
||||
enum:
|
||||
- revoked
|
||||
- no_active_sessions
|
||||
user_id:
|
||||
type: string
|
||||
affected_session_count:
|
||||
type: integer
|
||||
format: int64
|
||||
minimum: 0
|
||||
affected_device_session_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
BlockUserRequest:
|
||||
oneOf:
|
||||
- $ref: "#/components/schemas/BlockUserByUserIDRequest"
|
||||
- $ref: "#/components/schemas/BlockUserByEmailRequest"
|
||||
BlockUserByUserIDRequest:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- user_id
|
||||
- reason_code
|
||||
- actor
|
||||
properties:
|
||||
user_id:
|
||||
type: string
|
||||
reason_code:
|
||||
type: string
|
||||
description: Machine-readable block reason code.
|
||||
actor:
|
||||
$ref: "#/components/schemas/Actor"
|
||||
BlockUserByEmailRequest:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- email
|
||||
- reason_code
|
||||
- actor
|
||||
properties:
|
||||
email:
|
||||
type: string
|
||||
format: email
|
||||
reason_code:
|
||||
type: string
|
||||
description: Machine-readable block reason code.
|
||||
actor:
|
||||
$ref: "#/components/schemas/Actor"
|
||||
BlockUserResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- outcome
|
||||
- subject_kind
|
||||
- subject_value
|
||||
- affected_session_count
|
||||
- affected_device_session_ids
|
||||
properties:
|
||||
outcome:
|
||||
type: string
|
||||
enum:
|
||||
- blocked
|
||||
- already_blocked
|
||||
subject_kind:
|
||||
type: string
|
||||
enum:
|
||||
- user_id
|
||||
- email
|
||||
subject_value:
|
||||
type: string
|
||||
affected_session_count:
|
||||
type: integer
|
||||
format: int64
|
||||
minimum: 0
|
||||
affected_device_session_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
responses:
|
||||
InvalidRequestError:
|
||||
description: Request path, parameters, or body fields are invalid.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
invalidRequest:
|
||||
value:
|
||||
error:
|
||||
code: invalid_request
|
||||
message: reason_code must not be empty
|
||||
SessionNotFoundError:
|
||||
description: The referenced device session does not exist.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
sessionNotFound:
|
||||
value:
|
||||
error:
|
||||
code: session_not_found
|
||||
message: session not found
|
||||
SubjectNotFoundError:
|
||||
description: The referenced internal block or bulk-revoke subject does not exist.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
subjectNotFound:
|
||||
value:
|
||||
error:
|
||||
code: subject_not_found
|
||||
message: subject not found
|
||||
ServiceUnavailableError:
|
||||
description: A required dependency is temporarily unavailable.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
unavailable:
|
||||
value:
|
||||
error:
|
||||
code: service_unavailable
|
||||
message: service is unavailable
|
||||
InternalError:
|
||||
description: Unexpected internal failure while processing the request.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
internal:
|
||||
value:
|
||||
error:
|
||||
code: internal_error
|
||||
message: internal server error
|
||||
@@ -0,0 +1,284 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: Galaxy Auth / Session Service Public API
|
||||
version: v1
|
||||
description: |
|
||||
This specification documents the implemented `galaxy/authsession` v1
|
||||
public REST contract for the e-mail-code flow consumed by
|
||||
`galaxy/gateway`.
|
||||
|
||||
Implemented public operations:
|
||||
- `POST /api/v1/public/auth/send-email-code`
|
||||
- `POST /api/v1/public/auth/confirm-email-code`
|
||||
|
||||
Contract rules:
|
||||
- requests and responses are JSON only;
|
||||
- request schemas reject unknown fields via `additionalProperties: false`;
|
||||
- empty bodies, malformed JSON, multiple JSON objects, and unknown fields
|
||||
are rejected as `400 invalid_request`;
|
||||
- surrounding ASCII/Unicode whitespace is trimmed from input string fields
|
||||
before validation;
|
||||
- `send-email-code` remains success-shaped for existing, new, and blocked
|
||||
e-mail addresses;
|
||||
- `confirm-email-code` returns a ready `device_session_id` synchronously on
|
||||
success.
|
||||
tags:
|
||||
- name: PublicAuth
|
||||
description: Public unauthenticated e-mail-code authentication endpoints.
|
||||
paths:
|
||||
/api/v1/public/auth/send-email-code:
|
||||
post:
|
||||
tags:
|
||||
- PublicAuth
|
||||
operationId: sendEmailCode
|
||||
summary: Start a public e-mail login challenge
|
||||
description: |
|
||||
Accepts one client e-mail address and starts the public challenge flow.
|
||||
The outward result remains success-shaped even when the underlying
|
||||
policy suppresses mail delivery for anti-enumeration purposes.
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/SendEmailCodeRequest"
|
||||
examples:
|
||||
default:
|
||||
value:
|
||||
email: pilot@example.com
|
||||
responses:
|
||||
"200":
|
||||
description: The login challenge was accepted.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/SendEmailCodeResponse"
|
||||
examples:
|
||||
accepted:
|
||||
value:
|
||||
challenge_id: challenge-123
|
||||
"400":
|
||||
$ref: "#/components/responses/SendEmailCodeBadRequestError"
|
||||
"503":
|
||||
$ref: "#/components/responses/ServiceUnavailableError"
|
||||
/api/v1/public/auth/confirm-email-code:
|
||||
post:
|
||||
tags:
|
||||
- PublicAuth
|
||||
operationId: confirmEmailCode
|
||||
summary: Confirm a public e-mail login challenge
|
||||
description: |
|
||||
Completes a previously issued `challenge_id`, validates the submitted
|
||||
verification code, registers the standard base64-encoded raw 32-byte
|
||||
Ed25519 `client_public_key`, and returns the created
|
||||
`device_session_id`.
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ConfirmEmailCodeRequest"
|
||||
examples:
|
||||
default:
|
||||
value:
|
||||
challenge_id: challenge-123
|
||||
code: "123456"
|
||||
client_public_key: 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=
|
||||
responses:
|
||||
"200":
|
||||
description: The device session was created and is ready for use.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ConfirmEmailCodeResponse"
|
||||
examples:
|
||||
accepted:
|
||||
value:
|
||||
device_session_id: device-session-123
|
||||
"400":
|
||||
$ref: "#/components/responses/ConfirmEmailCodeBadRequestError"
|
||||
"403":
|
||||
$ref: "#/components/responses/BlockedByPolicyError"
|
||||
"404":
|
||||
$ref: "#/components/responses/ChallengeNotFoundError"
|
||||
"409":
|
||||
$ref: "#/components/responses/SessionLimitExceededError"
|
||||
"410":
|
||||
$ref: "#/components/responses/ChallengeExpiredError"
|
||||
"503":
|
||||
$ref: "#/components/responses/ServiceUnavailableError"
|
||||
components:
|
||||
schemas:
|
||||
SendEmailCodeRequest:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- email
|
||||
properties:
|
||||
email:
|
||||
type: string
|
||||
description: Single client e-mail address that should receive the login code.
|
||||
format: email
|
||||
SendEmailCodeResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- challenge_id
|
||||
properties:
|
||||
challenge_id:
|
||||
type: string
|
||||
description: Opaque challenge identifier returned by the Auth / Session Service.
|
||||
ConfirmEmailCodeRequest:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- challenge_id
|
||||
- code
|
||||
- client_public_key
|
||||
properties:
|
||||
challenge_id:
|
||||
type: string
|
||||
description: Opaque challenge identifier previously returned by send-email-code.
|
||||
code:
|
||||
type: string
|
||||
description: Verification code delivered to the client.
|
||||
client_public_key:
|
||||
type: string
|
||||
description: Standard base64-encoded raw 32-byte Ed25519 public key registered for the new device session.
|
||||
ConfirmEmailCodeResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- device_session_id
|
||||
properties:
|
||||
device_session_id:
|
||||
type: string
|
||||
description: Stable identifier of the created device session.
|
||||
ErrorResponse:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- error
|
||||
properties:
|
||||
error:
|
||||
$ref: "#/components/schemas/ErrorBody"
|
||||
ErrorBody:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- code
|
||||
- message
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
description: |
|
||||
Stable gateway-generated or client-safe auth-adapter-projected
|
||||
error code. Gateway-generated values include `invalid_request`,
|
||||
`not_found`, `method_not_allowed`, `request_too_large`,
|
||||
`rate_limited`, `internal_error`, and `service_unavailable`.
|
||||
message:
|
||||
type: string
|
||||
description: Human-readable client-safe error description.
|
||||
responses:
|
||||
SendEmailCodeBadRequestError:
|
||||
description: |
|
||||
Request body or field values are invalid. This includes empty bodies,
|
||||
malformed JSON, multiple JSON objects, unknown fields, and invalid
|
||||
`email`.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
invalidRequest:
|
||||
value:
|
||||
error:
|
||||
code: invalid_request
|
||||
message: email must be a single valid email address
|
||||
ConfirmEmailCodeBadRequestError:
|
||||
description: |
|
||||
Request body or field values are invalid. This includes malformed
|
||||
request payloads, invalid confirmation codes, and malformed
|
||||
`client_public_key` values.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
invalidRequest:
|
||||
value:
|
||||
error:
|
||||
code: invalid_request
|
||||
message: challenge_id must not be empty
|
||||
invalidCode:
|
||||
value:
|
||||
error:
|
||||
code: invalid_code
|
||||
message: confirmation code is invalid
|
||||
invalidClientPublicKey:
|
||||
value:
|
||||
error:
|
||||
code: invalid_client_public_key
|
||||
message: client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key
|
||||
ChallengeNotFoundError:
|
||||
description: The referenced challenge does not exist.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
notFound:
|
||||
value:
|
||||
error:
|
||||
code: challenge_not_found
|
||||
message: challenge not found
|
||||
ChallengeExpiredError:
|
||||
description: The referenced challenge has expired and can no longer be confirmed.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
expired:
|
||||
value:
|
||||
error:
|
||||
code: challenge_expired
|
||||
message: challenge expired
|
||||
BlockedByPolicyError:
|
||||
description: The auth flow is denied by account or registration policy.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
blocked:
|
||||
value:
|
||||
error:
|
||||
code: blocked_by_policy
|
||||
message: authentication is blocked by policy
|
||||
SessionLimitExceededError:
|
||||
description: Creating another active device session would exceed the configured limit.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
limitExceeded:
|
||||
value:
|
||||
error:
|
||||
code: session_limit_exceeded
|
||||
message: active session limit would be exceeded
|
||||
ServiceUnavailableError:
|
||||
description: The service is temporarily unable to serve the request safely.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ErrorResponse"
|
||||
examples:
|
||||
unavailable:
|
||||
value:
|
||||
error:
|
||||
code: service_unavailable
|
||||
message: service is unavailable
|
||||
@@ -0,0 +1,72 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"galaxy/authsession/internal/app"
|
||||
"galaxy/authsession/internal/config"
|
||||
"galaxy/authsession/internal/logging"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "authsession: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
cfg, err := config.LoadFromEnv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger, err := logging.New(cfg.Logging.Level)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build logger: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = logging.Sync(logger)
|
||||
}()
|
||||
|
||||
rootCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
telemetryRuntime, err := telemetry.NewProcess(rootCtx, telemetry.ProcessConfig{
|
||||
ServiceName: cfg.Telemetry.ServiceName,
|
||||
TracesExporter: cfg.Telemetry.TracesExporter,
|
||||
MetricsExporter: cfg.Telemetry.MetricsExporter,
|
||||
TracesProtocol: cfg.Telemetry.TracesProtocol,
|
||||
MetricsProtocol: cfg.Telemetry.MetricsProtocol,
|
||||
StdoutTracesEnabled: cfg.Telemetry.StdoutTracesEnabled,
|
||||
StdoutMetricsEnabled: cfg.Telemetry.StdoutMetricsEnabled,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build telemetry runtime: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout)
|
||||
defer cancel()
|
||||
_ = telemetryRuntime.Shutdown(shutdownCtx)
|
||||
}()
|
||||
|
||||
runtime, err := app.NewRuntime(rootCtx, cfg, logger, telemetryRuntime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = runtime.Close()
|
||||
}()
|
||||
|
||||
if err := runtime.App.Run(rootCtx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,418 @@
|
||||
package authsession
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicOpenAPISpecValidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loadSpec(t, "api", "public-openapi.yaml")
|
||||
}
|
||||
|
||||
func TestInternalOpenAPISpecValidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loadSpec(t, "api", "internal-openapi.yaml")
|
||||
}
|
||||
|
||||
func TestPublicOpenAPISpecMatchesGatewayPublicAuthContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authDoc := loadSpec(t, "api", "public-openapi.yaml")
|
||||
gatewayDoc := loadSpec(t, "..", "gateway", "openapi.yaml")
|
||||
authErrorEnvelope := componentSchemaRef(t, authDoc, "ErrorResponse")
|
||||
gatewayProjectedEnvelope := defaultResponseSchemaRef(t, getOperation(t, gatewayDoc, "/api/v1/public/auth/send-email-code", http.MethodPost))
|
||||
const errorResponseRef = "#/components/schemas/ErrorResponse"
|
||||
|
||||
paths := []string{
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
authOperation := getOperation(t, authDoc, path, http.MethodPost)
|
||||
gatewayOperation := getOperation(t, gatewayDoc, path, http.MethodPost)
|
||||
|
||||
if authOperation.OperationID != gatewayOperation.OperationID {
|
||||
require.Failf(t, "test failed", "operation %s: got operationId %q, want %q", path, authOperation.OperationID, gatewayOperation.OperationID)
|
||||
}
|
||||
|
||||
compareSchemaRefs(
|
||||
t,
|
||||
requestSchemaRef(t, authOperation),
|
||||
requestSchemaRef(t, gatewayOperation),
|
||||
"path "+path+" request schema",
|
||||
)
|
||||
compareSchemaRefs(
|
||||
t,
|
||||
responseSchemaRef(t, authOperation, http.StatusOK),
|
||||
responseSchemaRef(t, gatewayOperation, http.StatusOK),
|
||||
"path "+path+" success response schema",
|
||||
)
|
||||
|
||||
for _, status := range publicErrorStatuses(path) {
|
||||
assertSchemaRef(t, responseSchemaRef(t, authOperation, status), errorResponseRef, "path "+path+" error response "+http.StatusText(status)+" envelope")
|
||||
}
|
||||
}
|
||||
|
||||
compareSchemaRefs(
|
||||
t,
|
||||
authErrorEnvelope,
|
||||
componentSchemaRef(t, gatewayDoc, "ErrorResponse"),
|
||||
"ErrorResponse schema",
|
||||
)
|
||||
compareSchemaRefs(
|
||||
t,
|
||||
componentSchemaRef(t, authDoc, "ErrorBody"),
|
||||
componentSchemaRef(t, gatewayDoc, "ErrorBody"),
|
||||
"ErrorBody schema",
|
||||
)
|
||||
assertSchemaRef(t, gatewayProjectedEnvelope, errorResponseRef, "projected gateway auth error envelope")
|
||||
}
|
||||
|
||||
func TestPublicOpenAPISpecErrorExamplesMatchStablePublicErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
doc := loadSpec(t, "api", "public-openapi.yaml")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseName string
|
||||
exampleName string
|
||||
projection shared.PublicErrorProjection
|
||||
}{
|
||||
{
|
||||
name: "send invalid request",
|
||||
responseName: "SendEmailCodeBadRequestError",
|
||||
exampleName: "invalidRequest",
|
||||
projection: shared.ProjectPublicError(shared.InvalidRequest("email must be a single valid email address")),
|
||||
},
|
||||
{
|
||||
name: "confirm invalid request",
|
||||
responseName: "ConfirmEmailCodeBadRequestError",
|
||||
exampleName: "invalidRequest",
|
||||
projection: shared.ProjectPublicError(shared.InvalidRequest("challenge_id must not be empty")),
|
||||
},
|
||||
{
|
||||
name: "confirm invalid code",
|
||||
responseName: "ConfirmEmailCodeBadRequestError",
|
||||
exampleName: "invalidCode",
|
||||
projection: shared.ProjectPublicError(shared.InvalidCode()),
|
||||
},
|
||||
{
|
||||
name: "confirm invalid client public key",
|
||||
responseName: "ConfirmEmailCodeBadRequestError",
|
||||
exampleName: "invalidClientPublicKey",
|
||||
projection: shared.ProjectPublicError(shared.InvalidClientPublicKey()),
|
||||
},
|
||||
{
|
||||
name: "challenge not found",
|
||||
responseName: "ChallengeNotFoundError",
|
||||
exampleName: "notFound",
|
||||
projection: shared.ProjectPublicError(shared.ChallengeNotFound()),
|
||||
},
|
||||
{
|
||||
name: "challenge expired",
|
||||
responseName: "ChallengeExpiredError",
|
||||
exampleName: "expired",
|
||||
projection: shared.ProjectPublicError(shared.ChallengeExpired()),
|
||||
},
|
||||
{
|
||||
name: "blocked by policy",
|
||||
responseName: "BlockedByPolicyError",
|
||||
exampleName: "blocked",
|
||||
projection: shared.ProjectPublicError(shared.BlockedByPolicy()),
|
||||
},
|
||||
{
|
||||
name: "session limit exceeded",
|
||||
responseName: "SessionLimitExceededError",
|
||||
exampleName: "limitExceeded",
|
||||
projection: shared.ProjectPublicError(shared.SessionLimitExceeded()),
|
||||
},
|
||||
{
|
||||
name: "service unavailable",
|
||||
responseName: "ServiceUnavailableError",
|
||||
exampleName: "unavailable",
|
||||
projection: shared.ProjectPublicError(shared.ServiceUnavailable(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := responseExampleValue(t, doc, tt.responseName, tt.exampleName)
|
||||
want := map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": tt.projection.Code,
|
||||
"message": tt.projection.Message,
|
||||
},
|
||||
}
|
||||
|
||||
require.JSONEq(t, string(mustJSON(t, want)), string(mustJSON(t, got)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInternalOpenAPISpecFreezesMutationContracts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
doc := loadSpec(t, "api", "internal-openapi.yaml")
|
||||
|
||||
blockUser := componentSchemaRef(t, doc, "BlockUserRequest")
|
||||
if got := len(blockUser.Value.OneOf); got != 2 {
|
||||
require.Failf(t, "test failed", "BlockUserRequest oneOf length = %d, want 2", got)
|
||||
}
|
||||
|
||||
refs := []string{
|
||||
blockUser.Value.OneOf[0].Ref,
|
||||
blockUser.Value.OneOf[1].Ref,
|
||||
}
|
||||
slices.Sort(refs)
|
||||
wantRefs := []string{
|
||||
"#/components/schemas/BlockUserByEmailRequest",
|
||||
"#/components/schemas/BlockUserByUserIDRequest",
|
||||
}
|
||||
if !slices.Equal(refs, wantRefs) {
|
||||
require.Failf(t, "test failed", "BlockUserRequest oneOf refs = %v, want %v", refs, wantRefs)
|
||||
}
|
||||
|
||||
assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByUserIDRequest"), "reason_code", "actor", "user_id")
|
||||
assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByEmailRequest"), "reason_code", "actor", "email")
|
||||
assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "device_session_id", "affected_session_count")
|
||||
assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "user_id", "affected_session_count", "affected_device_session_ids")
|
||||
assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "subject_kind", "subject_value", "affected_session_count", "affected_device_session_ids")
|
||||
|
||||
assertStringEnum(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "revoked", "already_revoked")
|
||||
assertStringEnum(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "revoked", "no_active_sessions")
|
||||
assertStringEnum(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "blocked", "already_blocked")
|
||||
}
|
||||
|
||||
func loadSpec(t *testing.T, pathElems ...string) *openapi3.T {
|
||||
t.Helper()
|
||||
|
||||
_, thisFile, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
require.FailNow(t, "runtime.Caller failed")
|
||||
}
|
||||
|
||||
specPath := filepath.Join(append([]string{filepath.Dir(thisFile)}, pathElems...)...)
|
||||
loader := openapi3.NewLoader()
|
||||
doc, err := loader.LoadFromFile(specPath)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "load spec %s: %v", specPath, err)
|
||||
}
|
||||
if doc == nil {
|
||||
require.Failf(t, "test failed", "load spec %s: returned nil document", specPath)
|
||||
}
|
||||
if doc.Info == nil {
|
||||
require.Failf(t, "test failed", "load spec %s: missing info section", specPath)
|
||||
}
|
||||
if doc.Info.Version != "v1" {
|
||||
require.Failf(t, "test failed", "spec %s version = %q, want v1", specPath, doc.Info.Version)
|
||||
}
|
||||
if err := doc.Validate(context.Background()); err != nil {
|
||||
require.Failf(t, "test failed", "validate spec %s: %v", specPath, err)
|
||||
}
|
||||
|
||||
return doc
|
||||
}
|
||||
|
||||
func getOperation(t *testing.T, doc *openapi3.T, path string, method string) *openapi3.Operation {
|
||||
t.Helper()
|
||||
|
||||
if doc.Paths == nil {
|
||||
require.Failf(t, "test failed", "spec is missing paths while looking up %s %s", method, path)
|
||||
}
|
||||
pathItem := doc.Paths.Value(path)
|
||||
if pathItem == nil {
|
||||
require.Failf(t, "test failed", "spec is missing path %s", path)
|
||||
}
|
||||
operation := pathItem.GetOperation(method)
|
||||
if operation == nil {
|
||||
require.Failf(t, "test failed", "spec is missing %s operation for path %s", method, path)
|
||||
}
|
||||
|
||||
return operation
|
||||
}
|
||||
|
||||
func requestSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef {
|
||||
t.Helper()
|
||||
|
||||
if operation.RequestBody == nil || operation.RequestBody.Value == nil {
|
||||
require.FailNow(t, "operation is missing request body")
|
||||
}
|
||||
mediaType := operation.RequestBody.Value.Content.Get("application/json")
|
||||
if mediaType == nil || mediaType.Schema == nil {
|
||||
require.FailNow(t, "operation is missing application/json request schema")
|
||||
}
|
||||
|
||||
return mediaType.Schema
|
||||
}
|
||||
|
||||
func responseSchemaRef(t *testing.T, operation *openapi3.Operation, status int) *openapi3.SchemaRef {
|
||||
t.Helper()
|
||||
|
||||
if operation.Responses == nil {
|
||||
require.Failf(t, "test failed", "operation is missing responses for status %d", status)
|
||||
}
|
||||
response := operation.Responses.Status(status)
|
||||
if response == nil || response.Value == nil {
|
||||
require.Failf(t, "test failed", "operation is missing response for status %d", status)
|
||||
}
|
||||
mediaType := response.Value.Content.Get("application/json")
|
||||
if mediaType == nil || mediaType.Schema == nil {
|
||||
require.Failf(t, "test failed", "operation response %d is missing application/json schema", status)
|
||||
}
|
||||
|
||||
return mediaType.Schema
|
||||
}
|
||||
|
||||
func defaultResponseSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef {
|
||||
t.Helper()
|
||||
|
||||
if operation.Responses == nil {
|
||||
require.FailNow(t, "operation is missing default responses")
|
||||
}
|
||||
response := operation.Responses.Default()
|
||||
if response == nil || response.Value == nil {
|
||||
require.FailNow(t, "operation is missing default response")
|
||||
}
|
||||
mediaType := response.Value.Content.Get("application/json")
|
||||
if mediaType == nil || mediaType.Schema == nil {
|
||||
require.FailNow(t, "operation default response is missing application/json schema")
|
||||
}
|
||||
|
||||
return mediaType.Schema
|
||||
}
|
||||
|
||||
func componentSchemaRef(t *testing.T, doc *openapi3.T, name string) *openapi3.SchemaRef {
|
||||
t.Helper()
|
||||
|
||||
if doc.Components == nil {
|
||||
require.Failf(t, "test failed", "spec is missing components while looking up schema %s", name)
|
||||
}
|
||||
schema := doc.Components.Schemas[name]
|
||||
if schema == nil || schema.Value == nil {
|
||||
require.Failf(t, "test failed", "spec is missing schema %s", name)
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
func responseExampleValue(t *testing.T, doc *openapi3.T, responseName string, exampleName string) any {
|
||||
t.Helper()
|
||||
|
||||
if doc.Components == nil {
|
||||
require.Failf(t, "test failed", "spec is missing components while looking up response %s", responseName)
|
||||
}
|
||||
response := doc.Components.Responses[responseName]
|
||||
if response == nil || response.Value == nil {
|
||||
require.Failf(t, "test failed", "spec is missing response %s", responseName)
|
||||
}
|
||||
mediaType := response.Value.Content.Get("application/json")
|
||||
if mediaType == nil {
|
||||
require.Failf(t, "test failed", "response %s is missing application/json content", responseName)
|
||||
}
|
||||
example := mediaType.Examples[exampleName]
|
||||
if example == nil || example.Value == nil {
|
||||
require.Failf(t, "test failed", "response %s is missing example %s", responseName, exampleName)
|
||||
}
|
||||
|
||||
return example.Value.Value
|
||||
}
|
||||
|
||||
func compareSchemaRefs(t *testing.T, got *openapi3.SchemaRef, want *openapi3.SchemaRef, name string) {
|
||||
t.Helper()
|
||||
|
||||
gotJSON := mustJSON(t, got)
|
||||
wantJSON := mustJSON(t, want)
|
||||
if !bytes.Equal(gotJSON, wantJSON) {
|
||||
require.Failf(t, "test failed", "%s mismatch:\n got: %s\nwant: %s", name, gotJSON, wantJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func assertSchemaRef(t *testing.T, schemaRef *openapi3.SchemaRef, want string, name string) {
|
||||
t.Helper()
|
||||
|
||||
if schemaRef.Ref != want {
|
||||
require.Failf(t, "test failed", "%s ref = %q, want %q", name, schemaRef.Ref, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRequiredFields(t *testing.T, schemaRef *openapi3.SchemaRef, fields ...string) {
|
||||
t.Helper()
|
||||
|
||||
required := append([]string(nil), schemaRef.Value.Required...)
|
||||
slices.Sort(required)
|
||||
want := append([]string(nil), fields...)
|
||||
slices.Sort(want)
|
||||
if !slices.Equal(required, want) {
|
||||
require.Failf(t, "test failed", "schema required fields = %v, want %v", required, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertStringEnum(t *testing.T, schemaRef *openapi3.SchemaRef, property string, values ...string) {
|
||||
t.Helper()
|
||||
|
||||
prop := schemaRef.Value.Properties[property]
|
||||
if prop == nil || prop.Value == nil {
|
||||
require.Failf(t, "test failed", "schema is missing property %s", property)
|
||||
}
|
||||
|
||||
got := make([]string, 0, len(prop.Value.Enum))
|
||||
for _, raw := range prop.Value.Enum {
|
||||
value, ok := raw.(string)
|
||||
if !ok {
|
||||
require.Failf(t, "test failed", "property %s enum contains non-string value %T", property, raw)
|
||||
}
|
||||
got = append(got, value)
|
||||
}
|
||||
|
||||
if !slices.Equal(got, values) {
|
||||
require.Failf(t, "test failed", "property %s enum = %v, want %v", property, got, values)
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, value any) []byte {
|
||||
t.Helper()
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "marshal JSON: %v", err)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func publicErrorStatuses(path string) []int {
|
||||
switch path {
|
||||
case "/api/v1/public/auth/send-email-code":
|
||||
return []int{http.StatusBadRequest, http.StatusServiceUnavailable}
|
||||
case "/api/v1/public/auth/confirm-email-code":
|
||||
return []int{
|
||||
http.StatusBadRequest,
|
||||
http.StatusForbidden,
|
||||
http.StatusNotFound,
|
||||
http.StatusConflict,
|
||||
http.StatusGone,
|
||||
http.StatusServiceUnavailable,
|
||||
}
|
||||
default:
|
||||
panic("unexpected public auth path: " + path)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
# Auth / Session Service Docs
|
||||
|
||||
This directory keeps service-local documentation that is too detailed for the
|
||||
root architecture document and too operational for the OpenAPI specs.
|
||||
|
||||
Sections:
|
||||
|
||||
- [Runtime and components](runtime.md)
|
||||
- [Auth, revoke, and repair flows](flows.md)
|
||||
- [Operator runbook](runbook.md)
|
||||
- [Configuration and contract examples](examples.md)
|
||||
|
||||
Primary references:
|
||||
|
||||
- [`../README.md`](../README.md) for service scope, contracts, and core domain
|
||||
rules
|
||||
- [`../api/public-openapi.yaml`](../api/public-openapi.yaml) for the public
|
||||
REST contract
|
||||
- [`../api/internal-openapi.yaml`](../api/internal-openapi.yaml) for the
|
||||
trusted internal REST contract
|
||||
- [`../../gateway/README.md`](../../gateway/README.md) for the downstream
|
||||
consumer of authsession's public DTOs and Redis session projection
|
||||
@@ -0,0 +1,194 @@
|
||||
# Configuration And Contract Examples
|
||||
|
||||
The examples below are illustrative. Values such as keys, codes, and IDs are
|
||||
placeholders unless explicitly stated otherwise.
|
||||
|
||||
## Example Environment
|
||||
|
||||
Minimal local-development shape:
|
||||
|
||||
```dotenv
|
||||
AUTHSESSION_REDIS_ADDR=127.0.0.1:6379
|
||||
AUTHSESSION_PUBLIC_HTTP_ADDR=:8080
|
||||
AUTHSESSION_INTERNAL_HTTP_ADDR=:8081
|
||||
|
||||
AUTHSESSION_USER_SERVICE_MODE=stub
|
||||
AUTHSESSION_MAIL_SERVICE_MODE=stub
|
||||
|
||||
OTEL_SERVICE_NAME=galaxy-authsession
|
||||
OTEL_TRACES_EXPORTER=none
|
||||
OTEL_METRICS_EXPORTER=none
|
||||
```
|
||||
|
||||
Example REST-backed integration shape:
|
||||
|
||||
```dotenv
|
||||
AUTHSESSION_REDIS_ADDR=127.0.0.1:6379
|
||||
|
||||
AUTHSESSION_USER_SERVICE_MODE=rest
|
||||
AUTHSESSION_USER_SERVICE_BASE_URL=http://127.0.0.1:8091
|
||||
AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT=1s
|
||||
|
||||
AUTHSESSION_MAIL_SERVICE_MODE=rest
|
||||
AUTHSESSION_MAIL_SERVICE_BASE_URL=http://127.0.0.1:8092
|
||||
AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT=1s
|
||||
```
|
||||
|
||||
## Public Auth HTTP Examples
|
||||
|
||||
Start an e-mail challenge:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8080/api/v1/public/auth/send-email-code \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"pilot@example.com"}'
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"challenge_id": "challenge-123"
|
||||
}
|
||||
```
|
||||
|
||||
Confirm the challenge and register the device public key:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8080/api/v1/public/auth/confirm-email-code \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"challenge_id": "challenge-123",
|
||||
"code": "123456",
|
||||
"client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no="
|
||||
}'
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"device_session_id": "device-session-123"
|
||||
}
|
||||
```
|
||||
|
||||
Stable public error example:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"code": "challenge_expired",
|
||||
"message": "challenge expired"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Trusted Internal HTTP Examples
|
||||
|
||||
Read one session:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8081/api/v1/internal/sessions/device-session-123
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"session": {
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=",
|
||||
"status": "active",
|
||||
"created_at": "2026-04-05T12:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Revoke one session:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8081/api/v1/internal/sessions/device-session-123/revoke \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"reason_code":"admin_revoke","actor":{"type":"system"}}'
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"outcome": "revoked",
|
||||
"device_session_id": "device-session-123",
|
||||
"affected_session_count": 1
|
||||
}
|
||||
```
|
||||
|
||||
Block by e-mail:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8081/api/v1/internal/user-blocks \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin","id":"admin-1"}}'
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"outcome": "blocked",
|
||||
"subject_kind": "email",
|
||||
"subject_value": "pilot@example.com",
|
||||
"affected_session_count": 0,
|
||||
"affected_device_session_ids": []
|
||||
}
|
||||
```
|
||||
|
||||
## Redis Projection Examples
|
||||
|
||||
### Gateway Session Cache Record
|
||||
|
||||
Example Redis key and JSON value written by authsession for gateway:
|
||||
|
||||
```text
|
||||
gateway:session:device-session-123
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=",
|
||||
"status": "active"
|
||||
}
|
||||
```
|
||||
|
||||
### Gateway Session-Event Stream Entry
|
||||
|
||||
Active snapshot:
|
||||
|
||||
```bash
|
||||
redis-cli XADD gateway:session_events '*' \
|
||||
device_session_id device-session-123 \
|
||||
user_id user-123 \
|
||||
client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \
|
||||
status active
|
||||
```
|
||||
|
||||
Revoked snapshot:
|
||||
|
||||
```bash
|
||||
redis-cli XADD gateway:session_events '*' \
|
||||
device_session_id device-session-123 \
|
||||
user_id user-123 \
|
||||
client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \
|
||||
status revoked \
|
||||
revoked_at_ms 1775121700000
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- projected field values are strings in the Redis Stream payload
|
||||
- `revoked_at_ms` is written only for revoked snapshots
|
||||
- duplicate full-snapshot stream events are acceptable
|
||||
- the cache snapshot and stream event intentionally omit revoke reason and
|
||||
actor metadata because gateway does not consume them
|
||||
@@ -0,0 +1,119 @@
|
||||
# Auth, Revoke, and Repair Flows
|
||||
|
||||
## Public Auth Flow
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client
|
||||
participant Gateway
|
||||
participant Auth
|
||||
participant Abuse as Resend throttle
|
||||
participant User as UserDirectory
|
||||
participant Mail as MailSender
|
||||
participant Challenge as ChallengeStore
|
||||
participant Session as SessionStore
|
||||
participant Config as ConfigProvider
|
||||
participant Projection as Gateway projection publisher
|
||||
|
||||
Client->>Gateway: POST /api/v1/public/auth/send-email-code
|
||||
Gateway->>Auth: POST /api/v1/public/auth/send-email-code
|
||||
Auth->>Abuse: check and reserve cooldown
|
||||
alt throttled
|
||||
Abuse-->>Auth: throttled
|
||||
Auth->>Challenge: create delivery_throttled challenge
|
||||
Auth-->>Gateway: 200 {challenge_id}
|
||||
else allowed
|
||||
Abuse-->>Auth: allowed
|
||||
Auth->>User: ResolveByEmail(email)
|
||||
User-->>Auth: existing / creatable / blocked
|
||||
Auth->>Challenge: create pending challenge
|
||||
alt blocked
|
||||
Auth->>Challenge: mark delivery_suppressed
|
||||
else not blocked
|
||||
Auth->>Mail: SendLoginCode(email, code)
|
||||
Mail-->>Auth: sent / suppressed / failure
|
||||
Auth->>Challenge: persist final delivery outcome
|
||||
end
|
||||
Auth-->>Gateway: 200 {challenge_id}
|
||||
end
|
||||
|
||||
Client->>Gateway: POST /api/v1/public/auth/confirm-email-code
|
||||
Gateway->>Auth: POST /api/v1/public/auth/confirm-email-code
|
||||
Auth->>Challenge: load and validate challenge
|
||||
Auth->>User: EnsureUserByEmail(email)
|
||||
User-->>Auth: existing / created / blocked
|
||||
Auth->>Config: LoadSessionLimit()
|
||||
Auth->>Session: CountActiveByUserID(user_id)
|
||||
Auth->>Session: create device session
|
||||
Auth->>Challenge: CAS to confirmed_pending_expire
|
||||
Auth->>Session: reread current stored session view
|
||||
Auth->>Projection: publish gateway snapshot
|
||||
Auth-->>Gateway: 200 {device_session_id}
|
||||
```
|
||||
|
||||
## Revoke and Block Flow
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Caller as Trusted internal caller
|
||||
participant Auth
|
||||
participant User as UserDirectory
|
||||
participant Session as SessionStore
|
||||
participant Projection as Gateway projection publisher
|
||||
participant Gateway
|
||||
|
||||
Caller->>Auth: revoke or block request
|
||||
alt block by user or email
|
||||
Auth->>User: apply block mutation
|
||||
User-->>Auth: blocked / already_blocked
|
||||
end
|
||||
Auth->>Session: revoke one or many sessions
|
||||
Session-->>Auth: updated source-of-truth sessions
|
||||
loop each affected session
|
||||
Auth->>Projection: publish revoked snapshot
|
||||
end
|
||||
Auth-->>Caller: 200 acknowledgement
|
||||
Projection-->>Gateway: revoked session snapshot
|
||||
```
|
||||
|
||||
## Projection Repair On Retry
|
||||
|
||||
Projection writes happen after source-of-truth updates. If projection publish
|
||||
fails after state is already stored, the caller sees `service_unavailable`, and
|
||||
the repair path is to repeat the same request.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client
|
||||
participant Auth
|
||||
participant Challenge as ChallengeStore
|
||||
participant Session as SessionStore
|
||||
participant Projection as Gateway projection publisher
|
||||
|
||||
Client->>Auth: confirm-email-code
|
||||
Auth->>Challenge: validate challenge
|
||||
Auth->>Session: create session
|
||||
Auth->>Challenge: persist confirmed_pending_expire
|
||||
Auth->>Projection: publish snapshot
|
||||
Projection-->>Auth: failure
|
||||
Auth-->>Client: 503 service_unavailable
|
||||
|
||||
Client->>Auth: repeat same confirm-email-code
|
||||
Auth->>Challenge: load confirmed_pending_expire challenge
|
||||
Auth->>Session: load stored session from confirmation metadata
|
||||
Auth->>Projection: republish current stored session view
|
||||
Projection-->>Auth: success
|
||||
Auth-->>Client: 200 {device_session_id}
|
||||
```
|
||||
|
||||
## Confirm-Race Cleanup
|
||||
|
||||
Concurrent identical confirms are allowed to race at the store level, but the
|
||||
service converges them back to one surviving active session.
|
||||
|
||||
- the winning CAS stores challenge confirmation metadata and publishes the
|
||||
surviving session snapshot
|
||||
- a superseded session created by a losing racing request is revoked
|
||||
best-effort with `reason_code=confirm_race_repair`
|
||||
- cleanup uses the same projection helper, but cleanup failure is not part of
|
||||
the caller-visible success contract
|
||||
@@ -0,0 +1,157 @@
|
||||
# Operator Runbook
|
||||
|
||||
This runbook covers the checks that matter most during startup, steady-state
|
||||
verification, shutdown, and common authsession incidents.
|
||||
|
||||
## Startup Checks
|
||||
|
||||
Before starting the process, confirm:
|
||||
|
||||
- `AUTHSESSION_REDIS_ADDR` points to the Redis deployment used for authsession
|
||||
source-of-truth data, resend throttling, and gateway projection
|
||||
- the configured Redis ACL, DB, TLS, and key-prefix settings match the target
|
||||
environment
|
||||
- if `AUTHSESSION_USER_SERVICE_MODE=rest`, both
|
||||
`AUTHSESSION_USER_SERVICE_BASE_URL` and
|
||||
`AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT` are configured
|
||||
- if `AUTHSESSION_MAIL_SERVICE_MODE=rest`, both
|
||||
`AUTHSESSION_MAIL_SERVICE_BASE_URL` and
|
||||
`AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT` are configured
|
||||
- gateway and authsession agree on:
|
||||
- `gateway:session:` cache key prefix
|
||||
- `gateway:session_events` stream name
|
||||
|
||||
At startup the process performs bounded `PING` checks for:
|
||||
|
||||
- challenge store
|
||||
- session store
|
||||
- config provider
|
||||
- gateway projection publisher
|
||||
- resend-throttle protector
|
||||
|
||||
Startup fails fast if any of those checks fail.
|
||||
|
||||
Expected listener state after a healthy start:
|
||||
|
||||
- public HTTP on `AUTHSESSION_PUBLIC_HTTP_ADDR` or default `:8080`
|
||||
- internal HTTP on `AUTHSESSION_INTERNAL_HTTP_ADDR` or default `:8081`
|
||||
|
||||
Known startup caveats:
|
||||
|
||||
- there is no health, readiness, or metrics endpoint to probe directly
|
||||
- stub user-service and stub mail-service are valid production start modes
|
||||
only for development and isolated testing, not for real environments
|
||||
|
||||
## Steady-State Verification
|
||||
|
||||
Because the service intentionally exposes no `/healthz` or `/readyz`, practical
|
||||
verification is:
|
||||
|
||||
1. confirm the process emitted startup logs for both listeners
|
||||
2. open a TCP connection to the configured public and internal listener
|
||||
addresses
|
||||
3. send one smoke request to the public auth surface and one to the trusted
|
||||
internal surface when a non-destructive path is available
|
||||
4. confirm Redis connectivity and namespace configuration out of band
|
||||
|
||||
Recommended smoke requests:
|
||||
|
||||
- public: malformed `send-email-code` request and expect `400 invalid_request`
|
||||
- internal: `GET /api/v1/internal/users/{unknown}/sessions` and expect `200`
|
||||
with an empty list
|
||||
|
||||
## Shutdown
|
||||
|
||||
The process handles `SIGINT` and `SIGTERM`.
|
||||
|
||||
Shutdown behavior:
|
||||
|
||||
- the per-component shutdown budget is controlled by
|
||||
`AUTHSESSION_SHUTDOWN_TIMEOUT`
|
||||
- both HTTP listeners are stopped through the coordinated app shutdown
|
||||
- Redis and HTTP-client resources are closed after the app stops
|
||||
- telemetry providers are flushed and shut down after the process begins
|
||||
exiting
|
||||
|
||||
During planned restarts:
|
||||
|
||||
1. send `SIGTERM`
|
||||
2. wait for the listener shutdown logs
|
||||
3. restart the process with the same Redis configuration
|
||||
4. re-run the steady-state verification steps above
|
||||
|
||||
## Incident Triage
|
||||
|
||||
### Confirm Returns `503` But A Later Retry Succeeds
|
||||
|
||||
Interpret this as a projection-publication failure after source-of-truth state
|
||||
was already written.
|
||||
|
||||
Check:
|
||||
|
||||
1. whether the challenge moved to `confirmed_pending_expire`
|
||||
2. whether the created session exists in source of truth
|
||||
3. whether Redis was reachable for gateway projection writes at the time of
|
||||
failure
|
||||
4. whether a repeated identical confirm repaired the gateway projection
|
||||
|
||||
Expected behavior:
|
||||
|
||||
- the first request returns `503 service_unavailable`
|
||||
- the same confirm retried during the idempotency window returns the same
|
||||
`device_session_id`
|
||||
|
||||
### Revocation Does Not Reach Gateway
|
||||
|
||||
If a revoked session still authenticates through gateway:
|
||||
|
||||
1. verify the authsession source-of-truth record is revoked
|
||||
2. verify a gateway projection snapshot was written under
|
||||
`gateway:session:<device_session_id>`
|
||||
3. verify a matching snapshot event was appended to `gateway:session_events`
|
||||
4. verify gateway is pointed at the same Redis address, DB, and stream name
|
||||
5. check whether a later active snapshot overwrote the revoked view
|
||||
|
||||
### Send Flow Is Unexpectedly Throttled
|
||||
|
||||
If repeated `send-email-code` calls return challenge ids but no mail is sent:
|
||||
|
||||
1. check the resend-throttle key namespace
|
||||
2. confirm the same normalized e-mail address is being reused
|
||||
3. verify the requests are inside the fixed `1m` cooldown window
|
||||
4. confirm authsession is creating `delivery_throttled` challenges rather than
|
||||
`delivery_suppressed` ones
|
||||
|
||||
Expected throttled behavior:
|
||||
|
||||
- a fresh `challenge_id` is still returned
|
||||
- `UserDirectory` is not called
|
||||
- `MailSender` is not called
|
||||
|
||||
### User-Service Or Mail-Service REST Failures
|
||||
|
||||
If `rest` mode is enabled and calls begin failing:
|
||||
|
||||
1. verify the configured base URL
|
||||
2. verify outbound connectivity from the authsession process
|
||||
3. confirm request timeouts are large enough for the environment
|
||||
4. for user-service reads, remember the client retries only once on transport
|
||||
errors and `502`/`503`/`504`
|
||||
5. for mail-service sends, remember the client never auto-retries
|
||||
|
||||
Observed behavior:
|
||||
|
||||
- public auth flows usually surface these failures as `503 service_unavailable`
|
||||
- internal revoke and block flows surface them as `503 service_unavailable`
|
||||
|
||||
### Expired Challenge Questions
|
||||
|
||||
When callers report mixed `challenge_expired` and `challenge_not_found`
|
||||
responses:
|
||||
|
||||
- `challenge_expired` means the record still exists and has crossed the
|
||||
expiration boundary
|
||||
- `challenge_not_found` means the record is absent, including after Redis TTL
|
||||
cleanup removes it
|
||||
|
||||
That difference is expected and should not be treated as a contract drift.
|
||||
@@ -0,0 +1,176 @@
|
||||
# Runtime and Components
|
||||
|
||||
The diagram below focuses on the deployed `galaxy/authsession` process and its
|
||||
runtime dependencies.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph Clients
|
||||
Gateway["Edge Gateway"]
|
||||
Internal["Trusted internal callers"]
|
||||
end
|
||||
|
||||
subgraph Authsession["Auth / Session Service process"]
|
||||
PublicHTTP["Public HTTP listener\n/api/v1/public/auth/*"]
|
||||
InternalHTTP["Trusted internal listener\n/api/v1/internal/*"]
|
||||
Services["Application services"]
|
||||
Runtime["Clock, IDs, code generation, hashing"]
|
||||
Telemetry["Logs, traces, metrics"]
|
||||
end
|
||||
|
||||
Redis["Redis\nchallenges + sessions + config + projection + throttle"]
|
||||
User["User Service\nstub or REST"]
|
||||
Mail["Mail Service\nstub or REST"]
|
||||
GatewayCache["Gateway session cache\nand session-events stream"]
|
||||
|
||||
Gateway --> PublicHTTP
|
||||
Internal --> InternalHTTP
|
||||
PublicHTTP --> Services
|
||||
InternalHTTP --> Services
|
||||
Services --> Runtime
|
||||
Services --> Redis
|
||||
Services --> User
|
||||
Services --> Mail
|
||||
Services --> GatewayCache
|
||||
PublicHTTP --> Telemetry
|
||||
InternalHTTP --> Telemetry
|
||||
```
|
||||
|
||||
## Listeners
|
||||
|
||||
`authsession` exposes exactly two HTTP listeners:
|
||||
|
||||
| Listener | Default addr | Purpose |
|
||||
| --- | --- | --- |
|
||||
| Public HTTP | `:8080` | Unauthenticated public auth routes consumed directly or through gateway |
|
||||
| Internal HTTP | `:8081` | Trusted read, revoke, and block operations |
|
||||
|
||||
Shared listener defaults:
|
||||
|
||||
- read-header timeout: `2s`
|
||||
- read timeout: `10s`
|
||||
- idle timeout: `1m`
|
||||
- per-request application timeout: `3s`
|
||||
|
||||
Intentional omissions:
|
||||
|
||||
- no `/healthz`
|
||||
- no `/readyz`
|
||||
- no `/metrics`
|
||||
- no separate admin listener
|
||||
|
||||
## Startup Wiring
|
||||
|
||||
`cmd/authsession` loads process config, builds the logger and telemetry
|
||||
runtime, then assembles the application through `internal/app.NewRuntime`.
|
||||
|
||||
`NewRuntime` wires:
|
||||
|
||||
- Redis-backed `ChallengeStore`
|
||||
- Redis-backed `SessionStore`
|
||||
- Redis-backed `ConfigProvider`
|
||||
- Redis-backed gateway `ProjectionPublisher`
|
||||
- Redis-backed resend-throttle `SendEmailCodeAbuseProtector`
|
||||
- local runtime helpers for clock, ID generation, code generation, and code
|
||||
hashing
|
||||
- user-service adapter selected by `AUTHSESSION_USER_SERVICE_MODE`
|
||||
- mail-service adapter selected by `AUTHSESSION_MAIL_SERVICE_MODE`
|
||||
- public and internal HTTP servers
|
||||
|
||||
Before startup completes, the process performs bounded `PING` checks for every
|
||||
Redis-backed adapter listed above. Startup fails fast if any Redis-backed
|
||||
dependency is unavailable or misconfigured.
|
||||
|
||||
## Redis Namespaces
|
||||
|
||||
Default Redis naming:
|
||||
|
||||
- challenges: `authsession:challenge:`
|
||||
- sessions: `authsession:session:`
|
||||
- user-to-session index: `authsession:user-sessions:`
|
||||
- user-to-active-session index: `authsession:user-active-sessions:`
|
||||
- session limit key: `authsession:config:active-session-limit`
|
||||
- send-email-code throttle keys: `authsession:send-email-code-throttle:`
|
||||
- gateway session cache keys: `gateway:session:`
|
||||
- gateway session-events stream: `gateway:session_events`
|
||||
|
||||
The authsession process owns the source-of-truth namespaces and writes the
|
||||
gateway-facing projection namespaces as a derived integration view.
|
||||
|
||||
## Configuration Groups
|
||||
|
||||
Required for all process starts:
|
||||
|
||||
- `AUTHSESSION_REDIS_ADDR`
|
||||
|
||||
Core process config:
|
||||
|
||||
- `AUTHSESSION_SHUTDOWN_TIMEOUT`
|
||||
- `AUTHSESSION_LOG_LEVEL`
|
||||
|
||||
Public HTTP config:
|
||||
|
||||
- `AUTHSESSION_PUBLIC_HTTP_ADDR`
|
||||
- `AUTHSESSION_PUBLIC_HTTP_READ_HEADER_TIMEOUT`
|
||||
- `AUTHSESSION_PUBLIC_HTTP_READ_TIMEOUT`
|
||||
- `AUTHSESSION_PUBLIC_HTTP_IDLE_TIMEOUT`
|
||||
- `AUTHSESSION_PUBLIC_HTTP_REQUEST_TIMEOUT`
|
||||
|
||||
Internal HTTP config:
|
||||
|
||||
- `AUTHSESSION_INTERNAL_HTTP_ADDR`
|
||||
- `AUTHSESSION_INTERNAL_HTTP_READ_HEADER_TIMEOUT`
|
||||
- `AUTHSESSION_INTERNAL_HTTP_READ_TIMEOUT`
|
||||
- `AUTHSESSION_INTERNAL_HTTP_IDLE_TIMEOUT`
|
||||
- `AUTHSESSION_INTERNAL_HTTP_REQUEST_TIMEOUT`
|
||||
|
||||
Redis connectivity and namespace config:
|
||||
|
||||
- `AUTHSESSION_REDIS_USERNAME`
|
||||
- `AUTHSESSION_REDIS_PASSWORD`
|
||||
- `AUTHSESSION_REDIS_DB`
|
||||
- `AUTHSESSION_REDIS_TLS_ENABLED`
|
||||
- `AUTHSESSION_REDIS_OPERATION_TIMEOUT`
|
||||
- `AUTHSESSION_REDIS_CHALLENGE_KEY_PREFIX`
|
||||
- `AUTHSESSION_REDIS_SESSION_KEY_PREFIX`
|
||||
- `AUTHSESSION_REDIS_USER_SESSIONS_KEY_PREFIX`
|
||||
- `AUTHSESSION_REDIS_USER_ACTIVE_SESSIONS_KEY_PREFIX`
|
||||
- `AUTHSESSION_REDIS_SESSION_LIMIT_KEY`
|
||||
- `AUTHSESSION_REDIS_GATEWAY_SESSION_CACHE_KEY_PREFIX`
|
||||
- `AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM`
|
||||
- `AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM_MAX_LEN`
|
||||
- `AUTHSESSION_REDIS_SEND_EMAIL_CODE_THROTTLE_KEY_PREFIX`
|
||||
|
||||
User-service integration:
|
||||
|
||||
- `AUTHSESSION_USER_SERVICE_MODE=stub|rest`
|
||||
- `AUTHSESSION_USER_SERVICE_BASE_URL`
|
||||
- `AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT`
|
||||
|
||||
Mail-service integration:
|
||||
|
||||
- `AUTHSESSION_MAIL_SERVICE_MODE=stub|rest`
|
||||
- `AUTHSESSION_MAIL_SERVICE_BASE_URL`
|
||||
- `AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT`
|
||||
|
||||
Telemetry:
|
||||
|
||||
- `OTEL_SERVICE_NAME`
|
||||
- `OTEL_TRACES_EXPORTER`
|
||||
- `OTEL_METRICS_EXPORTER`
|
||||
- `OTEL_EXPORTER_OTLP_PROTOCOL`
|
||||
- `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL`
|
||||
- `OTEL_EXPORTER_OTLP_METRICS_PROTOCOL`
|
||||
- `AUTHSESSION_OTEL_STDOUT_TRACES_ENABLED`
|
||||
- `AUTHSESSION_OTEL_STDOUT_METRICS_ENABLED`
|
||||
|
||||
## Runtime Notes
|
||||
|
||||
- user-service and mail-service default to `stub`, which keeps local startup
|
||||
backward-compatible and does not require external URLs
|
||||
- read-style user-service REST methods retry once on transport errors and HTTP
|
||||
`502`, `503`, or `504`
|
||||
- user-service mutation methods do not auto-retry
|
||||
- mail-service REST requests do not auto-retry, to avoid duplicate delivery
|
||||
- authsession exports telemetry through OTel providers only; it does not serve
|
||||
Prometheus text exposition directly
|
||||
@@ -0,0 +1,720 @@
|
||||
package authsession
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/mail"
|
||||
"galaxy/authsession/internal/adapters/redis/challengestore"
|
||||
"galaxy/authsession/internal/adapters/redis/configprovider"
|
||||
"galaxy/authsession/internal/adapters/redis/projectionpublisher"
|
||||
"galaxy/authsession/internal/adapters/redis/sessionstore"
|
||||
"galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/api/internalhttp"
|
||||
"galaxy/authsession/internal/api/publichttp"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
gatewayCompatibilityChallengeKeyPrefix = "authsession:challenge:"
|
||||
gatewayCompatibilitySessionKeyPrefix = "authsession:session:"
|
||||
gatewayCompatibilityUserSessionsKeyPrefix = "authsession:user-sessions:"
|
||||
gatewayCompatibilityUserActiveKeyPrefix = "authsession:user-active-sessions:"
|
||||
gatewayCompatibilitySessionLimitKey = "authsession:config:active-session-limit"
|
||||
gatewayCompatibilitySessionCacheKeyPrefix = "gateway:session:"
|
||||
gatewayCompatibilitySessionEventsStream = "gateway:session_events"
|
||||
gatewayCompatibilityStreamMaxLen int64 = 128
|
||||
|
||||
gatewayCompatibilityEmail = "pilot@example.com"
|
||||
gatewayCompatibilityCode = "123456"
|
||||
)
|
||||
|
||||
var gatewayCompatibilityClientPublicKey = mustGatewayCompatibilityClientPublicKeyBase64()
|
||||
|
||||
func TestGatewayCompatibilityConfirmReturnsGatewayReadableSessionProjection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{})
|
||||
|
||||
sendResponse := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
|
||||
|
||||
var sendBody struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(sendResponse.Body), &sendBody))
|
||||
assert.Equal(t, "challenge-1", sendBody.ChallengeID)
|
||||
|
||||
attempts := app.mailSender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
|
||||
confirmResponse := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": sendBody.ChallengeID,
|
||||
"code": attempts[0].Input.Code,
|
||||
"client_public_key": gatewayCompatibilityClientPublicKey,
|
||||
})
|
||||
assert.Equal(t, http.StatusOK, confirmResponse.StatusCode)
|
||||
|
||||
var confirmBody struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(confirmResponse.Body), &confirmBody))
|
||||
assert.Equal(t, "device-session-1", confirmBody.DeviceSessionID)
|
||||
|
||||
record := app.mustReadGatewayCacheRecord(t, confirmBody.DeviceSessionID)
|
||||
assert.Equal(t, gatewayCacheRecord{
|
||||
DeviceSessionID: "device-session-1",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: gatewayCompatibilityClientPublicKey,
|
||||
Status: "active",
|
||||
}, record)
|
||||
|
||||
events := app.mustReadGatewaySessionEvents(t, confirmBody.DeviceSessionID)
|
||||
require.NotEmpty(t, events)
|
||||
assert.Equal(t, gatewaySessionEventRecord{
|
||||
DeviceSessionID: "device-session-1",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: gatewayCompatibilityClientPublicKey,
|
||||
Status: "active",
|
||||
}, events[len(events)-1])
|
||||
}
|
||||
|
||||
func TestGatewayCompatibilityRevokePublishesRevokedGatewayProjection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{})
|
||||
|
||||
sessionID := app.createSessionThroughPublicFlow(t)
|
||||
|
||||
revokeResponse := gatewayCompatibilityPostJSON(
|
||||
t,
|
||||
app.internalBaseURL+"/api/v1/internal/sessions/"+sessionID+"/revoke",
|
||||
`{"reason_code":"admin_revoke","actor":{"type":"system"}}`,
|
||||
)
|
||||
assert.Equal(t, http.StatusOK, revokeResponse.StatusCode)
|
||||
assert.JSONEq(t, `{"outcome":"revoked","device_session_id":"`+sessionID+`","affected_session_count":1}`, revokeResponse.Body)
|
||||
|
||||
record := app.mustReadGatewayCacheRecord(t, sessionID)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, gatewayCacheRecord{
|
||||
DeviceSessionID: sessionID,
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: gatewayCompatibilityClientPublicKey,
|
||||
Status: "revoked",
|
||||
RevokedAtMS: int64Pointer(app.now.UnixMilli()),
|
||||
}, record)
|
||||
|
||||
events := app.mustReadGatewaySessionEvents(t, sessionID)
|
||||
require.NotEmpty(t, events)
|
||||
last := events[len(events)-1]
|
||||
require.NotNil(t, last.RevokedAtMS)
|
||||
assert.Equal(t, gatewaySessionEventRecord{
|
||||
DeviceSessionID: sessionID,
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: gatewayCompatibilityClientPublicKey,
|
||||
Status: "revoked",
|
||||
RevokedAtMS: int64Pointer(app.now.UnixMilli()),
|
||||
}, last)
|
||||
}
|
||||
|
||||
func TestGatewayCompatibilityRepeatedConfirmReturnsSameSessionID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{})
|
||||
|
||||
sendResponse := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
|
||||
|
||||
var sendBody struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(sendResponse.Body), &sendBody))
|
||||
|
||||
attempts := app.mailSender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
|
||||
requestBody := map[string]string{
|
||||
"challenge_id": sendBody.ChallengeID,
|
||||
"code": attempts[0].Input.Code,
|
||||
"client_public_key": gatewayCompatibilityClientPublicKey,
|
||||
}
|
||||
|
||||
first := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", requestBody)
|
||||
second := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", requestBody)
|
||||
assert.Equal(t, http.StatusOK, first.StatusCode)
|
||||
assert.Equal(t, http.StatusOK, second.StatusCode)
|
||||
|
||||
var firstBody struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
var secondBody struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(first.Body), &firstBody))
|
||||
require.NoError(t, json.Unmarshal([]byte(second.Body), &secondBody))
|
||||
assert.Equal(t, firstBody.DeviceSessionID, secondBody.DeviceSessionID)
|
||||
|
||||
record := app.mustReadGatewayCacheRecord(t, firstBody.DeviceSessionID)
|
||||
assert.Equal(t, gatewayCacheRecord{
|
||||
DeviceSessionID: firstBody.DeviceSessionID,
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: gatewayCompatibilityClientPublicKey,
|
||||
Status: "active",
|
||||
}, record)
|
||||
}
|
||||
|
||||
func TestGatewayCompatibilityBlockedEmailSendRemainsSuccessShaped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{
|
||||
SeedBlockedEmail: true,
|
||||
})
|
||||
|
||||
response := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
|
||||
var body map[string]string
|
||||
require.NoError(t, json.Unmarshal([]byte(response.Body), &body))
|
||||
assert.Equal(t, map[string]string{"challenge_id": "challenge-1"}, body)
|
||||
}
|
||||
|
||||
func TestGatewayCompatibilitySessionLimitExceededReturnsStableClientError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limit := 1
|
||||
app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{
|
||||
SeedExistingUser: true,
|
||||
SessionLimit: &limit,
|
||||
SeedActiveSessions: []devicesession.Session{
|
||||
gatewayCompatibilityActiveSession(
|
||||
t,
|
||||
"device-session-existing",
|
||||
"user-1",
|
||||
gatewayCompatibilityClientPublicKey,
|
||||
time.Date(2026, 4, 5, 11, 58, 0, 0, time.UTC),
|
||||
),
|
||||
},
|
||||
})
|
||||
|
||||
sendResponse := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
|
||||
|
||||
attempts := app.mailSender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
|
||||
confirmResponse := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": "challenge-1",
|
||||
"code": attempts[0].Input.Code,
|
||||
"client_public_key": gatewayCompatibilityClientPublicKey,
|
||||
})
|
||||
assert.Equal(t, http.StatusConflict, confirmResponse.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`, confirmResponse.Body)
|
||||
}
|
||||
|
||||
func TestGatewayCompatibilityMalformedClientPublicKeyReturnsStableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{})
|
||||
|
||||
response := gatewayCompatibilityPostJSON(
|
||||
t,
|
||||
app.publicBaseURL+"/api/v1/public/auth/confirm-email-code",
|
||||
`{"challenge_id":"challenge-123","code":"123456","client_public_key":"invalid"}`,
|
||||
)
|
||||
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`, response.Body)
|
||||
}
|
||||
|
||||
type gatewayCompatibilityOptions struct {
|
||||
SeedBlockedEmail bool
|
||||
SeedExistingUser bool
|
||||
SessionLimit *int
|
||||
SeedActiveSessions []devicesession.Session
|
||||
}
|
||||
|
||||
// gatewayCompatibilityHarness owns one gateway-focused integration test setup
|
||||
// with real HTTP servers and real Redis-backed authsession adapters.
|
||||
type gatewayCompatibilityHarness struct {
|
||||
publicBaseURL string
|
||||
internalBaseURL string
|
||||
mailSender *mail.StubSender
|
||||
redisClient *redis.Client
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func newGatewayCompatibilityHarness(t *testing.T, options gatewayCompatibilityOptions) gatewayCompatibilityHarness {
|
||||
t.Helper()
|
||||
|
||||
now := time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)
|
||||
redisServer := miniredis.RunT(t)
|
||||
redisClient := redis.NewClient(&redis.Options{
|
||||
Addr: redisServer.Addr(),
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, redisClient.Close())
|
||||
})
|
||||
|
||||
if options.SessionLimit != nil {
|
||||
redisServer.Set(gatewayCompatibilitySessionLimitKey, strconv.Itoa(*options.SessionLimit))
|
||||
}
|
||||
|
||||
challengeStore, err := challengestore.New(challengestore.Config{
|
||||
Addr: redisServer.Addr(),
|
||||
DB: 0,
|
||||
KeyPrefix: gatewayCompatibilityChallengeKeyPrefix,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, challengeStore.Close())
|
||||
})
|
||||
|
||||
sessionStore, err := sessionstore.New(sessionstore.Config{
|
||||
Addr: redisServer.Addr(),
|
||||
DB: 0,
|
||||
SessionKeyPrefix: gatewayCompatibilitySessionKeyPrefix,
|
||||
UserSessionsKeyPrefix: gatewayCompatibilityUserSessionsKeyPrefix,
|
||||
UserActiveSessionsKeyPrefix: gatewayCompatibilityUserActiveKeyPrefix,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, sessionStore.Close())
|
||||
})
|
||||
|
||||
configStore, err := configprovider.New(configprovider.Config{
|
||||
Addr: redisServer.Addr(),
|
||||
DB: 0,
|
||||
SessionLimitKey: gatewayCompatibilitySessionLimitKey,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, configStore.Close())
|
||||
})
|
||||
|
||||
publisher, err := projectionpublisher.New(projectionpublisher.Config{
|
||||
Addr: redisServer.Addr(),
|
||||
DB: 0,
|
||||
SessionCacheKeyPrefix: gatewayCompatibilitySessionCacheKeyPrefix,
|
||||
SessionEventsStream: gatewayCompatibilitySessionEventsStream,
|
||||
StreamMaxLen: gatewayCompatibilityStreamMaxLen,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, publisher.Close())
|
||||
})
|
||||
|
||||
userDirectory := &userservice.StubDirectory{}
|
||||
if options.SeedBlockedEmail {
|
||||
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email(gatewayCompatibilityEmail), "policy_blocked"))
|
||||
}
|
||||
if options.SeedExistingUser {
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email(gatewayCompatibilityEmail), common.UserID("user-1")))
|
||||
}
|
||||
for _, session := range options.SeedActiveSessions {
|
||||
require.NoError(t, sessionStore.Create(context.Background(), session))
|
||||
}
|
||||
|
||||
mailSender := &mail.StubSender{}
|
||||
idGenerator := &testkit.SequenceIDGenerator{}
|
||||
codeGenerator := testkit.FixedCodeGenerator{Code: gatewayCompatibilityCode}
|
||||
codeHasher := testkit.DeterministicCodeHasher{}
|
||||
clock := testkit.FixedClock{Time: now}
|
||||
|
||||
sendEmailCodeService, err := sendemailcode.NewWithObservability(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
codeGenerator,
|
||||
codeHasher,
|
||||
mailSender,
|
||||
nil,
|
||||
clock,
|
||||
zap.NewNop(),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
confirmEmailCodeService, err := confirmemailcode.NewWithObservability(
|
||||
challengeStore,
|
||||
sessionStore,
|
||||
userDirectory,
|
||||
configStore,
|
||||
publisher,
|
||||
idGenerator,
|
||||
codeHasher,
|
||||
clock,
|
||||
zap.NewNop(),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
getSessionService, err := getsession.New(sessionStore)
|
||||
require.NoError(t, err)
|
||||
listUserSessionsService, err := listusersessions.New(sessionStore)
|
||||
require.NoError(t, err)
|
||||
revokeDeviceSessionService, err := revokedevicesession.NewWithObservability(sessionStore, publisher, clock, zap.NewNop(), nil)
|
||||
require.NoError(t, err)
|
||||
revokeAllUserSessionsService, err := revokeallusersessions.NewWithObservability(sessionStore, userDirectory, publisher, clock, zap.NewNop(), nil)
|
||||
require.NoError(t, err)
|
||||
blockUserService, err := blockuser.NewWithObservability(userDirectory, sessionStore, publisher, clock, zap.NewNop(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicCfg := publichttp.DefaultConfig()
|
||||
publicCfg.Addr = gatewayCompatibilityFreeAddr(t)
|
||||
publicServer, err := publichttp.NewServer(publicCfg, publichttp.Dependencies{
|
||||
SendEmailCode: sendEmailCodeService,
|
||||
ConfirmEmailCode: confirmEmailCodeService,
|
||||
Logger: zap.NewNop(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
internalCfg := internalhttp.DefaultConfig()
|
||||
internalCfg.Addr = gatewayCompatibilityFreeAddr(t)
|
||||
internalServer, err := internalhttp.NewServer(internalCfg, internalhttp.Dependencies{
|
||||
GetSession: getSessionService,
|
||||
ListUserSessions: listUserSessionsService,
|
||||
RevokeDeviceSession: revokeDeviceSessionService,
|
||||
RevokeAllUserSessions: revokeAllUserSessionsService,
|
||||
BlockUser: blockUserService,
|
||||
Logger: zap.NewNop(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
gatewayCompatibilityRunServer(t, publicServer.Run, publicServer.Shutdown, publicCfg.Addr)
|
||||
gatewayCompatibilityRunServer(t, internalServer.Run, internalServer.Shutdown, internalCfg.Addr)
|
||||
|
||||
return gatewayCompatibilityHarness{
|
||||
publicBaseURL: "http://" + publicCfg.Addr,
|
||||
internalBaseURL: "http://" + internalCfg.Addr,
|
||||
mailSender: mailSender,
|
||||
redisClient: redisClient,
|
||||
now: now,
|
||||
}
|
||||
}
|
||||
|
||||
func (h gatewayCompatibilityHarness) createSessionThroughPublicFlow(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
sendResponse := gatewayCompatibilityPostJSON(t, h.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
|
||||
|
||||
var sendBody struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(sendResponse.Body), &sendBody))
|
||||
|
||||
attempts := h.mailSender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
|
||||
confirmResponse := gatewayCompatibilityPostJSONValue(t, h.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": sendBody.ChallengeID,
|
||||
"code": attempts[0].Input.Code,
|
||||
"client_public_key": gatewayCompatibilityClientPublicKey,
|
||||
})
|
||||
assert.Equal(t, http.StatusOK, confirmResponse.StatusCode)
|
||||
|
||||
var confirmBody struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(confirmResponse.Body), &confirmBody))
|
||||
|
||||
return confirmBody.DeviceSessionID
|
||||
}
|
||||
|
||||
// gatewayCacheRecord mirrors the strict gateway Redis session-cache wire
|
||||
// contract used on the authenticated hot path.
|
||||
type gatewayCacheRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status string `json:"status"`
|
||||
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
||||
}
|
||||
|
||||
func (h gatewayCompatibilityHarness) mustReadGatewayCacheRecord(t *testing.T, deviceSessionID string) gatewayCacheRecord {
|
||||
t.Helper()
|
||||
|
||||
payload, err := h.redisClient.Get(context.Background(), gatewayCompatibilitySessionCacheKeyPrefix+deviceSessionID).Bytes()
|
||||
require.NoError(t, err)
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var record gatewayCacheRecord
|
||||
require.NoError(t, decoder.Decode(&record))
|
||||
|
||||
err = decoder.Decode(&struct{}{})
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
require.NotEmpty(t, record.DeviceSessionID)
|
||||
require.Equal(t, deviceSessionID, record.DeviceSessionID)
|
||||
require.NotEmpty(t, record.UserID)
|
||||
require.NotEmpty(t, record.ClientPublicKey)
|
||||
require.Contains(t, []string{"active", "revoked"}, record.Status)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
// gatewaySessionEventRecord mirrors the strict gateway Redis Stream event
|
||||
// contract for full session snapshots.
|
||||
type gatewaySessionEventRecord struct {
|
||||
DeviceSessionID string
|
||||
UserID string
|
||||
ClientPublicKey string
|
||||
Status string
|
||||
RevokedAtMS *int64
|
||||
}
|
||||
|
||||
func (h gatewayCompatibilityHarness) mustReadGatewaySessionEvents(t *testing.T, deviceSessionID string) []gatewaySessionEventRecord {
|
||||
t.Helper()
|
||||
|
||||
entries, err := h.redisClient.XRange(context.Background(), gatewayCompatibilitySessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
|
||||
records := make([]gatewaySessionEventRecord, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
record := decodeGatewaySessionEvent(t, entry.Values)
|
||||
if record.DeviceSessionID == deviceSessionID {
|
||||
records = append(records, record)
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, records)
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
func decodeGatewaySessionEvent(t *testing.T, values map[string]any) gatewaySessionEventRecord {
|
||||
t.Helper()
|
||||
|
||||
requiredKeys := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"user_id": {},
|
||||
"client_public_key": {},
|
||||
"status": {},
|
||||
}
|
||||
optionalKeys := map[string]struct{}{
|
||||
"revoked_at_ms": {},
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if _, ok := requiredKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := optionalKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
require.Failf(t, "test failed", "decode gateway session event: unsupported field %q", key)
|
||||
}
|
||||
|
||||
record := gatewaySessionEventRecord{
|
||||
DeviceSessionID: gatewayCompatibilityRequiredStringField(t, values, "device_session_id"),
|
||||
UserID: gatewayCompatibilityRequiredStringField(t, values, "user_id"),
|
||||
ClientPublicKey: gatewayCompatibilityRequiredStringField(t, values, "client_public_key"),
|
||||
Status: gatewayCompatibilityRequiredStringField(t, values, "status"),
|
||||
}
|
||||
require.Contains(t, []string{"active", "revoked"}, record.Status)
|
||||
|
||||
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
|
||||
parsed := gatewayCompatibilityParseInt64Field(t, rawRevokedAtMS, "revoked_at_ms")
|
||||
record.RevokedAtMS = &parsed
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func gatewayCompatibilityRequiredStringField(t *testing.T, values map[string]any, field string) string {
|
||||
t.Helper()
|
||||
|
||||
value, ok := values[field]
|
||||
require.Truef(t, ok, "decode gateway session event: missing %s", field)
|
||||
|
||||
stringValue := gatewayCompatibilityCoerceString(t, value, field)
|
||||
require.NotEmptyf(t, strings.TrimSpace(stringValue), "decode gateway session event: %s must not be empty", field)
|
||||
|
||||
return stringValue
|
||||
}
|
||||
|
||||
func gatewayCompatibilityParseInt64Field(t *testing.T, value any, field string) int64 {
|
||||
t.Helper()
|
||||
|
||||
stringValue := gatewayCompatibilityCoerceString(t, value, field)
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
|
||||
require.NoErrorf(t, err, "decode gateway session event: %s", field)
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
func gatewayCompatibilityCoerceString(t *testing.T, value any, field string) string {
|
||||
t.Helper()
|
||||
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed
|
||||
case []byte:
|
||||
return string(typed)
|
||||
case fmt.Stringer:
|
||||
return typed.String()
|
||||
case int:
|
||||
return strconv.Itoa(typed)
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(typed, 10)
|
||||
default:
|
||||
require.Failf(t, "test failed", "decode gateway session event: %s: unsupported value type %T", field, value)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func gatewayCompatibilityRunServer(
|
||||
t *testing.T,
|
||||
run func(context.Context) error,
|
||||
shutdown func(context.Context) error,
|
||||
addr string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- run(context.Background())
|
||||
}()
|
||||
|
||||
gatewayCompatibilityWaitForTCP(t, addr)
|
||||
t.Cleanup(func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
assert.NoError(t, shutdown(shutdownCtx))
|
||||
assert.NoError(t, <-errCh)
|
||||
})
|
||||
}
|
||||
|
||||
func gatewayCompatibilityWaitForTCP(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
conn, err := net.DialTimeout("tcp", addr, 50*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}, 5*time.Second, 25*time.Millisecond)
|
||||
}
|
||||
|
||||
func gatewayCompatibilityFreeAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, listener.Close())
|
||||
}()
|
||||
|
||||
return listener.Addr().String()
|
||||
}
|
||||
|
||||
type gatewayCompatibilityHTTPResponse struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func gatewayCompatibilityPostJSON(t *testing.T, url string, body string) gatewayCompatibilityHTTPResponse {
|
||||
t.Helper()
|
||||
|
||||
request, err := http.NewRequest(http.MethodPost, url, bytes.NewBufferString(body))
|
||||
require.NoError(t, err)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := http.DefaultClient.Do(request)
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return gatewayCompatibilityHTTPResponse{
|
||||
StatusCode: response.StatusCode,
|
||||
Body: string(payload),
|
||||
}
|
||||
}
|
||||
|
||||
func gatewayCompatibilityPostJSONValue(t *testing.T, url string, value any) gatewayCompatibilityHTTPResponse {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return gatewayCompatibilityPostJSON(t, url, string(payload))
|
||||
}
|
||||
|
||||
func gatewayCompatibilityActiveSession(
|
||||
t *testing.T,
|
||||
deviceSessionID string,
|
||||
userID string,
|
||||
clientPublicKeyBase64 string,
|
||||
createdAt time.Time,
|
||||
) devicesession.Session {
|
||||
t.Helper()
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(clientPublicKeyBase64)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey(keyBytes))
|
||||
require.NoError(t, err)
|
||||
|
||||
session := devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
require.NoError(t, session.Validate())
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func mustGatewayCompatibilityClientPublicKeyBase64() string {
|
||||
key := make([]byte, ed25519.PublicKeySize)
|
||||
for index := range key {
|
||||
key[index] = byte(index + 1)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(key)
|
||||
}
|
||||
|
||||
func int64Pointer(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
module galaxy/authsession
|
||||
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.37.0
|
||||
github.com/getkin/kin-openapi v0.134.0
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/redis/go-redis/v9 v9.18.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0
|
||||
go.opentelemetry.io/otel v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0
|
||||
go.opentelemetry.io/otel/metric v1.42.0
|
||||
go.opentelemetry.io/otel/sdk v1.42.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0
|
||||
go.opentelemetry.io/otel/trace v1.42.0
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.49.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
github.com/go-openapi/swag v0.23.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
|
||||
github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c // indirect
|
||||
github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
github.com/woodsbury/decimal128 v1.3.0 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/arch v0.24.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
@@ -0,0 +1,196 @@
|
||||
github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68=
|
||||
github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/getkin/kin-openapi v0.134.0 h1:/L5+1+kfe6dXh8Ot/wqiTgUkjOIEJiC0bbYVziHB8rU=
|
||||
github.com/getkin/kin-openapi v0.134.0/go.mod h1:wK6ZLG/VgoETO9pcLJ/VmAtIcl/DNlMayNTb716EUxE=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
|
||||
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
|
||||
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
|
||||
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
|
||||
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw=
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
|
||||
github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c h1:7ACFcSaQsrWtrH4WHHfUqE1C+f8r2uv8KGaW0jTNjus=
|
||||
github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c/go.mod h1:JKox4Gszkxt57kj27u7rvi7IFoIULvCZHUsBTUmQM/s=
|
||||
github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b h1:vivRhVUAa9t1q0Db4ZmezBP8pWQWnXHFokZj0AOea2g=
|
||||
github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
|
||||
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
github.com/woodsbury/decimal128 v1.3.0 h1:8pffMNWIlC0O5vbyHWFZAt5yWvWcrHA+3ovIIjVWss0=
|
||||
github.com/woodsbury/decimal128 v1.3.0/go.mod h1:C5UTmyTjW3JftjUFzOVhC20BEQa2a4ZKOB5I6Zjb+ds=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0 h1:E7DmskpIO7ZR6QI6zKSEKIDNUYoKw9oHXP23gzbCdU0=
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0/go.mod h1:WB2cS9y+AwqqKhoo9gw6/ZxlSjFBUQGZ8BQOaD3FVXM=
|
||||
go.opentelemetry.io/contrib/propagators/b3 v1.42.0 h1:B2Pew5ufEtgkjLF+tSkXjgYZXQr9m7aCm1wLKB0URbU=
|
||||
go.opentelemetry.io/contrib/propagators/b3 v1.42.0/go.mod h1:iPgUcSEF5DORW6+yNbdw/YevUy+QqJ508ncjhrRSCjc=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 h1:MdKucPl/HbzckWWEisiNqMPhRrAOQX8r4jTuGr636gk=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0/go.mod h1:RolT8tWtfHcjajEH5wFIZ4Dgh5jpPdFXYV9pTAk/qjc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 h1:H7O6RlGOMTizyl3R08Kn5pdM06bnH8oscSj7o11tmLA=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0/go.mod h1:mBFWu/WOVDkWWsR7Tx7h6EpQB8wsv7P0Yrh0Pb7othc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 h1:THuZiwpQZuHPul65w4WcwEnkX2QIuMT+UFoOrygtoJw=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0/go.mod h1:J2pvYM5NGHofZ2/Ru6zw/TNWnEQp5crgyDeSrYpXkAw=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 h1:zWWrB1U6nqhS/k6zYB74CjRpuiitRtLLi68VcgmOEto=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0/go.mod h1:2qXPNBX1OVRC0IwOnfo1ljoid+RD0QK3443EaqVlsOU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 h1:uLXP+3mghfMf7XmV4PkGfFhFKuNWoCvvx5wP/wOXo0o=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0/go.mod h1:v0Tj04armyT59mnURNUJf7RCKcKzq+lgJs6QSjHjaTc=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0 h1:lSZHgNHfbmQTPfuTmWVkEu8J8qXaQwuV30pjCcAUvP8=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0/go.mod h1:so9ounLcuoRDu033MW/E0AD4hhUjVqswrMF5FoZlBcw=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0 h1:s/1iRkCKDfhlh1JF26knRneorus8aOwVIDhvYx9WoDw=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0/go.mod h1:UI3wi0FXg1Pofb8ZBiBLhtMzgoTm1TYkMvn71fAqDzs=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
|
||||
golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,56 @@
|
||||
// Package antiabuse provides runtime in-process adapters for auth-specific
|
||||
// public abuse controls.
|
||||
package antiabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// SendEmailCodeProtector is a concurrency-safe in-process resend-throttle
|
||||
// adapter for public send-email-code attempts.
|
||||
type SendEmailCodeProtector struct {
|
||||
mu sync.Mutex
|
||||
reservedUntil map[common.Email]time.Time
|
||||
}
|
||||
|
||||
// CheckAndReserve applies the fixed Stage-17 resend cooldown using input.Now
|
||||
// as the authoritative decision timestamp.
|
||||
func (p *SendEmailCodeProtector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
|
||||
if ctx == nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err)
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.reservedUntil == nil {
|
||||
p.reservedUntil = make(map[common.Email]time.Time)
|
||||
}
|
||||
|
||||
reservedUntil, exists := p.reservedUntil[input.Email]
|
||||
if exists && input.Now.Before(reservedUntil) {
|
||||
return ports.SendEmailCodeAbuseResult{
|
||||
Outcome: ports.SendEmailCodeAbuseOutcomeThrottled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
p.reservedUntil[input.Email] = input.Now.UTC().Add(challenge.ResendThrottleCooldown)
|
||||
return ports.SendEmailCodeAbuseResult{
|
||||
Outcome: ports.SendEmailCodeAbuseOutcomeAllowed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ ports.SendEmailCodeAbuseProtector = (*SendEmailCodeProtector)(nil)
|
||||
@@ -0,0 +1,64 @@
|
||||
package antiabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSendEmailCodeProtectorCheckAndReserve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protector := &SendEmailCodeProtector{}
|
||||
email := common.Email("pilot@example.com")
|
||||
now := time.Unix(10, 0).UTC()
|
||||
|
||||
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(30 * time.Second),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
}
|
||||
|
||||
func TestSendEmailCodeProtectorNilOrCanceledContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protector := &SendEmailCodeProtector{}
|
||||
_, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Now: time.Unix(10, 0).UTC(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err = protector.CheckAndReserve(ctx, ports.SendEmailCodeAbuseInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Now: time.Unix(10, 0).UTC(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
// Package contracttest provides reusable adapter conformance suites that
|
||||
// exercise storage-agnostic port contracts without depending on one concrete
|
||||
// backend implementation.
|
||||
package contracttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ChallengeStoreFactory constructs a fresh ChallengeStore instance suitable
|
||||
// for one isolated contract subtest.
|
||||
type ChallengeStoreFactory func(t *testing.T) ports.ChallengeStore
|
||||
|
||||
// RunChallengeStoreContractTests executes the backend-agnostic ChallengeStore
|
||||
// contract suite against newStore.
|
||||
func RunChallengeStoreContractTests(t *testing.T, newStore ChallengeStoreFactory) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("create and get", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractConfirmedChallenge(t, time.Unix(1_775_130_000, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
})
|
||||
|
||||
t.Run("get not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("create conflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractPendingChallenge(time.Unix(1_775_130_100, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("compare and swap success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
now := time.Unix(1_775_130_200, 0).UTC()
|
||||
previous := contractPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
next.Attempts.Send = 1
|
||||
next.Abuse.LastAttemptAt = contractTimePointer(now.Add(time.Minute))
|
||||
require.NoError(t, next.Validate())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), previous))
|
||||
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
|
||||
|
||||
got, err := store.Get(context.Background(), previous.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, next, got)
|
||||
})
|
||||
|
||||
t.Run("compare and swap conflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
now := time.Unix(1_775_130_300, 0).UTC()
|
||||
stored := contractPendingChallenge(now)
|
||||
previous := stored
|
||||
previous.Attempts.Send = 99
|
||||
require.NoError(t, previous.Validate())
|
||||
next := stored
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
require.NoError(t, next.Validate())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), stored))
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("compare and swap not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
now := time.Unix(1_775_130_400, 0).UTC()
|
||||
previous := contractPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
require.NoError(t, next.Validate())
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("get returns defensive copies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractConfirmedChallenge(t, time.Unix(1_775_130_500, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got.CodeHash)
|
||||
got.CodeHash[0] = 0xFF
|
||||
if got.Confirmation != nil {
|
||||
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
|
||||
if len(keyBytes) > 0 {
|
||||
keyBytes[0] = 0xFE
|
||||
}
|
||||
}
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.CodeHash, again.CodeHash)
|
||||
require.NotNil(t, again.Confirmation)
|
||||
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
|
||||
})
|
||||
}
|
||||
|
||||
func contractPendingChallenge(now time.Time) challenge.Challenge {
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-pending"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-pending-code"),
|
||||
Status: challenge.StatusPendingSend,
|
||||
DeliveryState: challenge.DeliveryPending,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.InitialTTL),
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractConfirmedChallenge(t *testing.T, now time.Time) challenge.Challenge {
|
||||
t.Helper()
|
||||
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-confirmed"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-code"),
|
||||
Status: challenge.StatusConfirmedPendingExpire,
|
||||
DeliveryState: challenge.DeliverySent,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.ConfirmedRetention),
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: 1,
|
||||
Confirm: 2,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: contractTimePointer(now.Add(30 * time.Second)),
|
||||
},
|
||||
Confirmation: &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID("device-session-1"),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: now.Add(time.Minute),
|
||||
},
|
||||
}
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractTimePointer(value time.Time) *time.Time {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package contracttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ConfigProviderHarnessFactory constructs a fresh semantic ConfigProvider
|
||||
// harness suitable for one isolated contract subtest.
|
||||
type ConfigProviderHarnessFactory func(t *testing.T) ConfigProviderHarness
|
||||
|
||||
// ConfigProviderHarness bundles one semantic ConfigProvider instance with the
|
||||
// seed hooks needed by the backend-agnostic contract suite.
|
||||
type ConfigProviderHarness struct {
|
||||
// Provider is the semantic ConfigProvider under test.
|
||||
Provider ports.ConfigProvider
|
||||
|
||||
// SeedDisabled prepares storage so LoadSessionLimit observes “limit absent”.
|
||||
SeedDisabled func(t *testing.T)
|
||||
|
||||
// SeedLimit prepares storage so LoadSessionLimit observes a valid positive
|
||||
// configured limit.
|
||||
SeedLimit func(t *testing.T, limit int)
|
||||
}
|
||||
|
||||
// RunConfigProviderContractTests executes the backend-agnostic ConfigProvider
|
||||
// semantic contract suite against newHarness.
|
||||
func RunConfigProviderContractTests(t *testing.T, newHarness ConfigProviderHarnessFactory) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("limit absent means disabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
harness := newHarness(t)
|
||||
require.NotNil(t, harness.Provider)
|
||||
require.NotNil(t, harness.SeedDisabled)
|
||||
|
||||
harness.SeedDisabled(t)
|
||||
|
||||
got, err := harness.Provider.LoadSessionLimit(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SessionLimitConfig{}, got)
|
||||
})
|
||||
|
||||
t.Run("valid positive limit means configured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
harness := newHarness(t)
|
||||
require.NotNil(t, harness.Provider)
|
||||
require.NotNil(t, harness.SeedLimit)
|
||||
|
||||
want := 5
|
||||
harness.SeedLimit(t, want)
|
||||
|
||||
got, err := harness.Provider.LoadSessionLimit(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.ActiveSessionLimit)
|
||||
assert.Equal(t, want, *got.ActiveSessionLimit)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package contracttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// SessionStoreFactory constructs a fresh SessionStore instance suitable for
|
||||
// one isolated contract subtest.
|
||||
type SessionStoreFactory func(t *testing.T) ports.SessionStore
|
||||
|
||||
// RunSessionStoreContractTests executes the backend-agnostic SessionStore
|
||||
// contract suite against newStore.
|
||||
func RunSessionStoreContractTests(t *testing.T, newStore SessionStoreFactory) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("create and get", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
})
|
||||
|
||||
t.Run("create conflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_050, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("get not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("list by user id returns newest first", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
older := contractActiveSession(t, "device-session-old", "user-1", time.Unix(10, 0).UTC())
|
||||
newer := contractActiveSession(t, "device-session-new", "user-1", time.Unix(20, 0).UTC())
|
||||
revoked := contractRevokedSession(t, "device-session-revoked", "user-1", time.Unix(15, 0).UTC())
|
||||
otherUser := contractActiveSession(t, "device-session-other", "user-2", time.Unix(30, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got, 3)
|
||||
assert.Equal(
|
||||
t,
|
||||
[]common.DeviceSessionID{newer.ID, revoked.ID, older.ID},
|
||||
[]common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("list by user id returns empty slice for unknown user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
|
||||
t.Run("count active by user id", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
activeOne := contractActiveSession(t, "device-session-1", "user-1", time.Unix(40, 0).UTC())
|
||||
activeTwo := contractActiveSession(t, "device-session-2", "user-1", time.Unix(50, 0).UTC())
|
||||
revoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(60, 0).UTC())
|
||||
otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(70, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{activeOne, activeTwo, revoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
})
|
||||
|
||||
t.Run("revoke active session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
revocation := contractRevocation(time.Unix(200, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", "")
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, revocation, *result.Session.Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
})
|
||||
|
||||
t.Run("revoke already revoked preserves stored revocation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractRevokedSession(t, "device-session-2", "user-1", time.Unix(110, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
|
||||
})
|
||||
|
||||
t.Run("revoke not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: common.DeviceSessionID("missing-session"),
|
||||
Revocation: contractRevocation(time.Unix(210, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", ""),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("revoke all by user id revokes active sessions newest first", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
older := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
newer := contractActiveSession(t, "device-session-2", "user-1", time.Unix(200, 0).UTC())
|
||||
alreadyRevoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(150, 0).UTC())
|
||||
otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(250, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
revocation := contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1")
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
|
||||
require.Len(t, result.Sessions, 2)
|
||||
assert.Equal(
|
||||
t,
|
||||
[]common.DeviceSessionID{newer.ID, older.ID},
|
||||
[]common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID},
|
||||
)
|
||||
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
|
||||
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
})
|
||||
|
||||
t.Run("revoke all by user id reports no active sessions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractRevokedSession(t, "device-session-5", "user-1", time.Unix(120, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: contractRevocation(time.Unix(400, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", ""),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
|
||||
require.NotNil(t, result.Sessions)
|
||||
assert.Empty(t, result.Sessions)
|
||||
})
|
||||
|
||||
t.Run("get returns defensive copies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractRevokedSession(t, "device-session-copy", "user-1", time.Unix(130, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.Revocation)
|
||||
got.Revocation.ActorID = "mutated"
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, again.Revocation)
|
||||
assert.Equal(t, record, again)
|
||||
})
|
||||
}
|
||||
|
||||
func contractActiveSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
t.Helper()
|
||||
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
record := devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractRevokedSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
t.Helper()
|
||||
|
||||
record := contractActiveSession(t, deviceSessionID, userID, createdAt)
|
||||
revocation := contractRevocation(createdAt.Add(time.Minute), devicesession.RevokeReasonDeviceLogout, "user", "user-actor")
|
||||
record.Status = devicesession.StatusRevoked
|
||||
record.Revocation = &revocation
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractRevocation(at time.Time, reasonCode common.RevokeReasonCode, actorType string, actorID string) devicesession.Revocation {
|
||||
record := devicesession.Revocation{
|
||||
At: at,
|
||||
ReasonCode: reasonCode,
|
||||
ActorType: common.RevokeActorType(actorType),
|
||||
ActorID: actorID,
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
// Package local provides small in-process runtime implementations for
|
||||
// authsession ports that do not require network dependencies.
|
||||
package local
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
challengeIDPrefix = "challenge-"
|
||||
deviceSessionIDPrefix = "device-session-"
|
||||
codeDigits = 6
|
||||
)
|
||||
|
||||
// Clock implements ports.Clock using the local system clock in UTC.
|
||||
type Clock struct{}
|
||||
|
||||
// Now returns the current system time normalized to UTC.
|
||||
func (Clock) Now() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
// IDGenerator implements ports.IDGenerator with cryptographically random
|
||||
// opaque identifiers.
|
||||
type IDGenerator struct{}
|
||||
|
||||
// NewChallengeID returns a fresh random challenge identifier.
|
||||
func (IDGenerator) NewChallengeID() (common.ChallengeID, error) {
|
||||
value, err := newOpaqueIDString(challengeIDPrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return common.ChallengeID(value), nil
|
||||
}
|
||||
|
||||
// NewDeviceSessionID returns a fresh random device-session identifier.
|
||||
func (IDGenerator) NewDeviceSessionID() (common.DeviceSessionID, error) {
|
||||
value, err := newOpaqueIDString(deviceSessionIDPrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return common.DeviceSessionID(value), nil
|
||||
}
|
||||
|
||||
// CodeGenerator implements ports.CodeGenerator with random 6-digit decimal
|
||||
// confirmation codes.
|
||||
type CodeGenerator struct{}
|
||||
|
||||
// Generate returns one fresh random 6-digit decimal code.
|
||||
func (CodeGenerator) Generate() (string, error) {
|
||||
var builder strings.Builder
|
||||
builder.Grow(codeDigits)
|
||||
|
||||
for idx := 0; idx < codeDigits; idx++ {
|
||||
digit, err := rand.Int(rand.Reader, big.NewInt(10))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate confirmation code: %w", err)
|
||||
}
|
||||
builder.WriteByte(byte('0' + digit.Int64()))
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
// CodeHasher implements ports.CodeHasher with bcrypt-backed hashes.
|
||||
type CodeHasher struct{}
|
||||
|
||||
// Hash returns the bcrypt hash of code.
|
||||
func (CodeHasher) Hash(code string) ([]byte, error) {
|
||||
if err := validateCode(code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash confirmation code: %w", err)
|
||||
}
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// Compare reports whether hash matches code.
|
||||
func (CodeHasher) Compare(hash []byte, code string) (bool, error) {
|
||||
if err := validateCode(code); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(hash) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword(hash, []byte(code))
|
||||
switch err {
|
||||
case nil:
|
||||
return true, nil
|
||||
case bcrypt.ErrMismatchedHashAndPassword:
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("compare confirmation code hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newOpaqueIDString(prefix string) (string, error) {
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("generate opaque identifier: %w", err)
|
||||
}
|
||||
|
||||
return prefix + base64.RawURLEncoding.EncodeToString(randomBytes), nil
|
||||
}
|
||||
|
||||
func validateCode(code string) error {
|
||||
switch {
|
||||
case strings.TrimSpace(code) == "":
|
||||
return fmt.Errorf("code must not be empty")
|
||||
case strings.TrimSpace(code) != code:
|
||||
return fmt.Errorf("code must not contain surrounding whitespace")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ ports.Clock = Clock{}
|
||||
_ ports.IDGenerator = IDGenerator{}
|
||||
_ ports.CodeGenerator = CodeGenerator{}
|
||||
_ ports.CodeHasher = CodeHasher{}
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClockNowReturnsUTC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := Clock{}.Now()
|
||||
|
||||
assert.Equal(t, time.UTC, now.Location())
|
||||
}
|
||||
|
||||
func TestIDGeneratorProducesValidOpaqueIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
generator := IDGenerator{}
|
||||
|
||||
challengeID, err := generator.NewChallengeID()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, challengeID.Validate())
|
||||
assert.Regexp(t, regexp.MustCompile(`^challenge-[A-Za-z0-9_-]+$`), challengeID.String())
|
||||
|
||||
deviceSessionID, err := generator.NewDeviceSessionID()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, deviceSessionID.Validate())
|
||||
assert.Regexp(t, regexp.MustCompile(`^device-session-[A-Za-z0-9_-]+$`), deviceSessionID.String())
|
||||
}
|
||||
|
||||
func TestCodeGeneratorProducesSixDigitNumericCodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
code, err := CodeGenerator{}.Generate()
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, regexp.MustCompile(`^\d{6}$`), code)
|
||||
}
|
||||
|
||||
func TestCodeHasherHashesAndComparesCodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hasher := CodeHasher{}
|
||||
|
||||
hash, err := hasher.Hash("123456")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
match, err := hasher.Compare(hash, "123456")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, match)
|
||||
|
||||
match, err = hasher.Compare(hash, "000000")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, match)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
// Package mail provides runtime mail-delivery adapters for the auth/session
|
||||
// service.
|
||||
package mail
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
const sendLoginCodePath = "/api/v1/internal/login-code-deliveries"
|
||||
|
||||
// Config configures one HTTP-based mail-delivery client.
|
||||
type Config struct {
|
||||
// BaseURL is the absolute base URL of the internal mail-service HTTP API.
|
||||
BaseURL string
|
||||
|
||||
// RequestTimeout bounds each outbound mail-service request.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// RESTClient implements ports.MailSender over the frozen internal REST mail
|
||||
// contract.
|
||||
type RESTClient struct {
|
||||
baseURL string
|
||||
requestTimeout time.Duration
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewRESTClient constructs a REST-backed MailSender adapter from cfg.
|
||||
func NewRESTClient(cfg Config) (*RESTClient, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
return newRESTClient(cfg, &http.Client{Transport: transport})
|
||||
}
|
||||
|
||||
func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.BaseURL) == "":
|
||||
return nil, errors.New("new mail service REST client: base URL must not be empty")
|
||||
case cfg.RequestTimeout <= 0:
|
||||
return nil, errors.New("new mail service REST client: request timeout must be positive")
|
||||
case httpClient == nil:
|
||||
return nil, errors.New("new mail service REST client: http client must not be nil")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new mail service REST client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new mail service REST client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &RESTClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
requestTimeout: cfg.RequestTimeout,
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *RESTClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendLoginCode submits one delivery request to the internal mail service
|
||||
// without retrying transport or upstream failures.
|
||||
func (c *RESTClient) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
|
||||
if err := validateRESTContext(ctx, "send login code"); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
|
||||
}
|
||||
|
||||
payload, statusCode, err := c.doRequest(ctx, "send login code", map[string]string{
|
||||
"email": input.Email.String(),
|
||||
"code": input.Code,
|
||||
})
|
||||
if err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.SendLoginCodeOutcome `json:"outcome"`
|
||||
}
|
||||
if err := decodeJSONPayload(payload, &response); err != nil {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
|
||||
}
|
||||
|
||||
result := ports.SendLoginCodeResult{Outcome: response.Outcome}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *RESTClient) doRequest(ctx context.Context, operation string, requestBody any) ([]byte, int, error) {
|
||||
bodyBytes, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: marshal request body: %w", operation, err)
|
||||
}
|
||||
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
request, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, c.baseURL+sendLoginCodePath, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: build request: %w", operation, err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: read response body: %w", operation, err)
|
||||
}
|
||||
|
||||
return payload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
func decodeJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("decode response body: unexpected trailing JSON input")
|
||||
}
|
||||
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRESTContext(ctx context.Context, operation string) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ports.MailSender = (*RESTClient)(nil)
|
||||
@@ -0,0 +1,394 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRESTClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
cfg: Config{
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must not be empty",
|
||||
},
|
||||
{
|
||||
name: "relative base url",
|
||||
cfg: Config{
|
||||
BaseURL: "/relative",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must be absolute",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
},
|
||||
wantErr: "request timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := NewRESTClient(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientSendLoginCodeSuccessCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response string
|
||||
wantOutcome ports.SendLoginCodeOutcome
|
||||
}{
|
||||
{
|
||||
name: "sent",
|
||||
response: `{"outcome":"sent"}`,
|
||||
wantOutcome: ports.SendLoginCodeOutcomeSent,
|
||||
},
|
||||
{
|
||||
name: "suppressed",
|
||||
response: `{"outcome":"suppressed"}`,
|
||||
wantOutcome: ports.SendLoginCodeOutcomeSuppressed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestsMu sync.Mutex
|
||||
var requests []capturedRequest
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, captureRequest(t, r))
|
||||
requestsMu.Unlock()
|
||||
|
||||
writeJSON(t, w, http.StatusOK, json.RawMessage(tt.response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
result, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantOutcome, result.Outcome)
|
||||
|
||||
requestsMu.Lock()
|
||||
defer requestsMu.Unlock()
|
||||
|
||||
require.Len(t, requests, 1)
|
||||
assert.Equal(t, http.MethodPost, requests[0].Method)
|
||||
assert.Equal(t, sendLoginCodePath, requests[0].Path)
|
||||
assert.Equal(t, "application/json", requests[0].ContentType)
|
||||
assert.JSONEq(t, `{"email":"pilot@example.com","code":"654321"}`, requests[0].Body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientPreservesNormalizedEmailAndCodeExactly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var captured string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured = captureRequest(t, r).Body
|
||||
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
result, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
||||
Email: common.Email("Pilot+Alias@Example.com"),
|
||||
Code: "123456",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
|
||||
assert.JSONEq(t, `{"email":"Pilot+Alias@Example.com","code":"123456"}`, captured)
|
||||
}
|
||||
|
||||
func TestRESTClientSendLoginCodeDoesNotRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("no retry on 503", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls atomic.Int64
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls.Add(1)
|
||||
http.Error(w, "temporary", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "unexpected HTTP status 503")
|
||||
assert.EqualValues(t, 1, calls.Load())
|
||||
})
|
||||
|
||||
t.Run("no retry on transport failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls atomic.Int64
|
||||
client, err := newRESTClient(Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
RequestTimeout: 250 * time.Millisecond,
|
||||
}, &http.Client{
|
||||
Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
return nil, errors.New("temporary transport failure")
|
||||
}),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "temporary transport failure")
|
||||
assert.EqualValues(t, 1, calls.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "rejects unknown field",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"sent","extra":true}`,
|
||||
wantErrText: "decode response body",
|
||||
},
|
||||
{
|
||||
name: "rejects unsupported outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"queued"}`,
|
||||
wantErrText: "unsupported",
|
||||
},
|
||||
{
|
||||
name: "rejects missing outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{}`,
|
||||
wantErrText: "unsupported",
|
||||
},
|
||||
{
|
||||
name: "rejects trailing json",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"sent"}{}`,
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "rejects unexpected status",
|
||||
statusCode: http.StatusBadGateway,
|
||||
body: `{"error":"temporary"}`,
|
||||
wantErrText: "unexpected HTTP status 502",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.body)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientRequestTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
|
||||
|
||||
_, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "context deadline exceeded")
|
||||
}
|
||||
|
||||
func TestRESTClientContextAndValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected upstream call")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func() error
|
||||
}{
|
||||
{
|
||||
name: "nil context",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(nil, validInput())
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled context",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(cancelledCtx, validInput())
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
||||
Email: common.Email(" bad@example.com "),
|
||||
Code: "123456",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Code: " 123456 ",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.run()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type capturedRequest struct {
|
||||
Method string
|
||||
Path string
|
||||
ContentType string
|
||||
Body string
|
||||
}
|
||||
|
||||
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
|
||||
t.Helper()
|
||||
|
||||
body, err := io.ReadAll(request.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return capturedRequest{
|
||||
Method: request.Method,
|
||||
Path: request.URL.Path,
|
||||
ContentType: request.Header.Get("Content-Type"),
|
||||
Body: strings.TrimSpace(string(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(statusCode)
|
||||
_, err = writer.Write(payload)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := NewRESTClient(Config{
|
||||
BaseURL: baseURL,
|
||||
RequestTimeout: timeout,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return fn(request)
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
// Package mail provides runtime mail-delivery adapters for the auth/session
|
||||
// service.
|
||||
package mail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
var errForcedFailure = errors.New("stub mail sender: forced failure")
|
||||
|
||||
// StubMode identifies the deterministic outcome used by StubSender for one
|
||||
// delivery attempt.
|
||||
type StubMode string
|
||||
|
||||
const (
|
||||
// StubModeSent reports that the adapter accepts delivery and returns the
|
||||
// stable sent outcome expected by the auth flow.
|
||||
StubModeSent StubMode = "sent"
|
||||
|
||||
// StubModeSuppressed reports that the adapter intentionally suppresses
|
||||
// outward delivery while still returning a successful suppressed outcome.
|
||||
StubModeSuppressed StubMode = "suppressed"
|
||||
|
||||
// StubModeFailed reports that the adapter returns an explicit delivery
|
||||
// failure instead of a successful outcome.
|
||||
StubModeFailed StubMode = "failed"
|
||||
)
|
||||
|
||||
// IsKnown reports whether mode is one of the supported stub delivery modes.
|
||||
func (mode StubMode) IsKnown() bool {
|
||||
switch mode {
|
||||
case StubModeSent, StubModeSuppressed, StubModeFailed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// StubStep overrides the default stub behavior for one queued delivery
|
||||
// attempt.
|
||||
type StubStep struct {
|
||||
// Mode selects the delivery behavior for this queued step.
|
||||
Mode StubMode
|
||||
|
||||
// Err optionally overrides the failure returned when Mode is StubModeFailed.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Validate reports whether step contains one supported queued behavior.
|
||||
func (step StubStep) Validate() error {
|
||||
if !step.Mode.IsKnown() {
|
||||
return fmt.Errorf("stub mail step mode %q is unsupported", step.Mode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Attempt records one validated delivery request handled by StubSender.
|
||||
type Attempt struct {
|
||||
// Input stores the validated cleartext mail-delivery request exactly as it
|
||||
// was passed into SendLoginCode.
|
||||
Input ports.SendLoginCodeInput
|
||||
|
||||
// Mode stores the resolved stub mode after queued overrides were applied.
|
||||
Mode StubMode
|
||||
}
|
||||
|
||||
// StubSender is a deterministic runtime MailSender implementation intended
|
||||
// for development, local integration, and explicit stub-based tests.
|
||||
//
|
||||
// The zero value is ready to use and defaults to StubModeSent.
|
||||
type StubSender struct {
|
||||
// DefaultMode controls the fallback behavior when Script is empty. The zero
|
||||
// value is treated as StubModeSent so the zero-value sender is usable
|
||||
// without extra configuration.
|
||||
DefaultMode StubMode
|
||||
|
||||
// DefaultError optionally overrides the failure returned when DefaultMode
|
||||
// resolves to StubModeFailed.
|
||||
DefaultError error
|
||||
|
||||
// Script stores queued one-shot overrides consumed in FIFO order before the
|
||||
// default behavior is used.
|
||||
Script []StubStep
|
||||
|
||||
mu sync.Mutex
|
||||
attempts []Attempt
|
||||
}
|
||||
|
||||
// SendLoginCode records one validated delivery request and returns the
|
||||
// deterministic stub outcome selected by the queued script or the default
|
||||
// mode.
|
||||
func (s *StubSender) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
|
||||
if ctx == nil {
|
||||
return ports.SendLoginCodeResult{}, errors.New("stub mail sender: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
mode, errOverride, err := s.resolveNextStepLocked()
|
||||
if err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
|
||||
s.attempts = append(s.attempts, Attempt{
|
||||
Input: input,
|
||||
Mode: mode,
|
||||
})
|
||||
|
||||
switch mode {
|
||||
case StubModeSent:
|
||||
return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSent}, nil
|
||||
case StubModeSuppressed:
|
||||
return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSuppressed}, nil
|
||||
case StubModeFailed:
|
||||
if errOverride != nil {
|
||||
return ports.SendLoginCodeResult{}, errOverride
|
||||
}
|
||||
return ports.SendLoginCodeResult{}, errForcedFailure
|
||||
default:
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("stub mail sender: unsupported resolved mode %q", mode)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordedAttempts returns a stable defensive copy of every validated delivery
|
||||
// attempt handled by the stub.
|
||||
func (s *StubSender) RecordedAttempts() []Attempt {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return append([]Attempt(nil), s.attempts...)
|
||||
}
|
||||
|
||||
func (s *StubSender) resolveNextStepLocked() (StubMode, error, error) {
|
||||
if len(s.Script) > 0 {
|
||||
step := s.Script[0]
|
||||
s.Script = append([]StubStep(nil), s.Script[1:]...)
|
||||
if err := step.Validate(); err != nil {
|
||||
return "", nil, fmt.Errorf("stub mail sender: %w", err)
|
||||
}
|
||||
if step.Mode == StubModeFailed {
|
||||
if step.Err != nil {
|
||||
return step.Mode, step.Err, nil
|
||||
}
|
||||
return step.Mode, errForcedFailure, nil
|
||||
}
|
||||
return step.Mode, nil, nil
|
||||
}
|
||||
|
||||
mode := s.DefaultMode
|
||||
if mode == "" {
|
||||
mode = StubModeSent
|
||||
}
|
||||
if !mode.IsKnown() {
|
||||
return "", nil, fmt.Errorf("stub mail sender: default mode %q is unsupported", mode)
|
||||
}
|
||||
if mode == StubModeFailed {
|
||||
if s.DefaultError != nil {
|
||||
return mode, s.DefaultError, nil
|
||||
}
|
||||
return mode, errForcedFailure, nil
|
||||
}
|
||||
|
||||
return mode, nil, nil
|
||||
}
|
||||
|
||||
var _ ports.MailSender = (*StubSender)(nil)
|
||||
@@ -0,0 +1,198 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStubSenderSendLoginCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("zero value defaults to sent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
result, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, StubModeSent, attempts[0].Mode)
|
||||
assert.Equal(t, validInput(), attempts[0].Input)
|
||||
})
|
||||
|
||||
t.Run("default suppressed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{DefaultMode: StubModeSuppressed}
|
||||
|
||||
result, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, result.Outcome)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, StubModeSuppressed, attempts[0].Mode)
|
||||
})
|
||||
|
||||
t.Run("default failed uses configured error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wantErr := errors.New("delivery refused")
|
||||
sender := &StubSender{
|
||||
DefaultMode: StubModeFailed,
|
||||
DefaultError: wantErr,
|
||||
}
|
||||
|
||||
result, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, wantErr)
|
||||
assert.Equal(t, ports.SendLoginCodeResult{}, result)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, StubModeFailed, attempts[0].Mode)
|
||||
})
|
||||
|
||||
t.Run("default failed uses stable fallback error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{DefaultMode: StubModeFailed}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, "stub mail sender: forced failure")
|
||||
})
|
||||
|
||||
t.Run("script overrides default and is consumed fifo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wantErr := errors.New("step failed")
|
||||
sender := &StubSender{
|
||||
DefaultMode: StubModeSent,
|
||||
Script: []StubStep{
|
||||
{Mode: StubModeSuppressed},
|
||||
{Mode: StubModeFailed, Err: wantErr},
|
||||
},
|
||||
}
|
||||
|
||||
first, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, first.Outcome)
|
||||
|
||||
second, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, wantErr)
|
||||
assert.Equal(t, ports.SendLoginCodeResult{}, second)
|
||||
|
||||
third, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSent, third.Outcome)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 3)
|
||||
assert.Equal(t, []StubMode{StubModeSuppressed, StubModeFailed, StubModeSent}, []StubMode{
|
||||
attempts[0].Mode,
|
||||
attempts[1].Mode,
|
||||
attempts[2].Mode,
|
||||
})
|
||||
assert.Empty(t, sender.Script)
|
||||
})
|
||||
|
||||
t.Run("invalid default mode returns adapter error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{DefaultMode: StubMode("queued")}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `default mode "queued" is unsupported`)
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
})
|
||||
|
||||
t.Run("invalid scripted mode returns adapter error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{
|
||||
Script: []StubStep{
|
||||
{Mode: StubMode("queued")},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `mode "queued" is unsupported`)
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
assert.Empty(t, sender.Script)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubSenderRecordedAttemptsAreDefensive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
attempts[0].Mode = StubModeFailed
|
||||
attempts[0].Input.Code = "000000"
|
||||
|
||||
again := sender.RecordedAttempts()
|
||||
require.Len(t, again, 1)
|
||||
assert.Equal(t, StubModeSent, again[0].Mode)
|
||||
assert.Equal(t, "654321", again[0].Input.Code)
|
||||
}
|
||||
|
||||
func TestStubSenderSendLoginCodeNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
_, err := sender.SendLoginCode(nil, validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
}
|
||||
|
||||
func TestStubSenderSendLoginCodeCancelledContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := sender.SendLoginCode(ctx, validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
}
|
||||
|
||||
func TestStubSenderSendLoginCodeInvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), ports.SendLoginCodeInput{})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "send login code input email")
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
}
|
||||
|
||||
func validInput() ports.SendLoginCodeInput {
|
||||
return ports.SendLoginCodeInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Code: "654321",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,484 @@
|
||||
// Package challengestore implements ports.ChallengeStore with Redis-backed
|
||||
// strict JSON challenge records.
|
||||
package challengestore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const expirationGracePeriod = 5 * time.Minute
|
||||
|
||||
// Config configures one Redis-backed challenge store instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// KeyPrefix is the namespace prefix applied to every challenge key.
|
||||
KeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store persists challenges as one strict JSON value per Redis key.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
Email string `json:"email"`
|
||||
CodeHashBase64 string `json:"code_hash_base64"`
|
||||
Status challenge.Status `json:"status"`
|
||||
DeliveryState challenge.DeliveryState `json:"delivery_state"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
SendAttemptCount int `json:"send_attempt_count"`
|
||||
ConfirmAttemptCount int `json:"confirm_attempt_count"`
|
||||
LastAttemptAt *string `json:"last_attempt_at,omitempty"`
|
||||
ConfirmedSessionID string `json:"confirmed_session_id,omitempty"`
|
||||
ConfirmedClientPublicKey string `json:"confirmed_client_public_key,omitempty"`
|
||||
ConfirmedAt *string `json:"confirmed_at,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed challenge store from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
if strings.TrimSpace(cfg.Addr) == "" {
|
||||
return nil, errors.New("new redis challenge store: redis addr must not be empty")
|
||||
}
|
||||
if cfg.DB < 0 {
|
||||
return nil, errors.New("new redis challenge store: redis db must not be negative")
|
||||
}
|
||||
if strings.TrimSpace(cfg.KeyPrefix) == "" {
|
||||
return nil, errors.New("new redis challenge store: redis key prefix must not be empty")
|
||||
}
|
||||
if cfg.OperationTimeout <= 0 {
|
||||
return nil, errors.New("new redis challenge store: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis challenge store")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis challenge store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the stored challenge for challengeID.
|
||||
func (s *Store) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
|
||||
if err := challengeID.Validate(); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "get challenge from redis")
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
payload, err := s.client.Get(operationCtx, s.lookupKey(challengeID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
|
||||
}
|
||||
|
||||
record, err := decodeChallengeRecord(challengeID, payload)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Create persists record as a new challenge.
|
||||
func (s *Store) Create(ctx context.Context, record challenge.Challenge) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalChallengeRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "create challenge in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
created, err := s.client.SetNX(operationCtx, s.lookupKey(record.ID), payload, redisTTL(record.ExpiresAt)).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create challenge %q in redis: %w", record.ID, err)
|
||||
}
|
||||
if !created {
|
||||
return fmt.Errorf("create challenge %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompareAndSwap replaces previous with next when the currently stored
|
||||
// challenge matches previous exactly in canonical Redis representation.
|
||||
func (s *Store) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
|
||||
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
|
||||
return fmt.Errorf("compare and swap challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
nextPayload, err := marshalChallengeRecord(next)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "compare and swap challenge in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
key := s.lookupKey(previous.ID)
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
payload, err := tx.Get(operationCtx, key).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
current, err := decodeChallengeRecord(previous.ID, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
matches, err := equalStoredChallenges(current, previous)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
if !matches {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, key, nextPayload, redisTTL(next.ExpiresAt))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, key)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (s *Store) lookupKey(challengeID common.ChallengeID) string {
|
||||
return s.keyPrefix + encodeKeyComponent(challengeID.String())
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func marshalChallengeRecord(record challenge.Challenge) ([]byte, error) {
|
||||
stored, err := redisRecordFromChallenge(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(stored)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func redisRecordFromChallenge(record challenge.Challenge) (redisRecord, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return redisRecord{}, fmt.Errorf("encode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
stored := redisRecord{
|
||||
ChallengeID: record.ID.String(),
|
||||
Email: record.Email.String(),
|
||||
CodeHashBase64: base64.StdEncoding.EncodeToString(record.CodeHash),
|
||||
Status: record.Status,
|
||||
DeliveryState: record.DeliveryState,
|
||||
CreatedAt: formatTimestamp(record.CreatedAt),
|
||||
ExpiresAt: formatTimestamp(record.ExpiresAt),
|
||||
SendAttemptCount: record.Attempts.Send,
|
||||
ConfirmAttemptCount: record.Attempts.Confirm,
|
||||
LastAttemptAt: formatOptionalTimestamp(record.Abuse.LastAttemptAt),
|
||||
}
|
||||
if record.Confirmation != nil {
|
||||
stored.ConfirmedSessionID = record.Confirmation.SessionID.String()
|
||||
stored.ConfirmedClientPublicKey = record.Confirmation.ClientPublicKey.String()
|
||||
stored.ConfirmedAt = formatOptionalTimestamp(&record.Confirmation.ConfirmedAt)
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func decodeChallengeRecord(expectedChallengeID common.ChallengeID, payload []byte) (challenge.Challenge, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return challenge.Challenge{}, errors.New("decode redis challenge record: unexpected trailing JSON input")
|
||||
}
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
record, err := challengeFromRedisRecord(stored)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
if record.ID != expectedChallengeID {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: challenge_id %q does not match requested %q", record.ID, expectedChallengeID)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func challengeFromRedisRecord(stored redisRecord) (challenge.Challenge, error) {
|
||||
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
expiresAt, err := parseTimestamp("expires_at", stored.ExpiresAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
lastAttemptAt, err := parseOptionalTimestamp("last_attempt_at", stored.LastAttemptAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
|
||||
codeHash, err := base64.StdEncoding.Strict().DecodeString(stored.CodeHashBase64)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: code_hash_base64: %w", err)
|
||||
}
|
||||
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID(stored.ChallengeID),
|
||||
Email: common.Email(stored.Email),
|
||||
CodeHash: codeHash,
|
||||
Status: stored.Status,
|
||||
DeliveryState: stored.DeliveryState,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: stored.SendAttemptCount,
|
||||
Confirm: stored.ConfirmAttemptCount,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: lastAttemptAt,
|
||||
},
|
||||
}
|
||||
|
||||
confirmation, err := parseConfirmation(stored)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
record.Confirmation = confirmation
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func parseConfirmation(stored redisRecord) (*challenge.Confirmation, error) {
|
||||
hasSessionID := strings.TrimSpace(stored.ConfirmedSessionID) != ""
|
||||
hasClientPublicKey := strings.TrimSpace(stored.ConfirmedClientPublicKey) != ""
|
||||
hasConfirmedAt := stored.ConfirmedAt != nil
|
||||
|
||||
if !hasSessionID && !hasClientPublicKey && !hasConfirmedAt {
|
||||
return nil, nil
|
||||
}
|
||||
if !hasSessionID || !hasClientPublicKey || !hasConfirmedAt {
|
||||
return nil, errors.New("decode redis challenge record: confirmation metadata must be either fully present or fully absent")
|
||||
}
|
||||
|
||||
confirmedAt, err := parseTimestamp("confirmed_at", *stored.ConfirmedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ConfirmedClientPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
|
||||
}
|
||||
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
|
||||
}
|
||||
|
||||
return &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID(stored.ConfirmedSessionID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: confirmedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseOptionalTimestamp(fieldName string, value *string) (*time.Time, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
parsed, err := parseTimestamp(fieldName, *value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(fieldName string, value string) (time.Time, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must not be empty", fieldName)
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s: %w", fieldName, err)
|
||||
}
|
||||
|
||||
canonical := parsed.UTC().Format(time.RFC3339Nano)
|
||||
if value != canonical {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
|
||||
}
|
||||
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
|
||||
func formatTimestamp(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func formatOptionalTimestamp(value *time.Time) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
formatted := formatTimestamp(*value)
|
||||
return &formatted
|
||||
}
|
||||
|
||||
func redisTTL(expiresAt time.Time) time.Duration {
|
||||
ttl := time.Until(expiresAt.UTC())
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
return ttl + expirationGracePeriod
|
||||
}
|
||||
|
||||
func equalStoredChallenges(left challenge.Challenge, right challenge.Challenge) (bool, error) {
|
||||
leftRecord, err := redisRecordFromChallenge(left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
rightRecord, err := redisRecordFromChallenge(right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(leftRecord, rightRecord), nil
|
||||
}
|
||||
|
||||
var _ ports.ChallengeStore = (*Store)(nil)
|
||||
@@ -0,0 +1,531 @@
|
||||
package challengestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
return newTestStore(t, server, Config{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive operation timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_000, 0).UTC()
|
||||
|
||||
record := testChallenge(now)
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
got.CodeHash[0] = 0xFF
|
||||
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
|
||||
keyBytes[0] = 0xFE
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.CodeHash, again.CodeHash)
|
||||
require.NotNil(t, again.Confirmation)
|
||||
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetPendingChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_100, 0).UTC()
|
||||
|
||||
record := testPendingChallenge(now)
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetThrottledChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_150, 0).UTC()
|
||||
|
||||
record := testPendingChallenge(now)
|
||||
record.Status = challenge.StatusDeliveryThrottled
|
||||
record.DeliveryState = challenge.DeliveryThrottled
|
||||
record.Attempts.Send = 1
|
||||
record.Abuse.LastAttemptAt = timePointer(now)
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
}
|
||||
|
||||
func TestStoreGetStrictDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_130_200, 0).UTC()
|
||||
baseRecord := testChallenge(now)
|
||||
baseStored, err := redisRecordFromChallenge(baseRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(redisRecord) string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "malformed json",
|
||||
mutate: func(_ redisRecord) string {
|
||||
return "{"
|
||||
},
|
||||
wantErrText: "decode redis challenge record",
|
||||
},
|
||||
{
|
||||
name: "trailing json input",
|
||||
mutate: func(record redisRecord) string {
|
||||
return mustMarshalJSON(t, record) + "{}"
|
||||
},
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
mutate: func(record redisRecord) string {
|
||||
payload := map[string]any{
|
||||
"challenge_id": record.ChallengeID,
|
||||
"email": record.Email,
|
||||
"code_hash_base64": record.CodeHashBase64,
|
||||
"status": record.Status,
|
||||
"delivery_state": record.DeliveryState,
|
||||
"created_at": record.CreatedAt,
|
||||
"expires_at": record.ExpiresAt,
|
||||
"send_attempt_count": record.SendAttemptCount,
|
||||
"confirm_attempt_count": record.ConfirmAttemptCount,
|
||||
"last_attempt_at": record.LastAttemptAt,
|
||||
"confirmed_session_id": record.ConfirmedSessionID,
|
||||
"confirmed_client_public_key": record.ConfirmedClientPublicKey,
|
||||
"confirmed_at": record.ConfirmedAt,
|
||||
"unexpected": true,
|
||||
}
|
||||
return mustMarshalJSON(t, payload)
|
||||
},
|
||||
wantErrText: "unknown field",
|
||||
},
|
||||
{
|
||||
name: "unsupported status",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Status = challenge.Status("paused")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "unsupported delivery state",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.DeliveryState = challenge.DeliveryState("queued")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `delivery state "queued" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "missing required email",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Email = ""
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "challenge email",
|
||||
},
|
||||
{
|
||||
name: "challenge id mismatch",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.ChallengeID = "other-challenge"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `does not match requested`,
|
||||
},
|
||||
{
|
||||
name: "non canonical utc timestamp",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.CreatedAt = "2026-04-04T12:00:00+03:00"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "canonical UTC RFC3339Nano timestamp",
|
||||
},
|
||||
{
|
||||
name: "partial confirmation metadata",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.ConfirmedAt = nil
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "confirmation metadata must be either fully present or fully absent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored))
|
||||
|
||||
_, err := store.Get(context.Background(), baseRecord.ID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreKeySchemeAndTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"})
|
||||
now := time.Now().UTC()
|
||||
|
||||
prefixed := testPendingChallenge(now)
|
||||
prefixed.ID = common.ChallengeID("challenge:opaque/id?value")
|
||||
require.NoError(t, store.Create(context.Background(), prefixed))
|
||||
|
||||
key := store.lookupKey(prefixed.ID)
|
||||
assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key)
|
||||
assert.True(t, server.Exists(key))
|
||||
|
||||
freshTTL := server.TTL(key)
|
||||
assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod)
|
||||
assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second)
|
||||
|
||||
expired := testPendingChallenge(now.Add(-10 * time.Minute))
|
||||
expired.ID = common.ChallengeID("expired-challenge")
|
||||
expired.CreatedAt = now.Add(-20 * time.Minute)
|
||||
expired.ExpiresAt = now.Add(-1 * time.Minute)
|
||||
require.NoError(t, store.Create(context.Background(), expired))
|
||||
|
||||
expiredTTL := server.TTL(store.lookupKey(expired.ID))
|
||||
assert.LessOrEqual(t, expiredTTL, expirationGracePeriod)
|
||||
assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second)
|
||||
}
|
||||
|
||||
func TestStoreCreateConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
}
|
||||
|
||||
func TestStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreCompareAndSwap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_130_400, 0).UTC()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
next.Attempts.Send = 1
|
||||
next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute))
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), previous))
|
||||
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
|
||||
|
||||
got, err := store.Get(context.Background(), previous.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, next, got)
|
||||
})
|
||||
|
||||
t.Run("conflict when stored record differs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
stored := testPendingChallenge(now)
|
||||
previous := stored
|
||||
previous.Attempts.Send = 99
|
||||
next := stored
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), stored))
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("corrupt stored record returns adapter error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
server.Set(store.lookupKey(previous.ID), "{")
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.NotErrorIs(t, err, ports.ErrConflict)
|
||||
assert.ErrorContains(t, err, "decode redis challenge record")
|
||||
})
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "authsession:challenge:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func testPendingChallenge(now time.Time) challenge.Challenge {
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-pending"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-pending-code"),
|
||||
Status: challenge.StatusPendingSend,
|
||||
DeliveryState: challenge.DeliveryPending,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.InitialTTL),
|
||||
}
|
||||
}
|
||||
|
||||
func testChallenge(now time.Time) challenge.Challenge {
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-confirmed"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-code"),
|
||||
Status: challenge.StatusConfirmedPendingExpire,
|
||||
DeliveryState: challenge.DeliverySent,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.ConfirmedRetention),
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: 1,
|
||||
Confirm: 2,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: timePointer(now.Add(30 * time.Second)),
|
||||
},
|
||||
Confirmation: &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID("device-session-1"),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: now.Add(1 * time.Minute),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func timePointer(value time.Time) *time.Time {
|
||||
return &value
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func TestStorePingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
err := store.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestStoreGetNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(nil, common.ChallengeID("challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
// Package configprovider implements ports.ConfigProvider with Redis-backed
|
||||
// dynamic auth/session configuration.
|
||||
package configprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Config configures one Redis-backed config provider instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// SessionLimitKey identifies the single Redis string key that stores the
|
||||
// active-session-limit configuration value.
|
||||
SessionLimitKey string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store reads dynamic auth/session configuration from Redis.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
sessionLimitKey string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed config provider from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis config provider: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis config provider: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.SessionLimitKey) == "":
|
||||
return nil, errors.New("new redis config provider: session limit key must not be empty")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis config provider: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
sessionLimitKey: cfg.SessionLimitKey,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis config provider")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis config provider: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSessionLimit returns the current active-session-limit configuration.
|
||||
// Missing or invalid Redis values are treated as “limit absent” by policy.
|
||||
func (s *Store) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "load session limit from redis")
|
||||
if err != nil {
|
||||
return ports.SessionLimitConfig{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
value, err := s.client.Get(operationCtx, s.sessionLimitKey).Result()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return ports.SessionLimitConfig{}, nil
|
||||
case err != nil:
|
||||
return ports.SessionLimitConfig{}, fmt.Errorf("load session limit from redis: %w", err)
|
||||
}
|
||||
|
||||
config, valid := parseSessionLimitConfig(value)
|
||||
if !valid {
|
||||
return ports.SessionLimitConfig{}, nil
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return ports.SessionLimitConfig{}, nil
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func parseSessionLimitConfig(raw string) (ports.SessionLimitConfig, bool) {
|
||||
if strings.TrimSpace(raw) == "" || strings.TrimSpace(raw) != raw {
|
||||
return ports.SessionLimitConfig{}, false
|
||||
}
|
||||
for _, symbol := range raw {
|
||||
if symbol < '0' || symbol > '9' {
|
||||
return ports.SessionLimitConfig{}, false
|
||||
}
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseInt(raw, 10, strconv.IntSize)
|
||||
if err != nil || parsed <= 0 {
|
||||
return ports.SessionLimitConfig{}, false
|
||||
}
|
||||
|
||||
limit := int(parsed)
|
||||
return ports.SessionLimitConfig{
|
||||
ActiveSessionLimit: &limit,
|
||||
}, true
|
||||
}
|
||||
|
||||
var _ ports.ConfigProvider = (*Store)(nil)
|
||||
@@ -0,0 +1,283 @@
|
||||
package configprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunConfigProviderContractTests(t, func(t *testing.T) contracttest.ConfigProviderHarness {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
return contracttest.ConfigProviderHarness{
|
||||
Provider: store,
|
||||
SeedDisabled: func(t *testing.T) {
|
||||
t.Helper()
|
||||
server.Del(store.sessionLimitKey)
|
||||
},
|
||||
SeedLimit: func(t *testing.T, limit int) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, strconv.Itoa(limit))
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session limit key",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session limit key must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreLoadSessionLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
seed func(*testing.T, *miniredis.Miniredis, *Store)
|
||||
wantConfig ports.SessionLimitConfig
|
||||
}{
|
||||
{
|
||||
name: "missing key means disabled",
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "valid positive integer",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "5")
|
||||
},
|
||||
wantConfig: configWithLimit(5),
|
||||
},
|
||||
{
|
||||
name: "empty string is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "whitespace only is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, " ")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "whitespace padded integer is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, " 5 ")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "non integer text is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "five")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "zero is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "0")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "negative integer is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "-3")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "overflow is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "999999999999999999999999999999")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
if tt.seed != nil {
|
||||
tt.seed(t, server, store)
|
||||
}
|
||||
|
||||
got, err := store.LoadSessionLimit(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantConfig, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreLoadSessionLimitBackendFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Close()
|
||||
|
||||
_, err := store.LoadSessionLimit(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "load session limit from redis")
|
||||
}
|
||||
|
||||
func TestStoreLoadSessionLimitNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.LoadSessionLimit(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestStorePingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
err := store.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionLimitKey == "" {
|
||||
cfg.SessionLimitKey = "authsession:config:active-session-limit"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func configWithLimit(limit int) ports.SessionLimitConfig {
|
||||
return ports.SessionLimitConfig{
|
||||
ActiveSessionLimit: &limit,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
// Package projectionpublisher implements
|
||||
// ports.GatewaySessionProjectionPublisher with Redis-backed gateway-compatible
|
||||
// cache snapshots and session lifecycle events.
|
||||
package projectionpublisher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Config configures one Redis-backed gateway session projection publisher.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// SessionCacheKeyPrefix is the namespace prefix applied to gateway session
|
||||
// cache keys. The raw device session identifier is appended directly.
|
||||
SessionCacheKeyPrefix string
|
||||
|
||||
// SessionEventsStream identifies the gateway session lifecycle Redis Stream.
|
||||
SessionEventsStream string
|
||||
|
||||
// StreamMaxLen bounds the session lifecycle stream with approximate
|
||||
// trimming via XADD MAXLEN ~.
|
||||
StreamMaxLen int64
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Publisher publishes gateway-compatible session projections into Redis cache
|
||||
// and stream namespaces.
|
||||
type Publisher struct {
|
||||
client *redis.Client
|
||||
sessionCacheKeyPrefix string
|
||||
sessionEventsStream string
|
||||
streamMaxLen int64
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type cacheRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status gatewayprojection.Status `json:"status"`
|
||||
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed gateway session projection publisher from
|
||||
// cfg.
|
||||
func New(cfg Config) (*Publisher, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis projection publisher: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis projection publisher: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.SessionCacheKeyPrefix) == "":
|
||||
return nil, errors.New("new redis projection publisher: session cache key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.SessionEventsStream) == "":
|
||||
return nil, errors.New("new redis projection publisher: session events stream must not be empty")
|
||||
case cfg.StreamMaxLen <= 0:
|
||||
return nil, errors.New("new redis projection publisher: stream max len must be positive")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis projection publisher: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Publisher{
|
||||
client: redis.NewClient(options),
|
||||
sessionCacheKeyPrefix: cfg.SessionCacheKeyPrefix,
|
||||
sessionEventsStream: cfg.SessionEventsStream,
|
||||
streamMaxLen: cfg.StreamMaxLen,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (p *Publisher) Close() error {
|
||||
if p == nil || p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (p *Publisher) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "ping redis projection publisher")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := p.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis projection publisher: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishSession writes one gateway-compatible session snapshot into the
|
||||
// gateway cache namespace and appends the same snapshot to the gateway session
|
||||
// event stream within one Redis transaction.
|
||||
func (p *Publisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error {
|
||||
if err := snapshot.Validate(); err != nil {
|
||||
return fmt.Errorf("publish session projection to redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalCacheRecord(snapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("publish session projection to redis: %w", err)
|
||||
}
|
||||
values := buildStreamValues(snapshot)
|
||||
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "publish session projection to redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
key := p.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
_, err = p.client.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, key, payload, 0)
|
||||
pipe.XAdd(operationCtx, &redis.XAddArgs{
|
||||
Stream: p.sessionEventsStream,
|
||||
MaxLen: p.streamMaxLen,
|
||||
Approx: true,
|
||||
Values: values,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("publish session projection %q to redis: %w", snapshot.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Publisher) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if p == nil || p.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil publisher", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (p *Publisher) sessionCacheKey(deviceSessionID interface{ String() string }) string {
|
||||
return p.sessionCacheKeyPrefix + deviceSessionID.String()
|
||||
}
|
||||
|
||||
func marshalCacheRecord(snapshot gatewayprojection.Snapshot) ([]byte, error) {
|
||||
record := cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: snapshot.Status,
|
||||
}
|
||||
if snapshot.RevokedAt != nil {
|
||||
revokedAtMS := snapshot.RevokedAt.UTC().UnixMilli()
|
||||
record.RevokedAtMS = &revokedAtMS
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal gateway session cache record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func buildStreamValues(snapshot gatewayprojection.Snapshot) map[string]any {
|
||||
values := map[string]any{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(snapshot.Status),
|
||||
}
|
||||
if snapshot.RevokedAt != nil {
|
||||
values["revoked_at_ms"] = fmt.Sprint(snapshot.RevokedAt.UTC().UnixMilli())
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
var _ ports.GatewaySessionProjectionPublisher = (*Publisher)(nil)
|
||||
@@ -0,0 +1,442 @@
|
||||
package projectionpublisher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 3,
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session cache key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session cache key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty session events stream",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session events stream must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive stream max len",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "stream max len must be positive",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, publisher.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublisherPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
require.NoError(t, publisher.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key)
|
||||
assert.True(t, server.Exists(key))
|
||||
assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String())))
|
||||
|
||||
payload, err := server.Get(key)
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}, record)
|
||||
assert.Zero(t, server.TTL(key))
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusActive),
|
||||
}, stringifyValues(entries[0].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionRevoked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC()
|
||||
snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
payload, err := server.Get(key)
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusRevoked,
|
||||
RevokedAtMS: int64Pointer(revokedAt.UnixMilli()),
|
||||
}, record)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusRevoked),
|
||||
"revoked_at_ms": "1776000123456",
|
||||
}, stringifyValues(entries[0].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
||||
deviceSessionID := "device-session-456"
|
||||
|
||||
active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil)
|
||||
revokedAt := time.Unix(1_776_010_000, 0).UTC()
|
||||
revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), active))
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), revoked))
|
||||
|
||||
payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID))
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, record.Status)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": active.DeviceSessionID.String(),
|
||||
"user_id": active.UserID.String(),
|
||||
"client_public_key": active.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusActive),
|
||||
}, stringifyValues(entries[0].Values))
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": revoked.DeviceSessionID.String(),
|
||||
"user_id": revoked.UserID.String(),
|
||||
"client_public_key": revoked.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusRevoked),
|
||||
"revoked_at_ms": "1776010000000",
|
||||
}, stringifyValues(entries[1].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
||||
snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID))
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}, record)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2})
|
||||
|
||||
for index := range 6 {
|
||||
snapshot := testSnapshot(
|
||||
common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(),
|
||||
gatewayprojection.StatusActive,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
}
|
||||
|
||||
streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result()
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, streamLength, int64(2))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID("device-session-123"),
|
||||
UserID: common.UserID("user-123"),
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}
|
||||
|
||||
err := publisher.PublishSession(context.Background(), snapshot)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "gateway projection client public key")
|
||||
assert.Empty(t, server.Keys())
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionBackendFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
server.Close()
|
||||
|
||||
err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "publish session projection")
|
||||
}
|
||||
|
||||
func TestPublisherPingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
err := publisher.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionCacheKeyPrefix == "" {
|
||||
cfg.SessionCacheKeyPrefix = "gateway:session:"
|
||||
}
|
||||
if cfg.SessionEventsStream == "" {
|
||||
cfg.SessionEventsStream = "gateway:session_events"
|
||||
}
|
||||
if cfg.StreamMaxLen == 0 {
|
||||
cfg.StreamMaxLen = 1024
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
publisher, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, publisher.Close())
|
||||
})
|
||||
|
||||
return publisher
|
||||
}
|
||||
|
||||
func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot {
|
||||
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
for index := range raw {
|
||||
raw[index] = byte(index + 1)
|
||||
}
|
||||
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID("user-123"),
|
||||
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
|
||||
Status: status,
|
||||
RevokedAt: revokedAt,
|
||||
}
|
||||
if status == gatewayprojection.StatusRevoked {
|
||||
snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked")
|
||||
snapshot.RevokeActorType = common.RevokeActorType("system")
|
||||
}
|
||||
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func decodeCachePayload(t *testing.T, payload string) cacheRecord {
|
||||
t.Helper()
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader([]byte(payload)))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var record cacheRecord
|
||||
require.NoError(t, decoder.Decode(&record))
|
||||
err := decoder.Decode(&struct{}{})
|
||||
if err == nil {
|
||||
require.FailNow(t, "expected cache payload EOF after first JSON value")
|
||||
}
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
var fieldSet map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet))
|
||||
expectedFields := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"user_id": {},
|
||||
"client_public_key": {},
|
||||
"status": {},
|
||||
}
|
||||
if record.RevokedAtMS != nil {
|
||||
expectedFields["revoked_at_ms"] = struct{}{}
|
||||
}
|
||||
assert.Equal(t, len(expectedFields), len(fieldSet))
|
||||
for field := range fieldSet {
|
||||
_, ok := expectedFields[field]
|
||||
assert.Truef(t, ok, "unexpected cache payload field %q", field)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func stringifyValues(values map[string]any) map[string]string {
|
||||
stringified := make(map[string]string, len(values))
|
||||
for key, value := range values {
|
||||
stringified[key] = fmt.Sprint(value)
|
||||
}
|
||||
return stringified
|
||||
}
|
||||
|
||||
func encodeBase64URL(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func int64Pointer(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
// Package sendemailcodeabuse implements ports.SendEmailCodeAbuseProtector with
|
||||
// one Redis TTL key per normalized e-mail address.
|
||||
package sendemailcodeabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Config configures one Redis-backed send-email-code abuse protector.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// KeyPrefix is the namespace prefix applied to every resend-throttle key.
|
||||
KeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Protector applies the fixed resend cooldown with one Redis key per
|
||||
// normalized e-mail address.
|
||||
type Protector struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed resend-throttle protector from cfg.
|
||||
func New(cfg Config) (*Protector, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis send email code abuse protector: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis send email code abuse protector: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.KeyPrefix) == "":
|
||||
return nil, errors.New("new redis send email code abuse protector: redis key prefix must not be empty")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis send email code abuse protector: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Protector{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (p *Protector) Close() error {
|
||||
if p == nil || p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (p *Protector) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "ping redis send email code abuse protector")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := p.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis send email code abuse protector: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckAndReserve applies the fixed resend cooldown using one TTL key per
|
||||
// normalized e-mail address.
|
||||
func (p *Protector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "check and reserve send email code abuse")
|
||||
if err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
key := p.lookupKey(input.Email)
|
||||
value := input.Now.UTC().Add(challenge.ResendThrottleCooldown).Format(time.RFC3339Nano)
|
||||
created, err := p.client.SetNX(operationCtx, key, value, challenge.ResendThrottleCooldown).Result()
|
||||
if err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse for %q: %w", input.Email, err)
|
||||
}
|
||||
if created {
|
||||
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeAllowed}, nil
|
||||
}
|
||||
|
||||
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeThrottled}, nil
|
||||
}
|
||||
|
||||
func (p *Protector) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if p == nil || p.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil protector", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (p *Protector) lookupKey(email common.Email) string {
|
||||
return p.keyPrefix + base64.RawURLEncoding.EncodeToString([]byte(email.String()))
|
||||
}
|
||||
|
||||
var _ ports.SendEmailCodeAbuseProtector = (*Protector)(nil)
|
||||
@@ -0,0 +1,176 @@
|
||||
package sendemailcodeabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 1,
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protector, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, protector.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectorPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
protector := newTestProtector(t, server, Config{})
|
||||
|
||||
require.NoError(t, protector.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestProtectorCheckAndReserve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
protector := newTestProtector(t, server, Config{})
|
||||
email := common.Email("pilot@example.com")
|
||||
now := time.Unix(10, 0).UTC()
|
||||
|
||||
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
|
||||
key := protector.lookupKey(email)
|
||||
assert.True(t, server.Exists(key))
|
||||
ttl := server.TTL(key)
|
||||
assert.LessOrEqual(t, ttl, challenge.ResendThrottleCooldown)
|
||||
assert.GreaterOrEqual(t, ttl, challenge.ResendThrottleCooldown-2*time.Second)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(30 * time.Second),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
|
||||
ttlAfterThrottle := server.TTL(key)
|
||||
assert.LessOrEqual(t, ttlAfterThrottle, ttl)
|
||||
|
||||
server.FastForward(challenge.ResendThrottleCooldown)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(challenge.ResendThrottleCooldown),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
}
|
||||
|
||||
func TestProtectorNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
protector := newTestProtector(t, server, Config{})
|
||||
|
||||
_, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Now: time.Unix(10, 0).UTC(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestProtector(t *testing.T, server *miniredis.Miniredis, cfg Config) *Protector {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "authsession:send-email-code-throttle:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
protector, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, protector.Close())
|
||||
})
|
||||
|
||||
return protector
|
||||
}
|
||||
@@ -0,0 +1,723 @@
|
||||
// Package sessionstore implements ports.SessionStore with Redis-backed strict
|
||||
// JSON source-of-truth session records and per-user indexes.
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const mutationRetryLimit = 3
|
||||
|
||||
// Config configures one Redis-backed session store instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// SessionKeyPrefix is the namespace prefix applied to primary session keys.
|
||||
SessionKeyPrefix string
|
||||
|
||||
// UserSessionsKeyPrefix is the namespace prefix applied to all-session user
|
||||
// indexes.
|
||||
UserSessionsKeyPrefix string
|
||||
|
||||
// UserActiveSessionsKeyPrefix is the namespace prefix applied to active
|
||||
// session user indexes.
|
||||
UserActiveSessionsKeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store persists source-of-truth sessions in Redis and maintains user-scoped
|
||||
// indexes for list and count operations.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
sessionKeyPrefix string
|
||||
userSessionsKeyPrefix string
|
||||
userActiveSessionsKeyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKeyBase64 string `json:"client_public_key_base64"`
|
||||
Status devicesession.Status `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
RevokedAt *string `json:"revoked_at,omitempty"`
|
||||
RevokeReasonCode string `json:"revoke_reason_code,omitempty"`
|
||||
RevokeActorType string `json:"revoke_actor_type,omitempty"`
|
||||
RevokeActorID string `json:"revoke_actor_id,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed session store from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis session store: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis session store: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.SessionKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: session key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: user sessions key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: user active sessions key prefix must not be empty")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis session store: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
sessionKeyPrefix: cfg.SessionKeyPrefix,
|
||||
userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix,
|
||||
userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis session store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the stored session for deviceSessionID.
|
||||
func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
||||
if err := deviceSessionID.Validate(); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "get session from redis")
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
record, err := s.loadSession(operationCtx, deviceSessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound)
|
||||
default:
|
||||
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// ListByUserID returns every stored session for userID in newest-first order.
|
||||
func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("list sessions by user id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
if len(deviceSessionIDs) == 0 {
|
||||
return []devicesession.Session{}, nil
|
||||
}
|
||||
|
||||
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
|
||||
for _, rawDeviceSessionID := range deviceSessionIDs {
|
||||
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
|
||||
record, err := s.loadSession(operationCtx, deviceSessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID)
|
||||
default:
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
if record.UserID != userID {
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
sortSessionsNewestFirst(records)
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// CountActiveByUserID returns the number of active sessions currently stored
|
||||
// for userID.
|
||||
func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return 0, fmt.Errorf("count active sessions by user id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// Create persists record as a new device session.
|
||||
func (s *Store) Create(ctx context.Context, record devicesession.Session) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create session in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalSessionRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session in redis: %w", err)
|
||||
}
|
||||
|
||||
deviceSessionKey := s.sessionKey(record.ID)
|
||||
allSessionsKey := s.userSessionsKey(record.UserID)
|
||||
activeSessionsKey := s.userActiveSessionsKey(record.UserID)
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "create session in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
_, err := tx.Get(operationCtx, deviceSessionKey).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
case err != nil:
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
|
||||
default:
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
|
||||
pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{
|
||||
Score: createdAtScore(record.CreatedAt),
|
||||
Member: record.ID.String(),
|
||||
})
|
||||
if record.Status == devicesession.StatusActive {
|
||||
pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{
|
||||
Score: createdAtScore(record.CreatedAt),
|
||||
Member: record.ID.String(),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, deviceSessionKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke stores a revoked view of one target session.
|
||||
func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err)
|
||||
}
|
||||
|
||||
var result ports.RevokeSessionResult
|
||||
err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error {
|
||||
deviceSessionKey := s.sessionKey(input.DeviceSessionID)
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound)
|
||||
default:
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Status == devicesession.StatusRevoked {
|
||||
result = ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
|
||||
Session: current,
|
||||
}
|
||||
return result.Validate()
|
||||
}
|
||||
|
||||
next := current
|
||||
next.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
next.Revocation = &revocation
|
||||
if err := next.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
payload, err := marshalSessionRecord(next)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
|
||||
pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String())
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
result = ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeRevoked,
|
||||
Session: next,
|
||||
}
|
||||
return result.Validate()
|
||||
}, deviceSessionKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return errRetryMutation
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RevokeAllByUserID stores revoked views for all currently active sessions
|
||||
// owned by input.UserID.
|
||||
func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err)
|
||||
}
|
||||
|
||||
var result ports.RevokeUserSessionsResult
|
||||
err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error {
|
||||
activeSessionsKey := s.userActiveSessionsKey(input.UserID)
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
if len(deviceSessionIDs) == 0 {
|
||||
// Force EXEC so WATCH observes concurrent active-index changes even
|
||||
// for the no-op path.
|
||||
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.ZCard(operationCtx, activeSessionsKey)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
|
||||
result = ports.RevokeUserSessionsResult{
|
||||
Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions,
|
||||
UserID: input.UserID,
|
||||
Sessions: []devicesession.Session{},
|
||||
}
|
||||
return result.Validate()
|
||||
}
|
||||
|
||||
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
|
||||
for _, rawDeviceSessionID := range deviceSessionIDs {
|
||||
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
|
||||
record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID)
|
||||
default:
|
||||
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
if record.UserID != input.UserID {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID)
|
||||
}
|
||||
if record.Status != devicesession.StatusActive {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status)
|
||||
}
|
||||
|
||||
next := record
|
||||
next.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
next.Revocation = &revocation
|
||||
if err := next.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
|
||||
}
|
||||
records = append(records, next)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
for _, record := range records {
|
||||
payload, err := marshalSessionRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("session %q: %w", record.ID, err)
|
||||
}
|
||||
pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0)
|
||||
pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
|
||||
sortSessionsNewestFirst(records)
|
||||
result = ports.RevokeUserSessionsResult{
|
||||
Outcome: ports.RevokeUserSessionsOutcomeRevoked,
|
||||
UserID: input.UserID,
|
||||
Sessions: records,
|
||||
}
|
||||
return result.Validate()
|
||||
}, activeSessionsKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return errRetryMutation
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var errRetryMutation = errors.New("redis session store: retry mutation")
|
||||
|
||||
func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error {
|
||||
for attempt := 0; attempt < mutationRetryLimit; attempt++ {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, operation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = execute(operationCtx)
|
||||
cancel()
|
||||
|
||||
switch {
|
||||
case errors.Is(err, errRetryMutation):
|
||||
if attempt == mutationRetryLimit-1 {
|
||||
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
|
||||
}
|
||||
continue
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
||||
return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get)
|
||||
}
|
||||
|
||||
func (s *Store) loadSessionWithGetter(
|
||||
ctx context.Context,
|
||||
deviceSessionID common.DeviceSessionID,
|
||||
getter func(context.Context, string) *redis.StringCmd,
|
||||
) (devicesession.Session, error) {
|
||||
payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return devicesession.Session{}, ports.ErrNotFound
|
||||
case err != nil:
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
record, err := decodeSessionRecord(deviceSessionID, payload)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string {
|
||||
return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String())
|
||||
}
|
||||
|
||||
func (s *Store) userSessionsKey(userID common.UserID) string {
|
||||
return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
func (s *Store) userActiveSessionsKey(userID common.UserID) string {
|
||||
return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func marshalSessionRecord(record devicesession.Session) ([]byte, error) {
|
||||
stored, err := redisRecordFromSession(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(stored)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode redis session record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func redisRecordFromSession(record devicesession.Session) (redisRecord, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return redisRecord{}, fmt.Errorf("encode redis session record: %w", err)
|
||||
}
|
||||
|
||||
stored := redisRecord{
|
||||
DeviceSessionID: record.ID.String(),
|
||||
UserID: record.UserID.String(),
|
||||
ClientPublicKeyBase64: record.ClientPublicKey.String(),
|
||||
Status: record.Status,
|
||||
CreatedAt: formatTimestamp(record.CreatedAt),
|
||||
}
|
||||
if record.Revocation != nil {
|
||||
stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At)
|
||||
stored.RevokeReasonCode = record.Revocation.ReasonCode.String()
|
||||
stored.RevokeActorType = record.Revocation.ActorType.String()
|
||||
stored.RevokeActorID = record.Revocation.ActorID
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
||||
}
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
record, err := sessionFromRedisRecord(stored)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
if record.ID != expectedDeviceSessionID {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) {
|
||||
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
|
||||
}
|
||||
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
|
||||
}
|
||||
|
||||
record := devicesession.Session{
|
||||
ID: common.DeviceSessionID(stored.DeviceSessionID),
|
||||
UserID: common.UserID(stored.UserID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: stored.Status,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
revocation, err := parseRevocation(stored)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
record.Revocation = revocation
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) {
|
||||
hasRevokedAt := stored.RevokedAt != nil
|
||||
hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != ""
|
||||
hasActorType := strings.TrimSpace(stored.RevokeActorType) != ""
|
||||
hasActorID := strings.TrimSpace(stored.RevokeActorID) != ""
|
||||
|
||||
if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID {
|
||||
return nil, nil
|
||||
}
|
||||
if !hasRevokedAt || !hasReasonCode || !hasActorType {
|
||||
return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent")
|
||||
}
|
||||
|
||||
revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &devicesession.Revocation{
|
||||
At: revokedAt,
|
||||
ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode),
|
||||
ActorType: common.RevokeActorType(stored.RevokeActorType),
|
||||
ActorID: stored.RevokeActorID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(fieldName string, value string) (time.Time, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName)
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err)
|
||||
}
|
||||
|
||||
canonical := parsed.UTC().Format(time.RFC3339Nano)
|
||||
if value != canonical {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
|
||||
}
|
||||
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
|
||||
func formatTimestamp(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func formatOptionalTimestamp(value *time.Time) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
formatted := formatTimestamp(*value)
|
||||
return &formatted
|
||||
}
|
||||
|
||||
func createdAtScore(createdAt time.Time) float64 {
|
||||
return float64(createdAt.UTC().UnixMicro())
|
||||
}
|
||||
|
||||
func sortSessionsNewestFirst(records []devicesession.Session) {
|
||||
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
|
||||
switch {
|
||||
case left.CreatedAt.Equal(right.CreatedAt):
|
||||
return strings.Compare(left.ID.String(), right.ID.String())
|
||||
case left.CreatedAt.After(right.CreatedAt):
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var _ ports.SessionStore = (*Store)(nil)
|
||||
@@ -0,0 +1,635 @@
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunSessionStoreContractTests(t, func(t *testing.T) ports.SessionStore {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
return newTestStore(t, server, Config{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 1,
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty all sessions prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "user sessions key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty active sessions prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "user active sessions key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
got.Revocation = &devicesession.Revocation{
|
||||
At: got.CreatedAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
}
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, again.Revocation)
|
||||
assert.Equal(t, record, again)
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetRevoked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(1_775_240_100, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
}
|
||||
|
||||
func TestStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreCreateConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_200, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
}
|
||||
|
||||
func TestStoreIndexesAndOrdering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
older := activeSessionFixture("device-session-old", "user-1", time.Unix(10, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-new", "user-1", time.Unix(20, 0).UTC())
|
||||
revoked := revokedSessionFixture("device-session-revoked", "user-1", time.Unix(15, 0).UTC())
|
||||
otherUser := activeSessionFixture("device-session-other", "user-2", time.Unix(30, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got, 3)
|
||||
assert.Equal(t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID})
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
unknown, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unknown)
|
||||
}
|
||||
|
||||
func TestStoreKeyPrefixesAndEncodedPrimaryKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{
|
||||
SessionKeyPrefix: "custom:session:",
|
||||
UserSessionsKeyPrefix: "custom:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "custom:user-active-sessions:",
|
||||
})
|
||||
|
||||
record := activeSessionFixture("device/session:opaque?1", "user/opaque:1", time.Unix(40, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
primaryKey := store.sessionKey(record.ID)
|
||||
assert.Equal(t, "custom:session:"+encodeKeyComponent(record.ID.String()), primaryKey)
|
||||
assert.True(t, server.Exists(primaryKey))
|
||||
|
||||
allSessionsKey := store.userSessionsKey(record.UserID)
|
||||
activeSessionsKey := store.userActiveSessionsKey(record.UserID)
|
||||
assert.Equal(t, "custom:user-sessions:"+encodeKeyComponent(record.UserID.String()), allSessionsKey)
|
||||
assert.Equal(t, "custom:user-active-sessions:"+encodeKeyComponent(record.UserID.String()), activeSessionsKey)
|
||||
|
||||
allMembers, err := server.ZMembers(allSessionsKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{record.ID.String()}, allMembers)
|
||||
|
||||
activeMembers, err := server.ZMembers(activeSessionsKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{record.ID.String()}, activeMembers)
|
||||
}
|
||||
|
||||
func TestStoreRevoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("active session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
revocation := devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
}
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, revocation, *result.Session.Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
})
|
||||
|
||||
t.Run("already revoked keeps stored revocation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(300, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
ActorID: "admin-1",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
|
||||
})
|
||||
|
||||
t.Run("unknown session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: common.DeviceSessionID("missing-session"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRevokeAllByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("revokes active sessions newest first and clears active index", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
older := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(200, 0).UTC())
|
||||
alreadyRevoked := revokedSessionFixture("device-session-3", "user-1", time.Unix(150, 0).UTC())
|
||||
otherUser := activeSessionFixture("device-session-4", "user-2", time.Unix(250, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
revocation := devicesession.Revocation{
|
||||
At: time.Unix(300, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
ActorID: "admin-1",
|
||||
}
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
|
||||
require.Len(t, result.Sessions, 2)
|
||||
assert.Equal(t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID})
|
||||
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
|
||||
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
|
||||
otherCount, err := store.CountActiveByUserID(context.Background(), common.UserID("user-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, otherCount)
|
||||
})
|
||||
|
||||
t.Run("no active sessions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-5", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(400, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
|
||||
assert.Empty(t, result.Sessions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreStrictDecodeCorruption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_240_300, 0).UTC()
|
||||
baseRecord := revokedSessionFixture("device-session-corrupt", "user-1", now)
|
||||
stored, err := redisRecordFromSession(baseRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(redisRecord) string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "malformed json",
|
||||
mutate: func(_ redisRecord) string {
|
||||
return "{"
|
||||
},
|
||||
wantErrText: "decode redis session record",
|
||||
},
|
||||
{
|
||||
name: "trailing json input",
|
||||
mutate: func(record redisRecord) string {
|
||||
return mustMarshalJSON(t, record) + "{}"
|
||||
},
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
mutate: func(record redisRecord) string {
|
||||
payload := map[string]any{
|
||||
"device_session_id": record.DeviceSessionID,
|
||||
"user_id": record.UserID,
|
||||
"client_public_key_base64": record.ClientPublicKeyBase64,
|
||||
"status": record.Status,
|
||||
"created_at": record.CreatedAt,
|
||||
"revoked_at": record.RevokedAt,
|
||||
"revoke_reason_code": record.RevokeReasonCode,
|
||||
"revoke_actor_type": record.RevokeActorType,
|
||||
"revoke_actor_id": record.RevokeActorID,
|
||||
"unexpected": true,
|
||||
}
|
||||
return mustMarshalJSON(t, payload)
|
||||
},
|
||||
wantErrText: "unknown field",
|
||||
},
|
||||
{
|
||||
name: "unsupported status",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Status = devicesession.Status("paused")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "non canonical timestamp",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.CreatedAt = "2026-04-04T12:00:00+03:00"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "canonical UTC RFC3339Nano timestamp",
|
||||
},
|
||||
{
|
||||
name: "incomplete revocation metadata",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.RevokeActorType = ""
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "revocation metadata must be either fully present or fully absent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Set(store.sessionKey(baseRecord.ID), tt.mutate(stored))
|
||||
|
||||
_, err := store.Get(context.Background(), baseRecord.ID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreListByUserIDDetectsCorruptIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("missing primary record", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
userID := common.UserID("user-1")
|
||||
_, err := server.ZAdd(store.userSessionsKey(userID), 100, "missing-session")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.ListByUserID(context.Background(), userID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "references missing session")
|
||||
})
|
||||
|
||||
t.Run("wrong user id in primary record", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-2", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
|
||||
_, err := server.ZAdd(store.userSessionsKey(common.UserID("user-1")), createdAtScore(record.CreatedAt), record.ID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `belongs to "user-2"`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRevokeAllByUserIDDetectsCorruptActiveIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
|
||||
_, err := server.ZAdd(store.userActiveSessionsKey(record.UserID), createdAtScore(record.CreatedAt), record.ID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: record.UserID,
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `is "revoked"`)
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionKeyPrefix == "" {
|
||||
cfg.SessionKeyPrefix = "authsession:session:"
|
||||
}
|
||||
if cfg.UserSessionsKeyPrefix == "" {
|
||||
cfg.UserSessionsKeyPrefix = "authsession:user-sessions:"
|
||||
}
|
||||
if cfg.UserActiveSessionsKeyPrefix == "" {
|
||||
cfg.UserActiveSessionsKeyPrefix = "authsession:user-active-sessions:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
record := activeSessionFixture(deviceSessionID, userID, createdAt)
|
||||
record.Status = devicesession.StatusRevoked
|
||||
record.Revocation = &devicesession.Revocation{
|
||||
At: createdAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonDeviceLogout,
|
||||
ActorType: common.RevokeActorType("user"),
|
||||
ActorID: "user-actor",
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
func seedSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record devicesession.Session) error {
|
||||
t.Helper()
|
||||
|
||||
stored, err := redisRecordFromSession(record)
|
||||
require.NoError(t, err)
|
||||
server.Set(key, mustMarshalJSON(t, stored))
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(payload)
|
||||
}
|
||||
@@ -0,0 +1,382 @@
|
||||
// Package userservice provides runtime user-directory adapters for the
|
||||
// auth/session service.
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
resolveByEmailPath = "/api/v1/internal/user-resolutions/by-email"
|
||||
existsByUserIDPath = "/api/v1/internal/users/%s/exists"
|
||||
ensureByEmailPath = "/api/v1/internal/users/ensure-by-email"
|
||||
blockByUserIDPath = "/api/v1/internal/users/%s/block"
|
||||
blockByEmailPath = "/api/v1/internal/user-blocks/by-email"
|
||||
)
|
||||
|
||||
// Config configures one HTTP-based UserDirectory client.
|
||||
type Config struct {
|
||||
// BaseURL is the absolute base URL of the future user-service internal
|
||||
// HTTP API.
|
||||
BaseURL string
|
||||
|
||||
// RequestTimeout bounds each outbound user-service request.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// RESTClient implements ports.UserDirectory over a frozen internal REST
|
||||
// contract.
|
||||
type RESTClient struct {
|
||||
baseURL string
|
||||
requestTimeout time.Duration
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewRESTClient constructs a REST-backed UserDirectory adapter from cfg.
|
||||
func NewRESTClient(cfg Config) (*RESTClient, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
return newRESTClient(cfg, &http.Client{Transport: transport})
|
||||
}
|
||||
|
||||
func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.BaseURL) == "":
|
||||
return nil, errors.New("new user service REST client: base URL must not be empty")
|
||||
case cfg.RequestTimeout <= 0:
|
||||
return nil, errors.New("new user service REST client: request timeout must be positive")
|
||||
case httpClient == nil:
|
||||
return nil, errors.New("new user service REST client: http client must not be nil")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new user service REST client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new user service REST client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &RESTClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
requestTimeout: cfg.RequestTimeout,
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *RESTClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveByEmail returns the current coarse user-resolution state for email
|
||||
// without creating any new user record.
|
||||
func (c *RESTClient) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
|
||||
if err := validateContext(ctx, "resolve by email"); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Kind userresolution.Kind `json:"kind"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "resolve by email", http.MethodPost, resolveByEmailPath, map[string]string{
|
||||
"email": email.String(),
|
||||
}, &response, true); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
result := userresolution.Result{
|
||||
Kind: response.Kind,
|
||||
UserID: common.UserID(response.UserID),
|
||||
BlockReasonCode: response.BlockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID reports whether userID currently identifies a stored user
|
||||
// record.
|
||||
func (c *RESTClient) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
|
||||
if err := validateContext(ctx, "exists by user id"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return false, fmt.Errorf("exists by user id: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Exists bool `json:"exists"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "exists by user id", http.MethodGet, fmt.Sprintf(existsByUserIDPath, url.PathEscape(userID.String())), nil, &response, true); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return response.Exists, nil
|
||||
}
|
||||
|
||||
// EnsureUserByEmail returns an existing user for email, creates a new user
|
||||
// when registration is allowed, or reports a blocked outcome.
|
||||
func (c *RESTClient) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
|
||||
if err := validateContext(ctx, "ensure user by email"); err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.EnsureUserOutcome `json:"outcome"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "ensure user by email", http.MethodPost, ensureByEmailPath, map[string]string{
|
||||
"email": email.String(),
|
||||
}, &response, false); err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: response.Outcome,
|
||||
UserID: common.UserID(response.UserID),
|
||||
BlockReasonCode: response.BlockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByUserID applies a block state to the user identified by input.UserID.
|
||||
// Unknown user ids wrap ports.ErrNotFound.
|
||||
func (c *RESTClient) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by user id"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
payload, statusCode, err := c.doRequest(ctx, "block by user id", http.MethodPost, fmt.Sprintf(blockByUserIDPath, url.PathEscape(input.UserID.String())), map[string]string{
|
||||
"reason_code": input.ReasonCode.String(),
|
||||
}, false)
|
||||
if err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if statusCode == http.StatusNotFound {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.BlockUserOutcome `json:"outcome"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
if err := decodeJSONPayload(payload, &response); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: response.Outcome,
|
||||
UserID: common.UserID(response.UserID),
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByEmail applies a block state to input.Email even when no user record
|
||||
// currently exists for that e-mail address.
|
||||
func (c *RESTClient) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by email"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.BlockUserOutcome `json:"outcome"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "block by email", http.MethodPost, blockByEmailPath, map[string]string{
|
||||
"email": input.Email.String(),
|
||||
"reason_code": input.ReasonCode.String(),
|
||||
}, &response, false); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: response.Outcome,
|
||||
UserID: common.UserID(response.UserID),
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *RESTClient) doJSON(ctx context.Context, operation string, method string, requestPath string, requestBody any, responseTarget any, retryRead bool) error {
|
||||
payload, statusCode, err := c.doRequest(ctx, operation, method, requestPath, requestBody, retryRead)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return fmt.Errorf("%s: unexpected HTTP status %d", operation, statusCode)
|
||||
}
|
||||
if err := decodeJSONPayload(payload, responseTarget); err != nil {
|
||||
return fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RESTClient) doRequest(ctx context.Context, operation string, method string, requestPath string, requestBody any, retryRead bool) ([]byte, int, error) {
|
||||
bodyBytes, err := marshalOptionalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
attempts := 1
|
||||
if retryRead {
|
||||
attempts = 2
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
|
||||
request, err := http.NewRequestWithContext(attemptCtx, method, c.baseURL+requestPath, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, 0, fmt.Errorf("%s: build request: %w", operation, err)
|
||||
}
|
||||
if method == http.MethodPost {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
cancel()
|
||||
lastErr = fmt.Errorf("%s: %w", operation, err)
|
||||
if retryRead && attempt == 0 && ctx.Err() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
|
||||
payload, readErr := io.ReadAll(response.Body)
|
||||
closeErr := response.Body.Close()
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
lastErr = fmt.Errorf("%s: read response body: %w", operation, readErr)
|
||||
if retryRead && attempt == 0 && ctx.Err() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
if closeErr != nil {
|
||||
lastErr = fmt.Errorf("%s: close response body: %w", operation, closeErr)
|
||||
if retryRead && attempt == 0 && ctx.Err() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
|
||||
if retryRead && attempt == 0 && isRetriableUserServiceStatus(response.StatusCode) {
|
||||
lastErr = fmt.Errorf("%s: unexpected HTTP status %d", operation, response.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
return payload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
|
||||
func marshalOptionalRequestBody(value any) ([]byte, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func decodeJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("decode response body: unexpected trailing JSON input")
|
||||
}
|
||||
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRetriableUserServiceStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var _ ports.UserDirectory = (*RESTClient)(nil)
|
||||
@@ -0,0 +1,622 @@
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRESTClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
cfg: Config{
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must not be empty",
|
||||
},
|
||||
{
|
||||
name: "relative base url",
|
||||
cfg: Config{
|
||||
BaseURL: "/relative",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must be absolute",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
},
|
||||
wantErr: "request timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := NewRESTClient(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientEndpointSuccessCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*testing.T, *RESTClient)
|
||||
}{
|
||||
{
|
||||
name: "resolve by email",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Case@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userresolution.Result{
|
||||
Kind: userresolution.KindExisting,
|
||||
UserID: common.UserID("user-123"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exists by user id",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ensure user by email",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeCreated,
|
||||
UserID: common.UserID("user-234"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by user id",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-123"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: common.UserID("user-123"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by email",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("blocked@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: common.UserID("user-345"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestsMu sync.Mutex
|
||||
var requests []capturedRequest
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, captureRequest(t, r))
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == resolveByEmailPath:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"kind": "existing",
|
||||
"user_id": "user-123",
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/internal/users/user-123/exists":
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"exists": true})
|
||||
case r.Method == http.MethodPost && r.URL.Path == ensureByEmailPath:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"outcome": "created",
|
||||
"user_id": "user-234",
|
||||
})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/internal/users/user-123/block":
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"outcome": "blocked",
|
||||
"user_id": "user-123",
|
||||
})
|
||||
case r.Method == http.MethodPost && r.URL.Path == blockByEmailPath:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"outcome": "already_blocked",
|
||||
"user_id": "user-345",
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
tt.run(t, client)
|
||||
|
||||
requestsMu.Lock()
|
||||
defer requestsMu.Unlock()
|
||||
|
||||
require.Len(t, requests, 1)
|
||||
switch tt.name {
|
||||
case "resolve by email":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: resolveByEmailPath,
|
||||
ContentType: "application/json",
|
||||
Body: `{"email":"Pilot+Case@example.com"}`,
|
||||
}, requests[0])
|
||||
case "exists by user id":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/v1/internal/users/user-123/exists",
|
||||
}, requests[0])
|
||||
case "ensure user by email":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: ensureByEmailPath,
|
||||
ContentType: "application/json",
|
||||
Body: `{"email":"created@example.com"}`,
|
||||
}, requests[0])
|
||||
case "block by user id":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/v1/internal/users/user-123/block",
|
||||
ContentType: "application/json",
|
||||
Body: `{"reason_code":"policy_blocked"}`,
|
||||
}, requests[0])
|
||||
case "block by email":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: blockByEmailPath,
|
||||
ContentType: "application/json",
|
||||
Body: `{"email":"blocked@example.com","reason_code":"policy_blocked"}`,
|
||||
}, requests[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientPreservesNormalizedEmailExactly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var captured string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
request := captureRequest(t, r)
|
||||
captured = request.Body
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Alias@Example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{"email":"Pilot+Alias@Example.com"}`, captured)
|
||||
}
|
||||
|
||||
func TestRESTClientBlockByUserIDNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("missing-user"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRESTClientReadMethodsRetryOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("resolve by email retries on 503", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
http.Error(w, "temporary", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
result, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userresolution.KindCreatable, result.Kind)
|
||||
assert.Equal(t, 2, calls)
|
||||
})
|
||||
|
||||
t.Run("exists by user id retries on transport failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"exists": true})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
baseTransport := server.Client().Transport
|
||||
client, err := newRESTClient(Config{
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: 250 * time.Millisecond,
|
||||
}, &http.Client{
|
||||
Transport: &failOnceRoundTripper{
|
||||
next: baseTransport,
|
||||
err: errors.New("temporary transport failure"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRESTClientMutationMethodsDoNotRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*RESTClient) error
|
||||
}{
|
||||
{
|
||||
name: "ensure user by email",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by user id",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-123"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by email",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls++
|
||||
http.Error(w, "temporary", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
err := tt.run(client)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, calls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
wantErrText string
|
||||
run func(*RESTClient) error
|
||||
}{
|
||||
{
|
||||
name: "resolve by email rejects unknown field",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"kind":"creatable","extra":true}`,
|
||||
wantErrText: "decode response body",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ensure user by email rejects malformed outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"mystery"}`,
|
||||
wantErrText: "unsupported",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ensure user by email rejects missing user id for created outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"created"}`,
|
||||
wantErrText: "user id",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exists by user id rejects trailing json",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"exists":true}{}`,
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by email rejects unexpected status",
|
||||
statusCode: http.StatusBadGateway,
|
||||
body: `{"error":"temporary"}`,
|
||||
wantErrText: "unexpected HTTP status 502",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.body)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
err := tt.run(client)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientRequestTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
|
||||
|
||||
_, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "context deadline exceeded")
|
||||
}
|
||||
|
||||
func TestRESTClientContextAndValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected upstream call")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func() error
|
||||
}{
|
||||
{
|
||||
name: "nil context",
|
||||
run: func() error {
|
||||
_, err := client.ResolveByEmail(nil, common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled context",
|
||||
run: func() error {
|
||||
_, err := client.ExistsByUserID(cancelledCtx, common.UserID("user-123"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
run: func() error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email(" bad@example.com "))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid user id",
|
||||
run: func() error {
|
||||
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID(" bad "),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid reason code",
|
||||
run: func() error {
|
||||
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode(" bad "),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.run()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type capturedRequest struct {
|
||||
Method string
|
||||
Path string
|
||||
ContentType string
|
||||
Body string
|
||||
}
|
||||
|
||||
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
|
||||
t.Helper()
|
||||
|
||||
body, err := io.ReadAll(request.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return capturedRequest{
|
||||
Method: request.Method,
|
||||
Path: request.URL.Path,
|
||||
ContentType: request.Header.Get("Content-Type"),
|
||||
Body: strings.TrimSpace(string(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(statusCode)
|
||||
_, err = writer.Write(payload)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := NewRESTClient(Config{
|
||||
BaseURL: baseURL,
|
||||
RequestTimeout: timeout,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
type failOnceRoundTripper struct {
|
||||
mu sync.Mutex
|
||||
next http.RoundTripper
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
func (rt *failOnceRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
rt.mu.Lock()
|
||||
if !rt.done {
|
||||
rt.done = true
|
||||
err := rt.err
|
||||
rt.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
next := rt.next
|
||||
rt.mu.Unlock()
|
||||
|
||||
return next.RoundTrip(request)
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
// Package userservice provides runtime user-directory adapters for the
|
||||
// auth/session service.
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
userID common.UserID
|
||||
blockReasonCode userresolution.BlockReasonCode
|
||||
}
|
||||
|
||||
// StubDirectory is a concurrency-safe in-process UserDirectory stub intended
|
||||
// for development, local integration, and explicit stub-based tests.
|
||||
//
|
||||
// The zero value is ready to use. Unknown e-mail addresses resolve as
|
||||
// creatable, unknown user identifiers do not exist, and EnsureUserByEmail
|
||||
// creates deterministic user ids such as "user-1", "user-2", and so on.
|
||||
type StubDirectory struct {
|
||||
mu sync.Mutex
|
||||
byEmail map[common.Email]entry
|
||||
emailByUserID map[common.UserID]common.Email
|
||||
createdUserIDs []common.UserID
|
||||
nextUserNumber int
|
||||
}
|
||||
|
||||
// ResolveByEmail returns the current coarse user-resolution state for email
|
||||
// without creating any new user record.
|
||||
func (d *StubDirectory) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
|
||||
if err := validateContext(ctx, "resolve by email"); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
result, err := d.resolveLocked(email)
|
||||
if err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID reports whether userID currently identifies a stored user
|
||||
// record.
|
||||
func (d *StubDirectory) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
|
||||
if err := validateContext(ctx, "exists by user id"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return false, fmt.Errorf("exists by user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
_, ok := d.emailByUserID[userID]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// EnsureUserByEmail returns an existing user for email, creates a new user
|
||||
// when registration is allowed, or reports a blocked outcome.
|
||||
func (d *StubDirectory) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
|
||||
if err := validateContext(ctx, "ensure user by email"); err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
|
||||
stored, ok := d.byEmail[email]
|
||||
if ok {
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeBlocked,
|
||||
BlockReasonCode: stored.blockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeExisting,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
userID, err := d.nextCreatedUserIDLocked()
|
||||
if err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
d.byEmail[email] = entry{userID: userID}
|
||||
d.emailByUserID[userID] = email
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeCreated,
|
||||
UserID: userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByUserID applies a block state to the user identified by input.UserID.
|
||||
// Unknown user ids wrap ports.ErrNotFound.
|
||||
func (d *StubDirectory) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by user id"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
email, ok := d.emailByUserID[input.UserID]
|
||||
if !ok {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
stored := d.byEmail[email]
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: input.UserID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
stored.blockReasonCode = input.ReasonCode
|
||||
d.byEmail[email] = stored
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: input.UserID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByEmail applies a block state to input.Email even when no user record
|
||||
// currently exists for that e-mail address.
|
||||
func (d *StubDirectory) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by email"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
|
||||
stored := d.byEmail[input.Email]
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
stored.blockReasonCode = input.ReasonCode
|
||||
d.byEmail[input.Email] = stored
|
||||
if !stored.userID.IsZero() {
|
||||
d.emailByUserID[stored.userID] = input.Email
|
||||
}
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SeedExisting preloads one existing unblocked user record into the runtime
|
||||
// stub.
|
||||
func (d *StubDirectory) SeedExisting(email common.Email, userID common.UserID) error {
|
||||
if err := email.Validate(); err != nil {
|
||||
return fmt.Errorf("seed existing email: %w", err)
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return fmt.Errorf("seed existing user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
d.byEmail[email] = entry{userID: userID}
|
||||
d.emailByUserID[userID] = email
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeedBlockedEmail preloads one blocked e-mail address that does not
|
||||
// necessarily belong to an existing user record.
|
||||
func (d *StubDirectory) SeedBlockedEmail(email common.Email, reasonCode userresolution.BlockReasonCode) error {
|
||||
if err := email.Validate(); err != nil {
|
||||
return fmt.Errorf("seed blocked email: %w", err)
|
||||
}
|
||||
if err := reasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("seed blocked email reason code: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
d.byEmail[email] = entry{blockReasonCode: reasonCode}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeedBlockedUser preloads one blocked existing user record into the runtime
|
||||
// stub.
|
||||
func (d *StubDirectory) SeedBlockedUser(email common.Email, userID common.UserID, reasonCode userresolution.BlockReasonCode) error {
|
||||
if err := d.SeedExisting(email, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
stored := d.byEmail[email]
|
||||
stored.blockReasonCode = reasonCode
|
||||
d.byEmail[email] = stored
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueCreatedUserIDs appends deterministic user identifiers that
|
||||
// EnsureUserByEmail consumes before falling back to generated ids.
|
||||
func (d *StubDirectory) QueueCreatedUserIDs(userIDs ...common.UserID) error {
|
||||
for index, userID := range userIDs {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return fmt.Errorf("queue created user id %d: %w", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.createdUserIDs = append(d.createdUserIDs, userIDs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *StubDirectory) ensureMapsLocked() {
|
||||
if d.byEmail == nil {
|
||||
d.byEmail = make(map[common.Email]entry)
|
||||
}
|
||||
if d.emailByUserID == nil {
|
||||
d.emailByUserID = make(map[common.UserID]common.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *StubDirectory) resolveLocked(email common.Email) (userresolution.Result, error) {
|
||||
stored, ok := d.byEmail[email]
|
||||
if !ok {
|
||||
result := userresolution.Result{Kind: userresolution.KindCreatable}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := userresolution.Result{
|
||||
Kind: userresolution.KindBlocked,
|
||||
BlockReasonCode: stored.blockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result := userresolution.Result{
|
||||
Kind: userresolution.KindExisting,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (d *StubDirectory) nextCreatedUserIDLocked() (common.UserID, error) {
|
||||
if len(d.createdUserIDs) > 0 {
|
||||
userID := d.createdUserIDs[0]
|
||||
d.createdUserIDs = d.createdUserIDs[1:]
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
d.nextUserNumber++
|
||||
userID := common.UserID(fmt.Sprintf("user-%d", d.nextUserNumber))
|
||||
if err := userID.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func validateContext(ctx context.Context, operation string) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ports.UserDirectory = (*StubDirectory)(nil)
|
||||
@@ -0,0 +1,329 @@
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStubDirectoryResolveByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
|
||||
require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email common.Email
|
||||
wantKind userresolution.Kind
|
||||
wantUserID common.UserID
|
||||
wantReasonCode userresolution.BlockReasonCode
|
||||
}{
|
||||
{
|
||||
name: "zero value unknown email is creatable",
|
||||
email: common.Email("new@example.com"),
|
||||
wantKind: userresolution.KindCreatable,
|
||||
},
|
||||
{
|
||||
name: "existing email",
|
||||
email: common.Email("existing@example.com"),
|
||||
wantKind: userresolution.KindExisting,
|
||||
wantUserID: common.UserID("user-existing"),
|
||||
},
|
||||
{
|
||||
name: "blocked email",
|
||||
email: common.Email("blocked@example.com"),
|
||||
wantKind: userresolution.KindBlocked,
|
||||
wantReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := directory.ResolveByEmail(context.Background(), tt.email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantKind, result.Kind)
|
||||
assert.Equal(t, tt.wantUserID, result.UserID)
|
||||
assert.Equal(t, tt.wantReasonCode, result.BlockReasonCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStubDirectoryEnsureUserByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("existing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
|
||||
|
||||
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("existing@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeExisting, result.Outcome)
|
||||
assert.Equal(t, common.UserID("user-existing"), result.UserID)
|
||||
})
|
||||
|
||||
t.Run("blocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")))
|
||||
|
||||
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("blocked@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeBlocked, result.Outcome)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), result.BlockReasonCode)
|
||||
})
|
||||
|
||||
t.Run("created queued then existing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.QueueCreatedUserIDs(common.UserID("user-created")))
|
||||
|
||||
first, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeCreated, first.Outcome)
|
||||
assert.Equal(t, common.UserID("user-created"), first.UserID)
|
||||
|
||||
second, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeExisting, second.Outcome)
|
||||
assert.Equal(t, common.UserID("user-created"), second.UserID)
|
||||
})
|
||||
|
||||
t.Run("created fallback id", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
|
||||
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("fallback@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeCreated, result.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), result.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubDirectoryExistsByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
|
||||
|
||||
exists, err := directory.ExistsByUserID(context.Background(), common.UserID("user-existing"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, err = directory.ExistsByUserID(context.Background(), common.UserID("missing"))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestStubDirectoryBlockByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unknown email becomes blocked without user id", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
|
||||
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("blocked@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeBlocked, result.Outcome)
|
||||
assert.True(t, result.UserID.IsZero())
|
||||
|
||||
resolution, err := directory.ResolveByEmail(context.Background(), common.Email("blocked@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
})
|
||||
|
||||
t.Run("existing user preserves linked user id and repeat is already blocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
first, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), first.UserID)
|
||||
|
||||
second, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), second.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubDirectoryBlockByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unknown user wraps ErrNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
|
||||
_, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("missing"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("existing user blocks then returns already blocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
first, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), first.UserID)
|
||||
|
||||
second, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), second.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubDirectoryContextAndValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func() error
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "resolve nil context",
|
||||
run: func() error {
|
||||
_, err := directory.ResolveByEmail(nil, common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
want: "nil context",
|
||||
},
|
||||
{
|
||||
name: "ensure cancelled context",
|
||||
run: func() error {
|
||||
_, err := directory.EnsureUserByEmail(cancelledCtx, common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
want: context.Canceled.Error(),
|
||||
},
|
||||
{
|
||||
name: "exists invalid user id",
|
||||
run: func() error {
|
||||
_, err := directory.ExistsByUserID(context.Background(), common.UserID(" bad "))
|
||||
return err
|
||||
},
|
||||
want: "exists by user id",
|
||||
},
|
||||
{
|
||||
name: "block by email invalid email",
|
||||
run: func() error {
|
||||
_, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("bad"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
want: "block by email",
|
||||
},
|
||||
{
|
||||
name: "seed invalid user id",
|
||||
run: func() error {
|
||||
return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID(" bad "))
|
||||
},
|
||||
want: "seed existing user id",
|
||||
},
|
||||
{
|
||||
name: "queue invalid created user id",
|
||||
run: func() error {
|
||||
return directory.QueueCreatedUserIDs(common.UserID(" bad "))
|
||||
},
|
||||
want: "queue created user id 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.run()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStubDirectorySeedBlockedUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedBlockedUser(
|
||||
common.Email("pilot@example.com"),
|
||||
common.UserID("user-1"),
|
||||
userresolution.BlockReasonCode("policy_block"),
|
||||
))
|
||||
|
||||
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, result.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), result.UserID)
|
||||
}
|
||||
|
||||
func TestStubDirectoryCancelledContextWrapsContextError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := directory.BlockByUserID(cancelledCtx, ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.ErrorContains(t, err, "block by user id")
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
// Package internalhttp exposes the trusted internal HTTP API used for session
|
||||
// read, revoke, and block operations.
|
||||
package internalhttp
|
||||
@@ -0,0 +1,286 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInternalHTTPEndToEndGetSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := getJSON(t, server.URL+"/api/v1/internal/sessions/device-session-1")
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"session":{"device_session_id":"device-session-1","user_id":"user-1","client_public_key":"`+validClientPublicKey+`","status":"active","created_at":"2026-04-05T12:00:00Z"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndListUserSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
key := testClientPublicKey(t, validClientPublicKey)
|
||||
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", key, time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
|
||||
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-2", "user-1", key, time.Date(2026, 4, 5, 12, 1, 0, 0, time.UTC))))
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := getJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions")
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.Contains(t, response.Body, `"device_session_id":"device-session-2"`)
|
||||
assert.Contains(t, response.Body, `"device_session_id":"device-session-1"`)
|
||||
assert.Less(t, bytes.Index([]byte(response.Body), []byte(`"device_session_id":"device-session-2"`)), bytes.Index([]byte(response.Body), []byte(`"device_session_id":"device-session-1"`)))
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndListUserSessionsUnknownUserReturnsEmptyArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := getJSON(t, server.URL+"/api/v1/internal/users/unknown-user/sessions")
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"sessions":[]}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndGetSessionNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := getJSON(t, server.URL+"/api/v1/internal/sessions/missing-session")
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"session_not_found","message":"session not found"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndRevokeDeviceSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/internal/sessions/device-session-1/revoke", `{"reason_code":"admin_revoke","actor":{"type":"system"}}`)
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"outcome":"revoked","device_session_id":"device-session-1","affected_session_count":1}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndRevokeAllUserSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
key := testClientPublicKey(t, validClientPublicKey)
|
||||
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", key, time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
|
||||
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-2", "user-1", key, time.Date(2026, 4, 5, 12, 1, 0, 0, time.UTC))))
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`)
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"outcome":"revoked","user_id":"user-1","affected_session_count":2,"affected_device_session_ids":["device-session-2","device-session-1"]}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndRevokeAllUserSessionsNoActiveSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`)
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-1","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndRevokeAllUserSessionsUnknownUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/internal/users/missing-user/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"subject_not_found","message":"subject not found"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndBlockUserByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`)
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndBlockUserByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"user_id":"user-1","reason_code":"policy_blocked","actor":{"type":"admin"}}`)
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"user_id","subject_value":"user-1","affected_session_count":1,"affected_device_session_ids":["device-session-1"]}`, response.Body)
|
||||
}
|
||||
|
||||
func TestInternalHTTPEndToEndBlockUserUnknownUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t)
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"user_id":"missing-user","reason_code":"policy_blocked","actor":{"type":"admin"}}`)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"subject_not_found","message":"subject not found"}}`, response.Body)
|
||||
}
|
||||
|
||||
type endToEndApp struct {
|
||||
handler http.Handler
|
||||
sessionStore *testkit.InMemorySessionStore
|
||||
userDirectory *userservice.StubDirectory
|
||||
}
|
||||
|
||||
func newEndToEndApp(t *testing.T) endToEndApp {
|
||||
t.Helper()
|
||||
|
||||
sessionStore := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &userservice.StubDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
clock := testkit.FixedClock{Time: time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)}
|
||||
|
||||
getSessionService, err := getsession.New(sessionStore)
|
||||
require.NoError(t, err)
|
||||
listUserSessionsService, err := listusersessions.New(sessionStore)
|
||||
require.NoError(t, err)
|
||||
revokeDeviceSessionService, err := revokedevicesession.New(sessionStore, publisher, clock)
|
||||
require.NoError(t, err)
|
||||
revokeAllUserSessionsService, err := revokeallusersessions.New(sessionStore, userDirectory, publisher, clock)
|
||||
require.NoError(t, err)
|
||||
blockUserService, err := blockuser.New(userDirectory, sessionStore, publisher, clock)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionService,
|
||||
ListUserSessions: listUserSessionsService,
|
||||
RevokeDeviceSession: revokeDeviceSessionService,
|
||||
RevokeAllUserSessions: revokeAllUserSessionsService,
|
||||
BlockUser: blockUserService,
|
||||
})
|
||||
|
||||
return endToEndApp{
|
||||
handler: handler,
|
||||
sessionStore: sessionStore,
|
||||
userDirectory: userDirectory,
|
||||
}
|
||||
}
|
||||
|
||||
type httpResponse struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func getJSON(t *testing.T, url string) httpResponse {
|
||||
t.Helper()
|
||||
|
||||
response, err := http.Get(url)
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return httpResponse{StatusCode: response.StatusCode, Body: string(payload)}
|
||||
}
|
||||
|
||||
func postJSON(t *testing.T, url string, body string) httpResponse {
|
||||
t.Helper()
|
||||
|
||||
response, err := http.Post(url, "application/json", bytes.NewBufferString(body))
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return httpResponse{StatusCode: response.StatusCode, Body: string(payload)}
|
||||
}
|
||||
|
||||
func postJSONValue(t *testing.T, url string, value any) httpResponse {
|
||||
t.Helper()
|
||||
|
||||
body, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
return postJSON(t, url, string(body))
|
||||
}
|
||||
|
||||
func activeSession(id string, userID string, key common.ClientPublicKey, createdAt time.Time) devicesession.Session {
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(id),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func testClientPublicKey(t *testing.T, encoded string) common.ClientPublicKey {
|
||||
t.Helper()
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded))
|
||||
require.NoError(t, err)
|
||||
return key
|
||||
}
|
||||
|
||||
const validClientPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
|
||||
@@ -0,0 +1,513 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||
)
|
||||
|
||||
const jsonContentType = "application/json; charset=utf-8"
|
||||
|
||||
const internalHTTPServiceName = "galaxy-authsession-internal"
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorBody `json:"error"`
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type actorRequest struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
type sessionResponseDTO struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
RevokedAt *string `json:"revoked_at,omitempty"`
|
||||
RevokeReasonCode *string `json:"revoke_reason_code,omitempty"`
|
||||
RevokeActorType *string `json:"revoke_actor_type,omitempty"`
|
||||
RevokeActorID *string `json:"revoke_actor_id,omitempty"`
|
||||
}
|
||||
|
||||
type getSessionResponse struct {
|
||||
Session sessionResponseDTO `json:"session"`
|
||||
}
|
||||
|
||||
type listUserSessionsResponse struct {
|
||||
Sessions []sessionResponseDTO `json:"sessions"`
|
||||
}
|
||||
|
||||
type revokeDeviceSessionRequest struct {
|
||||
ReasonCode string `json:"reason_code"`
|
||||
Actor actorRequest `json:"actor"`
|
||||
}
|
||||
|
||||
type revokeDeviceSessionResponse struct {
|
||||
Outcome string `json:"outcome"`
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
AffectedSessionCount int64 `json:"affected_session_count"`
|
||||
}
|
||||
|
||||
type revokeAllUserSessionsRequest struct {
|
||||
ReasonCode string `json:"reason_code"`
|
||||
Actor actorRequest `json:"actor"`
|
||||
}
|
||||
|
||||
type revokeAllUserSessionsResponse struct {
|
||||
Outcome string `json:"outcome"`
|
||||
UserID string `json:"user_id"`
|
||||
AffectedSessionCount int64 `json:"affected_session_count"`
|
||||
AffectedDeviceSessionIDs []string `json:"affected_device_session_ids"`
|
||||
}
|
||||
|
||||
type blockUserRequest struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ReasonCode string `json:"reason_code"`
|
||||
Actor actorRequest `json:"actor"`
|
||||
}
|
||||
|
||||
type blockUserResponse struct {
|
||||
Outcome string `json:"outcome"`
|
||||
SubjectKind string `json:"subject_kind"`
|
||||
SubjectValue string `json:"subject_value"`
|
||||
AffectedSessionCount int64 `json:"affected_session_count"`
|
||||
AffectedDeviceSessionIDs []string `json:"affected_device_session_ids"`
|
||||
}
|
||||
|
||||
var configureGinModeOnce sync.Once
|
||||
|
||||
func newHandlerWithConfig(cfg Config, deps Dependencies) (http.Handler, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizedDeps, err := normalizeDependencies(deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configureGinModeOnce.Do(func() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
})
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(newOTelMiddleware(normalizedDeps.Telemetry))
|
||||
engine.Use(withInternalObservability(normalizedDeps.Logger, normalizedDeps.Telemetry))
|
||||
engine.GET("/api/v1/internal/sessions/:device_session_id", handleGetSession(normalizedDeps.GetSession, cfg.RequestTimeout))
|
||||
engine.GET("/api/v1/internal/users/:user_id/sessions", handleListUserSessions(normalizedDeps.ListUserSessions, cfg.RequestTimeout))
|
||||
engine.POST("/api/v1/internal/sessions/:device_session_id/revoke", handleRevokeDeviceSession(normalizedDeps.RevokeDeviceSession, cfg.RequestTimeout))
|
||||
engine.POST("/api/v1/internal/users/:user_id/sessions/revoke-all", handleRevokeAllUserSessions(normalizedDeps.RevokeAllUserSessions, cfg.RequestTimeout))
|
||||
engine.POST("/api/v1/internal/user-blocks", handleBlockUser(normalizedDeps.BlockUser, cfg.RequestTimeout))
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func newOTelMiddleware(runtime *telemetry.Runtime) gin.HandlerFunc {
|
||||
options := []otelgin.Option{}
|
||||
if runtime != nil {
|
||||
options = append(
|
||||
options,
|
||||
otelgin.WithTracerProvider(runtime.TracerProvider()),
|
||||
otelgin.WithMeterProvider(runtime.MeterProvider()),
|
||||
)
|
||||
}
|
||||
|
||||
return otelgin.Middleware(internalHTTPServiceName, options...)
|
||||
}
|
||||
|
||||
func handleGetSession(useCase GetSessionUseCase, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := useCase.Execute(callCtx, getsession.Input{
|
||||
DeviceSessionID: c.Param("device_session_id"),
|
||||
})
|
||||
if err != nil {
|
||||
abortWithProjection(c, projectInternalError(err))
|
||||
return
|
||||
}
|
||||
if err := validateGetSessionResult(&result); err != nil {
|
||||
abortWithProjection(c, internalErrorProjection(fmt.Errorf("get session response: %w", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, getSessionResponse{Session: toSessionResponseDTO(result.Session)})
|
||||
}
|
||||
}
|
||||
|
||||
func handleListUserSessions(useCase ListUserSessionsUseCase, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := useCase.Execute(callCtx, listusersessions.Input{
|
||||
UserID: c.Param("user_id"),
|
||||
})
|
||||
if err != nil {
|
||||
abortWithProjection(c, projectInternalError(err))
|
||||
return
|
||||
}
|
||||
if err := validateListUserSessionsResult(&result); err != nil {
|
||||
abortWithProjection(c, internalErrorProjection(fmt.Errorf("list user sessions response: %w", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, listUserSessionsResponse{Sessions: toSessionResponseDTOs(result.Sessions)})
|
||||
}
|
||||
}
|
||||
|
||||
func handleRevokeDeviceSession(useCase RevokeDeviceSessionUseCase, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var request revokeDeviceSessionRequest
|
||||
if err := decodeJSONRequest(c.Request, &request); err != nil {
|
||||
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil {
|
||||
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := useCase.Execute(callCtx, revokedevicesession.Input{
|
||||
DeviceSessionID: c.Param("device_session_id"),
|
||||
ReasonCode: request.ReasonCode,
|
||||
ActorType: request.Actor.Type,
|
||||
ActorID: request.Actor.ID,
|
||||
})
|
||||
if err != nil {
|
||||
abortWithProjection(c, projectInternalError(err))
|
||||
return
|
||||
}
|
||||
if err := validateRevokeDeviceSessionResult(&result); err != nil {
|
||||
abortWithProjection(c, internalErrorProjection(fmt.Errorf("revoke device session response: %w", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, revokeDeviceSessionResponse{
|
||||
Outcome: result.Outcome,
|
||||
DeviceSessionID: result.DeviceSessionID,
|
||||
AffectedSessionCount: result.AffectedSessionCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func handleRevokeAllUserSessions(useCase RevokeAllUserSessionsUseCase, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var request revokeAllUserSessionsRequest
|
||||
if err := decodeJSONRequest(c.Request, &request); err != nil {
|
||||
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil {
|
||||
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := useCase.Execute(callCtx, revokeallusersessions.Input{
|
||||
UserID: c.Param("user_id"),
|
||||
ReasonCode: request.ReasonCode,
|
||||
ActorType: request.Actor.Type,
|
||||
ActorID: request.Actor.ID,
|
||||
})
|
||||
if err != nil {
|
||||
abortWithProjection(c, projectInternalError(err))
|
||||
return
|
||||
}
|
||||
if err := validateRevokeAllUserSessionsResult(&result); err != nil {
|
||||
abortWithProjection(c, internalErrorProjection(fmt.Errorf("revoke all user sessions response: %w", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, revokeAllUserSessionsResponse{
|
||||
Outcome: result.Outcome,
|
||||
UserID: result.UserID,
|
||||
AffectedSessionCount: result.AffectedSessionCount,
|
||||
AffectedDeviceSessionIDs: cloneStrings(result.AffectedDeviceSessionIDs),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func handleBlockUser(useCase BlockUserUseCase, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var request blockUserRequest
|
||||
if err := decodeJSONRequest(c.Request, &request); err != nil {
|
||||
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
if err := validateBlockUserRequest(&request); err != nil {
|
||||
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := useCase.Execute(callCtx, blockuser.Input{
|
||||
UserID: request.UserID,
|
||||
Email: request.Email,
|
||||
ReasonCode: request.ReasonCode,
|
||||
ActorType: request.Actor.Type,
|
||||
ActorID: request.Actor.ID,
|
||||
})
|
||||
if err != nil {
|
||||
abortWithProjection(c, projectInternalError(err))
|
||||
return
|
||||
}
|
||||
if err := validateBlockUserResult(&result); err != nil {
|
||||
abortWithProjection(c, internalErrorProjection(fmt.Errorf("block user response: %w", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, blockUserResponse{
|
||||
Outcome: result.Outcome,
|
||||
SubjectKind: result.SubjectKind,
|
||||
SubjectValue: result.SubjectValue,
|
||||
AffectedSessionCount: result.AffectedSessionCount,
|
||||
AffectedDeviceSessionIDs: cloneStrings(result.AffectedDeviceSessionIDs),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toSessionResponseDTO(session shared.Session) sessionResponseDTO {
|
||||
return sessionResponseDTO{
|
||||
DeviceSessionID: session.DeviceSessionID,
|
||||
UserID: session.UserID,
|
||||
ClientPublicKey: session.ClientPublicKey,
|
||||
Status: session.Status,
|
||||
CreatedAt: session.CreatedAt,
|
||||
RevokedAt: cloneStringPointer(session.RevokedAt),
|
||||
RevokeReasonCode: cloneStringPointer(session.RevokeReasonCode),
|
||||
RevokeActorType: cloneStringPointer(session.RevokeActorType),
|
||||
RevokeActorID: cloneStringPointer(session.RevokeActorID),
|
||||
}
|
||||
}
|
||||
|
||||
func toSessionResponseDTOs(sessions []shared.Session) []sessionResponseDTO {
|
||||
result := make([]sessionResponseDTO, 0, len(sessions))
|
||||
for _, session := range sessions {
|
||||
result = append(result, toSessionResponseDTO(session))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func cloneStrings(values []string) []string {
|
||||
result := make([]string, 0, len(values))
|
||||
return append(result, values...)
|
||||
}
|
||||
|
||||
func cloneStringPointer(value *string) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func validateAuditRequest(reasonCode string, actor actorRequest) error {
|
||||
if strings.TrimSpace(reasonCode) == "" {
|
||||
return errors.New("reason_code must not be empty")
|
||||
}
|
||||
if strings.TrimSpace(actor.Type) == "" {
|
||||
return errors.New("actor.type must not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateBlockUserRequest(request *blockUserRequest) error {
|
||||
if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hasUserID := strings.TrimSpace(request.UserID) != ""
|
||||
hasEmail := strings.TrimSpace(request.Email) != ""
|
||||
switch {
|
||||
case hasUserID && hasEmail:
|
||||
return errors.New("exactly one of user_id or email must be provided")
|
||||
case !hasUserID && !hasEmail:
|
||||
return errors.New("exactly one of user_id or email must be provided")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateSessionDTO(session *shared.Session) error {
|
||||
switch {
|
||||
case strings.TrimSpace(session.DeviceSessionID) == "":
|
||||
return errors.New("session.device_session_id must not be empty")
|
||||
case strings.TrimSpace(session.UserID) == "":
|
||||
return errors.New("session.user_id must not be empty")
|
||||
case strings.TrimSpace(session.ClientPublicKey) == "":
|
||||
return errors.New("session.client_public_key must not be empty")
|
||||
case strings.TrimSpace(session.CreatedAt) == "":
|
||||
return errors.New("session.created_at must not be empty")
|
||||
}
|
||||
|
||||
if _, err := time.Parse(time.RFC3339, session.CreatedAt); err != nil {
|
||||
return fmt.Errorf("session.created_at: %w", err)
|
||||
}
|
||||
|
||||
switch session.Status {
|
||||
case "active":
|
||||
if session.RevokedAt != nil || session.RevokeReasonCode != nil || session.RevokeActorType != nil || session.RevokeActorID != nil {
|
||||
return errors.New("active session must not contain revoke metadata")
|
||||
}
|
||||
case "revoked":
|
||||
switch {
|
||||
case session.RevokedAt == nil || strings.TrimSpace(*session.RevokedAt) == "":
|
||||
return errors.New("revoked session must contain revoked_at")
|
||||
case session.RevokeReasonCode == nil || strings.TrimSpace(*session.RevokeReasonCode) == "":
|
||||
return errors.New("revoked session must contain revoke_reason_code")
|
||||
case session.RevokeActorType == nil || strings.TrimSpace(*session.RevokeActorType) == "":
|
||||
return errors.New("revoked session must contain revoke_actor_type")
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, *session.RevokedAt); err != nil {
|
||||
return fmt.Errorf("session.revoked_at: %w", err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("session.status %q is unsupported", session.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateGetSessionResult(result *getsession.Result) error {
|
||||
return validateSessionDTO(&result.Session)
|
||||
}
|
||||
|
||||
func validateListUserSessionsResult(result *listusersessions.Result) error {
|
||||
if result.Sessions == nil {
|
||||
return errors.New("sessions must not be null")
|
||||
}
|
||||
|
||||
for index := range result.Sessions {
|
||||
if err := validateSessionDTO(&result.Sessions[index]); err != nil {
|
||||
return fmt.Errorf("sessions[%d]: %w", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRevokeDeviceSessionResult(result *revokedevicesession.Result) error {
|
||||
switch result.Outcome {
|
||||
case "revoked":
|
||||
if result.AffectedSessionCount != 1 {
|
||||
return errors.New("revoked outcome must affect exactly one session")
|
||||
}
|
||||
case "already_revoked":
|
||||
if result.AffectedSessionCount != 0 {
|
||||
return errors.New("already_revoked outcome must affect zero sessions")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("revoke device session outcome %q is unsupported", result.Outcome)
|
||||
}
|
||||
if strings.TrimSpace(result.DeviceSessionID) == "" {
|
||||
return errors.New("device_session_id must not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRevokeAllUserSessionsResult(result *revokeallusersessions.Result) error {
|
||||
switch result.Outcome {
|
||||
case "revoked", "no_active_sessions":
|
||||
default:
|
||||
return fmt.Errorf("revoke all user sessions outcome %q is unsupported", result.Outcome)
|
||||
}
|
||||
if strings.TrimSpace(result.UserID) == "" {
|
||||
return errors.New("user_id must not be empty")
|
||||
}
|
||||
if result.AffectedSessionCount < 0 {
|
||||
return errors.New("affected_session_count must not be negative")
|
||||
}
|
||||
if result.AffectedDeviceSessionIDs == nil {
|
||||
return errors.New("affected_device_session_ids must not be null")
|
||||
}
|
||||
if int64(len(result.AffectedDeviceSessionIDs)) != result.AffectedSessionCount {
|
||||
return errors.New("affected_device_session_ids length must match affected_session_count")
|
||||
}
|
||||
for index, deviceSessionID := range result.AffectedDeviceSessionIDs {
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return fmt.Errorf("affected_device_session_ids[%d] must not be empty", index)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateBlockUserResult(result *blockuser.Result) error {
|
||||
switch result.Outcome {
|
||||
case "blocked", "already_blocked":
|
||||
default:
|
||||
return fmt.Errorf("block user outcome %q is unsupported", result.Outcome)
|
||||
}
|
||||
switch result.SubjectKind {
|
||||
case blockuser.SubjectKindUserID, blockuser.SubjectKindEmail:
|
||||
default:
|
||||
return fmt.Errorf("subject_kind %q is unsupported", result.SubjectKind)
|
||||
}
|
||||
if strings.TrimSpace(result.SubjectValue) == "" {
|
||||
return errors.New("subject_value must not be empty")
|
||||
}
|
||||
if result.AffectedSessionCount < 0 {
|
||||
return errors.New("affected_session_count must not be negative")
|
||||
}
|
||||
if result.AffectedDeviceSessionIDs == nil {
|
||||
return errors.New("affected_device_session_ids must not be null")
|
||||
}
|
||||
if int64(len(result.AffectedDeviceSessionIDs)) != result.AffectedSessionCount {
|
||||
return errors.New("affected_device_session_ids length must match affected_session_count")
|
||||
}
|
||||
for index, deviceSessionID := range result.AffectedDeviceSessionIDs {
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return fmt.Errorf("affected_device_session_ids[%d] must not be empty", index)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func projectInternalError(err error) shared.InternalErrorProjection {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return shared.ProjectInternalError(shared.ServiceUnavailable(err))
|
||||
}
|
||||
|
||||
return shared.ProjectInternalError(err)
|
||||
}
|
||||
|
||||
func internalErrorProjection(err error) shared.InternalErrorProjection {
|
||||
return shared.ProjectInternalError(shared.InternalError(err))
|
||||
}
|
||||
@@ -0,0 +1,784 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestGetSessionHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(func(_ context.Context, input getsession.Input) (getsession.Result, error) {
|
||||
assert.Equal(t, getsession.Input{DeviceSessionID: "device-session-123"}, input)
|
||||
return getsession.Result{
|
||||
Session: validSessionDTO(),
|
||||
}, nil
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil)
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"session":{"device_session_id":"device-session-123","user_id":"user-123","client_public_key":"public-key-material","status":"active","created_at":"2026-04-05T12:00:00Z"}}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestListUserSessionsHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(func(_ context.Context, input listusersessions.Input) (listusersessions.Result, error) {
|
||||
assert.Equal(t, listusersessions.Input{UserID: "user-123"}, input)
|
||||
first := validSessionDTO()
|
||||
second := validRevokedSessionDTO()
|
||||
second.DeviceSessionID = "device-session-122"
|
||||
return listusersessions.Result{Sessions: []shared.Session{first, second}}, nil
|
||||
}),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/users/user-123/sessions", nil)
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Contains(t, recorder.Body.String(), `"sessions":[`)
|
||||
assert.Contains(t, recorder.Body.String(), `"device_session_id":"device-session-123"`)
|
||||
assert.Contains(t, recorder.Body.String(), `"device_session_id":"device-session-122"`)
|
||||
}
|
||||
|
||||
func TestListUserSessionsHandlerUnknownUserReturnsEmptyArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(func(_ context.Context, input listusersessions.Input) (listusersessions.Result, error) {
|
||||
assert.Equal(t, listusersessions.Input{UserID: "unknown-user"}, input)
|
||||
return listusersessions.Result{Sessions: []shared.Session{}}, nil
|
||||
}),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/users/unknown-user/sessions", nil)
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"sessions":[]}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestRevokeDeviceSessionHandlerAlreadyRevoked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(func(_ context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) {
|
||||
assert.Equal(t, revokedevicesession.Input{
|
||||
DeviceSessionID: "device-session-123",
|
||||
ReasonCode: "admin_revoke",
|
||||
ActorType: "system",
|
||||
}, input)
|
||||
return revokedevicesession.Result{
|
||||
Outcome: "already_revoked",
|
||||
DeviceSessionID: "device-session-123",
|
||||
AffectedSessionCount: 0,
|
||||
}, nil
|
||||
}),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/internal/sessions/device-session-123/revoke",
|
||||
bytes.NewBufferString(`{"reason_code":"admin_revoke","actor":{"type":"system"}}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"outcome":"already_revoked","device_session_id":"device-session-123","affected_session_count":0}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestRevokeAllUserSessionsHandlerNoActiveSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(_ context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
assert.Equal(t, revokeallusersessions.Input{
|
||||
UserID: "user-123",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
}, input)
|
||||
return revokeallusersessions.Result{
|
||||
Outcome: "no_active_sessions",
|
||||
UserID: "user-123",
|
||||
AffectedSessionCount: 0,
|
||||
AffectedDeviceSessionIDs: []string{},
|
||||
}, nil
|
||||
}),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/internal/users/user-123/sessions/revoke-all",
|
||||
bytes.NewBufferString(`{"reason_code":"logout_all","actor":{"type":"system"}}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-123","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestBlockUserHandlerSuccessByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(func(_ context.Context, input blockuser.Input) (blockuser.Result, error) {
|
||||
assert.Equal(t, blockuser.Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_blocked",
|
||||
ActorType: "admin",
|
||||
}, input)
|
||||
return blockuser.Result{
|
||||
Outcome: "blocked",
|
||||
SubjectKind: blockuser.SubjectKindEmail,
|
||||
SubjectValue: "pilot@example.com",
|
||||
AffectedSessionCount: 0,
|
||||
AffectedDeviceSessionIDs: []string{},
|
||||
}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/internal/user-blocks",
|
||||
bytes.NewBufferString(`{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestBlockUserHandlerSuccessByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(func(_ context.Context, input blockuser.Input) (blockuser.Result, error) {
|
||||
assert.Equal(t, blockuser.Input{
|
||||
UserID: "user-123",
|
||||
ReasonCode: "policy_blocked",
|
||||
ActorType: "admin",
|
||||
}, input)
|
||||
return blockuser.Result{
|
||||
Outcome: "already_blocked",
|
||||
SubjectKind: blockuser.SubjectKindUserID,
|
||||
SubjectValue: "user-123",
|
||||
AffectedSessionCount: 0,
|
||||
AffectedDeviceSessionIDs: []string{},
|
||||
}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/internal/user-blocks",
|
||||
bytes.NewBufferString(`{"user_id":"user-123","reason_code":"policy_blocked","actor":{"type":"admin"}}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"outcome":"already_blocked","subject_kind":"user_id","subject_value":"user-123","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestInternalHandlersRejectInvalidPathParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "get session empty device session id",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/internal/sessions/%20",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"device session id must not be empty"}}`,
|
||||
},
|
||||
{
|
||||
name: "list sessions empty user id",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/internal/users/%20/sessions",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"user id must not be empty"}}`,
|
||||
},
|
||||
{
|
||||
name: "revoke all empty user id",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/users/%20/sessions/revoke-all",
|
||||
body: `{"reason_code":"logout_all","actor":{"type":"system"}}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"user id must not be empty"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{}, shared.InvalidRequest("device session id must not be empty")
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) {
|
||||
return listusersessions.Result{}, shared.InvalidRequest("user id must not be empty")
|
||||
}),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
return revokeallusersessions.Result{}, shared.InvalidRequest("user id must not be empty")
|
||||
}),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
|
||||
if tt.body != "" {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInternalMutationHandlersRejectInvalidRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "revoke device session empty body",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/sessions/device-session-123/revoke",
|
||||
body: ``,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body must not be empty"}}`,
|
||||
},
|
||||
{
|
||||
name: "revoke device session malformed json",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/sessions/device-session-123/revoke",
|
||||
body: `{"reason_code":`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
|
||||
},
|
||||
{
|
||||
name: "revoke device session multiple objects",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/sessions/device-session-123/revoke",
|
||||
body: `{"reason_code":"admin_revoke","actor":{"type":"system"}}{"reason_code":"admin_revoke","actor":{"type":"system"}}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body must contain a single JSON object"}}`,
|
||||
},
|
||||
{
|
||||
name: "revoke device session unknown field",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/sessions/device-session-123/revoke",
|
||||
body: `{"reason_code":"admin_revoke","actor":{"type":"system"},"extra":true}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains unknown field \"extra\""}}`,
|
||||
},
|
||||
{
|
||||
name: "revoke device session invalid json type",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/sessions/device-session-123/revoke",
|
||||
body: `{"reason_code":123,"actor":{"type":"system"}}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains an invalid value for \"reason_code\""}}`,
|
||||
},
|
||||
{
|
||||
name: "revoke all missing reason code",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/users/user-123/sessions/revoke-all",
|
||||
body: `{"reason_code":" ","actor":{"type":"system"}}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"reason_code must not be empty"}}`,
|
||||
},
|
||||
{
|
||||
name: "block user missing actor type",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/user-blocks",
|
||||
body: `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":" "}}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"actor.type must not be empty"}}`,
|
||||
},
|
||||
{
|
||||
name: "block user missing subject",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/user-blocks",
|
||||
body: `{"reason_code":"policy_blocked","actor":{"type":"system"}}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"exactly one of user_id or email must be provided"}}`,
|
||||
},
|
||||
{
|
||||
name: "block user conflicting subjects",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/user-blocks",
|
||||
body: `{"user_id":"user-123","email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"system"}}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"exactly one of user_id or email must be provided"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
|
||||
if tt.body != "" {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInternalHandlersMapServiceErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
deps Dependencies
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "get session not found",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/internal/sessions/missing",
|
||||
deps: Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{}, shared.SessionNotFound()
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
},
|
||||
wantStatus: http.StatusNotFound,
|
||||
wantBody: `{"error":{"code":"session_not_found","message":"session not found"}}`,
|
||||
},
|
||||
{
|
||||
name: "revoke all subject not found",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/users/missing/sessions/revoke-all",
|
||||
body: `{"reason_code":"logout_all","actor":{"type":"system"}}`,
|
||||
deps: Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
return revokeallusersessions.Result{}, shared.SubjectNotFound()
|
||||
}),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
},
|
||||
wantStatus: http.StatusNotFound,
|
||||
wantBody: `{"error":{"code":"subject_not_found","message":"subject not found"}}`,
|
||||
},
|
||||
{
|
||||
name: "service unavailable",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/internal/sessions/device-session-123",
|
||||
deps: Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{}, shared.ServiceUnavailable(errors.New("redis timeout"))
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
},
|
||||
wantStatus: http.StatusServiceUnavailable,
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
|
||||
},
|
||||
{
|
||||
name: "internal error",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/internal/sessions/device-session-123",
|
||||
deps: Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{}, shared.InternalError(errors.New("broken invariant"))
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
|
||||
},
|
||||
{
|
||||
name: "unexpected error hidden",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/internal/sessions/device-session-123",
|
||||
deps: Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{}, errors.New("boom")
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
|
||||
if tt.body != "" {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInternalHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.RequestTimeout = 5 * time.Millisecond
|
||||
|
||||
handler := mustNewHandler(t, cfg, Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{}, context.DeadlineExceeded
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil)
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
assert.JSONEq(t, `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestInternalHandlersRejectInvalidSuccessPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
deps Dependencies
|
||||
}{
|
||||
{
|
||||
name: "get session malformed response",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/internal/sessions/device-session-123",
|
||||
deps: Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
dto := validSessionDTO()
|
||||
dto.DeviceSessionID = ""
|
||||
return getsession.Result{Session: dto}, nil
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "revoke all malformed response",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/internal/users/user-123/sessions/revoke-all",
|
||||
body: `{"reason_code":"logout_all","actor":{"type":"system"}}`,
|
||||
deps: Dependencies{
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
return revokeallusersessions.Result{
|
||||
Outcome: "revoked",
|
||||
UserID: "user-123",
|
||||
AffectedSessionCount: 2,
|
||||
AffectedDeviceSessionIDs: []string{"device-session-1"},
|
||||
}, nil
|
||||
}),
|
||||
BlockUser: blockUserFunc(unexpectedBlockUser),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
|
||||
if tt.body != "" {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
assert.JSONEq(t, `{"error":{"code":"internal_error","message":"internal server error"}}`, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInternalHandlerLogsDoNotContainSensitiveFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, buffer := newObservedLogger()
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
Logger: logger,
|
||||
GetSession: getSessionFunc(unexpectedGetSession),
|
||||
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
|
||||
BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) {
|
||||
return blockuser.Result{
|
||||
Outcome: "blocked",
|
||||
SubjectKind: blockuser.SubjectKindEmail,
|
||||
SubjectValue: "pilot@example.com",
|
||||
AffectedSessionCount: 0,
|
||||
AffectedDeviceSessionIDs: []string{},
|
||||
}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/internal/user-blocks",
|
||||
bytes.NewBufferString(`{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin","id":"admin-1"}}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
logOutput := buffer.String()
|
||||
assert.NotContains(t, logOutput, "pilot@example.com")
|
||||
assert.NotContains(t, logOutput, "admin-1")
|
||||
assert.NotContains(t, logOutput, "reason_code")
|
||||
}
|
||||
|
||||
func mustNewHandler(t *testing.T, cfg Config, deps Dependencies) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
handler, err := newHandlerWithConfig(cfg, deps)
|
||||
require.NoError(t, err)
|
||||
return handler
|
||||
}
|
||||
|
||||
type getSessionFunc func(ctx context.Context, input getsession.Input) (getsession.Result, error)
|
||||
|
||||
func (f getSessionFunc) Execute(ctx context.Context, input getsession.Input) (getsession.Result, error) {
|
||||
return f(ctx, input)
|
||||
}
|
||||
|
||||
type listUserSessionsFunc func(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error)
|
||||
|
||||
func (f listUserSessionsFunc) Execute(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error) {
|
||||
return f(ctx, input)
|
||||
}
|
||||
|
||||
type revokeDeviceSessionFunc func(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error)
|
||||
|
||||
func (f revokeDeviceSessionFunc) Execute(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) {
|
||||
return f(ctx, input)
|
||||
}
|
||||
|
||||
type revokeAllUserSessionsFunc func(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error)
|
||||
|
||||
func (f revokeAllUserSessionsFunc) Execute(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
return f(ctx, input)
|
||||
}
|
||||
|
||||
type blockUserFunc func(ctx context.Context, input blockuser.Input) (blockuser.Result, error)
|
||||
|
||||
func (f blockUserFunc) Execute(ctx context.Context, input blockuser.Input) (blockuser.Result, error) {
|
||||
return f(ctx, input)
|
||||
}
|
||||
|
||||
func validSessionDTO() shared.Session {
|
||||
return shared.Session{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-material",
|
||||
Status: "active",
|
||||
CreatedAt: "2026-04-05T12:00:00Z",
|
||||
}
|
||||
}
|
||||
|
||||
func validRevokedSessionDTO() shared.Session {
|
||||
dto := validSessionDTO()
|
||||
dto.Status = "revoked"
|
||||
revokedAt := "2026-04-05T12:01:00Z"
|
||||
reasonCode := "admin_revoke"
|
||||
actorType := "admin"
|
||||
actorID := "admin-1"
|
||||
dto.RevokedAt = &revokedAt
|
||||
dto.RevokeReasonCode = &reasonCode
|
||||
dto.RevokeActorType = &actorType
|
||||
dto.RevokeActorID = &actorID
|
||||
return dto
|
||||
}
|
||||
|
||||
func newObservedLogger() (*zap.Logger, *bytes.Buffer) {
|
||||
buffer := &bytes.Buffer{}
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = ""
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
zapcore.AddSync(buffer),
|
||||
zap.DebugLevel,
|
||||
)
|
||||
|
||||
return zap.New(core), buffer
|
||||
}
|
||||
|
||||
func unexpectedGetSession(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{}, errors.New("unexpected call")
|
||||
}
|
||||
|
||||
func unexpectedListUserSessions(context.Context, listusersessions.Input) (listusersessions.Result, error) {
|
||||
return listusersessions.Result{}, errors.New("unexpected call")
|
||||
}
|
||||
|
||||
func unexpectedRevokeDeviceSession(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) {
|
||||
return revokedevicesession.Result{}, errors.New("unexpected call")
|
||||
}
|
||||
|
||||
func unexpectedRevokeAllUserSessions(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
return revokeallusersessions.Result{}, errors.New("unexpected call")
|
||||
}
|
||||
|
||||
func unexpectedBlockUser(context.Context, blockuser.Input) (blockuser.Result, error) {
|
||||
return blockuser.Result{}, errors.New("unexpected call")
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const internalErrorCodeContextKey = "internal_error_code"
|
||||
|
||||
type malformedJSONRequestError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *malformedJSONRequestError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return e.message
|
||||
}
|
||||
|
||||
func decodeJSONRequest(request *http.Request, target any) error {
|
||||
if request == nil || request.Body == nil {
|
||||
return &malformedJSONRequestError{message: "request body must not be empty"}
|
||||
}
|
||||
|
||||
return decodeJSONReader(request.Body, target)
|
||||
}
|
||||
|
||||
func decodeJSONReader(reader io.Reader, target any) error {
|
||||
decoder := json.NewDecoder(reader)
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return describeJSONDecodeError(err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(&struct{}{}); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
|
||||
}
|
||||
|
||||
func describeJSONDecodeError(err error) error {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return &malformedJSONRequestError{message: "request body must not be empty"}
|
||||
case errors.As(err, &syntaxErr):
|
||||
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
|
||||
case errors.As(err, &typeErr):
|
||||
if strings.TrimSpace(typeErr.Field) != "" {
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
|
||||
}
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{message: "request body contains an invalid JSON value"}
|
||||
case strings.HasPrefix(err.Error(), "json: unknown field "):
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
|
||||
}
|
||||
default:
|
||||
return &malformedJSONRequestError{message: "request body contains invalid JSON"}
|
||||
}
|
||||
}
|
||||
|
||||
func abortWithProjection(c *gin.Context, projection shared.InternalErrorProjection) {
|
||||
c.Set(internalErrorCodeContextKey, projection.Code)
|
||||
c.AbortWithStatusJSON(projection.StatusCode, errorResponse{
|
||||
Error: errorBody{
|
||||
Code: projection.Code,
|
||||
Message: projection.Message,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
authlogging "galaxy/authsession/internal/logging"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type edgeOutcome string
|
||||
|
||||
const (
|
||||
edgeOutcomeSuccess edgeOutcome = "success"
|
||||
edgeOutcomeRejected edgeOutcome = "rejected"
|
||||
edgeOutcomeFailed edgeOutcome = "failed"
|
||||
)
|
||||
|
||||
func withInternalObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
route := c.FullPath()
|
||||
if route == "" {
|
||||
route = "unmatched"
|
||||
}
|
||||
|
||||
errorCode, _ := c.Get(internalErrorCodeContextKey)
|
||||
errorCodeValue, _ := errorCode.(string)
|
||||
outcome := outcomeFromStatusCode(statusCode)
|
||||
duration := time.Since(start)
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "internal_http"),
|
||||
zap.String("transport", "http"),
|
||||
zap.String("route", route),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
|
||||
zap.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if errorCodeValue != "" {
|
||||
fields = append(fields, zap.String("error_code", errorCodeValue))
|
||||
}
|
||||
fields = append(fields, authlogging.TraceFieldsFromContext(c.Request.Context())...)
|
||||
|
||||
metricAttrs := []attribute.KeyValue{
|
||||
attribute.String("route", route),
|
||||
attribute.String("method", c.Request.Method),
|
||||
attribute.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if errorCodeValue != "" {
|
||||
metricAttrs = append(metricAttrs, attribute.String("error_code", errorCodeValue))
|
||||
}
|
||||
metrics.RecordInternalHTTPRequest(c.Request.Context(), metricAttrs, duration)
|
||||
|
||||
switch outcome {
|
||||
case edgeOutcomeSuccess:
|
||||
logger.Info("internal request completed", fields...)
|
||||
case edgeOutcomeFailed:
|
||||
logger.Error("internal request failed", fields...)
|
||||
default:
|
||||
logger.Warn("internal request rejected", fields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func outcomeFromStatusCode(statusCode int) edgeOutcome {
|
||||
switch {
|
||||
case statusCode >= 500:
|
||||
return edgeOutcomeFailed
|
||||
case statusCode >= 400:
|
||||
return edgeOutcomeRejected
|
||||
default:
|
||||
return edgeOutcomeSuccess
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
authtelemetry "galaxy/authsession/internal/telemetry"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/sdk/trace/tracetest"
|
||||
)
|
||||
|
||||
func TestInternalHandlerEmitsTraceFieldsAndMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, buffer := newObservedLogger()
|
||||
telemetryRuntime, reader, recorder := newObservedInternalTelemetryRuntime(t)
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{Session: validSessionDTO()}, nil
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) {
|
||||
return listusersessions.Result{Sessions: []shared.Session{}}, nil
|
||||
}),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(func(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) {
|
||||
return revokedevicesession.Result{}, nil
|
||||
}),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
return revokeallusersessions.Result{}, nil
|
||||
}),
|
||||
BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) {
|
||||
return blockuser.Result{}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
recorderHTTP := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil)
|
||||
|
||||
handler.ServeHTTP(recorderHTTP, request)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorderHTTP.Code)
|
||||
require.NotEmpty(t, recorder.Ended())
|
||||
assert.Contains(t, buffer.String(), "otel_trace_id")
|
||||
assert.Contains(t, buffer.String(), "otel_span_id")
|
||||
|
||||
assertMetricCount(t, reader, "authsession.internal_http.requests", map[string]string{
|
||||
"route": "/api/v1/internal/sessions/:device_session_id",
|
||||
"method": http.MethodGet,
|
||||
"edge_outcome": "success",
|
||||
}, 1)
|
||||
}
|
||||
|
||||
func newObservedInternalTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader, *tracetest.SpanRecorder) {
|
||||
t.Helper()
|
||||
|
||||
reader := sdkmetric.NewManualReader()
|
||||
meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
|
||||
recorder := tracetest.NewSpanRecorder()
|
||||
tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
|
||||
|
||||
runtime, err := authtelemetry.NewWithProviders(meterProvider, tracerProvider)
|
||||
require.NoError(t, err)
|
||||
|
||||
return runtime, reader, recorder
|
||||
}
|
||||
|
||||
func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) {
|
||||
t.Helper()
|
||||
|
||||
var resourceMetrics metricdata.ResourceMetrics
|
||||
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
|
||||
|
||||
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
|
||||
for _, metric := range scopeMetrics.Metrics {
|
||||
if metric.Name != metricName {
|
||||
continue
|
||||
}
|
||||
|
||||
sum, ok := metric.Data.(metricdata.Sum[int64])
|
||||
require.True(t, ok)
|
||||
|
||||
for _, point := range sum.DataPoints {
|
||||
if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
|
||||
assert.Equal(t, wantValue, point.Value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs)
|
||||
}
|
||||
|
||||
func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
|
||||
if len(values) != len(want) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
if want[string(value.Key)] != value.Value.AsString() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAddr = ":8081"
|
||||
defaultReadHeaderTimeout = 2 * time.Second
|
||||
defaultReadTimeout = 10 * time.Second
|
||||
defaultIdleTimeout = time.Minute
|
||||
defaultRequestTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// GetSessionUseCase describes the trusted internal get-session service
|
||||
// consumed by the HTTP transport layer.
|
||||
type GetSessionUseCase interface {
|
||||
// Execute loads one device session for trusted internal callers.
|
||||
Execute(ctx context.Context, input getsession.Input) (getsession.Result, error)
|
||||
}
|
||||
|
||||
// ListUserSessionsUseCase describes the trusted internal list-user-sessions
|
||||
// service consumed by the HTTP transport layer.
|
||||
type ListUserSessionsUseCase interface {
|
||||
// Execute lists all sessions of one user for trusted internal callers.
|
||||
Execute(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error)
|
||||
}
|
||||
|
||||
// RevokeDeviceSessionUseCase describes the trusted internal single-session
|
||||
// revoke service consumed by the HTTP transport layer.
|
||||
type RevokeDeviceSessionUseCase interface {
|
||||
// Execute revokes one device session and returns the frozen
|
||||
// acknowledgement.
|
||||
Execute(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error)
|
||||
}
|
||||
|
||||
// RevokeAllUserSessionsUseCase describes the trusted internal bulk-revoke
|
||||
// service consumed by the HTTP transport layer.
|
||||
type RevokeAllUserSessionsUseCase interface {
|
||||
// Execute revokes all active sessions of one user and returns the frozen
|
||||
// acknowledgement.
|
||||
Execute(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error)
|
||||
}
|
||||
|
||||
// BlockUserUseCase describes the trusted internal block-user service consumed
|
||||
// by the HTTP transport layer.
|
||||
type BlockUserUseCase interface {
|
||||
// Execute applies a block state to one subject and returns the frozen
|
||||
// acknowledgement.
|
||||
Execute(ctx context.Context, input blockuser.Input) (blockuser.Result, error)
|
||||
}
|
||||
|
||||
// Config describes the trusted internal HTTP listener owned by authsession.
|
||||
type Config struct {
|
||||
// Addr is the TCP listen address used by the trusted internal HTTP server.
|
||||
Addr string
|
||||
|
||||
// ReadHeaderTimeout bounds how long the listener may spend reading request
|
||||
// headers before the server rejects the connection.
|
||||
ReadHeaderTimeout time.Duration
|
||||
|
||||
// ReadTimeout bounds how long the listener may spend reading one trusted
|
||||
// internal request.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// IdleTimeout bounds how long the listener keeps an idle keep-alive
|
||||
// connection open.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// RequestTimeout bounds one application-layer internal use-case call.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// Validate reports whether cfg contains a usable internal HTTP listener
|
||||
// configuration.
|
||||
func (cfg Config) Validate() error {
|
||||
switch {
|
||||
case cfg.Addr == "":
|
||||
return errors.New("internal HTTP addr must not be empty")
|
||||
case cfg.ReadHeaderTimeout <= 0:
|
||||
return errors.New("internal HTTP read header timeout must be positive")
|
||||
case cfg.ReadTimeout <= 0:
|
||||
return errors.New("internal HTTP read timeout must be positive")
|
||||
case cfg.IdleTimeout <= 0:
|
||||
return errors.New("internal HTTP idle timeout must be positive")
|
||||
case cfg.RequestTimeout <= 0:
|
||||
return errors.New("internal HTTP request timeout must be positive")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default trusted internal HTTP listener settings.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
Addr: defaultAddr,
|
||||
ReadHeaderTimeout: defaultReadHeaderTimeout,
|
||||
ReadTimeout: defaultReadTimeout,
|
||||
IdleTimeout: defaultIdleTimeout,
|
||||
RequestTimeout: defaultRequestTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Dependencies describes the collaborators used by the trusted internal HTTP
|
||||
// transport layer.
|
||||
type Dependencies struct {
|
||||
// GetSession executes the trusted internal get-session use case.
|
||||
GetSession GetSessionUseCase
|
||||
|
||||
// ListUserSessions executes the trusted internal list-user-sessions use
|
||||
// case.
|
||||
ListUserSessions ListUserSessionsUseCase
|
||||
|
||||
// RevokeDeviceSession executes the trusted internal single-session revoke
|
||||
// use case.
|
||||
RevokeDeviceSession RevokeDeviceSessionUseCase
|
||||
|
||||
// RevokeAllUserSessions executes the trusted internal bulk-revoke use case.
|
||||
RevokeAllUserSessions RevokeAllUserSessionsUseCase
|
||||
|
||||
// BlockUser executes the trusted internal block-user use case.
|
||||
BlockUser BlockUserUseCase
|
||||
|
||||
// Logger writes structured transport logs. When nil, a no-op logger is
|
||||
// used.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Telemetry records OpenTelemetry spans and low-cardinality HTTP metrics.
|
||||
// When nil, the transport still serves requests with no-op providers.
|
||||
Telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// Server owns the trusted internal HTTP listener exposed by authsession.
|
||||
type Server struct {
|
||||
cfg Config
|
||||
|
||||
handler http.Handler
|
||||
logger *zap.Logger
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs one trusted internal HTTP server for cfg and deps.
|
||||
func NewServer(cfg Config, deps Dependencies) (*Server, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("new internal HTTP server: %w", err)
|
||||
}
|
||||
|
||||
handler, err := newHandlerWithConfig(cfg, deps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new internal HTTP server: %w", err)
|
||||
}
|
||||
|
||||
logger := deps.Logger
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
logger = logger.Named("internal_http")
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
handler: handler,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the trusted internal HTTP
|
||||
// surface until Shutdown closes the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run internal HTTP server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run internal HTTP server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: s.handler,
|
||||
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
|
||||
ReadTimeout: s.cfg.ReadTimeout,
|
||||
IdleTimeout: s.cfg.IdleTimeout,
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = server
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("internal HTTP server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = server.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, http.ErrServerClosed):
|
||||
s.logger.Info("internal HTTP server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run internal HTTP server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the trusted internal HTTP server within ctx.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown internal HTTP server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("shutdown internal HTTP server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeDependencies(deps Dependencies) (Dependencies, error) {
|
||||
switch {
|
||||
case deps.GetSession == nil:
|
||||
return Dependencies{}, errors.New("get session use case must not be nil")
|
||||
case deps.ListUserSessions == nil:
|
||||
return Dependencies{}, errors.New("list user sessions use case must not be nil")
|
||||
case deps.RevokeDeviceSession == nil:
|
||||
return Dependencies{}, errors.New("revoke device session use case must not be nil")
|
||||
case deps.RevokeAllUserSessions == nil:
|
||||
return Dependencies{}, errors.New("revoke all user sessions use case must not be nil")
|
||||
case deps.BlockUser == nil:
|
||||
return Dependencies{}, errors.New("block user use case must not be nil")
|
||||
case deps.Logger == nil:
|
||||
deps.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
deps.Logger = deps.Logger.Named("internal_http")
|
||||
return deps, nil
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package internalhttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewServerRejectsInvalidConfiguration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Addr = ""
|
||||
|
||||
_, err := NewServer(cfg, validDependencies())
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "addr")
|
||||
}
|
||||
|
||||
func TestServerRunAndShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Addr = "127.0.0.1:0"
|
||||
|
||||
server, err := NewServer(cfg, validDependencies())
|
||||
require.NoError(t, err)
|
||||
|
||||
runErr := make(chan error, 1)
|
||||
go func() {
|
||||
runErr <- server.Run(context.Background())
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
server.stateMu.RLock()
|
||||
defer server.stateMu.RUnlock()
|
||||
return server.listener != nil
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
server.stateMu.RLock()
|
||||
addr := server.listener.Addr().String()
|
||||
server.stateMu.RUnlock()
|
||||
|
||||
response, err := http.Post(
|
||||
"http://"+addr+"/api/v1/internal/sessions/device-session-123/revoke",
|
||||
"application/json",
|
||||
bytes.NewBufferString(`{"reason_code":"admin_revoke","actor":{"type":"system"}}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, server.Shutdown(shutdownCtx))
|
||||
require.NoError(t, <-runErr)
|
||||
}
|
||||
|
||||
func validDependencies() Dependencies {
|
||||
return Dependencies{
|
||||
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
|
||||
return getsession.Result{Session: validSessionDTO()}, nil
|
||||
}),
|
||||
ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) {
|
||||
return listusersessions.Result{Sessions: []shared.Session{validSessionDTO()}}, nil
|
||||
}),
|
||||
RevokeDeviceSession: revokeDeviceSessionFunc(func(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) {
|
||||
return revokedevicesession.Result{
|
||||
Outcome: "revoked",
|
||||
DeviceSessionID: "device-session-123",
|
||||
AffectedSessionCount: 1,
|
||||
}, nil
|
||||
}),
|
||||
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
|
||||
return revokeallusersessions.Result{
|
||||
Outcome: "revoked",
|
||||
UserID: "user-123",
|
||||
AffectedSessionCount: 1,
|
||||
AffectedDeviceSessionIDs: []string{"device-session-123"},
|
||||
}, nil
|
||||
}),
|
||||
BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) {
|
||||
return blockuser.Result{
|
||||
Outcome: "blocked",
|
||||
SubjectKind: blockuser.SubjectKindEmail,
|
||||
SubjectValue: "pilot@example.com",
|
||||
AffectedSessionCount: 0,
|
||||
AffectedDeviceSessionIDs: []string{},
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
// Package publichttp exposes the public HTTP transport expected by the
|
||||
// gateway-facing authentication flow.
|
||||
package publichttp
|
||||
@@ -0,0 +1,391 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/mail"
|
||||
"galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicHTTPEndToEndSendThenConfirm(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
sendResponse := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
|
||||
assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, sendResponse.Body)
|
||||
|
||||
attempts := app.mailSender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
|
||||
confirmBody := map[string]string{
|
||||
"challenge_id": "challenge-1",
|
||||
"code": attempts[0].Input.Code,
|
||||
"client_public_key": validClientPublicKey,
|
||||
}
|
||||
confirmResponse := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", confirmBody)
|
||||
|
||||
assert.Equal(t, http.StatusOK, confirmResponse.StatusCode)
|
||||
assert.JSONEq(t, `{"device_session_id":"device-session-1"}`, confirmResponse.Body)
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndBlockedSendReturnsChallengeID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{
|
||||
SeedBlockedEmail: true,
|
||||
})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, response.Body)
|
||||
assert.Empty(t, app.mailSender.RecordedAttempts())
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndThrottledSendStillReturnsChallengeID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{
|
||||
AbuseProtector: &testkit.InMemorySendEmailCodeAbuseProtector{},
|
||||
})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
first := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, first.StatusCode)
|
||||
assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, first.Body)
|
||||
|
||||
second := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, second.StatusCode)
|
||||
assert.JSONEq(t, `{"challenge_id":"challenge-2"}`, second.Body)
|
||||
assert.Len(t, app.mailSender.RecordedAttempts(), 1)
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndInvalidClientPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{
|
||||
SeedChallenge: seedChallengeOptions{
|
||||
ID: "challenge-123",
|
||||
Code: "123456",
|
||||
Status: challenge.StatusSent,
|
||||
},
|
||||
})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSON(
|
||||
t,
|
||||
server.URL+"/api/v1/public/auth/confirm-email-code",
|
||||
`{"challenge_id":"challenge-123","code":"123456","client_public_key":"invalid"}`,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndChallengeNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": "missing",
|
||||
"code": "123456",
|
||||
"client_public_key": validClientPublicKey,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"challenge_not_found","message":"challenge not found"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndChallengeExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{
|
||||
SeedChallenge: seedChallengeOptions{
|
||||
ID: "challenge-123",
|
||||
Code: "123456",
|
||||
Status: challenge.StatusSent,
|
||||
ExpiresAt: time.Date(2026, 4, 5, 11, 59, 0, 0, time.UTC),
|
||||
},
|
||||
})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": "challenge-123",
|
||||
"code": "123456",
|
||||
"client_public_key": validClientPublicKey,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusGone, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"challenge_expired","message":"challenge expired"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndInvalidCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{
|
||||
SeedChallenge: seedChallengeOptions{
|
||||
ID: "challenge-123",
|
||||
Code: "123456",
|
||||
Status: challenge.StatusSent,
|
||||
},
|
||||
})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": "challenge-123",
|
||||
"code": "654321",
|
||||
"client_public_key": validClientPublicKey,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"invalid_code","message":"confirmation code is invalid"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndThrottledChallengeConfirmReturnsInvalidCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := newEndToEndApp(t, endToEndOptions{
|
||||
SeedChallenge: seedChallengeOptions{
|
||||
ID: "challenge-123",
|
||||
Code: "123456",
|
||||
Status: challenge.StatusDeliveryThrottled,
|
||||
},
|
||||
})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": "challenge-123",
|
||||
"code": "123456",
|
||||
"client_public_key": validClientPublicKey,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"invalid_code","message":"confirmation code is invalid"}}`, response.Body)
|
||||
}
|
||||
|
||||
func TestPublicHTTPEndToEndSessionLimitExceeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limit := 1
|
||||
app := newEndToEndApp(t, endToEndOptions{
|
||||
Config: ports.SessionLimitConfig{ActiveSessionLimit: &limit},
|
||||
SeedExistingUser: true,
|
||||
SeedActiveSession: &devicesession.Session{
|
||||
ID: common.DeviceSessionID("device-session-existing"),
|
||||
UserID: common.UserID("user-1"),
|
||||
ClientPublicKey: mustClientPublicKey(t, secondValidClientPublicKey),
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: time.Date(2026, 4, 5, 11, 58, 0, 0, time.UTC),
|
||||
},
|
||||
})
|
||||
server := httptest.NewServer(app.handler)
|
||||
defer server.Close()
|
||||
|
||||
sendResponse := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
|
||||
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
|
||||
|
||||
attempts := app.mailSender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
|
||||
confirmResponse := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
|
||||
"challenge_id": "challenge-1",
|
||||
"code": attempts[0].Input.Code,
|
||||
"client_public_key": validClientPublicKey,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusConflict, confirmResponse.StatusCode)
|
||||
assert.JSONEq(t, `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`, confirmResponse.Body)
|
||||
}
|
||||
|
||||
type endToEndOptions struct {
|
||||
Config ports.SessionLimitConfig
|
||||
AbuseProtector ports.SendEmailCodeAbuseProtector
|
||||
SeedBlockedEmail bool
|
||||
SeedExistingUser bool
|
||||
SeedChallenge seedChallengeOptions
|
||||
SeedActiveSession *devicesession.Session
|
||||
}
|
||||
|
||||
type seedChallengeOptions struct {
|
||||
ID string
|
||||
Code string
|
||||
Status challenge.Status
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type endToEndApp struct {
|
||||
handler http.Handler
|
||||
mailSender *mail.StubSender
|
||||
}
|
||||
|
||||
func newEndToEndApp(t *testing.T, options endToEndOptions) endToEndApp {
|
||||
t.Helper()
|
||||
|
||||
now := time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
sessionStore := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &userservice.StubDirectory{}
|
||||
mailSender := &mail.StubSender{}
|
||||
idGenerator := &testkit.SequenceIDGenerator{}
|
||||
codeGenerator := testkit.FixedCodeGenerator{Code: "123456"}
|
||||
codeHasher := testkit.DeterministicCodeHasher{}
|
||||
clock := testkit.FixedClock{Time: now}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
|
||||
if options.SeedBlockedEmail {
|
||||
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_blocked")))
|
||||
}
|
||||
if options.SeedExistingUser {
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
}
|
||||
if options.SeedActiveSession != nil {
|
||||
require.NoError(t, sessionStore.Create(context.Background(), *options.SeedActiveSession))
|
||||
}
|
||||
if options.SeedChallenge.ID != "" {
|
||||
expiresAt := options.SeedChallenge.ExpiresAt
|
||||
if expiresAt.IsZero() {
|
||||
expiresAt = now.Add(challenge.InitialTTL)
|
||||
}
|
||||
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID(options.SeedChallenge.ID),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: mustHashCode(t, options.SeedChallenge.Code),
|
||||
Status: options.SeedChallenge.Status,
|
||||
DeliveryState: deliveryStateForSeedChallenge(options.SeedChallenge.Status),
|
||||
CreatedAt: now.Add(-time.Minute),
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
require.NoError(t, challengeStore.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
sendService, err := sendemailcode.NewWithRuntime(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
codeGenerator,
|
||||
codeHasher,
|
||||
mailSender,
|
||||
options.AbuseProtector,
|
||||
clock,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
confirmService, err := confirmemailcode.New(
|
||||
challengeStore,
|
||||
sessionStore,
|
||||
userDirectory,
|
||||
testkit.StaticConfigProvider{Config: options.Config},
|
||||
publisher,
|
||||
idGenerator,
|
||||
codeHasher,
|
||||
clock,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
SendEmailCode: sendService,
|
||||
ConfirmEmailCode: confirmService,
|
||||
})
|
||||
|
||||
return endToEndApp{
|
||||
handler: handler,
|
||||
mailSender: mailSender,
|
||||
}
|
||||
}
|
||||
|
||||
func deliveryStateForSeedChallenge(status challenge.Status) challenge.DeliveryState {
|
||||
switch status {
|
||||
case challenge.StatusDeliverySuppressed:
|
||||
return challenge.DeliverySuppressed
|
||||
case challenge.StatusDeliveryThrottled:
|
||||
return challenge.DeliveryThrottled
|
||||
default:
|
||||
return challenge.DeliverySent
|
||||
}
|
||||
}
|
||||
|
||||
type httpResponse struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func postJSON(t *testing.T, url string, body string) httpResponse {
|
||||
t.Helper()
|
||||
|
||||
response, err := http.Post(url, "application/json", bytes.NewBufferString(body))
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return httpResponse{StatusCode: response.StatusCode, Body: string(payload)}
|
||||
}
|
||||
|
||||
func postJSONValue(t *testing.T, url string, value any) httpResponse {
|
||||
t.Helper()
|
||||
|
||||
body, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
return postJSON(t, url, string(body))
|
||||
}
|
||||
|
||||
func mustHashCode(t *testing.T, code string) []byte {
|
||||
t.Helper()
|
||||
|
||||
sum := sha256.Sum256([]byte(code))
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
func mustClientPublicKey(t *testing.T, encoded string) common.ClientPublicKey {
|
||||
t.Helper()
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded))
|
||||
require.NoError(t, err)
|
||||
return key
|
||||
}
|
||||
|
||||
const (
|
||||
validClientPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
|
||||
secondValidClientPublicKey = "ICEiIyQlJicoKSorLC0uLzAxMjM0NTY3ODk6Ozw9Pj8="
|
||||
)
|
||||
@@ -0,0 +1,242 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||
)
|
||||
|
||||
const jsonContentType = "application/json; charset=utf-8"
|
||||
|
||||
const publicHTTPServiceName = "galaxy-authsession-public"
|
||||
|
||||
type sendEmailCodeRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type sendEmailCodeResponse struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
}
|
||||
|
||||
type confirmEmailCodeRequest struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
Code string `json:"code"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
}
|
||||
|
||||
type confirmEmailCodeResponse struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorBody `json:"error"`
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
var configureGinModeOnce sync.Once
|
||||
|
||||
func newHandlerWithConfig(cfg Config, deps Dependencies) (http.Handler, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizedDeps, err := normalizeDependencies(deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configureGinModeOnce.Do(func() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
})
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(newOTelMiddleware(normalizedDeps.Telemetry))
|
||||
engine.Use(withPublicObservability(normalizedDeps.Logger, normalizedDeps.Telemetry))
|
||||
engine.POST(
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
handleSendEmailCode(normalizedDeps.SendEmailCode, cfg.RequestTimeout),
|
||||
)
|
||||
engine.POST(
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
handleConfirmEmailCode(normalizedDeps.ConfirmEmailCode, cfg.RequestTimeout),
|
||||
)
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func newOTelMiddleware(runtime *telemetry.Runtime) gin.HandlerFunc {
|
||||
options := []otelgin.Option{}
|
||||
if runtime != nil {
|
||||
options = append(
|
||||
options,
|
||||
otelgin.WithTracerProvider(runtime.TracerProvider()),
|
||||
otelgin.WithMeterProvider(runtime.MeterProvider()),
|
||||
)
|
||||
}
|
||||
|
||||
return otelgin.Middleware(publicHTTPServiceName, options...)
|
||||
}
|
||||
|
||||
func handleSendEmailCode(useCase SendEmailCodeUseCase, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var request sendEmailCodeRequest
|
||||
if err := decodeJSONRequest(c.Request, &request); err != nil {
|
||||
abortWithProjection(c, projectSendEmailCodeError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
if err := validateSendEmailCodeRequest(&request); err != nil {
|
||||
abortWithProjection(c, projectSendEmailCodeError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := useCase.Execute(callCtx, sendemailcode.Input{Email: request.Email})
|
||||
if err != nil {
|
||||
abortWithProjection(c, projectSendEmailCodeError(err))
|
||||
return
|
||||
}
|
||||
if err := validateSendEmailCodeResult(&result); err != nil {
|
||||
abortWithProjection(c, unavailableProjection(fmt.Errorf("send email code response: %w", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, sendEmailCodeResponse{ChallengeID: result.ChallengeID})
|
||||
}
|
||||
}
|
||||
|
||||
func handleConfirmEmailCode(useCase ConfirmEmailCodeUseCase, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var request confirmEmailCodeRequest
|
||||
if err := decodeJSONRequest(c.Request, &request); err != nil {
|
||||
abortWithProjection(c, projectConfirmEmailCodeError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
if err := validateConfirmEmailCodeRequest(&request); err != nil {
|
||||
abortWithProjection(c, projectConfirmEmailCodeError(shared.InvalidRequest(err.Error())))
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := useCase.Execute(callCtx, confirmemailcode.Input{
|
||||
ChallengeID: request.ChallengeID,
|
||||
Code: request.Code,
|
||||
ClientPublicKey: request.ClientPublicKey,
|
||||
})
|
||||
if err != nil {
|
||||
abortWithProjection(c, projectConfirmEmailCodeError(err))
|
||||
return
|
||||
}
|
||||
if err := validateConfirmEmailCodeResult(&result); err != nil {
|
||||
abortWithProjection(c, unavailableProjection(fmt.Errorf("confirm email code response: %w", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, confirmEmailCodeResponse{DeviceSessionID: result.DeviceSessionID})
|
||||
}
|
||||
}
|
||||
|
||||
func validateSendEmailCodeRequest(request *sendEmailCodeRequest) error {
|
||||
request.Email = strings.TrimSpace(request.Email)
|
||||
if request.Email == "" {
|
||||
return errors.New("email must not be empty")
|
||||
}
|
||||
|
||||
parsedAddress, err := mail.ParseAddress(request.Email)
|
||||
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != request.Email {
|
||||
return errors.New("email must be a single valid email address")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSendEmailCodeResult(result *sendemailcode.Result) error {
|
||||
result.ChallengeID = strings.TrimSpace(result.ChallengeID)
|
||||
if result.ChallengeID == "" {
|
||||
return errors.New("challenge_id must not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfirmEmailCodeRequest(request *confirmEmailCodeRequest) error {
|
||||
request.ChallengeID = strings.TrimSpace(request.ChallengeID)
|
||||
if request.ChallengeID == "" {
|
||||
return errors.New("challenge_id must not be empty")
|
||||
}
|
||||
|
||||
request.Code = strings.TrimSpace(request.Code)
|
||||
if request.Code == "" {
|
||||
return errors.New("code must not be empty")
|
||||
}
|
||||
|
||||
request.ClientPublicKey = strings.TrimSpace(request.ClientPublicKey)
|
||||
if request.ClientPublicKey == "" {
|
||||
return errors.New("client_public_key must not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfirmEmailCodeResult(result *confirmemailcode.Result) error {
|
||||
result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID)
|
||||
if result.DeviceSessionID == "" {
|
||||
return errors.New("device_session_id must not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func projectSendEmailCodeError(err error) shared.PublicErrorProjection {
|
||||
if isTimeoutOrCanceled(err) {
|
||||
return unavailableProjection(err)
|
||||
}
|
||||
|
||||
projection := shared.ProjectPublicError(err)
|
||||
if !shared.IsSendEmailCodePublicErrorCode(projection.Code) {
|
||||
return unavailableProjection(err)
|
||||
}
|
||||
|
||||
return projection
|
||||
}
|
||||
|
||||
func projectConfirmEmailCodeError(err error) shared.PublicErrorProjection {
|
||||
if isTimeoutOrCanceled(err) {
|
||||
return unavailableProjection(err)
|
||||
}
|
||||
|
||||
projection := shared.ProjectPublicError(err)
|
||||
if !shared.IsConfirmEmailCodePublicErrorCode(projection.Code) {
|
||||
return unavailableProjection(err)
|
||||
}
|
||||
|
||||
return projection
|
||||
}
|
||||
|
||||
func unavailableProjection(err error) shared.PublicErrorProjection {
|
||||
return shared.ProjectPublicError(shared.ServiceUnavailable(err))
|
||||
}
|
||||
|
||||
func isTimeoutOrCanceled(err error) bool {
|
||||
return errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
@@ -0,0 +1,463 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestSendEmailCodeHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{ChallengeID: "challenge-123"}, nil
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
bytes.NewBufferString(`{"email":" pilot@example.com "}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(_ context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
assert.Equal(t, confirmemailcode.Input{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key-material",
|
||||
}, input)
|
||||
return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
bytes.NewBufferString(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
body string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "empty body",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: ``,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body must not be empty"}}`,
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
|
||||
},
|
||||
{
|
||||
name: "multiple objects",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}{"email":"next@example.com"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body must contain a single JSON object"}}`,
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com","extra":true}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains unknown field \"extra\""}}`,
|
||||
},
|
||||
{
|
||||
name: "invalid json type",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":123}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains an invalid value for \"email\""}}`,
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"not-an-email"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`,
|
||||
},
|
||||
{
|
||||
name: "empty code",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
})
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body))
|
||||
if tt.body != "" {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlersMapServiceErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
body string
|
||||
deps Dependencies
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "send route hides blocked by policy",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, shared.BlockedByPolicy()
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
},
|
||||
wantStatus: http.StatusServiceUnavailable,
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm invalid client public key",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, shared.InvalidClientPublicKey()
|
||||
}),
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm challenge not found",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, shared.ChallengeNotFound()
|
||||
}),
|
||||
},
|
||||
wantStatus: http.StatusNotFound,
|
||||
wantBody: `{"error":{"code":"challenge_not_found","message":"challenge not found"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm challenge expired",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, shared.ChallengeExpired()
|
||||
}),
|
||||
},
|
||||
wantStatus: http.StatusGone,
|
||||
wantBody: `{"error":{"code":"challenge_expired","message":"challenge expired"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm blocked by policy",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, shared.BlockedByPolicy()
|
||||
}),
|
||||
},
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantBody: `{"error":{"code":"blocked_by_policy","message":"authentication is blocked by policy"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm session limit exceeded",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, shared.SessionLimitExceeded()
|
||||
}),
|
||||
},
|
||||
wantStatus: http.StatusConflict,
|
||||
wantBody: `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm hides internal error",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, shared.InternalError(errors.New("broken invariant"))
|
||||
}),
|
||||
},
|
||||
wantStatus: http.StatusServiceUnavailable,
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.RequestTimeout = 5 * time.Millisecond
|
||||
|
||||
handler := mustNewHandler(t, cfg, Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, context.DeadlineExceeded
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
assert.JSONEq(t, `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlersRejectInvalidSuccessPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
body string
|
||||
deps Dependencies
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "send email blank challenge id",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{ChallengeID: " "}, nil
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
},
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm blank device session id",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
deps: Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{DeviceSessionID: " "}, nil
|
||||
}),
|
||||
},
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, buffer := newObservedLogger()
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
Logger: logger,
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, errors.New("unexpected call")
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
bytes.NewBufferString(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
logOutput := buffer.String()
|
||||
assert.NotContains(t, logOutput, "challenge-123")
|
||||
assert.NotContains(t, logOutput, "123456")
|
||||
assert.NotContains(t, logOutput, "public-key-material")
|
||||
assert.NotContains(t, logOutput, "pilot@example.com")
|
||||
assert.NotContains(t, logOutput, "device-session-123")
|
||||
}
|
||||
|
||||
func mustNewHandler(t *testing.T, cfg Config, deps Dependencies) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
handler, err := newHandlerWithConfig(cfg, deps)
|
||||
require.NoError(t, err)
|
||||
return handler
|
||||
}
|
||||
|
||||
type sendEmailCodeFunc func(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error)
|
||||
|
||||
func (f sendEmailCodeFunc) Execute(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return f(ctx, input)
|
||||
}
|
||||
|
||||
type confirmEmailCodeFunc func(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error)
|
||||
|
||||
func (f confirmEmailCodeFunc) Execute(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return f(ctx, input)
|
||||
}
|
||||
|
||||
func newObservedLogger() (*zap.Logger, *bytes.Buffer) {
|
||||
buffer := &bytes.Buffer{}
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = ""
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
zapcore.AddSync(buffer),
|
||||
zap.DebugLevel,
|
||||
)
|
||||
|
||||
return zap.New(core), buffer
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const publicErrorCodeContextKey = "public_error_code"
|
||||
|
||||
type malformedJSONRequestError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *malformedJSONRequestError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return e.message
|
||||
}
|
||||
|
||||
func decodeJSONRequest(request *http.Request, target any) error {
|
||||
if request == nil || request.Body == nil {
|
||||
return &malformedJSONRequestError{message: "request body must not be empty"}
|
||||
}
|
||||
|
||||
return decodeJSONReader(request.Body, target)
|
||||
}
|
||||
|
||||
func decodeJSONReader(reader io.Reader, target any) error {
|
||||
decoder := json.NewDecoder(reader)
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return describeJSONDecodeError(err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(&struct{}{}); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
|
||||
}
|
||||
|
||||
func describeJSONDecodeError(err error) error {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return &malformedJSONRequestError{message: "request body must not be empty"}
|
||||
case errors.As(err, &syntaxErr):
|
||||
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
|
||||
case errors.As(err, &typeErr):
|
||||
if strings.TrimSpace(typeErr.Field) != "" {
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
|
||||
}
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{message: "request body contains an invalid JSON value"}
|
||||
case strings.HasPrefix(err.Error(), "json: unknown field "):
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
|
||||
}
|
||||
default:
|
||||
return &malformedJSONRequestError{message: "request body contains invalid JSON"}
|
||||
}
|
||||
}
|
||||
|
||||
func abortWithProjection(c *gin.Context, projection shared.PublicErrorProjection) {
|
||||
c.Set(publicErrorCodeContextKey, projection.Code)
|
||||
c.AbortWithStatusJSON(projection.StatusCode, errorResponse{
|
||||
Error: errorBody{
|
||||
Code: projection.Code,
|
||||
Message: projection.Message,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
authlogging "galaxy/authsession/internal/logging"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type edgeOutcome string
|
||||
|
||||
const (
|
||||
edgeOutcomeSuccess edgeOutcome = "success"
|
||||
edgeOutcomeRejected edgeOutcome = "rejected"
|
||||
edgeOutcomeFailed edgeOutcome = "failed"
|
||||
)
|
||||
|
||||
func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
route := c.FullPath()
|
||||
if route == "" {
|
||||
route = "unmatched"
|
||||
}
|
||||
|
||||
errorCode, _ := c.Get(publicErrorCodeContextKey)
|
||||
errorCodeValue, _ := errorCode.(string)
|
||||
outcome := outcomeFromStatusCode(statusCode)
|
||||
duration := time.Since(start)
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "public_http"),
|
||||
zap.String("transport", "http"),
|
||||
zap.String("route", route),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
|
||||
zap.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if errorCodeValue != "" {
|
||||
fields = append(fields, zap.String("error_code", errorCodeValue))
|
||||
}
|
||||
fields = append(fields, authlogging.TraceFieldsFromContext(c.Request.Context())...)
|
||||
|
||||
metricAttrs := []attribute.KeyValue{
|
||||
attribute.String("route", route),
|
||||
attribute.String("method", c.Request.Method),
|
||||
attribute.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if errorCodeValue != "" {
|
||||
metricAttrs = append(metricAttrs, attribute.String("error_code", errorCodeValue))
|
||||
}
|
||||
metrics.RecordPublicHTTPRequest(c.Request.Context(), metricAttrs, duration)
|
||||
|
||||
switch outcome {
|
||||
case edgeOutcomeSuccess:
|
||||
logger.Info("public request completed", fields...)
|
||||
case edgeOutcomeFailed:
|
||||
logger.Error("public request failed", fields...)
|
||||
default:
|
||||
logger.Warn("public request rejected", fields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func outcomeFromStatusCode(statusCode int) edgeOutcome {
|
||||
switch {
|
||||
case statusCode >= 500:
|
||||
return edgeOutcomeFailed
|
||||
case statusCode >= 400:
|
||||
return edgeOutcomeRejected
|
||||
default:
|
||||
return edgeOutcomeSuccess
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
authtelemetry "galaxy/authsession/internal/telemetry"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/sdk/trace/tracetest"
|
||||
)
|
||||
|
||||
func TestPublicHandlerEmitsTraceFieldsAndMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, buffer := newObservedLogger()
|
||||
telemetryRuntime, reader, recorder := newObservedPublicTelemetryRuntime(t)
|
||||
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{ChallengeID: "challenge-123"}, nil
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
recorderHTTP := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.ServeHTTP(recorderHTTP, request)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorderHTTP.Code)
|
||||
require.NotEmpty(t, recorder.Ended())
|
||||
assert.Contains(t, buffer.String(), "otel_trace_id")
|
||||
assert.Contains(t, buffer.String(), "otel_span_id")
|
||||
|
||||
assertMetricCount(t, reader, "authsession.public_http.requests", map[string]string{
|
||||
"route": "/api/v1/public/auth/send-email-code",
|
||||
"method": http.MethodPost,
|
||||
"edge_outcome": "success",
|
||||
}, 1)
|
||||
}
|
||||
|
||||
func newObservedPublicTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader, *tracetest.SpanRecorder) {
|
||||
t.Helper()
|
||||
|
||||
reader := sdkmetric.NewManualReader()
|
||||
meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
|
||||
recorder := tracetest.NewSpanRecorder()
|
||||
tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
|
||||
|
||||
runtime, err := authtelemetry.NewWithProviders(meterProvider, tracerProvider)
|
||||
require.NoError(t, err)
|
||||
|
||||
return runtime, reader, recorder
|
||||
}
|
||||
|
||||
func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) {
|
||||
t.Helper()
|
||||
|
||||
var resourceMetrics metricdata.ResourceMetrics
|
||||
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
|
||||
|
||||
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
|
||||
for _, metric := range scopeMetrics.Metrics {
|
||||
if metric.Name != metricName {
|
||||
continue
|
||||
}
|
||||
|
||||
sum, ok := metric.Data.(metricdata.Sum[int64])
|
||||
require.True(t, ok)
|
||||
|
||||
for _, point := range sum.DataPoints {
|
||||
if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
|
||||
assert.Equal(t, wantValue, point.Value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs)
|
||||
}
|
||||
|
||||
func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
|
||||
if len(values) != len(want) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
if want[string(value.Key)] != value.Value.AsString() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAddr = ":8080"
|
||||
defaultReadHeaderTimeout = 2 * time.Second
|
||||
defaultReadTimeout = 10 * time.Second
|
||||
defaultIdleTimeout = time.Minute
|
||||
defaultRequestTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// SendEmailCodeUseCase describes the public send-email-code application
|
||||
// service consumed by the HTTP transport layer.
|
||||
type SendEmailCodeUseCase interface {
|
||||
// Execute validates input and creates a new login challenge.
|
||||
Execute(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error)
|
||||
}
|
||||
|
||||
// ConfirmEmailCodeUseCase describes the public confirm-email-code application
|
||||
// service consumed by the HTTP transport layer.
|
||||
type ConfirmEmailCodeUseCase interface {
|
||||
// Execute validates input and completes an existing login challenge.
|
||||
Execute(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error)
|
||||
}
|
||||
|
||||
// Config describes the public HTTP listener owned by authsession.
|
||||
type Config struct {
|
||||
// Addr is the TCP listen address used by the public HTTP server.
|
||||
Addr string
|
||||
|
||||
// ReadHeaderTimeout bounds how long the listener may spend reading request
|
||||
// headers before the server rejects the connection.
|
||||
ReadHeaderTimeout time.Duration
|
||||
|
||||
// ReadTimeout bounds how long the listener may spend reading one public
|
||||
// request.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// IdleTimeout bounds how long the listener keeps an idle keep-alive
|
||||
// connection open.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// RequestTimeout bounds one application-layer public-auth use-case call.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// Validate reports whether cfg contains a usable public HTTP listener
|
||||
// configuration.
|
||||
func (cfg Config) Validate() error {
|
||||
switch {
|
||||
case cfg.Addr == "":
|
||||
return errors.New("public HTTP addr must not be empty")
|
||||
case cfg.ReadHeaderTimeout <= 0:
|
||||
return errors.New("public HTTP read header timeout must be positive")
|
||||
case cfg.ReadTimeout <= 0:
|
||||
return errors.New("public HTTP read timeout must be positive")
|
||||
case cfg.IdleTimeout <= 0:
|
||||
return errors.New("public HTTP idle timeout must be positive")
|
||||
case cfg.RequestTimeout <= 0:
|
||||
return errors.New("public HTTP request timeout must be positive")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default public HTTP listener settings aligned with
|
||||
// the gateway public-auth transport timeouts.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
Addr: defaultAddr,
|
||||
ReadHeaderTimeout: defaultReadHeaderTimeout,
|
||||
ReadTimeout: defaultReadTimeout,
|
||||
IdleTimeout: defaultIdleTimeout,
|
||||
RequestTimeout: defaultRequestTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Dependencies describes the collaborators used by the public HTTP transport
|
||||
// layer.
|
||||
type Dependencies struct {
|
||||
// SendEmailCode executes the public send-email-code use case.
|
||||
SendEmailCode SendEmailCodeUseCase
|
||||
|
||||
// ConfirmEmailCode executes the public confirm-email-code use case.
|
||||
ConfirmEmailCode ConfirmEmailCodeUseCase
|
||||
|
||||
// Logger writes structured transport logs. When nil, a no-op logger is
|
||||
// used.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Telemetry records OpenTelemetry spans and low-cardinality HTTP metrics.
|
||||
// When nil, the transport still serves requests with no-op providers.
|
||||
Telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// Server owns the public auth HTTP listener exposed by authsession.
|
||||
type Server struct {
|
||||
cfg Config
|
||||
|
||||
handler http.Handler
|
||||
logger *zap.Logger
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs one public auth HTTP server for cfg and deps.
|
||||
func NewServer(cfg Config, deps Dependencies) (*Server, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("new public HTTP server: %w", err)
|
||||
}
|
||||
|
||||
handler, err := newHandlerWithConfig(cfg, deps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new public HTTP server: %w", err)
|
||||
}
|
||||
|
||||
logger := deps.Logger
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
logger = logger.Named("public_http")
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
handler: handler,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the public auth HTTP surface
|
||||
// until Shutdown closes the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run public HTTP server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run public HTTP server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: s.handler,
|
||||
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
|
||||
ReadTimeout: s.cfg.ReadTimeout,
|
||||
IdleTimeout: s.cfg.IdleTimeout,
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = server
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("public HTTP server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = server.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, http.ErrServerClosed):
|
||||
s.logger.Info("public HTTP server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run public HTTP server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the public HTTP server within ctx.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown public HTTP server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("shutdown public HTTP server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeDependencies(deps Dependencies) (Dependencies, error) {
|
||||
switch {
|
||||
case deps.SendEmailCode == nil:
|
||||
return Dependencies{}, errors.New("send email code use case must not be nil")
|
||||
case deps.ConfirmEmailCode == nil:
|
||||
return Dependencies{}, errors.New("confirm email code use case must not be nil")
|
||||
case deps.Logger == nil:
|
||||
deps.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
deps.Logger = deps.Logger.Named("public_http")
|
||||
return deps, nil
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package publichttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewServerRejectsInvalidConfiguration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Addr = ""
|
||||
|
||||
_, err := NewServer(cfg, Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{}, nil
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{}, nil
|
||||
}),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "addr")
|
||||
}
|
||||
|
||||
func TestServerRunAndShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Addr = "127.0.0.1:0"
|
||||
|
||||
server, err := NewServer(cfg, Dependencies{
|
||||
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
|
||||
return sendemailcode.Result{ChallengeID: "challenge-123"}, nil
|
||||
}),
|
||||
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
|
||||
return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil
|
||||
}),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
runErr := make(chan error, 1)
|
||||
go func() {
|
||||
runErr <- server.Run(context.Background())
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
server.stateMu.RLock()
|
||||
defer server.stateMu.RUnlock()
|
||||
return server.listener != nil
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
server.stateMu.RLock()
|
||||
addr := server.listener.Addr().String()
|
||||
server.stateMu.RUnlock()
|
||||
|
||||
response, err := http.Post(
|
||||
"http://"+addr+"/api/v1/public/auth/send-email-code",
|
||||
"application/json",
|
||||
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, server.Shutdown(shutdownCtx))
|
||||
require.NoError(t, <-runErr)
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
// Package app wires the authsession process lifecycle and coordinates
|
||||
// component startup and graceful shutdown.
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/config"
|
||||
)
|
||||
|
||||
// Component is a long-lived authsession subsystem that participates in
|
||||
// coordinated startup and graceful shutdown.
|
||||
type Component interface {
|
||||
// Run starts the component and blocks until it stops.
|
||||
Run(context.Context) error
|
||||
|
||||
// Shutdown stops the component within the provided timeout-bounded context.
|
||||
Shutdown(context.Context) error
|
||||
}
|
||||
|
||||
// App owns the process-level lifecycle of authsession and its registered
|
||||
// components.
|
||||
type App struct {
|
||||
cfg config.Config
|
||||
components []Component
|
||||
}
|
||||
|
||||
// New constructs an App with a defensive copy of the supplied components.
|
||||
func New(cfg config.Config, components ...Component) *App {
|
||||
clonedComponents := append([]Component(nil), components...)
|
||||
|
||||
return &App{
|
||||
cfg: cfg,
|
||||
components: clonedComponents,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts all configured components, waits for cancellation or the first
|
||||
// component failure, and then executes best-effort graceful shutdown.
|
||||
func (a *App) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run authsession app: nil context")
|
||||
}
|
||||
if err := a.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(a.components) == 0 {
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
results := make(chan componentResult, len(a.components))
|
||||
var runWG sync.WaitGroup
|
||||
|
||||
for idx, component := range a.components {
|
||||
runWG.Add(1)
|
||||
|
||||
go func(index int, component Component) {
|
||||
defer runWG.Done()
|
||||
results <- componentResult{
|
||||
index: index,
|
||||
err: component.Run(runCtx),
|
||||
}
|
||||
}(idx, component)
|
||||
}
|
||||
|
||||
var runErr error
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case result := <-results:
|
||||
runErr = classifyComponentResult(ctx, result)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
shutdownErr := a.shutdownComponents()
|
||||
waitErr := a.waitForComponents(&runWG)
|
||||
|
||||
return errors.Join(runErr, shutdownErr, waitErr)
|
||||
}
|
||||
|
||||
type componentResult struct {
|
||||
index int
|
||||
err error
|
||||
}
|
||||
|
||||
func (a *App) validate() error {
|
||||
if a.cfg.ShutdownTimeout <= 0 {
|
||||
return fmt.Errorf("run authsession app: shutdown timeout must be positive, got %s", a.cfg.ShutdownTimeout)
|
||||
}
|
||||
|
||||
for idx, component := range a.components {
|
||||
if component == nil {
|
||||
return fmt.Errorf("run authsession app: component %d is nil", idx)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func classifyComponentResult(parentCtx context.Context, result componentResult) error {
|
||||
switch {
|
||||
case result.err == nil:
|
||||
if parentCtx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("run authsession app: component %d exited without error before shutdown", result.index)
|
||||
case errors.Is(result.err, context.Canceled) && parentCtx.Err() != nil:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run authsession app: component %d: %w", result.index, result.err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) shutdownComponents() error {
|
||||
var shutdownWG sync.WaitGroup
|
||||
errs := make(chan error, len(a.components))
|
||||
|
||||
for idx, component := range a.components {
|
||||
shutdownWG.Add(1)
|
||||
|
||||
go func(index int, component Component) {
|
||||
defer shutdownWG.Done()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := component.Shutdown(shutdownCtx); err != nil {
|
||||
errs <- fmt.Errorf("shutdown authsession component %d: %w", index, err)
|
||||
}
|
||||
}(idx, component)
|
||||
}
|
||||
|
||||
shutdownWG.Wait()
|
||||
close(errs)
|
||||
|
||||
var joined error
|
||||
for err := range errs {
|
||||
joined = errors.Join(joined, err)
|
||||
}
|
||||
|
||||
return joined
|
||||
}
|
||||
|
||||
func (a *App) waitForComponents(runWG *sync.WaitGroup) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
runWG.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-waitCtx.Done():
|
||||
return fmt.Errorf("wait for authsession components: %w", waitCtx.Err())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/adapters/local"
|
||||
"galaxy/authsession/internal/adapters/mail"
|
||||
"galaxy/authsession/internal/adapters/redis/challengestore"
|
||||
"galaxy/authsession/internal/adapters/redis/configprovider"
|
||||
"galaxy/authsession/internal/adapters/redis/projectionpublisher"
|
||||
"galaxy/authsession/internal/adapters/redis/sendemailcodeabuse"
|
||||
"galaxy/authsession/internal/adapters/redis/sessionstore"
|
||||
"galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/api/internalhttp"
|
||||
"galaxy/authsession/internal/api/publichttp"
|
||||
"galaxy/authsession/internal/config"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/blockuser"
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/getsession"
|
||||
"galaxy/authsession/internal/service/listusersessions"
|
||||
"galaxy/authsession/internal/service/revokeallusersessions"
|
||||
"galaxy/authsession/internal/service/revokedevicesession"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type pinger interface {
|
||||
Ping(context.Context) error
|
||||
}
|
||||
|
||||
type closer interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Runtime owns the runnable authsession application plus the adapter cleanup
|
||||
// functions that must run after the process stops.
|
||||
type Runtime struct {
|
||||
// App coordinates the long-lived HTTP listeners.
|
||||
App *App
|
||||
|
||||
cleanupFns []func() error
|
||||
}
|
||||
|
||||
// NewRuntime constructs the runnable authsession process from cfg using the
|
||||
// Stage 18 Redis adapters, local runtime helpers, and the selectable mail and
|
||||
// user-service runtime adapters from Stages 20 and 21.
|
||||
func NewRuntime(ctx context.Context, cfg config.Config, logger *zap.Logger, telemetryRuntime *telemetry.Runtime) (*Runtime, error) {
|
||||
if ctx == nil {
|
||||
return nil, errors.New("new authsession runtime: nil context")
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("new authsession runtime: %w", err)
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
runtime := &Runtime{}
|
||||
cleanupOnError := func(err error) (*Runtime, error) {
|
||||
return nil, errors.Join(err, runtime.Close())
|
||||
}
|
||||
|
||||
challengeStore, err := challengestore.New(challengestore.Config{
|
||||
Addr: cfg.Redis.Addr,
|
||||
Username: cfg.Redis.Username,
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
TLSEnabled: cfg.Redis.TLSEnabled,
|
||||
KeyPrefix: cfg.Redis.ChallengeKeyPrefix,
|
||||
OperationTimeout: cfg.Redis.OperationTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: challenge store: %w", err))
|
||||
}
|
||||
runtime.cleanupFns = append(runtime.cleanupFns, challengeStore.Close)
|
||||
|
||||
sessionStore, err := sessionstore.New(sessionstore.Config{
|
||||
Addr: cfg.Redis.Addr,
|
||||
Username: cfg.Redis.Username,
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
TLSEnabled: cfg.Redis.TLSEnabled,
|
||||
SessionKeyPrefix: cfg.Redis.SessionKeyPrefix,
|
||||
UserSessionsKeyPrefix: cfg.Redis.UserSessionsKeyPrefix,
|
||||
UserActiveSessionsKeyPrefix: cfg.Redis.UserActiveSessionsKeyPrefix,
|
||||
OperationTimeout: cfg.Redis.OperationTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: session store: %w", err))
|
||||
}
|
||||
runtime.cleanupFns = append(runtime.cleanupFns, sessionStore.Close)
|
||||
|
||||
configStore, err := configprovider.New(configprovider.Config{
|
||||
Addr: cfg.Redis.Addr,
|
||||
Username: cfg.Redis.Username,
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
TLSEnabled: cfg.Redis.TLSEnabled,
|
||||
SessionLimitKey: cfg.Redis.SessionLimitKey,
|
||||
OperationTimeout: cfg.Redis.OperationTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: config provider: %w", err))
|
||||
}
|
||||
runtime.cleanupFns = append(runtime.cleanupFns, configStore.Close)
|
||||
|
||||
publisher, err := projectionpublisher.New(projectionpublisher.Config{
|
||||
Addr: cfg.Redis.Addr,
|
||||
Username: cfg.Redis.Username,
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
TLSEnabled: cfg.Redis.TLSEnabled,
|
||||
SessionCacheKeyPrefix: cfg.Redis.GatewaySessionCacheKeyPrefix,
|
||||
SessionEventsStream: cfg.Redis.GatewaySessionEventsStream,
|
||||
StreamMaxLen: cfg.Redis.GatewaySessionEventsStreamMaxLen,
|
||||
OperationTimeout: cfg.Redis.OperationTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: projection publisher: %w", err))
|
||||
}
|
||||
runtime.cleanupFns = append(runtime.cleanupFns, publisher.Close)
|
||||
|
||||
abuseProtector, err := sendemailcodeabuse.New(sendemailcodeabuse.Config{
|
||||
Addr: cfg.Redis.Addr,
|
||||
Username: cfg.Redis.Username,
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
TLSEnabled: cfg.Redis.TLSEnabled,
|
||||
KeyPrefix: cfg.Redis.SendEmailCodeThrottleKeyPrefix,
|
||||
OperationTimeout: cfg.Redis.OperationTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: send email code abuse protector: %w", err))
|
||||
}
|
||||
runtime.cleanupFns = append(runtime.cleanupFns, abuseProtector.Close)
|
||||
|
||||
for name, dependency := range map[string]pinger{
|
||||
"challenge store": challengeStore,
|
||||
"session store": sessionStore,
|
||||
"config provider": configStore,
|
||||
"projection publisher": publisher,
|
||||
"send email code abuse protector": abuseProtector,
|
||||
} {
|
||||
if err := dependency.Ping(ctx); err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: ping %s: %w", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
clock := local.Clock{}
|
||||
idGenerator := local.IDGenerator{}
|
||||
codeGenerator := local.CodeGenerator{}
|
||||
codeHasher := local.CodeHasher{}
|
||||
var mailSender ports.MailSender
|
||||
switch cfg.MailService.Mode {
|
||||
case "stub":
|
||||
mailSender = &mail.StubSender{}
|
||||
case "rest":
|
||||
restClient, err := mail.NewRESTClient(mail.Config{
|
||||
BaseURL: cfg.MailService.BaseURL,
|
||||
RequestTimeout: cfg.MailService.RequestTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: mail service REST client: %w", err))
|
||||
}
|
||||
runtime.cleanupFns = append(runtime.cleanupFns, restClient.Close)
|
||||
mailSender = restClient
|
||||
default:
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: unsupported mail service mode %q", cfg.MailService.Mode))
|
||||
}
|
||||
var userDirectory ports.UserDirectory
|
||||
switch cfg.UserService.Mode {
|
||||
case "stub":
|
||||
userDirectory = &userservice.StubDirectory{}
|
||||
case "rest":
|
||||
restClient, err := userservice.NewRESTClient(userservice.Config{
|
||||
BaseURL: cfg.UserService.BaseURL,
|
||||
RequestTimeout: cfg.UserService.RequestTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: user service REST client: %w", err))
|
||||
}
|
||||
runtime.cleanupFns = append(runtime.cleanupFns, restClient.Close)
|
||||
userDirectory = restClient
|
||||
default:
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: unsupported user service mode %q", cfg.UserService.Mode))
|
||||
}
|
||||
|
||||
sendEmailCodeService, err := sendemailcode.NewWithObservability(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
codeGenerator,
|
||||
codeHasher,
|
||||
mailSender,
|
||||
abuseProtector,
|
||||
clock,
|
||||
logger,
|
||||
telemetryRuntime,
|
||||
)
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: send email code service: %w", err))
|
||||
}
|
||||
confirmEmailCodeService, err := confirmemailcode.NewWithObservability(
|
||||
challengeStore,
|
||||
sessionStore,
|
||||
userDirectory,
|
||||
configStore,
|
||||
publisher,
|
||||
idGenerator,
|
||||
codeHasher,
|
||||
clock,
|
||||
logger,
|
||||
telemetryRuntime,
|
||||
)
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: confirm email code service: %w", err))
|
||||
}
|
||||
getSessionService, err := getsession.New(sessionStore)
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: get session service: %w", err))
|
||||
}
|
||||
listUserSessionsService, err := listusersessions.New(sessionStore)
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: list user sessions service: %w", err))
|
||||
}
|
||||
revokeDeviceSessionService, err := revokedevicesession.NewWithObservability(sessionStore, publisher, clock, logger, telemetryRuntime)
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: revoke device session service: %w", err))
|
||||
}
|
||||
revokeAllUserSessionsService, err := revokeallusersessions.NewWithObservability(sessionStore, userDirectory, publisher, clock, logger, telemetryRuntime)
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: revoke all user sessions service: %w", err))
|
||||
}
|
||||
blockUserService, err := blockuser.NewWithObservability(userDirectory, sessionStore, publisher, clock, logger, telemetryRuntime)
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: block user service: %w", err))
|
||||
}
|
||||
|
||||
publicServer, err := publichttp.NewServer(cfg.PublicHTTP, publichttp.Dependencies{
|
||||
SendEmailCode: sendEmailCodeService,
|
||||
ConfirmEmailCode: confirmEmailCodeService,
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: public HTTP server: %w", err))
|
||||
}
|
||||
|
||||
internalServer, err := internalhttp.NewServer(cfg.InternalHTTP, internalhttp.Dependencies{
|
||||
GetSession: getSessionService,
|
||||
ListUserSessions: listUserSessionsService,
|
||||
RevokeDeviceSession: revokeDeviceSessionService,
|
||||
RevokeAllUserSessions: revokeAllUserSessionsService,
|
||||
BlockUser: blockUserService,
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
})
|
||||
if err != nil {
|
||||
return cleanupOnError(fmt.Errorf("new authsession runtime: internal HTTP server: %w", err))
|
||||
}
|
||||
|
||||
runtime.App = New(cfg, publicServer, internalServer)
|
||||
return runtime, nil
|
||||
}
|
||||
|
||||
// Close releases the runtime-managed adapter resources. Close is idempotent in
|
||||
// practice because every underlying adapter Close method is idempotent.
|
||||
func (r *Runtime) Close() error {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var joined error
|
||||
for index := len(r.cleanupFns) - 1; index >= 0; index-- {
|
||||
joined = errors.Join(joined, r.cleanupFns[index]())
|
||||
}
|
||||
|
||||
return joined
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/config"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNewRuntimeStartsAndStopsHTTPServers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
redisServer := miniredis.RunT(t)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Redis.Addr = redisServer.Addr()
|
||||
cfg.PublicHTTP.Addr = mustFreeAddr(t)
|
||||
cfg.InternalHTTP.Addr = mustFreeAddr(t)
|
||||
|
||||
runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, runtime.Close())
|
||||
}()
|
||||
|
||||
runCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- runtime.App.Run(runCtx)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
response, err := http.Post(
|
||||
"http://"+cfg.PublicHTTP.Addr+"/api/v1/public/auth/send-email-code",
|
||||
"application/json",
|
||||
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
|
||||
)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
_, _ = io.ReadAll(response.Body)
|
||||
|
||||
return response.StatusCode == http.StatusOK
|
||||
}, 5*time.Second, 25*time.Millisecond)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
response, err := http.Get("http://" + cfg.InternalHTTP.Addr + "/api/v1/internal/sessions/missing")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
_, _ = io.ReadAll(response.Body)
|
||||
|
||||
return response.StatusCode == http.StatusNotFound
|
||||
}, 5*time.Second, 25*time.Millisecond)
|
||||
|
||||
cancel()
|
||||
require.NoError(t, <-runErrCh)
|
||||
}
|
||||
|
||||
func TestNewRuntimeUsesRESTUserDirectoryWhenConfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
redisServer := miniredis.RunT(t)
|
||||
userService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/api/v1/internal/users/user-1/exists" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"exists":true}`)
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer userService.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Redis.Addr = redisServer.Addr()
|
||||
cfg.PublicHTTP.Addr = mustFreeAddr(t)
|
||||
cfg.InternalHTTP.Addr = mustFreeAddr(t)
|
||||
cfg.UserService.Mode = "rest"
|
||||
cfg.UserService.BaseURL = userService.URL
|
||||
cfg.UserService.RequestTimeout = 250 * time.Millisecond
|
||||
|
||||
runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, runtime.Close())
|
||||
}()
|
||||
|
||||
runCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- runtime.App.Run(runCtx)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
response, err := http.Post(
|
||||
"http://"+cfg.InternalHTTP.Addr+"/api/v1/internal/users/user-1/sessions/revoke-all",
|
||||
"application/json",
|
||||
bytes.NewBufferString(`{"reason_code":"logout_all","actor":{"type":"system"}}`),
|
||||
)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return response.StatusCode == http.StatusOK &&
|
||||
bytes.Contains(payload, []byte(`"outcome":"no_active_sessions"`)) &&
|
||||
bytes.Contains(payload, []byte(`"user_id":"user-1"`))
|
||||
}, 5*time.Second, 25*time.Millisecond)
|
||||
|
||||
cancel()
|
||||
require.NoError(t, <-runErrCh)
|
||||
}
|
||||
|
||||
func TestNewRuntimeUsesRESTMailSenderWhenConfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
redisServer := miniredis.RunT(t)
|
||||
var calls atomic.Int64
|
||||
mailService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/api/v1/internal/login-code-deliveries" {
|
||||
calls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"outcome":"suppressed"}`)
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer mailService.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Redis.Addr = redisServer.Addr()
|
||||
cfg.PublicHTTP.Addr = mustFreeAddr(t)
|
||||
cfg.InternalHTTP.Addr = mustFreeAddr(t)
|
||||
cfg.MailService.Mode = "rest"
|
||||
cfg.MailService.BaseURL = mailService.URL
|
||||
cfg.MailService.RequestTimeout = 250 * time.Millisecond
|
||||
|
||||
runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, runtime.Close())
|
||||
}()
|
||||
|
||||
runCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- runtime.App.Run(runCtx)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
response, err := http.Post(
|
||||
"http://"+cfg.PublicHTTP.Addr+"/api/v1/public/auth/send-email-code",
|
||||
"application/json",
|
||||
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
|
||||
)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return response.StatusCode == http.StatusOK &&
|
||||
bytes.Contains(payload, []byte(`"challenge_id":"`)) &&
|
||||
calls.Load() == 1
|
||||
}, 5*time.Second, 25*time.Millisecond)
|
||||
|
||||
cancel()
|
||||
require.NoError(t, <-runErrCh)
|
||||
}
|
||||
|
||||
func mustFreeAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, listener.Close())
|
||||
}()
|
||||
|
||||
return listener.Addr().String()
|
||||
}
|
||||
@@ -0,0 +1,610 @@
|
||||
// Package config loads the authsession process configuration from environment
|
||||
// variables.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/api/internalhttp"
|
||||
"galaxy/authsession/internal/api/publichttp"
|
||||
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
const (
|
||||
shutdownTimeoutEnvVar = "AUTHSESSION_SHUTDOWN_TIMEOUT"
|
||||
logLevelEnvVar = "AUTHSESSION_LOG_LEVEL"
|
||||
|
||||
publicHTTPAddrEnvVar = "AUTHSESSION_PUBLIC_HTTP_ADDR"
|
||||
publicHTTPReadHeaderTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_READ_HEADER_TIMEOUT"
|
||||
publicHTTPReadTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_READ_TIMEOUT"
|
||||
publicHTTPIdleTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_IDLE_TIMEOUT"
|
||||
publicHTTPRequestTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_REQUEST_TIMEOUT"
|
||||
|
||||
internalHTTPAddrEnvVar = "AUTHSESSION_INTERNAL_HTTP_ADDR"
|
||||
internalHTTPReadHeaderTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_READ_HEADER_TIMEOUT"
|
||||
internalHTTPReadTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_READ_TIMEOUT"
|
||||
internalHTTPIdleTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_IDLE_TIMEOUT"
|
||||
internalHTTPRequestTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_REQUEST_TIMEOUT"
|
||||
|
||||
redisAddrEnvVar = "AUTHSESSION_REDIS_ADDR"
|
||||
redisUsernameEnvVar = "AUTHSESSION_REDIS_USERNAME"
|
||||
redisPasswordEnvVar = "AUTHSESSION_REDIS_PASSWORD"
|
||||
redisDBEnvVar = "AUTHSESSION_REDIS_DB"
|
||||
redisTLSEnabledEnvVar = "AUTHSESSION_REDIS_TLS_ENABLED"
|
||||
redisOperationTimeoutEnvVar = "AUTHSESSION_REDIS_OPERATION_TIMEOUT"
|
||||
|
||||
redisChallengeKeyPrefixEnvVar = "AUTHSESSION_REDIS_CHALLENGE_KEY_PREFIX"
|
||||
redisSessionKeyPrefixEnvVar = "AUTHSESSION_REDIS_SESSION_KEY_PREFIX"
|
||||
redisUserSessionsKeyPrefixEnvVar = "AUTHSESSION_REDIS_USER_SESSIONS_KEY_PREFIX"
|
||||
redisUserActiveSessionsKeyPrefixEnvVar = "AUTHSESSION_REDIS_USER_ACTIVE_SESSIONS_KEY_PREFIX"
|
||||
redisSessionLimitKeyEnvVar = "AUTHSESSION_REDIS_SESSION_LIMIT_KEY"
|
||||
redisGatewaySessionCacheKeyPrefixEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_CACHE_KEY_PREFIX"
|
||||
redisGatewaySessionEventsStreamEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM"
|
||||
redisGatewaySessionEventsStreamMaxLenEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM_MAX_LEN"
|
||||
redisSendEmailCodeThrottleKeyPrefixEnvVar = "AUTHSESSION_REDIS_SEND_EMAIL_CODE_THROTTLE_KEY_PREFIX"
|
||||
|
||||
userServiceModeEnvVar = "AUTHSESSION_USER_SERVICE_MODE"
|
||||
userServiceBaseURLEnvVar = "AUTHSESSION_USER_SERVICE_BASE_URL"
|
||||
userServiceRequestTimeoutEnvVar = "AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT"
|
||||
|
||||
mailServiceModeEnvVar = "AUTHSESSION_MAIL_SERVICE_MODE"
|
||||
mailServiceBaseURLEnvVar = "AUTHSESSION_MAIL_SERVICE_BASE_URL"
|
||||
mailServiceRequestTimeoutEnvVar = "AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT"
|
||||
|
||||
otelServiceNameEnvVar = "OTEL_SERVICE_NAME"
|
||||
otelTracesExporterEnvVar = "OTEL_TRACES_EXPORTER"
|
||||
otelMetricsExporterEnvVar = "OTEL_METRICS_EXPORTER"
|
||||
otelExporterOTLPProtocolEnvVar = "OTEL_EXPORTER_OTLP_PROTOCOL"
|
||||
otelExporterOTLPTracesProtocolEnvVar = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"
|
||||
otelExporterOTLPMetricsProtocolEnvVar = "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL"
|
||||
otelStdoutTracesEnabledEnvVar = "AUTHSESSION_OTEL_STDOUT_TRACES_ENABLED"
|
||||
otelStdoutMetricsEnabledEnvVar = "AUTHSESSION_OTEL_STDOUT_METRICS_ENABLED"
|
||||
|
||||
defaultShutdownTimeout = 5 * time.Second
|
||||
defaultLogLevel = "info"
|
||||
defaultRedisDB = 0
|
||||
defaultRedisOperationTimeout = 250 * time.Millisecond
|
||||
defaultChallengeKeyPrefix = "authsession:challenge:"
|
||||
defaultSessionKeyPrefix = "authsession:session:"
|
||||
defaultUserSessionsKeyPrefix = "authsession:user-sessions:"
|
||||
defaultUserActiveSessionsKeyPrefix = "authsession:user-active-sessions:"
|
||||
defaultSessionLimitKey = "authsession:config:active-session-limit"
|
||||
defaultGatewaySessionCacheKeyPrefix = "gateway:session:"
|
||||
defaultGatewaySessionEventsStream = "gateway:session_events"
|
||||
defaultGatewaySessionEventsStreamMaxLen = 1024
|
||||
defaultSendEmailCodeThrottleKeyPrefix = "authsession:send-email-code-throttle:"
|
||||
defaultUserServiceMode = userServiceModeStub
|
||||
defaultUserServiceRequestTimeout = time.Second
|
||||
defaultMailServiceMode = mailServiceModeStub
|
||||
defaultMailServiceRequestTimeout = time.Second
|
||||
defaultOTelServiceName = "galaxy-authsession"
|
||||
otelExporterNone = "none"
|
||||
otelExporterOTLP = "otlp"
|
||||
otelProtocolHTTPProtobuf = "http/protobuf"
|
||||
otelProtocolGRPC = "grpc"
|
||||
userServiceModeStub = "stub"
|
||||
userServiceModeREST = "rest"
|
||||
mailServiceModeStub = "stub"
|
||||
mailServiceModeREST = "rest"
|
||||
)
|
||||
|
||||
// Config stores the full process-level authsession configuration.
|
||||
type Config struct {
|
||||
// ShutdownTimeout bounds graceful shutdown of every long-lived component.
|
||||
ShutdownTimeout time.Duration
|
||||
|
||||
// Logging configures the process-wide structured logger.
|
||||
Logging LoggingConfig
|
||||
|
||||
// PublicHTTP configures the public HTTP listener.
|
||||
PublicHTTP publichttp.Config
|
||||
|
||||
// InternalHTTP configures the trusted internal HTTP listener.
|
||||
InternalHTTP internalhttp.Config
|
||||
|
||||
// Redis configures the Redis-backed adapters.
|
||||
Redis RedisConfig
|
||||
|
||||
// UserService configures the selectable runtime user-directory adapter.
|
||||
UserService UserServiceConfig
|
||||
|
||||
// MailService configures the selectable runtime mail-delivery adapter.
|
||||
MailService MailServiceConfig
|
||||
|
||||
// Telemetry configures the process-wide OpenTelemetry runtime.
|
||||
Telemetry TelemetryConfig
|
||||
}
|
||||
|
||||
// LoggingConfig configures the process-wide structured logger.
|
||||
type LoggingConfig struct {
|
||||
// Level stores the zap-compatible log level string.
|
||||
Level string
|
||||
}
|
||||
|
||||
// RedisConfig configures the Redis-backed authsession adapters.
|
||||
type RedisConfig struct {
|
||||
// Addr is the shared Redis address used by the authsession adapters.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled configures whether Redis connections use TLS.
|
||||
TLSEnabled bool
|
||||
|
||||
// OperationTimeout bounds each adapter Redis round trip.
|
||||
OperationTimeout time.Duration
|
||||
|
||||
// ChallengeKeyPrefix namespaces the challenge source-of-truth records.
|
||||
ChallengeKeyPrefix string
|
||||
|
||||
// SessionKeyPrefix namespaces the primary session records.
|
||||
SessionKeyPrefix string
|
||||
|
||||
// UserSessionsKeyPrefix namespaces the all-session user index.
|
||||
UserSessionsKeyPrefix string
|
||||
|
||||
// UserActiveSessionsKeyPrefix namespaces the active-session user index.
|
||||
UserActiveSessionsKeyPrefix string
|
||||
|
||||
// SessionLimitKey stores the exact session-limit Redis key.
|
||||
SessionLimitKey string
|
||||
|
||||
// GatewaySessionCacheKeyPrefix namespaces the projected gateway session
|
||||
// cache keys.
|
||||
GatewaySessionCacheKeyPrefix string
|
||||
|
||||
// GatewaySessionEventsStream stores the projected gateway session-events
|
||||
// Redis Stream key.
|
||||
GatewaySessionEventsStream string
|
||||
|
||||
// GatewaySessionEventsStreamMaxLen bounds the projected gateway session
|
||||
// event stream with approximate trimming.
|
||||
GatewaySessionEventsStreamMaxLen int64
|
||||
|
||||
// SendEmailCodeThrottleKeyPrefix namespaces the resend-throttle TTL keys.
|
||||
SendEmailCodeThrottleKeyPrefix string
|
||||
}
|
||||
|
||||
// UserServiceConfig configures the runtime user-directory integration mode.
|
||||
type UserServiceConfig struct {
|
||||
// Mode selects the runtime adapter implementation. Supported values are
|
||||
// `stub` and `rest`.
|
||||
Mode string
|
||||
|
||||
// BaseURL is the absolute base URL of the REST-backed user-service when
|
||||
// Mode is `rest`.
|
||||
BaseURL string
|
||||
|
||||
// RequestTimeout bounds each outbound user-service request when Mode is
|
||||
// `rest`.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// MailServiceConfig configures the runtime mail-delivery integration mode.
|
||||
type MailServiceConfig struct {
|
||||
// Mode selects the runtime adapter implementation. Supported values are
|
||||
// `stub` and `rest`.
|
||||
Mode string
|
||||
|
||||
// BaseURL is the absolute base URL of the REST-backed mail service when
|
||||
// Mode is `rest`.
|
||||
BaseURL string
|
||||
|
||||
// RequestTimeout bounds each outbound mail-service request when Mode is
|
||||
// `rest`.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// TelemetryConfig configures the authsession OpenTelemetry runtime.
|
||||
type TelemetryConfig struct {
|
||||
// ServiceName overrides the default OpenTelemetry service name.
|
||||
ServiceName string
|
||||
|
||||
// TracesExporter selects the external traces exporter. Supported values are
|
||||
// `none` and `otlp`.
|
||||
TracesExporter string
|
||||
|
||||
// MetricsExporter selects the external metrics exporter. Supported values
|
||||
// are `none` and `otlp`.
|
||||
MetricsExporter string
|
||||
|
||||
// TracesProtocol selects the OTLP traces protocol when TracesExporter is
|
||||
// `otlp`.
|
||||
TracesProtocol string
|
||||
|
||||
// MetricsProtocol selects the OTLP metrics protocol when MetricsExporter is
|
||||
// `otlp`.
|
||||
MetricsProtocol string
|
||||
|
||||
// StdoutTracesEnabled enables the additional stdout trace exporter used for
|
||||
// local development and debugging.
|
||||
StdoutTracesEnabled bool
|
||||
|
||||
// StdoutMetricsEnabled enables the additional stdout metric exporter used
|
||||
// for local development and debugging.
|
||||
StdoutMetricsEnabled bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default authsession process configuration with all
|
||||
// optional values filled.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
ShutdownTimeout: defaultShutdownTimeout,
|
||||
Logging: LoggingConfig{
|
||||
Level: defaultLogLevel,
|
||||
},
|
||||
PublicHTTP: publichttp.DefaultConfig(),
|
||||
InternalHTTP: internalhttp.DefaultConfig(),
|
||||
Redis: RedisConfig{
|
||||
DB: defaultRedisDB,
|
||||
OperationTimeout: defaultRedisOperationTimeout,
|
||||
ChallengeKeyPrefix: defaultChallengeKeyPrefix,
|
||||
SessionKeyPrefix: defaultSessionKeyPrefix,
|
||||
UserSessionsKeyPrefix: defaultUserSessionsKeyPrefix,
|
||||
UserActiveSessionsKeyPrefix: defaultUserActiveSessionsKeyPrefix,
|
||||
SessionLimitKey: defaultSessionLimitKey,
|
||||
GatewaySessionCacheKeyPrefix: defaultGatewaySessionCacheKeyPrefix,
|
||||
GatewaySessionEventsStream: defaultGatewaySessionEventsStream,
|
||||
GatewaySessionEventsStreamMaxLen: defaultGatewaySessionEventsStreamMaxLen,
|
||||
SendEmailCodeThrottleKeyPrefix: defaultSendEmailCodeThrottleKeyPrefix,
|
||||
},
|
||||
UserService: UserServiceConfig{
|
||||
Mode: defaultUserServiceMode,
|
||||
RequestTimeout: defaultUserServiceRequestTimeout,
|
||||
},
|
||||
MailService: MailServiceConfig{
|
||||
Mode: defaultMailServiceMode,
|
||||
RequestTimeout: defaultMailServiceRequestTimeout,
|
||||
},
|
||||
Telemetry: TelemetryConfig{
|
||||
ServiceName: defaultOTelServiceName,
|
||||
TracesExporter: otelExporterNone,
|
||||
MetricsExporter: otelExporterNone,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads the authsession process configuration from environment
|
||||
// variables, applying documented defaults where appropriate.
|
||||
func LoadFromEnv() (Config, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
var err error
|
||||
|
||||
cfg.ShutdownTimeout, err = loadDurationEnvWithDefault(shutdownTimeoutEnvVar, cfg.ShutdownTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
|
||||
cfg.Logging.Level = loadStringEnvWithDefault(logLevelEnvVar, cfg.Logging.Level)
|
||||
if err := validateLogLevel(cfg.Logging.Level); err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %s: %w", logLevelEnvVar, err)
|
||||
}
|
||||
|
||||
cfg.PublicHTTP.Addr = loadStringEnvWithDefault(publicHTTPAddrEnvVar, cfg.PublicHTTP.Addr)
|
||||
cfg.PublicHTTP.ReadHeaderTimeout, err = loadDurationEnvWithDefault(publicHTTPReadHeaderTimeoutEnvVar, cfg.PublicHTTP.ReadHeaderTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.PublicHTTP.ReadTimeout, err = loadDurationEnvWithDefault(publicHTTPReadTimeoutEnvVar, cfg.PublicHTTP.ReadTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.PublicHTTP.IdleTimeout, err = loadDurationEnvWithDefault(publicHTTPIdleTimeoutEnvVar, cfg.PublicHTTP.IdleTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.PublicHTTP.RequestTimeout, err = loadDurationEnvWithDefault(publicHTTPRequestTimeoutEnvVar, cfg.PublicHTTP.RequestTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
|
||||
cfg.InternalHTTP.Addr = loadStringEnvWithDefault(internalHTTPAddrEnvVar, cfg.InternalHTTP.Addr)
|
||||
cfg.InternalHTTP.ReadHeaderTimeout, err = loadDurationEnvWithDefault(internalHTTPReadHeaderTimeoutEnvVar, cfg.InternalHTTP.ReadHeaderTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.InternalHTTP.ReadTimeout, err = loadDurationEnvWithDefault(internalHTTPReadTimeoutEnvVar, cfg.InternalHTTP.ReadTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.InternalHTTP.IdleTimeout, err = loadDurationEnvWithDefault(internalHTTPIdleTimeoutEnvVar, cfg.InternalHTTP.IdleTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.InternalHTTP.RequestTimeout, err = loadDurationEnvWithDefault(internalHTTPRequestTimeoutEnvVar, cfg.InternalHTTP.RequestTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
|
||||
cfg.Redis.Addr = loadStringEnvWithDefault(redisAddrEnvVar, cfg.Redis.Addr)
|
||||
cfg.Redis.Username = os.Getenv(redisUsernameEnvVar)
|
||||
cfg.Redis.Password = os.Getenv(redisPasswordEnvVar)
|
||||
cfg.Redis.DB, err = loadIntEnvWithDefault(redisDBEnvVar, cfg.Redis.DB)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.Redis.TLSEnabled, err = loadBoolEnvWithDefault(redisTLSEnabledEnvVar, cfg.Redis.TLSEnabled)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.Redis.OperationTimeout, err = loadDurationEnvWithDefault(redisOperationTimeoutEnvVar, cfg.Redis.OperationTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.Redis.ChallengeKeyPrefix = loadStringEnvWithDefault(redisChallengeKeyPrefixEnvVar, cfg.Redis.ChallengeKeyPrefix)
|
||||
cfg.Redis.SessionKeyPrefix = loadStringEnvWithDefault(redisSessionKeyPrefixEnvVar, cfg.Redis.SessionKeyPrefix)
|
||||
cfg.Redis.UserSessionsKeyPrefix = loadStringEnvWithDefault(redisUserSessionsKeyPrefixEnvVar, cfg.Redis.UserSessionsKeyPrefix)
|
||||
cfg.Redis.UserActiveSessionsKeyPrefix = loadStringEnvWithDefault(redisUserActiveSessionsKeyPrefixEnvVar, cfg.Redis.UserActiveSessionsKeyPrefix)
|
||||
cfg.Redis.SessionLimitKey = loadStringEnvWithDefault(redisSessionLimitKeyEnvVar, cfg.Redis.SessionLimitKey)
|
||||
cfg.Redis.GatewaySessionCacheKeyPrefix = loadStringEnvWithDefault(redisGatewaySessionCacheKeyPrefixEnvVar, cfg.Redis.GatewaySessionCacheKeyPrefix)
|
||||
cfg.Redis.GatewaySessionEventsStream = loadStringEnvWithDefault(redisGatewaySessionEventsStreamEnvVar, cfg.Redis.GatewaySessionEventsStream)
|
||||
streamMaxLen, err := loadInt64EnvWithDefault(redisGatewaySessionEventsStreamMaxLenEnvVar, cfg.Redis.GatewaySessionEventsStreamMaxLen)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.Redis.GatewaySessionEventsStreamMaxLen = streamMaxLen
|
||||
cfg.Redis.SendEmailCodeThrottleKeyPrefix = loadStringEnvWithDefault(redisSendEmailCodeThrottleKeyPrefixEnvVar, cfg.Redis.SendEmailCodeThrottleKeyPrefix)
|
||||
|
||||
cfg.UserService.Mode = strings.TrimSpace(loadStringEnvWithDefault(userServiceModeEnvVar, cfg.UserService.Mode))
|
||||
cfg.UserService.BaseURL = loadStringEnvWithDefault(userServiceBaseURLEnvVar, cfg.UserService.BaseURL)
|
||||
cfg.UserService.RequestTimeout, err = loadDurationEnvWithDefault(userServiceRequestTimeoutEnvVar, cfg.UserService.RequestTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
|
||||
cfg.MailService.Mode = strings.TrimSpace(loadStringEnvWithDefault(mailServiceModeEnvVar, cfg.MailService.Mode))
|
||||
cfg.MailService.BaseURL = loadStringEnvWithDefault(mailServiceBaseURLEnvVar, cfg.MailService.BaseURL)
|
||||
cfg.MailService.RequestTimeout, err = loadDurationEnvWithDefault(mailServiceRequestTimeoutEnvVar, cfg.MailService.RequestTimeout)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
|
||||
cfg.Telemetry.ServiceName = loadStringEnvWithDefault(otelServiceNameEnvVar, cfg.Telemetry.ServiceName)
|
||||
cfg.Telemetry.TracesExporter = normalizeExporterValue(loadStringEnvWithDefault(otelTracesExporterEnvVar, cfg.Telemetry.TracesExporter))
|
||||
cfg.Telemetry.MetricsExporter = normalizeExporterValue(loadStringEnvWithDefault(otelMetricsExporterEnvVar, cfg.Telemetry.MetricsExporter))
|
||||
cfg.Telemetry.TracesProtocol = loadOTLPProtocol(
|
||||
os.Getenv(otelExporterOTLPTracesProtocolEnvVar),
|
||||
os.Getenv(otelExporterOTLPProtocolEnvVar),
|
||||
cfg.Telemetry.TracesExporter,
|
||||
)
|
||||
cfg.Telemetry.MetricsProtocol = loadOTLPProtocol(
|
||||
os.Getenv(otelExporterOTLPMetricsProtocolEnvVar),
|
||||
os.Getenv(otelExporterOTLPProtocolEnvVar),
|
||||
cfg.Telemetry.MetricsExporter,
|
||||
)
|
||||
cfg.Telemetry.StdoutTracesEnabled, err = loadBoolEnvWithDefault(otelStdoutTracesEnabledEnvVar, cfg.Telemetry.StdoutTracesEnabled)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
cfg.Telemetry.StdoutMetricsEnabled, err = loadBoolEnvWithDefault(otelStdoutMetricsEnabledEnvVar, cfg.Telemetry.StdoutMetricsEnabled)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Validate reports whether cfg contains a consistent authsession process
|
||||
// configuration.
|
||||
func (cfg Config) Validate() error {
|
||||
switch {
|
||||
case cfg.ShutdownTimeout <= 0:
|
||||
return fmt.Errorf("load authsession config: %s must be positive", shutdownTimeoutEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.Addr) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisAddrEnvVar)
|
||||
case cfg.Redis.DB < 0:
|
||||
return fmt.Errorf("load authsession config: %s must not be negative", redisDBEnvVar)
|
||||
case cfg.Redis.OperationTimeout <= 0:
|
||||
return fmt.Errorf("load authsession config: %s must be positive", redisOperationTimeoutEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.ChallengeKeyPrefix) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisChallengeKeyPrefixEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.SessionKeyPrefix) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisSessionKeyPrefixEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.UserSessionsKeyPrefix) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisUserSessionsKeyPrefixEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.UserActiveSessionsKeyPrefix) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisUserActiveSessionsKeyPrefixEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.SessionLimitKey) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisSessionLimitKeyEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.GatewaySessionCacheKeyPrefix) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisGatewaySessionCacheKeyPrefixEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.GatewaySessionEventsStream) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisGatewaySessionEventsStreamEnvVar)
|
||||
case cfg.Redis.GatewaySessionEventsStreamMaxLen <= 0:
|
||||
return fmt.Errorf("load authsession config: %s must be positive", redisGatewaySessionEventsStreamMaxLenEnvVar)
|
||||
case strings.TrimSpace(cfg.Redis.SendEmailCodeThrottleKeyPrefix) == "":
|
||||
return fmt.Errorf("load authsession config: %s must not be empty", redisSendEmailCodeThrottleKeyPrefixEnvVar)
|
||||
}
|
||||
|
||||
if err := cfg.PublicHTTP.Validate(); err != nil {
|
||||
return fmt.Errorf("load authsession config: public HTTP: %w", err)
|
||||
}
|
||||
if err := cfg.InternalHTTP.Validate(); err != nil {
|
||||
return fmt.Errorf("load authsession config: internal HTTP: %w", err)
|
||||
}
|
||||
if err := cfg.UserService.Validate(); err != nil {
|
||||
return fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
if err := cfg.MailService.Validate(); err != nil {
|
||||
return fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
if err := cfg.Telemetry.Validate(); err != nil {
|
||||
return fmt.Errorf("load authsession config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate reports whether cfg contains a supported user-service runtime
|
||||
// configuration.
|
||||
func (cfg UserServiceConfig) Validate() error {
|
||||
switch cfg.Mode {
|
||||
case userServiceModeStub:
|
||||
return nil
|
||||
case userServiceModeREST:
|
||||
if strings.TrimSpace(cfg.BaseURL) == "" {
|
||||
return fmt.Errorf("%s must not be empty in rest mode", userServiceBaseURLEnvVar)
|
||||
}
|
||||
if cfg.RequestTimeout <= 0 {
|
||||
return fmt.Errorf("%s must be positive in rest mode", userServiceRequestTimeoutEnvVar)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%s %q is unsupported", userServiceModeEnvVar, cfg.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate reports whether cfg contains a supported mail-service runtime
|
||||
// configuration.
|
||||
func (cfg MailServiceConfig) Validate() error {
|
||||
switch cfg.Mode {
|
||||
case mailServiceModeStub:
|
||||
return nil
|
||||
case mailServiceModeREST:
|
||||
if strings.TrimSpace(cfg.BaseURL) == "" {
|
||||
return fmt.Errorf("%s must not be empty in rest mode", mailServiceBaseURLEnvVar)
|
||||
}
|
||||
if cfg.RequestTimeout <= 0 {
|
||||
return fmt.Errorf("%s must be positive in rest mode", mailServiceRequestTimeoutEnvVar)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%s %q is unsupported", mailServiceModeEnvVar, cfg.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate reports whether cfg contains a supported OpenTelemetry exporter
|
||||
// configuration.
|
||||
func (cfg TelemetryConfig) Validate() error {
|
||||
switch cfg.TracesExporter {
|
||||
case otelExporterNone, otelExporterOTLP:
|
||||
default:
|
||||
return fmt.Errorf("%s %q is unsupported", otelTracesExporterEnvVar, cfg.TracesExporter)
|
||||
}
|
||||
|
||||
switch cfg.MetricsExporter {
|
||||
case otelExporterNone, otelExporterOTLP:
|
||||
default:
|
||||
return fmt.Errorf("%s %q is unsupported", otelMetricsExporterEnvVar, cfg.MetricsExporter)
|
||||
}
|
||||
|
||||
if cfg.TracesProtocol != "" && cfg.TracesProtocol != otelProtocolHTTPProtobuf && cfg.TracesProtocol != otelProtocolGRPC {
|
||||
return fmt.Errorf("%s %q is unsupported", otelExporterOTLPTracesProtocolEnvVar, cfg.TracesProtocol)
|
||||
}
|
||||
if cfg.MetricsProtocol != "" && cfg.MetricsProtocol != otelProtocolHTTPProtobuf && cfg.MetricsProtocol != otelProtocolGRPC {
|
||||
return fmt.Errorf("%s %q is unsupported", otelExporterOTLPMetricsProtocolEnvVar, cfg.MetricsProtocol)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadStringEnvWithDefault(name string, value string) string {
|
||||
if raw, ok := os.LookupEnv(name); ok {
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func loadDurationEnvWithDefault(name string, value time.Duration) (time.Duration, error) {
|
||||
raw, ok := os.LookupEnv(name)
|
||||
if !ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
parsed, err := time.ParseDuration(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s: %w", name, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func loadIntEnvWithDefault(name string, value int) (int, error) {
|
||||
raw, ok := os.LookupEnv(name)
|
||||
if !ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s: %w", name, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func loadInt64EnvWithDefault(name string, value int64) (int64, error) {
|
||||
raw, ok := os.LookupEnv(name)
|
||||
if !ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s: %w", name, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func loadBoolEnvWithDefault(name string, value bool) (bool, error) {
|
||||
raw, ok := os.LookupEnv(name)
|
||||
if !ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseBool(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%s: %w", name, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateLogLevel(value string) error {
|
||||
var level zapcore.Level
|
||||
if err := level.UnmarshalText([]byte(strings.TrimSpace(value))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeExporterValue(value string) string {
|
||||
switch strings.TrimSpace(value) {
|
||||
case "", otelExporterNone:
|
||||
return otelExporterNone
|
||||
default:
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
|
||||
func loadOTLPProtocol(primary string, fallback string, exporter string) string {
|
||||
protocol := strings.TrimSpace(primary)
|
||||
if protocol == "" {
|
||||
protocol = strings.TrimSpace(fallback)
|
||||
}
|
||||
if protocol == "" && exporter == otelExporterOTLP {
|
||||
return otelProtocolHTTPProtobuf
|
||||
}
|
||||
|
||||
return protocol
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoadFromEnvUsesDefaults(t *testing.T) {
|
||||
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
|
||||
|
||||
cfg, err := LoadFromEnv()
|
||||
require.NoError(t, err)
|
||||
|
||||
defaults := DefaultConfig()
|
||||
assert.Equal(t, defaults.ShutdownTimeout, cfg.ShutdownTimeout)
|
||||
assert.Equal(t, defaults.Logging.Level, cfg.Logging.Level)
|
||||
assert.Equal(t, defaults.PublicHTTP, cfg.PublicHTTP)
|
||||
assert.Equal(t, defaults.InternalHTTP, cfg.InternalHTTP)
|
||||
assert.Equal(t, "127.0.0.1:6379", cfg.Redis.Addr)
|
||||
assert.Equal(t, defaults.Redis.DB, cfg.Redis.DB)
|
||||
assert.Equal(t, defaults.Redis.OperationTimeout, cfg.Redis.OperationTimeout)
|
||||
assert.Equal(t, defaults.UserService, cfg.UserService)
|
||||
assert.Equal(t, defaults.MailService, cfg.MailService)
|
||||
assert.Equal(t, defaults.Telemetry.ServiceName, cfg.Telemetry.ServiceName)
|
||||
assert.Equal(t, defaults.Telemetry.TracesExporter, cfg.Telemetry.TracesExporter)
|
||||
assert.Equal(t, defaults.Telemetry.MetricsExporter, cfg.Telemetry.MetricsExporter)
|
||||
assert.False(t, cfg.Telemetry.StdoutTracesEnabled)
|
||||
assert.False(t, cfg.Telemetry.StdoutMetricsEnabled)
|
||||
}
|
||||
|
||||
func TestLoadFromEnvAppliesOverrides(t *testing.T) {
|
||||
t.Setenv(shutdownTimeoutEnvVar, "9s")
|
||||
t.Setenv(logLevelEnvVar, "debug")
|
||||
t.Setenv(publicHTTPAddrEnvVar, "127.0.0.1:18080")
|
||||
t.Setenv(internalHTTPAddrEnvVar, "127.0.0.1:18081")
|
||||
t.Setenv(redisAddrEnvVar, "127.0.0.1:6380")
|
||||
t.Setenv(redisUsernameEnvVar, "alice")
|
||||
t.Setenv(redisPasswordEnvVar, "secret")
|
||||
t.Setenv(redisDBEnvVar, "3")
|
||||
t.Setenv(redisTLSEnabledEnvVar, "true")
|
||||
t.Setenv(redisOperationTimeoutEnvVar, "750ms")
|
||||
t.Setenv(userServiceModeEnvVar, "rest")
|
||||
t.Setenv(userServiceBaseURLEnvVar, "http://127.0.0.1:19090")
|
||||
t.Setenv(userServiceRequestTimeoutEnvVar, "900ms")
|
||||
t.Setenv(mailServiceModeEnvVar, "rest")
|
||||
t.Setenv(mailServiceBaseURLEnvVar, "http://127.0.0.1:19091")
|
||||
t.Setenv(mailServiceRequestTimeoutEnvVar, "950ms")
|
||||
t.Setenv(otelServiceNameEnvVar, "custom-authsession")
|
||||
t.Setenv(otelTracesExporterEnvVar, "otlp")
|
||||
t.Setenv(otelMetricsExporterEnvVar, "otlp")
|
||||
t.Setenv(otelExporterOTLPProtocolEnvVar, "grpc")
|
||||
t.Setenv(otelStdoutTracesEnabledEnvVar, "true")
|
||||
t.Setenv(otelStdoutMetricsEnabledEnvVar, "true")
|
||||
|
||||
cfg, err := LoadFromEnv()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 9*time.Second, cfg.ShutdownTimeout)
|
||||
assert.Equal(t, "debug", cfg.Logging.Level)
|
||||
assert.Equal(t, "127.0.0.1:18080", cfg.PublicHTTP.Addr)
|
||||
assert.Equal(t, "127.0.0.1:18081", cfg.InternalHTTP.Addr)
|
||||
assert.Equal(t, "127.0.0.1:6380", cfg.Redis.Addr)
|
||||
assert.Equal(t, "alice", cfg.Redis.Username)
|
||||
assert.Equal(t, "secret", cfg.Redis.Password)
|
||||
assert.Equal(t, 3, cfg.Redis.DB)
|
||||
assert.True(t, cfg.Redis.TLSEnabled)
|
||||
assert.Equal(t, 750*time.Millisecond, cfg.Redis.OperationTimeout)
|
||||
assert.Equal(t, UserServiceConfig{
|
||||
Mode: "rest",
|
||||
BaseURL: "http://127.0.0.1:19090",
|
||||
RequestTimeout: 900 * time.Millisecond,
|
||||
}, cfg.UserService)
|
||||
assert.Equal(t, MailServiceConfig{
|
||||
Mode: "rest",
|
||||
BaseURL: "http://127.0.0.1:19091",
|
||||
RequestTimeout: 950 * time.Millisecond,
|
||||
}, cfg.MailService)
|
||||
assert.Equal(t, "custom-authsession", cfg.Telemetry.ServiceName)
|
||||
assert.Equal(t, "otlp", cfg.Telemetry.TracesExporter)
|
||||
assert.Equal(t, "otlp", cfg.Telemetry.MetricsExporter)
|
||||
assert.Equal(t, "grpc", cfg.Telemetry.TracesProtocol)
|
||||
assert.Equal(t, "grpc", cfg.Telemetry.MetricsProtocol)
|
||||
assert.True(t, cfg.Telemetry.StdoutTracesEnabled)
|
||||
assert.True(t, cfg.Telemetry.StdoutMetricsEnabled)
|
||||
}
|
||||
|
||||
func TestLoadFromEnvRejectsInvalidValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envName string
|
||||
envVal string
|
||||
}{
|
||||
{name: "invalid duration", envName: shutdownTimeoutEnvVar, envVal: "later"},
|
||||
{name: "invalid bool", envName: otelStdoutTracesEnabledEnvVar, envVal: "sometimes"},
|
||||
{name: "invalid log level", envName: logLevelEnvVar, envVal: "verbose"},
|
||||
{name: "invalid traces protocol", envName: otelExporterOTLPTracesProtocolEnvVar, envVal: "udp"},
|
||||
{name: "invalid user service mode", envName: userServiceModeEnvVar, envVal: "grpc"},
|
||||
{name: "invalid user service timeout", envName: userServiceRequestTimeoutEnvVar, envVal: "never"},
|
||||
{name: "invalid mail service mode", envName: mailServiceModeEnvVar, envVal: "grpc"},
|
||||
{name: "invalid mail service timeout", envName: mailServiceRequestTimeoutEnvVar, envVal: "never"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
|
||||
t.Setenv(tt.envName, tt.envVal)
|
||||
if tt.envName == otelExporterOTLPTracesProtocolEnvVar {
|
||||
t.Setenv(otelTracesExporterEnvVar, "otlp")
|
||||
}
|
||||
|
||||
_, err := LoadFromEnv()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.envName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromEnvRejectsInvalidRESTUserServiceConfiguration(t *testing.T) {
|
||||
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
|
||||
t.Setenv(userServiceModeEnvVar, "rest")
|
||||
|
||||
t.Run("missing base url", func(t *testing.T) {
|
||||
_, err := LoadFromEnv()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), userServiceBaseURLEnvVar)
|
||||
})
|
||||
|
||||
t.Run("non positive timeout", func(t *testing.T) {
|
||||
t.Setenv(userServiceBaseURLEnvVar, "http://127.0.0.1:19090")
|
||||
t.Setenv(userServiceRequestTimeoutEnvVar, "0s")
|
||||
|
||||
_, err := LoadFromEnv()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), userServiceRequestTimeoutEnvVar)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadFromEnvRejectsInvalidRESTMailServiceConfiguration(t *testing.T) {
|
||||
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
|
||||
t.Setenv(mailServiceModeEnvVar, "rest")
|
||||
|
||||
t.Run("missing base url", func(t *testing.T) {
|
||||
_, err := LoadFromEnv()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), mailServiceBaseURLEnvVar)
|
||||
})
|
||||
|
||||
t.Run("non positive timeout", func(t *testing.T) {
|
||||
t.Setenv(mailServiceBaseURLEnvVar, "http://127.0.0.1:19091")
|
||||
t.Setenv(mailServiceRequestTimeoutEnvVar, "0s")
|
||||
|
||||
_, err := LoadFromEnv()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), mailServiceRequestTimeoutEnvVar)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
// Package challenge defines the source-of-truth domain model for one e-mail
|
||||
// confirmation challenge.
|
||||
package challenge
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
// Status identifies the coarse lifecycle state of one challenge.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusPendingSend reports that the challenge has been created but its
|
||||
// delivery outcome has not been recorded yet.
|
||||
StatusPendingSend Status = "pending_send"
|
||||
|
||||
// StatusSent reports that the confirmation code was delivered successfully.
|
||||
StatusSent Status = "sent"
|
||||
|
||||
// StatusDeliverySuppressed reports that outward send succeeded but actual
|
||||
// delivery was intentionally suppressed by policy.
|
||||
StatusDeliverySuppressed Status = "delivery_suppressed"
|
||||
|
||||
// StatusDeliveryThrottled reports that a fresh challenge was created but
|
||||
// delivery was skipped because the auth-side resend cooldown is still
|
||||
// active.
|
||||
StatusDeliveryThrottled Status = "delivery_throttled"
|
||||
|
||||
// StatusConfirmedPendingExpire reports that the challenge was confirmed
|
||||
// successfully and is temporarily retained for idempotent retry handling.
|
||||
StatusConfirmedPendingExpire Status = "confirmed_pending_expire"
|
||||
|
||||
// StatusExpired reports that the challenge can no longer be confirmed.
|
||||
StatusExpired Status = "expired"
|
||||
|
||||
// StatusFailed reports that the challenge reached a terminal failure state.
|
||||
StatusFailed Status = "failed"
|
||||
|
||||
// StatusCancelled reports that the challenge was cancelled explicitly.
|
||||
StatusCancelled Status = "cancelled"
|
||||
)
|
||||
|
||||
// IsKnown reports whether Status is one of the challenge states supported by
|
||||
// the current domain model.
|
||||
func (s Status) IsKnown() bool {
|
||||
switch s {
|
||||
case StatusPendingSend,
|
||||
StatusSent,
|
||||
StatusDeliverySuppressed,
|
||||
StatusDeliveryThrottled,
|
||||
StatusConfirmedPendingExpire,
|
||||
StatusExpired,
|
||||
StatusFailed,
|
||||
StatusCancelled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsTerminal reports whether Status can no longer accept any lifecycle
|
||||
// transition in the v1 challenge state machine.
|
||||
func (s Status) IsTerminal() bool {
|
||||
switch s {
|
||||
case StatusExpired, StatusFailed, StatusCancelled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// AcceptsFreshConfirm reports whether Status may still consume a first
|
||||
// successful confirmation attempt.
|
||||
func (s Status) AcceptsFreshConfirm() bool {
|
||||
switch s {
|
||||
case StatusSent, StatusDeliverySuppressed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsConfirmedRetryState reports whether Status should use the idempotent retry
|
||||
// path for a previously successful confirmation.
|
||||
func (s Status) IsConfirmedRetryState() bool {
|
||||
return s == StatusConfirmedPendingExpire
|
||||
}
|
||||
|
||||
// CanTransitionTo reports whether the current challenge Status may move to
|
||||
// next under the coarse lifecycle rules fixed by Stage 2.
|
||||
func (s Status) CanTransitionTo(next Status) bool {
|
||||
switch s {
|
||||
case StatusPendingSend:
|
||||
switch next {
|
||||
case StatusSent, StatusDeliverySuppressed, StatusDeliveryThrottled, StatusFailed, StatusCancelled, StatusExpired:
|
||||
return true
|
||||
}
|
||||
case StatusSent, StatusDeliverySuppressed:
|
||||
switch next {
|
||||
case StatusConfirmedPendingExpire, StatusFailed, StatusCancelled, StatusExpired:
|
||||
return true
|
||||
}
|
||||
case StatusConfirmedPendingExpire:
|
||||
return next == StatusExpired
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// DeliveryState identifies the recorded delivery result of one challenge.
|
||||
type DeliveryState string
|
||||
|
||||
const (
|
||||
// DeliveryPending reports that no delivery outcome has been recorded yet.
|
||||
DeliveryPending DeliveryState = "pending"
|
||||
|
||||
// DeliverySent reports that the challenge code was sent successfully.
|
||||
DeliverySent DeliveryState = "sent"
|
||||
|
||||
// DeliverySuppressed reports that the outward flow stays success-shaped
|
||||
// while actual delivery is intentionally skipped.
|
||||
DeliverySuppressed DeliveryState = "suppressed"
|
||||
|
||||
// DeliveryThrottled reports that the outward flow stays success-shaped
|
||||
// while actual delivery is skipped because the resend cooldown is active.
|
||||
DeliveryThrottled DeliveryState = "throttled"
|
||||
|
||||
// DeliveryFailed reports that delivery was attempted and failed explicitly.
|
||||
DeliveryFailed DeliveryState = "failed"
|
||||
)
|
||||
|
||||
// IsKnown reports whether DeliveryState is one of the delivery states
|
||||
// supported by the current domain model.
|
||||
func (s DeliveryState) IsKnown() bool {
|
||||
switch s {
|
||||
case DeliveryPending, DeliverySent, DeliverySuppressed, DeliveryThrottled, DeliveryFailed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CanTransitionTo reports whether the current DeliveryState may move to next
|
||||
// under the coarse delivery rules fixed by Stage 2.
|
||||
func (s DeliveryState) CanTransitionTo(next DeliveryState) bool {
|
||||
if s != DeliveryPending {
|
||||
return false
|
||||
}
|
||||
|
||||
switch next {
|
||||
case DeliverySent, DeliverySuppressed, DeliveryThrottled, DeliveryFailed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// AttemptCounters groups the mutable send and confirm counters tracked by one
|
||||
// challenge aggregate.
|
||||
type AttemptCounters struct {
|
||||
// Send counts delivery attempts initiated for the challenge.
|
||||
Send int
|
||||
|
||||
// Confirm counts confirmation attempts evaluated against the challenge.
|
||||
Confirm int
|
||||
}
|
||||
|
||||
// Validate reports whether AttemptCounters contains only non-negative values.
|
||||
func (c AttemptCounters) Validate() error {
|
||||
if c.Send < 0 {
|
||||
return errors.New("challenge send attempt count must not be negative")
|
||||
}
|
||||
if c.Confirm < 0 {
|
||||
return errors.New("challenge confirm attempt count must not be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AbuseMetadata stores minimal abuse-related timestamps without fixing later
|
||||
// anti-abuse policy details too early.
|
||||
type AbuseMetadata struct {
|
||||
// LastAttemptAt optionally records the last send or confirm attempt time
|
||||
// associated with the challenge.
|
||||
LastAttemptAt *time.Time
|
||||
}
|
||||
|
||||
// Validate reports whether AbuseMetadata contains structurally valid values.
|
||||
func (m AbuseMetadata) Validate() error {
|
||||
if m.LastAttemptAt != nil && m.LastAttemptAt.IsZero() {
|
||||
return errors.New("challenge abuse metadata last attempt time must not be zero")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirmation stores the idempotency metadata recorded after a successful
|
||||
// challenge confirmation.
|
||||
type Confirmation struct {
|
||||
// SessionID is the created device session returned by the successful
|
||||
// confirmation.
|
||||
SessionID common.DeviceSessionID
|
||||
|
||||
// ClientPublicKey is the validated client key bound to SessionID.
|
||||
ClientPublicKey common.ClientPublicKey
|
||||
|
||||
// ConfirmedAt records when the successful confirmation happened.
|
||||
ConfirmedAt time.Time
|
||||
}
|
||||
|
||||
// Validate reports whether Confirmation contains all metadata required for a
|
||||
// confirmed challenge.
|
||||
func (c Confirmation) Validate() error {
|
||||
if err := c.SessionID.Validate(); err != nil {
|
||||
return fmt.Errorf("challenge confirmation session id: %w", err)
|
||||
}
|
||||
if err := c.ClientPublicKey.Validate(); err != nil {
|
||||
return fmt.Errorf("challenge confirmation client public key: %w", err)
|
||||
}
|
||||
if c.ConfirmedAt.IsZero() {
|
||||
return errors.New("challenge confirmation time must not be zero")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Challenge is the minimal source-of-truth aggregate shape fixed by Stage 2.
|
||||
type Challenge struct {
|
||||
// ID identifies the challenge.
|
||||
ID common.ChallengeID
|
||||
|
||||
// Email stores the normalized target e-mail address.
|
||||
Email common.Email
|
||||
|
||||
// CodeHash stores only the hashed confirmation code.
|
||||
CodeHash []byte
|
||||
|
||||
// Status reports the coarse challenge lifecycle state.
|
||||
Status Status
|
||||
|
||||
// DeliveryState reports the recorded delivery outcome.
|
||||
DeliveryState DeliveryState
|
||||
|
||||
// CreatedAt reports when the challenge was created.
|
||||
CreatedAt time.Time
|
||||
|
||||
// ExpiresAt reports when the challenge becomes unusable.
|
||||
ExpiresAt time.Time
|
||||
|
||||
// Attempts groups the send and confirm counters.
|
||||
Attempts AttemptCounters
|
||||
|
||||
// Abuse stores minimal abuse-related timestamps.
|
||||
Abuse AbuseMetadata
|
||||
|
||||
// Confirmation is present only after a successful confirm transition.
|
||||
Confirmation *Confirmation
|
||||
}
|
||||
|
||||
// IsExpiredAt reports whether the challenge is unusable at now either because
|
||||
// it is already marked expired or because its expiration timestamp has passed.
|
||||
func (c Challenge) IsExpiredAt(now time.Time) bool {
|
||||
return c.Status == StatusExpired || !c.ExpiresAt.After(now)
|
||||
}
|
||||
|
||||
// Validate reports whether Challenge satisfies the Stage-2 structural and
|
||||
// lifecycle invariants.
|
||||
func (c Challenge) Validate() error {
|
||||
if err := c.ID.Validate(); err != nil {
|
||||
return fmt.Errorf("challenge id: %w", err)
|
||||
}
|
||||
if err := c.Email.Validate(); err != nil {
|
||||
return fmt.Errorf("challenge email: %w", err)
|
||||
}
|
||||
if len(c.CodeHash) == 0 {
|
||||
return errors.New("challenge code hash must not be empty")
|
||||
}
|
||||
if !c.Status.IsKnown() {
|
||||
return fmt.Errorf("challenge status %q is unsupported", c.Status)
|
||||
}
|
||||
if !c.DeliveryState.IsKnown() {
|
||||
return fmt.Errorf("challenge delivery state %q is unsupported", c.DeliveryState)
|
||||
}
|
||||
if c.CreatedAt.IsZero() {
|
||||
return errors.New("challenge creation time must not be zero")
|
||||
}
|
||||
if c.ExpiresAt.IsZero() {
|
||||
return errors.New("challenge expiration time must not be zero")
|
||||
}
|
||||
if c.ExpiresAt.Before(c.CreatedAt) {
|
||||
return errors.New("challenge expiration time must not be before creation time")
|
||||
}
|
||||
if err := c.Attempts.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Abuse.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.Status {
|
||||
case StatusPendingSend:
|
||||
if c.DeliveryState != DeliveryPending {
|
||||
return errors.New("pending_send challenge must keep pending delivery state")
|
||||
}
|
||||
case StatusSent:
|
||||
if c.DeliveryState != DeliverySent {
|
||||
return errors.New("sent challenge must keep sent delivery state")
|
||||
}
|
||||
case StatusDeliverySuppressed:
|
||||
if c.DeliveryState != DeliverySuppressed {
|
||||
return errors.New("delivery_suppressed challenge must keep suppressed delivery state")
|
||||
}
|
||||
case StatusDeliveryThrottled:
|
||||
if c.DeliveryState != DeliveryThrottled {
|
||||
return errors.New("delivery_throttled challenge must keep throttled delivery state")
|
||||
}
|
||||
case StatusConfirmedPendingExpire:
|
||||
if c.DeliveryState != DeliverySent && c.DeliveryState != DeliverySuppressed {
|
||||
return errors.New("confirmed_pending_expire challenge must come from sent or suppressed delivery state")
|
||||
}
|
||||
}
|
||||
|
||||
if c.Status == StatusConfirmedPendingExpire {
|
||||
if c.Confirmation == nil {
|
||||
return errors.New("confirmed_pending_expire challenge must contain confirmation metadata")
|
||||
}
|
||||
if err := c.Confirmation.Validate(); err != nil {
|
||||
return fmt.Errorf("challenge confirmation: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Confirmation != nil {
|
||||
return errors.New("only confirmed_pending_expire challenge may contain confirmation metadata")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,439 @@
|
||||
package challenge
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
func TestPolicyConstants(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if InitialTTL != 5*time.Minute {
|
||||
require.Failf(t, "test failed", "InitialTTL = %s, want %s", InitialTTL, 5*time.Minute)
|
||||
}
|
||||
if ResendThrottleCooldown != time.Minute {
|
||||
require.Failf(t, "test failed", "ResendThrottleCooldown = %s, want %s", ResendThrottleCooldown, time.Minute)
|
||||
}
|
||||
if ConfirmedRetention != 5*time.Minute {
|
||||
require.Failf(t, "test failed", "ConfirmedRetention = %s, want %s", ConfirmedRetention, 5*time.Minute)
|
||||
}
|
||||
if MaxInvalidConfirmAttempts != 5 {
|
||||
require.Failf(t, "test failed", "MaxInvalidConfirmAttempts = %d, want %d", MaxInvalidConfirmAttempts, 5)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Status
|
||||
want bool
|
||||
}{
|
||||
{name: "pending send", value: StatusPendingSend, want: true},
|
||||
{name: "sent", value: StatusSent, want: true},
|
||||
{name: "suppressed", value: StatusDeliverySuppressed, want: true},
|
||||
{name: "throttled", value: StatusDeliveryThrottled, want: true},
|
||||
{name: "confirmed", value: StatusConfirmedPendingExpire, want: true},
|
||||
{name: "expired", value: StatusExpired, want: true},
|
||||
{name: "failed", value: StatusFailed, want: true},
|
||||
{name: "cancelled", value: StatusCancelled, want: true},
|
||||
{name: "unknown", value: Status("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusIsTerminal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Status
|
||||
want bool
|
||||
}{
|
||||
{name: "pending send", value: StatusPendingSend, want: false},
|
||||
{name: "sent", value: StatusSent, want: false},
|
||||
{name: "delivery suppressed", value: StatusDeliverySuppressed, want: false},
|
||||
{name: "delivery throttled", value: StatusDeliveryThrottled, want: false},
|
||||
{name: "confirmed pending expire", value: StatusConfirmedPendingExpire, want: false},
|
||||
{name: "expired", value: StatusExpired, want: true},
|
||||
{name: "failed", value: StatusFailed, want: true},
|
||||
{name: "cancelled", value: StatusCancelled, want: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsTerminal(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsTerminal() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusAcceptsFreshConfirm(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Status
|
||||
want bool
|
||||
}{
|
||||
{name: "pending send", value: StatusPendingSend, want: false},
|
||||
{name: "sent", value: StatusSent, want: true},
|
||||
{name: "delivery suppressed", value: StatusDeliverySuppressed, want: true},
|
||||
{name: "delivery throttled", value: StatusDeliveryThrottled, want: false},
|
||||
{name: "confirmed", value: StatusConfirmedPendingExpire, want: false},
|
||||
{name: "expired", value: StatusExpired, want: false},
|
||||
{name: "failed", value: StatusFailed, want: false},
|
||||
{name: "cancelled", value: StatusCancelled, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.AcceptsFreshConfirm(); got != tt.want {
|
||||
require.Failf(t, "test failed", "AcceptsFreshConfirm() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusIsConfirmedRetryState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Status
|
||||
want bool
|
||||
}{
|
||||
{name: "sent", value: StatusSent, want: false},
|
||||
{name: "delivery suppressed", value: StatusDeliverySuppressed, want: false},
|
||||
{name: "delivery throttled", value: StatusDeliveryThrottled, want: false},
|
||||
{name: "confirmed", value: StatusConfirmedPendingExpire, want: true},
|
||||
{name: "expired", value: StatusExpired, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsConfirmedRetryState(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsConfirmedRetryState() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusCanTransitionTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
from Status
|
||||
to Status
|
||||
want bool
|
||||
}{
|
||||
{name: "pending to sent", from: StatusPendingSend, to: StatusSent, want: true},
|
||||
{name: "pending to suppressed", from: StatusPendingSend, to: StatusDeliverySuppressed, want: true},
|
||||
{name: "pending to throttled", from: StatusPendingSend, to: StatusDeliveryThrottled, want: true},
|
||||
{name: "pending to failed", from: StatusPendingSend, to: StatusFailed, want: true},
|
||||
{name: "pending to cancelled", from: StatusPendingSend, to: StatusCancelled, want: true},
|
||||
{name: "pending to expired", from: StatusPendingSend, to: StatusExpired, want: true},
|
||||
{name: "pending to confirmed", from: StatusPendingSend, to: StatusConfirmedPendingExpire, want: false},
|
||||
{name: "sent to confirmed", from: StatusSent, to: StatusConfirmedPendingExpire, want: true},
|
||||
{name: "sent to failed", from: StatusSent, to: StatusFailed, want: true},
|
||||
{name: "suppressed to confirmed", from: StatusDeliverySuppressed, to: StatusConfirmedPendingExpire, want: true},
|
||||
{name: "throttled to confirmed", from: StatusDeliveryThrottled, to: StatusConfirmedPendingExpire, want: false},
|
||||
{name: "confirmed to expired", from: StatusConfirmedPendingExpire, to: StatusExpired, want: true},
|
||||
{name: "confirmed to failed", from: StatusConfirmedPendingExpire, to: StatusFailed, want: false},
|
||||
{name: "expired terminal", from: StatusExpired, to: StatusCancelled, want: false},
|
||||
{name: "failed terminal", from: StatusFailed, to: StatusExpired, want: false},
|
||||
{name: "cancelled terminal", from: StatusCancelled, to: StatusExpired, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.from.CanTransitionTo(tt.to); got != tt.want {
|
||||
require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeliveryStateIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value DeliveryState
|
||||
want bool
|
||||
}{
|
||||
{name: "pending", value: DeliveryPending, want: true},
|
||||
{name: "sent", value: DeliverySent, want: true},
|
||||
{name: "suppressed", value: DeliverySuppressed, want: true},
|
||||
{name: "throttled", value: DeliveryThrottled, want: true},
|
||||
{name: "failed", value: DeliveryFailed, want: true},
|
||||
{name: "unknown", value: DeliveryState("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeliveryStateCanTransitionTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
from DeliveryState
|
||||
to DeliveryState
|
||||
want bool
|
||||
}{
|
||||
{name: "pending to sent", from: DeliveryPending, to: DeliverySent, want: true},
|
||||
{name: "pending to suppressed", from: DeliveryPending, to: DeliverySuppressed, want: true},
|
||||
{name: "pending to throttled", from: DeliveryPending, to: DeliveryThrottled, want: true},
|
||||
{name: "pending to failed", from: DeliveryPending, to: DeliveryFailed, want: true},
|
||||
{name: "sent terminal", from: DeliverySent, to: DeliveryFailed, want: false},
|
||||
{name: "suppressed terminal", from: DeliverySuppressed, to: DeliverySent, want: false},
|
||||
{name: "failed terminal", from: DeliveryFailed, to: DeliverySent, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.from.CanTransitionTo(tt.to); got != tt.want {
|
||||
require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChallengeIsExpiredAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_121_700, 0).UTC()
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*Challenge)
|
||||
want bool
|
||||
}{
|
||||
{name: "active before expiration", want: false},
|
||||
{
|
||||
name: "expired status",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Status = StatusExpired
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "expiration timestamp passed",
|
||||
mutate: func(c *Challenge) {
|
||||
c.ExpiresAt = now
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "confirmed retained before expiration",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Status = StatusConfirmedPendingExpire
|
||||
c.DeliveryState = DeliverySent
|
||||
c.Confirmation = validConfirmation(t)
|
||||
c.ExpiresAt = now.Add(time.Second)
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challenge := validChallenge(t)
|
||||
challenge.CreatedAt = now.Add(-time.Minute)
|
||||
challenge.ExpiresAt = now.Add(time.Minute)
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(&challenge)
|
||||
}
|
||||
if err := challenge.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
|
||||
if got := challenge.IsExpiredAt(now); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsExpiredAt() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChallengeValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*Challenge)
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "valid pending"},
|
||||
{
|
||||
name: "valid confirmed",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Status = StatusConfirmedPendingExpire
|
||||
c.DeliveryState = DeliverySent
|
||||
c.Confirmation = validConfirmation(t)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "confirmed requires metadata",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Status = StatusConfirmedPendingExpire
|
||||
c.DeliveryState = DeliverySent
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unconfirmed rejects metadata",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Confirmation = validConfirmation(t)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "pending requires pending delivery",
|
||||
mutate: func(c *Challenge) {
|
||||
c.DeliveryState = DeliverySent
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sent requires sent delivery",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Status = StatusSent
|
||||
c.DeliveryState = DeliverySuppressed
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "throttled requires throttled delivery",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Status = StatusDeliveryThrottled
|
||||
c.DeliveryState = DeliverySent
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expiration before creation",
|
||||
mutate: func(c *Challenge) {
|
||||
c.ExpiresAt = c.CreatedAt.Add(-time.Second)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative confirm attempts",
|
||||
mutate: func(c *Challenge) {
|
||||
c.Attempts.Confirm = -1
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challenge := validChallenge(t)
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(&challenge)
|
||||
}
|
||||
|
||||
err := challenge.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func validChallenge(t *testing.T) Challenge {
|
||||
t.Helper()
|
||||
|
||||
return Challenge{
|
||||
ID: common.ChallengeID("challenge-123"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hash-123"),
|
||||
Status: StatusPendingSend,
|
||||
DeliveryState: DeliveryPending,
|
||||
CreatedAt: time.Unix(1_775_121_600, 0).UTC(),
|
||||
ExpiresAt: time.Unix(1_775_121_900, 0).UTC(),
|
||||
Attempts: AttemptCounters{
|
||||
Send: 0,
|
||||
Confirm: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validConfirmation(t *testing.T) *Confirmation {
|
||||
t.Helper()
|
||||
|
||||
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
for index := range raw {
|
||||
raw[index] = byte(index + 1)
|
||||
}
|
||||
|
||||
key, err := common.NewClientPublicKey(raw)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err)
|
||||
}
|
||||
|
||||
return &Confirmation{
|
||||
SessionID: common.DeviceSessionID("device-session-123"),
|
||||
ClientPublicKey: key,
|
||||
ConfirmedAt: time.Unix(1_775_121_700, 0).UTC(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package challenge
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
// InitialTTL is the v1 lifetime of a newly created challenge before it
|
||||
// becomes expired.
|
||||
InitialTTL = 5 * time.Minute
|
||||
|
||||
// ResendThrottleCooldown is the fixed Stage-17 cooldown applied to repeated
|
||||
// public send-email-code requests for the same normalized e-mail address.
|
||||
ResendThrottleCooldown = time.Minute
|
||||
|
||||
// ConfirmedRetention is the v1 idempotency window kept after a successful
|
||||
// challenge confirmation.
|
||||
ConfirmedRetention = 5 * time.Minute
|
||||
|
||||
// MaxInvalidConfirmAttempts is the v1 threshold after which repeated invalid
|
||||
// confirmation codes move a challenge into the failed state.
|
||||
MaxInvalidConfirmAttempts = 5
|
||||
)
|
||||
|
||||
// V1 resend policy keeps every public send-email-code request independent:
|
||||
// each call creates a fresh challenge, existing challenges are not reused or
|
||||
// deduplicated, and Stage 17 adds a fixed auth-side resend cooldown that may
|
||||
// record the fresh challenge as delivery_throttled.
|
||||
@@ -0,0 +1,201 @@
|
||||
// Package common defines small shared domain primitives used by auth/session
|
||||
// aggregates and integration models.
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChallengeID identifies one auth confirmation challenge owned by the service.
|
||||
type ChallengeID string
|
||||
|
||||
// String returns ChallengeID as a plain string identifier.
|
||||
func (id ChallengeID) String() string {
|
||||
return string(id)
|
||||
}
|
||||
|
||||
// IsZero reports whether ChallengeID does not contain a usable identifier.
|
||||
func (id ChallengeID) IsZero() bool {
|
||||
return strings.TrimSpace(string(id)) == ""
|
||||
}
|
||||
|
||||
// Validate reports whether ChallengeID is non-empty and already normalized for
|
||||
// domain use.
|
||||
func (id ChallengeID) Validate() error {
|
||||
return validateToken("challenge id", string(id))
|
||||
}
|
||||
|
||||
// DeviceSessionID identifies one persisted device session.
|
||||
type DeviceSessionID string
|
||||
|
||||
// String returns DeviceSessionID as a plain string identifier.
|
||||
func (id DeviceSessionID) String() string {
|
||||
return string(id)
|
||||
}
|
||||
|
||||
// IsZero reports whether DeviceSessionID does not contain a usable identifier.
|
||||
func (id DeviceSessionID) IsZero() bool {
|
||||
return strings.TrimSpace(string(id)) == ""
|
||||
}
|
||||
|
||||
// Validate reports whether DeviceSessionID is non-empty and already
|
||||
// normalized for domain use.
|
||||
func (id DeviceSessionID) Validate() error {
|
||||
return validateToken("device session id", string(id))
|
||||
}
|
||||
|
||||
// UserID identifies one user resolved through the user-service boundary.
|
||||
type UserID string
|
||||
|
||||
// String returns UserID as a plain string identifier.
|
||||
func (id UserID) String() string {
|
||||
return string(id)
|
||||
}
|
||||
|
||||
// IsZero reports whether UserID does not contain a usable identifier.
|
||||
func (id UserID) IsZero() bool {
|
||||
return strings.TrimSpace(string(id)) == ""
|
||||
}
|
||||
|
||||
// Validate reports whether UserID is non-empty and already normalized for
|
||||
// domain use.
|
||||
func (id UserID) Validate() error {
|
||||
return validateToken("user id", string(id))
|
||||
}
|
||||
|
||||
// Email stores one already-normalized e-mail address used by the auth domain.
|
||||
type Email string
|
||||
|
||||
// String returns Email as the stored canonical e-mail string.
|
||||
func (e Email) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// IsZero reports whether Email does not contain a usable e-mail value.
|
||||
func (e Email) IsZero() bool {
|
||||
return strings.TrimSpace(string(e)) == ""
|
||||
}
|
||||
|
||||
// Validate reports whether Email is non-empty, does not contain surrounding
|
||||
// whitespace, and matches the same single-address syntax expected by the
|
||||
// public gateway contract.
|
||||
func (e Email) Validate() error {
|
||||
raw := string(e)
|
||||
if err := validateToken("email", raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsedAddress, err := mail.ParseAddress(raw)
|
||||
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != raw {
|
||||
return fmt.Errorf("email %q must be a single valid email address", raw)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeReasonCode stores one machine-readable revoke reason code.
|
||||
type RevokeReasonCode string
|
||||
|
||||
// String returns RevokeReasonCode as its stored code value.
|
||||
func (code RevokeReasonCode) String() string {
|
||||
return string(code)
|
||||
}
|
||||
|
||||
// IsZero reports whether RevokeReasonCode is empty.
|
||||
func (code RevokeReasonCode) IsZero() bool {
|
||||
return strings.TrimSpace(string(code)) == ""
|
||||
}
|
||||
|
||||
// Validate reports whether RevokeReasonCode is non-empty and normalized for
|
||||
// domain use.
|
||||
func (code RevokeReasonCode) Validate() error {
|
||||
return validateToken("revoke reason code", string(code))
|
||||
}
|
||||
|
||||
// RevokeActorType stores one machine-readable actor type for revoke audit.
|
||||
type RevokeActorType string
|
||||
|
||||
// String returns RevokeActorType as its stored type value.
|
||||
func (actorType RevokeActorType) String() string {
|
||||
return string(actorType)
|
||||
}
|
||||
|
||||
// IsZero reports whether RevokeActorType is empty.
|
||||
func (actorType RevokeActorType) IsZero() bool {
|
||||
return strings.TrimSpace(string(actorType)) == ""
|
||||
}
|
||||
|
||||
// Validate reports whether RevokeActorType is non-empty and normalized for
|
||||
// domain use.
|
||||
func (actorType RevokeActorType) Validate() error {
|
||||
return validateToken("revoke actor type", string(actorType))
|
||||
}
|
||||
|
||||
// ClientPublicKey stores one validated Ed25519 public key in parsed binary
|
||||
// form inside the domain model.
|
||||
type ClientPublicKey struct {
|
||||
value ed25519.PublicKey
|
||||
}
|
||||
|
||||
// NewClientPublicKey validates value and returns a defensive copy suitable for
|
||||
// storing inside domain aggregates.
|
||||
func NewClientPublicKey(value ed25519.PublicKey) (ClientPublicKey, error) {
|
||||
key := ClientPublicKey{
|
||||
value: bytes.Clone(value),
|
||||
}
|
||||
if err := key.Validate(); err != nil {
|
||||
return ClientPublicKey{}, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// String returns ClientPublicKey as the standard base64-encoded raw 32-byte
|
||||
// Ed25519 public key string.
|
||||
func (key ClientPublicKey) String() string {
|
||||
if key.IsZero() {
|
||||
return ""
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(key.value)
|
||||
}
|
||||
|
||||
// IsZero reports whether ClientPublicKey does not contain key material.
|
||||
func (key ClientPublicKey) IsZero() bool {
|
||||
return len(key.value) == 0
|
||||
}
|
||||
|
||||
// Validate reports whether ClientPublicKey contains exactly one Ed25519 public
|
||||
// key.
|
||||
func (key ClientPublicKey) Validate() error {
|
||||
switch len(key.value) {
|
||||
case 0:
|
||||
return errors.New("client public key must not be empty")
|
||||
case ed25519.PublicKeySize:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("client public key must contain exactly %d bytes", ed25519.PublicKeySize)
|
||||
}
|
||||
}
|
||||
|
||||
// PublicKey returns a defensive copy of the parsed Ed25519 public key.
|
||||
func (key ClientPublicKey) PublicKey() ed25519.PublicKey {
|
||||
return bytes.Clone(key.value)
|
||||
}
|
||||
|
||||
func validateToken(name string, value string) error {
|
||||
switch {
|
||||
case strings.TrimSpace(value) == "":
|
||||
return fmt.Errorf("%s must not be empty", name)
|
||||
case strings.TrimSpace(value) != value:
|
||||
return fmt.Errorf("%s must not contain surrounding whitespace", name)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChallengeIDValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value ChallengeID
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "valid", value: ChallengeID("challenge-123")},
|
||||
{name: "empty", value: ChallengeID(""), wantErr: true},
|
||||
{name: "whitespace", value: ChallengeID(" challenge-123 "), wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.value.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Email
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "valid", value: Email("pilot@example.com")},
|
||||
{name: "invalid", value: Email("pilot"), wantErr: true},
|
||||
{name: "surrounding whitespace", value: Email(" pilot@example.com "), wantErr: true},
|
||||
{name: "display name", value: Email("Pilot <pilot@example.com>"), wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.value.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
for i := range raw {
|
||||
raw[i] = byte(i)
|
||||
}
|
||||
|
||||
key, err := NewClientPublicKey(raw)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err)
|
||||
}
|
||||
|
||||
if key.IsZero() {
|
||||
require.FailNow(t, "IsZero() = true, want false")
|
||||
}
|
||||
|
||||
cloned := key.PublicKey()
|
||||
if len(cloned) != ed25519.PublicKeySize {
|
||||
require.Failf(t, "test failed", "PublicKey() length = %d, want %d", len(cloned), ed25519.PublicKeySize)
|
||||
}
|
||||
|
||||
raw[0] = 99
|
||||
if key.PublicKey()[0] == 99 {
|
||||
require.FailNow(t, "PublicKey() was mutated through constructor input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPublicKeyValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value ClientPublicKey
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "empty", value: ClientPublicKey{}, wantErr: true},
|
||||
{
|
||||
name: "short",
|
||||
value: ClientPublicKey{value: make(ed25519.PublicKey, ed25519.PublicKeySize-1)},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
value: ClientPublicKey{value: make(ed25519.PublicKey, ed25519.PublicKeySize)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.value.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
// Package devicesession defines the source-of-truth domain model for one
|
||||
// authenticated device session.
|
||||
package devicesession
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
// Status identifies the coarse lifecycle state of one device session.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusActive reports that the session may be used for authenticated
|
||||
// request verification.
|
||||
StatusActive Status = "active"
|
||||
|
||||
// StatusRevoked reports that the session has been revoked and must no
|
||||
// longer authenticate requests.
|
||||
StatusRevoked Status = "revoked"
|
||||
)
|
||||
|
||||
// RevokeReasonDeviceLogout reports that one device logged itself out.
|
||||
const RevokeReasonDeviceLogout common.RevokeReasonCode = "device_logout"
|
||||
|
||||
// RevokeReasonLogoutAll reports that the session was revoked by a
|
||||
// user-scoped logout-all action.
|
||||
const RevokeReasonLogoutAll common.RevokeReasonCode = "logout_all"
|
||||
|
||||
// RevokeReasonAdminRevoke reports that the session was revoked
|
||||
// administratively.
|
||||
const RevokeReasonAdminRevoke common.RevokeReasonCode = "admin_revoke"
|
||||
|
||||
// RevokeReasonUserBlocked reports that the session was revoked because future
|
||||
// auth flow for the user or e-mail was blocked.
|
||||
const RevokeReasonUserBlocked common.RevokeReasonCode = "user_blocked"
|
||||
|
||||
// IsKnown reports whether Status is one of the device-session states
|
||||
// supported by the current domain model.
|
||||
func (s Status) IsKnown() bool {
|
||||
switch s {
|
||||
case StatusActive, StatusRevoked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CanTransitionTo reports whether the current device-session Status may move
|
||||
// to next under the Stage-2 lifecycle rules.
|
||||
func (s Status) CanTransitionTo(next Status) bool {
|
||||
return s == StatusActive && next == StatusRevoked
|
||||
}
|
||||
|
||||
// IsKnownRevokeReasonCode reports whether code is one of the built-in revoke
|
||||
// reasons fixed by the Stage-2 domain model.
|
||||
func IsKnownRevokeReasonCode(code common.RevokeReasonCode) bool {
|
||||
switch code {
|
||||
case RevokeReasonDeviceLogout,
|
||||
RevokeReasonLogoutAll,
|
||||
RevokeReasonAdminRevoke,
|
||||
RevokeReasonUserBlocked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Revocation stores the audit metadata recorded when a session is revoked.
|
||||
type Revocation struct {
|
||||
// At reports when the revoke took effect.
|
||||
At time.Time
|
||||
|
||||
// ReasonCode stores one machine-readable revoke reason code.
|
||||
ReasonCode common.RevokeReasonCode
|
||||
|
||||
// ActorType stores one machine-readable initiator type.
|
||||
ActorType common.RevokeActorType
|
||||
|
||||
// ActorID optionally stores a stable initiator identifier.
|
||||
ActorID string
|
||||
}
|
||||
|
||||
// Validate reports whether Revocation contains all metadata required for a
|
||||
// revoked session.
|
||||
func (r Revocation) Validate() error {
|
||||
if r.At.IsZero() {
|
||||
return errors.New("session revocation time must not be zero")
|
||||
}
|
||||
if err := r.ReasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("session revocation reason code: %w", err)
|
||||
}
|
||||
if err := r.ActorType.Validate(); err != nil {
|
||||
return fmt.Errorf("session revocation actor type: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(r.ActorID) != r.ActorID {
|
||||
return errors.New("session revocation actor id must not contain surrounding whitespace")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Session is the minimal source-of-truth aggregate shape fixed by Stage 2.
|
||||
type Session struct {
|
||||
// ID identifies the device session.
|
||||
ID common.DeviceSessionID
|
||||
|
||||
// UserID identifies the durable user linkage for the session.
|
||||
UserID common.UserID
|
||||
|
||||
// ClientPublicKey stores the validated device public key in parsed form.
|
||||
ClientPublicKey common.ClientPublicKey
|
||||
|
||||
// Status reports the coarse lifecycle state of the session.
|
||||
Status Status
|
||||
|
||||
// CreatedAt reports when the session was created.
|
||||
CreatedAt time.Time
|
||||
|
||||
// Revocation is present only when Status is StatusRevoked.
|
||||
Revocation *Revocation
|
||||
}
|
||||
|
||||
// Validate reports whether Session satisfies the Stage-2 structural and
|
||||
// lifecycle invariants.
|
||||
func (s Session) Validate() error {
|
||||
if err := s.ID.Validate(); err != nil {
|
||||
return fmt.Errorf("session id: %w", err)
|
||||
}
|
||||
if err := s.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("session user id: %w", err)
|
||||
}
|
||||
if err := s.ClientPublicKey.Validate(); err != nil {
|
||||
return fmt.Errorf("session client public key: %w", err)
|
||||
}
|
||||
if !s.Status.IsKnown() {
|
||||
return fmt.Errorf("session status %q is unsupported", s.Status)
|
||||
}
|
||||
if s.CreatedAt.IsZero() {
|
||||
return errors.New("session creation time must not be zero")
|
||||
}
|
||||
|
||||
switch s.Status {
|
||||
case StatusActive:
|
||||
if s.Revocation != nil {
|
||||
return errors.New("active session must not contain revocation metadata")
|
||||
}
|
||||
case StatusRevoked:
|
||||
if s.Revocation == nil {
|
||||
return errors.New("revoked session must contain revocation metadata")
|
||||
}
|
||||
if err := s.Revocation.Validate(); err != nil {
|
||||
return fmt.Errorf("session revocation: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package devicesession
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
func TestStatusIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Status
|
||||
want bool
|
||||
}{
|
||||
{name: "active", value: StatusActive, want: true},
|
||||
{name: "revoked", value: StatusRevoked, want: true},
|
||||
{name: "unknown", value: Status("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusCanTransitionTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
from Status
|
||||
to Status
|
||||
want bool
|
||||
}{
|
||||
{name: "active to revoked", from: StatusActive, to: StatusRevoked, want: true},
|
||||
{name: "active to active", from: StatusActive, to: StatusActive, want: false},
|
||||
{name: "revoked terminal", from: StatusRevoked, to: StatusActive, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.from.CanTransitionTo(tt.to); got != tt.want {
|
||||
require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsKnownRevokeReasonCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value common.RevokeReasonCode
|
||||
want bool
|
||||
}{
|
||||
{name: "device logout", value: RevokeReasonDeviceLogout, want: true},
|
||||
{name: "logout all", value: RevokeReasonLogoutAll, want: true},
|
||||
{name: "admin revoke", value: RevokeReasonAdminRevoke, want: true},
|
||||
{name: "user blocked", value: RevokeReasonUserBlocked, want: true},
|
||||
{name: "custom code", value: common.RevokeReasonCode("custom_policy"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := IsKnownRevokeReasonCode(tt.value); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnownRevokeReasonCode() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*Session)
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "active valid"},
|
||||
{
|
||||
name: "revoked valid",
|
||||
mutate: func(s *Session) {
|
||||
s.Status = StatusRevoked
|
||||
s.Revocation = validRevocation()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "active rejects revocation",
|
||||
mutate: func(s *Session) {
|
||||
s.Revocation = validRevocation()
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "revoked requires revocation",
|
||||
mutate: func(s *Session) {
|
||||
s.Status = StatusRevoked
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "revoked requires complete metadata",
|
||||
mutate: func(s *Session) {
|
||||
s.Status = StatusRevoked
|
||||
revocation := validRevocation()
|
||||
revocation.ReasonCode = ""
|
||||
s.Revocation = revocation
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := validSession(t)
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(&session)
|
||||
}
|
||||
|
||||
err := session.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func validSession(t *testing.T) Session {
|
||||
t.Helper()
|
||||
|
||||
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
for index := range raw {
|
||||
raw[index] = byte(index + 7)
|
||||
}
|
||||
|
||||
key, err := common.NewClientPublicKey(raw)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err)
|
||||
}
|
||||
|
||||
return Session{
|
||||
ID: common.DeviceSessionID("device-session-123"),
|
||||
UserID: common.UserID("user-123"),
|
||||
ClientPublicKey: key,
|
||||
Status: StatusActive,
|
||||
CreatedAt: time.Unix(1_775_121_600, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func validRevocation() *Revocation {
|
||||
return &Revocation{
|
||||
At: time.Unix(1_775_121_800, 0).UTC(),
|
||||
ReasonCode: RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
ActorID: "admin-123",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Package gatewayprojection defines the gateway-facing integration snapshot
|
||||
// model that stays separate from source-of-truth session entities.
|
||||
package gatewayprojection
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
// Status identifies the coarse lifecycle state projected to the gateway.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusActive reports that the projected session may authenticate
|
||||
// requests on the gateway hot path.
|
||||
StatusActive Status = "active"
|
||||
|
||||
// StatusRevoked reports that the projected session must be rejected on the
|
||||
// gateway hot path.
|
||||
StatusRevoked Status = "revoked"
|
||||
)
|
||||
|
||||
// IsKnown reports whether Status is one of the projection states supported by
|
||||
// the current integration model.
|
||||
func (s Status) IsKnown() bool {
|
||||
switch s {
|
||||
case StatusActive, StatusRevoked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot stores the gateway-facing session projection without exposing any
|
||||
// Redis-specific field naming or storage encoding.
|
||||
type Snapshot struct {
|
||||
// DeviceSessionID identifies the projected device session.
|
||||
DeviceSessionID common.DeviceSessionID
|
||||
|
||||
// UserID identifies the projected user.
|
||||
UserID common.UserID
|
||||
|
||||
// ClientPublicKey stores the standard base64-encoded raw 32-byte Ed25519
|
||||
// public key string expected by the gateway.
|
||||
ClientPublicKey string
|
||||
|
||||
// Status reports whether the projected session is active or revoked.
|
||||
Status Status
|
||||
|
||||
// RevokedAt optionally reports when the revoke took effect.
|
||||
RevokedAt *time.Time
|
||||
|
||||
// RevokeReasonCode optionally stores the machine-readable revoke reason.
|
||||
RevokeReasonCode common.RevokeReasonCode
|
||||
|
||||
// RevokeActorType optionally stores the machine-readable revoke actor type.
|
||||
RevokeActorType common.RevokeActorType
|
||||
|
||||
// RevokeActorID optionally stores a stable revoke actor identifier.
|
||||
RevokeActorID string
|
||||
}
|
||||
|
||||
// Validate reports whether Snapshot satisfies the Stage-2 structural
|
||||
// invariants.
|
||||
func (s Snapshot) Validate() error {
|
||||
if err := s.DeviceSessionID.Validate(); err != nil {
|
||||
return fmt.Errorf("gateway projection device session id: %w", err)
|
||||
}
|
||||
if err := s.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("gateway projection user id: %w", err)
|
||||
}
|
||||
if err := validateClientPublicKey(s.ClientPublicKey); err != nil {
|
||||
return fmt.Errorf("gateway projection client public key: %w", err)
|
||||
}
|
||||
if !s.Status.IsKnown() {
|
||||
return fmt.Errorf("gateway projection status %q is unsupported", s.Status)
|
||||
}
|
||||
|
||||
if s.Status == StatusActive {
|
||||
if s.RevokedAt != nil {
|
||||
return errors.New("active gateway projection must not contain revoked time")
|
||||
}
|
||||
if !s.RevokeReasonCode.IsZero() {
|
||||
return errors.New("active gateway projection must not contain revoke reason code")
|
||||
}
|
||||
if !s.RevokeActorType.IsZero() {
|
||||
return errors.New("active gateway projection must not contain revoke actor type")
|
||||
}
|
||||
if s.RevokeActorID != "" {
|
||||
return errors.New("active gateway projection must not contain revoke actor id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.RevokedAt != nil && s.RevokedAt.IsZero() {
|
||||
return errors.New("gateway projection revoked time must not be zero")
|
||||
}
|
||||
if !s.RevokeReasonCode.IsZero() {
|
||||
if err := s.RevokeReasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("gateway projection revoke reason code: %w", err)
|
||||
}
|
||||
}
|
||||
if !s.RevokeActorType.IsZero() {
|
||||
if err := s.RevokeActorType.Validate(); err != nil {
|
||||
return fmt.Errorf("gateway projection revoke actor type: %w", err)
|
||||
}
|
||||
}
|
||||
if s.RevokeActorType.IsZero() && s.RevokeActorID != "" {
|
||||
return errors.New("gateway projection revoke actor id requires revoke actor type")
|
||||
}
|
||||
if strings.TrimSpace(s.RevokeActorID) != s.RevokeActorID {
|
||||
return errors.New("gateway projection revoke actor id must not contain surrounding whitespace")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateClientPublicKey(value string) error {
|
||||
switch {
|
||||
case strings.TrimSpace(value) == "":
|
||||
return errors.New("client public key must not be empty")
|
||||
case strings.TrimSpace(value) != value:
|
||||
return errors.New("client public key must not contain surrounding whitespace")
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client public key must be valid base64: %w", err)
|
||||
}
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return fmt.Errorf("client public key must contain exactly %d bytes", ed25519.PublicKeySize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package gatewayprojection
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"github.com/stretchr/testify/require"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
)
|
||||
|
||||
func TestStatusIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Status
|
||||
want bool
|
||||
}{
|
||||
{name: "active", value: StatusActive, want: true},
|
||||
{name: "revoked", value: StatusRevoked, want: true},
|
||||
{name: "unknown", value: Status("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*Snapshot)
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "active valid"},
|
||||
{
|
||||
name: "revoked valid",
|
||||
mutate: func(snapshot *Snapshot) {
|
||||
snapshot.Status = StatusRevoked
|
||||
revokedAt := time.Unix(1_775_121_900, 0).UTC()
|
||||
snapshot.RevokedAt = &revokedAt
|
||||
snapshot.RevokeReasonCode = common.RevokeReasonCode("admin_revoke")
|
||||
snapshot.RevokeActorType = common.RevokeActorType("admin")
|
||||
snapshot.RevokeActorID = "admin-123"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "active rejects revoke metadata",
|
||||
mutate: func(snapshot *Snapshot) {
|
||||
snapshot.RevokeReasonCode = common.RevokeReasonCode("admin_revoke")
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid key encoding",
|
||||
mutate: func(snapshot *Snapshot) {
|
||||
snapshot.ClientPublicKey = "not-base64"
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "actor id requires actor type",
|
||||
mutate: func(snapshot *Snapshot) {
|
||||
snapshot.Status = StatusRevoked
|
||||
snapshot.RevokeActorID = "admin-123"
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
snapshot := validSnapshot()
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(&snapshot)
|
||||
}
|
||||
|
||||
err := snapshot.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotStaysSeparateFromSessionDomainShape(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
snapshotType := reflect.TypeOf(Snapshot{})
|
||||
sessionType := reflect.TypeOf(devicesession.Session{})
|
||||
|
||||
clientPublicKeyField, ok := snapshotType.FieldByName("ClientPublicKey")
|
||||
if !ok {
|
||||
require.FailNow(t, "Snapshot is missing ClientPublicKey field")
|
||||
}
|
||||
if clientPublicKeyField.Type.Kind() != reflect.String {
|
||||
require.Failf(t, "test failed", "Snapshot.ClientPublicKey kind = %s, want string", clientPublicKeyField.Type.Kind())
|
||||
}
|
||||
|
||||
sessionClientPublicKeyField, ok := sessionType.FieldByName("ClientPublicKey")
|
||||
if !ok {
|
||||
require.FailNow(t, "devicesession.Session is missing ClientPublicKey field")
|
||||
}
|
||||
if clientPublicKeyField.Type == sessionClientPublicKeyField.Type {
|
||||
require.FailNow(t, "Snapshot.ClientPublicKey must stay separate from devicesession.Session.ClientPublicKey type")
|
||||
}
|
||||
|
||||
if _, ok := snapshotType.FieldByName("RevokedAtMS"); ok {
|
||||
require.FailNow(t, "Snapshot must not expose Redis-specific RevokedAtMS field")
|
||||
}
|
||||
}
|
||||
|
||||
func validSnapshot() Snapshot {
|
||||
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
for index := range raw {
|
||||
raw[index] = byte(index + 17)
|
||||
}
|
||||
|
||||
return Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID("device-session-123"),
|
||||
UserID: common.UserID("user-123"),
|
||||
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
|
||||
Status: StatusActive,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
// Package sessionlimit defines the domain decision shape used for active
|
||||
// device-session limit evaluation.
|
||||
package sessionlimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Kind identifies the coarse outcome of evaluating the active-session limit.
|
||||
type Kind string
|
||||
|
||||
const (
|
||||
// KindDisabled reports that no configured limit is currently active.
|
||||
KindDisabled Kind = "disabled"
|
||||
|
||||
// KindAllowed reports that creating the next session is allowed.
|
||||
KindAllowed Kind = "allowed"
|
||||
|
||||
// KindExceeded reports that creating the next session would exceed the
|
||||
// configured limit.
|
||||
KindExceeded Kind = "exceeded"
|
||||
)
|
||||
|
||||
// IsKnown reports whether Kind is one of the session-limit outcomes supported
|
||||
// by the current domain model.
|
||||
func (k Kind) IsKnown() bool {
|
||||
switch k {
|
||||
case KindDisabled, KindAllowed, KindExceeded:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Decision stores the result of evaluating one possible next session creation.
|
||||
type Decision struct {
|
||||
// Kind reports the coarse decision outcome.
|
||||
Kind Kind
|
||||
|
||||
// ConfiguredLimit stores the active configured limit when one exists.
|
||||
ConfiguredLimit *int
|
||||
|
||||
// ActiveSessionCount stores the current active-session count before create.
|
||||
ActiveSessionCount int
|
||||
|
||||
// NextSessionCount stores the count that would exist after creating the next
|
||||
// session.
|
||||
NextSessionCount int
|
||||
}
|
||||
|
||||
// Validate reports whether Decision satisfies the Stage-2 structural
|
||||
// invariants.
|
||||
func (d Decision) Validate() error {
|
||||
if !d.Kind.IsKnown() {
|
||||
return fmt.Errorf("session-limit decision kind %q is unsupported", d.Kind)
|
||||
}
|
||||
if d.ActiveSessionCount < 0 {
|
||||
return errors.New("session-limit active session count must not be negative")
|
||||
}
|
||||
if d.NextSessionCount < 0 {
|
||||
return errors.New("session-limit next session count must not be negative")
|
||||
}
|
||||
if d.NextSessionCount != d.ActiveSessionCount+1 {
|
||||
return errors.New("session-limit next session count must equal active session count plus one")
|
||||
}
|
||||
|
||||
switch d.Kind {
|
||||
case KindDisabled:
|
||||
if d.ConfiguredLimit != nil {
|
||||
return errors.New("disabled session-limit decision must not contain configured limit")
|
||||
}
|
||||
case KindAllowed, KindExceeded:
|
||||
if d.ConfiguredLimit == nil {
|
||||
return errors.New("limited session-limit decision must contain configured limit")
|
||||
}
|
||||
if *d.ConfiguredLimit <= 0 {
|
||||
return errors.New("session-limit configured limit must be positive")
|
||||
}
|
||||
if d.Kind == KindAllowed && d.NextSessionCount > *d.ConfiguredLimit {
|
||||
return errors.New("allowed session-limit decision must not exceed configured limit")
|
||||
}
|
||||
if d.Kind == KindExceeded && d.NextSessionCount <= *d.ConfiguredLimit {
|
||||
return errors.New("exceeded session-limit decision must be above configured limit")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package sessionlimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestKindIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Kind
|
||||
want bool
|
||||
}{
|
||||
{name: "disabled", value: KindDisabled, want: true},
|
||||
{name: "allowed", value: KindAllowed, want: true},
|
||||
{name: "exceeded", value: KindExceeded, want: true},
|
||||
{name: "unknown", value: Kind("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecisionValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limitTwo := 2
|
||||
limitThree := 3
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Decision
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "disabled valid",
|
||||
value: Decision{
|
||||
Kind: KindDisabled,
|
||||
ActiveSessionCount: 0,
|
||||
NextSessionCount: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "allowed valid",
|
||||
value: Decision{
|
||||
Kind: KindAllowed,
|
||||
ConfiguredLimit: &limitThree,
|
||||
ActiveSessionCount: 1,
|
||||
NextSessionCount: 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exceeded valid",
|
||||
value: Decision{
|
||||
Kind: KindExceeded,
|
||||
ConfiguredLimit: &limitTwo,
|
||||
ActiveSessionCount: 2,
|
||||
NextSessionCount: 3,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "disabled rejects limit",
|
||||
value: Decision{
|
||||
Kind: KindDisabled,
|
||||
ConfiguredLimit: &limitTwo,
|
||||
ActiveSessionCount: 0,
|
||||
NextSessionCount: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "allowed requires limit",
|
||||
value: Decision{
|
||||
Kind: KindAllowed,
|
||||
ActiveSessionCount: 0,
|
||||
NextSessionCount: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "allowed rejects overflow",
|
||||
value: Decision{
|
||||
Kind: KindAllowed,
|
||||
ConfiguredLimit: &limitTwo,
|
||||
ActiveSessionCount: 2,
|
||||
NextSessionCount: 3,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "next count must be active plus one",
|
||||
value: Decision{
|
||||
Kind: KindDisabled,
|
||||
ActiveSessionCount: 2,
|
||||
NextSessionCount: 2,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.value.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
// Package userresolution defines the domain result returned by the user
|
||||
// resolution boundary before session creation.
|
||||
package userresolution
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
// Kind identifies the coarse user-resolution result for one normalized e-mail.
|
||||
type Kind string
|
||||
|
||||
const (
|
||||
// KindExisting reports that the e-mail belongs to an existing user.
|
||||
KindExisting Kind = "existing"
|
||||
|
||||
// KindCreatable reports that the e-mail is free and user creation is
|
||||
// allowed.
|
||||
KindCreatable Kind = "creatable"
|
||||
|
||||
// KindBlocked reports that the e-mail or subject is blocked from login or
|
||||
// registration.
|
||||
KindBlocked Kind = "blocked"
|
||||
)
|
||||
|
||||
// IsKnown reports whether Kind is one of the user-resolution kinds supported
|
||||
// by the current domain model.
|
||||
func (k Kind) IsKnown() bool {
|
||||
switch k {
|
||||
case KindExisting, KindCreatable, KindBlocked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// BlockReasonCode stores one machine-readable user-block reason.
|
||||
type BlockReasonCode string
|
||||
|
||||
// String returns BlockReasonCode as its stored code value.
|
||||
func (code BlockReasonCode) String() string {
|
||||
return string(code)
|
||||
}
|
||||
|
||||
// IsZero reports whether BlockReasonCode is empty.
|
||||
func (code BlockReasonCode) IsZero() bool {
|
||||
return strings.TrimSpace(string(code)) == ""
|
||||
}
|
||||
|
||||
// Validate reports whether BlockReasonCode is non-empty and normalized for
|
||||
// domain use.
|
||||
func (code BlockReasonCode) Validate() error {
|
||||
switch {
|
||||
case code.IsZero():
|
||||
return errors.New("block reason code must not be empty")
|
||||
case strings.TrimSpace(string(code)) != string(code):
|
||||
return errors.New("block reason code must not contain surrounding whitespace")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Result stores the coarse user-resolution outcome consumed by later auth
|
||||
// workflow stages.
|
||||
type Result struct {
|
||||
// Kind reports the coarse resolution outcome.
|
||||
Kind Kind
|
||||
|
||||
// UserID is set only when Kind is KindExisting.
|
||||
UserID common.UserID
|
||||
|
||||
// BlockReasonCode is set only when Kind is KindBlocked.
|
||||
BlockReasonCode BlockReasonCode
|
||||
}
|
||||
|
||||
// Validate reports whether Result satisfies the Stage-2 structural invariants.
|
||||
func (r Result) Validate() error {
|
||||
if !r.Kind.IsKnown() {
|
||||
return fmt.Errorf("user resolution kind %q is unsupported", r.Kind)
|
||||
}
|
||||
|
||||
switch r.Kind {
|
||||
case KindExisting:
|
||||
if err := r.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("user resolution user id: %w", err)
|
||||
}
|
||||
if !r.BlockReasonCode.IsZero() {
|
||||
return errors.New("existing user resolution must not contain block reason code")
|
||||
}
|
||||
case KindCreatable:
|
||||
if !r.UserID.IsZero() {
|
||||
return errors.New("creatable user resolution must not contain user id")
|
||||
}
|
||||
if !r.BlockReasonCode.IsZero() {
|
||||
return errors.New("creatable user resolution must not contain block reason code")
|
||||
}
|
||||
case KindBlocked:
|
||||
if !r.UserID.IsZero() {
|
||||
return errors.New("blocked user resolution must not contain user id")
|
||||
}
|
||||
if err := r.BlockReasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("user resolution block reason code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package userresolution
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
func TestKindIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Kind
|
||||
want bool
|
||||
}{
|
||||
{name: "existing", value: KindExisting, want: true},
|
||||
{name: "creatable", value: KindCreatable, want: true},
|
||||
{name: "blocked", value: KindBlocked, want: true},
|
||||
{name: "unknown", value: Kind("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResultValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value Result
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "existing valid",
|
||||
value: Result{
|
||||
Kind: KindExisting,
|
||||
UserID: common.UserID("user-123"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "creatable valid",
|
||||
value: Result{
|
||||
Kind: KindCreatable,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocked valid",
|
||||
value: Result{
|
||||
Kind: KindBlocked,
|
||||
BlockReasonCode: BlockReasonCode("policy_blocked"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing requires user id",
|
||||
value: Result{
|
||||
Kind: KindExisting,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "creatable rejects user id",
|
||||
value: Result{
|
||||
Kind: KindCreatable,
|
||||
UserID: common.UserID("user-123"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "blocked requires reason",
|
||||
value: Result{
|
||||
Kind: KindBlocked,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "blocked rejects user id",
|
||||
value: Result{
|
||||
Kind: KindBlocked,
|
||||
UserID: common.UserID("user-123"),
|
||||
BlockReasonCode: BlockReasonCode("policy_blocked"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.value.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
// Package logging configures the authsession structured logger and provides
|
||||
// context-aware helpers for attaching OpenTelemetry trace identifiers.
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// New constructs the process-wide JSON logger from level.
|
||||
func New(level string) (*zap.Logger, error) {
|
||||
atomicLevel := zap.NewAtomicLevel()
|
||||
if err := atomicLevel.UnmarshalText([]byte(strings.TrimSpace(level))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zapCfg := zap.NewProductionConfig()
|
||||
zapCfg.Level = atomicLevel
|
||||
zapCfg.Sampling = nil
|
||||
zapCfg.Encoding = "json"
|
||||
zapCfg.EncoderConfig.TimeKey = "timestamp"
|
||||
zapCfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
zapCfg.OutputPaths = []string{"stdout"}
|
||||
zapCfg.ErrorOutputPaths = []string{"stderr"}
|
||||
|
||||
return zapCfg.Build()
|
||||
}
|
||||
|
||||
// TraceFieldsFromContext returns zap fields for the active OpenTelemetry span
|
||||
// when ctx carries a valid span context.
|
||||
func TraceFieldsFromContext(ctx context.Context) []zap.Field {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
spanContext := trace.SpanContextFromContext(ctx)
|
||||
if !spanContext.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []zap.Field{
|
||||
zap.String("otel_trace_id", spanContext.TraceID().String()),
|
||||
zap.String("otel_span_id", spanContext.SpanID().String()),
|
||||
}
|
||||
}
|
||||
|
||||
// Sync flushes logger and ignores the benign stdout or stderr sync errors
|
||||
// commonly returned by containerized or redirected process outputs.
|
||||
func Sync(logger *zap.Logger) error {
|
||||
if logger == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := logger.Sync()
|
||||
if err == nil || isIgnorableSyncError(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func isIgnorableSyncError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
message := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(message, "invalid argument"):
|
||||
return true
|
||||
case strings.Contains(message, "bad file descriptor"):
|
||||
return true
|
||||
case strings.Contains(message, "inappropriate ioctl for device"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/sdk/trace/tracetest"
|
||||
)
|
||||
|
||||
func TestNewRejectsInvalidLogLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := New("verbose")
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTraceFieldsFromContextReturnsTraceAndSpanIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := tracetest.NewSpanRecorder()
|
||||
provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
|
||||
|
||||
ctx, span := provider.Tracer("test").Start(context.Background(), "operation")
|
||||
defer span.End()
|
||||
|
||||
fields := TraceFieldsFromContext(ctx)
|
||||
|
||||
require.Len(t, fields, 2)
|
||||
assert.Equal(t, "otel_trace_id", fields[0].Key)
|
||||
assert.Equal(t, "otel_span_id", fields[1].Key)
|
||||
assert.NotEmpty(t, fields[0].String)
|
||||
assert.NotEmpty(t, fields[1].String)
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
// ChallengeStore provides source-of-truth persistence for auth confirmation
|
||||
// challenges without exposing storage-specific primitives.
|
||||
type ChallengeStore interface {
|
||||
// Get returns the stored challenge for challengeID. Implementations must
|
||||
// wrap ErrNotFound when challengeID does not exist.
|
||||
Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error)
|
||||
|
||||
// Create persists record as a new challenge. Implementations must wrap
|
||||
// ErrConflict when record.ID already exists.
|
||||
Create(ctx context.Context, record challenge.Challenge) error
|
||||
|
||||
// CompareAndSwap replaces previous with next when the currently stored
|
||||
// challenge matches previous exactly. Implementations must wrap ErrConflict
|
||||
// when the stored challenge differs from previous and wrap ErrNotFound when
|
||||
// previous.ID does not exist.
|
||||
CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error
|
||||
}
|
||||
|
||||
// ValidateComparableChallenges reports whether previous and next are suitable
|
||||
// for one ChallengeStore compare-and-swap call.
|
||||
func ValidateComparableChallenges(previous challenge.Challenge, next challenge.Challenge) error {
|
||||
if err := previous.Validate(); err != nil {
|
||||
return fmt.Errorf("previous challenge: %w", err)
|
||||
}
|
||||
if err := next.Validate(); err != nil {
|
||||
return fmt.Errorf("next challenge: %w", err)
|
||||
}
|
||||
if previous.ID != next.ID {
|
||||
return fmt.Errorf("challenge compare-and-swap ids must match: %q != %q", previous.ID, next.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package ports
|
||||
|
||||
import "time"
|
||||
|
||||
// Clock returns current UTC time for the auth/session application layer.
|
||||
type Clock interface {
|
||||
// Now returns the current service time.
|
||||
Now() time.Time
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package ports
|
||||
|
||||
// CodeGenerator generates cleartext confirmation codes for new auth
|
||||
// challenges.
|
||||
type CodeGenerator interface {
|
||||
// Generate returns one fresh cleartext confirmation code.
|
||||
Generate() (string, error)
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package ports
|
||||
|
||||
// CodeHasher hashes cleartext confirmation codes and compares later user input
|
||||
// against stored hashes.
|
||||
type CodeHasher interface {
|
||||
// Hash returns the stored representation for code.
|
||||
Hash(code string) ([]byte, error)
|
||||
|
||||
// Compare reports whether hash matches code.
|
||||
Compare(hash []byte, code string) (bool, error)
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ConfigProvider returns dynamic auth/session configuration required by later
|
||||
// service workflows.
|
||||
type ConfigProvider interface {
|
||||
// LoadSessionLimit returns the current active-session-limit configuration.
|
||||
// A nil ActiveSessionLimit means that the limit is disabled.
|
||||
LoadSessionLimit(ctx context.Context) (SessionLimitConfig, error)
|
||||
}
|
||||
|
||||
// SessionLimitConfig stores the active-session-limit configuration in a form
|
||||
// that preserves “limit absent” as a first-class state.
|
||||
type SessionLimitConfig struct {
|
||||
// ActiveSessionLimit stores the configured limit when one is present. Nil
|
||||
// means that no active-session limit is configured.
|
||||
ActiveSessionLimit *int
|
||||
}
|
||||
|
||||
// Validate reports whether SessionLimitConfig contains a valid limit value
|
||||
// when one is configured.
|
||||
func (c SessionLimitConfig) Validate() error {
|
||||
if c.ActiveSessionLimit != nil && *c.ActiveSessionLimit <= 0 {
|
||||
return errors.New("session limit config active session limit must be positive when configured")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns a debug-friendly representation of SessionLimitConfig.
|
||||
func (c SessionLimitConfig) String() string {
|
||||
if c.ActiveSessionLimit == nil {
|
||||
return "session_limit=disabled"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("session_limit=%d", *c.ActiveSessionLimit)
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Package ports defines the storage-agnostic and transport-agnostic service
|
||||
// boundaries used by the auth/session application layer.
|
||||
package ports
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNotFound reports that a requested source-of-truth record or remote
|
||||
// subject does not exist in the dependency behind the port.
|
||||
ErrNotFound = errors.New("ports: record not found")
|
||||
|
||||
// ErrConflict reports that a create or compare-and-swap style mutation
|
||||
// cannot be applied because the current dependency state no longer matches
|
||||
// the caller expectation.
|
||||
ErrConflict = errors.New("ports: conflict")
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
package ports
|
||||
|
||||
import "galaxy/authsession/internal/domain/common"
|
||||
|
||||
// IDGenerator generates stable domain identifiers for new challenges and
|
||||
// device sessions.
|
||||
type IDGenerator interface {
|
||||
// NewChallengeID returns a fresh challenge identifier.
|
||||
NewChallengeID() (common.ChallengeID, error)
|
||||
|
||||
// NewDeviceSessionID returns a fresh device-session identifier.
|
||||
NewDeviceSessionID() (common.DeviceSessionID, error)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
// MailSender delivers the public login code or intentionally suppresses
|
||||
// outward delivery while keeping the auth flow success-shaped.
|
||||
type MailSender interface {
|
||||
// SendLoginCode attempts delivery for one generated login code. Explicit
|
||||
// delivery failure is reported through error, while sent vs suppressed is
|
||||
// returned in the result.
|
||||
SendLoginCode(ctx context.Context, input SendLoginCodeInput) (SendLoginCodeResult, error)
|
||||
}
|
||||
|
||||
// SendLoginCodeInput describes one mail-delivery request generated by the auth
|
||||
// flow.
|
||||
type SendLoginCodeInput struct {
|
||||
// Email identifies the normalized target e-mail address.
|
||||
Email common.Email
|
||||
|
||||
// Code stores the cleartext login code that should be delivered to Email.
|
||||
Code string
|
||||
}
|
||||
|
||||
// Validate reports whether SendLoginCodeInput contains a complete delivery
|
||||
// request.
|
||||
func (i SendLoginCodeInput) Validate() error {
|
||||
if err := i.Email.Validate(); err != nil {
|
||||
return fmt.Errorf("send login code input email: %w", err)
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(i.Code) == "":
|
||||
return errors.New("send login code input code must not be empty")
|
||||
case strings.TrimSpace(i.Code) != i.Code:
|
||||
return errors.New("send login code input code must not contain surrounding whitespace")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendLoginCodeOutcome identifies the coarse mail-delivery outcome reported
|
||||
// back to the auth flow.
|
||||
type SendLoginCodeOutcome string
|
||||
|
||||
const (
|
||||
// SendLoginCodeOutcomeSent reports that delivery was attempted and accepted.
|
||||
SendLoginCodeOutcomeSent SendLoginCodeOutcome = "sent"
|
||||
|
||||
// SendLoginCodeOutcomeSuppressed reports that outward behavior remains
|
||||
// success-shaped while actual delivery is intentionally skipped.
|
||||
SendLoginCodeOutcomeSuppressed SendLoginCodeOutcome = "suppressed"
|
||||
)
|
||||
|
||||
// IsKnown reports whether SendLoginCodeOutcome is supported by the current
|
||||
// mail-sender contract.
|
||||
func (o SendLoginCodeOutcome) IsKnown() bool {
|
||||
switch o {
|
||||
case SendLoginCodeOutcomeSent, SendLoginCodeOutcomeSuppressed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SendLoginCodeResult describes the stable outcome returned by MailSender for
|
||||
// one delivery request.
|
||||
type SendLoginCodeResult struct {
|
||||
// Outcome reports whether delivery was sent or intentionally suppressed.
|
||||
Outcome SendLoginCodeOutcome
|
||||
}
|
||||
|
||||
// Validate reports whether SendLoginCodeResult satisfies the mail-sender
|
||||
// contract invariants.
|
||||
func (r SendLoginCodeResult) Validate() error {
|
||||
if !r.Outcome.IsKnown() {
|
||||
return fmt.Errorf("send login code result outcome %q is unsupported", r.Outcome)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
)
|
||||
|
||||
func TestRevokeSessionOutcomeIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value RevokeSessionOutcome
|
||||
want bool
|
||||
}{
|
||||
{name: "revoked", value: RevokeSessionOutcomeRevoked, want: true},
|
||||
{name: "already revoked", value: RevokeSessionOutcomeAlreadyRevoked, want: true},
|
||||
{name: "unknown", value: RevokeSessionOutcome("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeUserSessionsOutcomeIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value RevokeUserSessionsOutcome
|
||||
want bool
|
||||
}{
|
||||
{name: "revoked", value: RevokeUserSessionsOutcomeRevoked, want: true},
|
||||
{name: "no active sessions", value: RevokeUserSessionsOutcomeNoActiveSessions, want: true},
|
||||
{name: "unknown", value: RevokeUserSessionsOutcome("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureUserOutcomeIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value EnsureUserOutcome
|
||||
want bool
|
||||
}{
|
||||
{name: "existing", value: EnsureUserOutcomeExisting, want: true},
|
||||
{name: "created", value: EnsureUserOutcomeCreated, want: true},
|
||||
{name: "blocked", value: EnsureUserOutcomeBlocked, want: true},
|
||||
{name: "unknown", value: EnsureUserOutcome("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlockUserOutcomeIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value BlockUserOutcome
|
||||
want bool
|
||||
}{
|
||||
{name: "blocked", value: BlockUserOutcomeBlocked, want: true},
|
||||
{name: "already blocked", value: BlockUserOutcomeAlreadyBlocked, want: true},
|
||||
{name: "unknown", value: BlockUserOutcome("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoginCodeOutcomeIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value SendLoginCodeOutcome
|
||||
want bool
|
||||
}{
|
||||
{name: "sent", value: SendLoginCodeOutcomeSent, want: true},
|
||||
{name: "suppressed", value: SendLoginCodeOutcomeSuppressed, want: true},
|
||||
{name: "unknown", value: SendLoginCodeOutcome("unknown"), want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := tt.value.IsKnown(); got != tt.want {
|
||||
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLimitConfigValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
positive := 3
|
||||
zero := 0
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value SessionLimitConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "absent", value: SessionLimitConfig{}},
|
||||
{name: "positive", value: SessionLimitConfig{ActiveSessionLimit: &positive}},
|
||||
{name: "zero", value: SessionLimitConfig{ActiveSessionLimit: &zero}, wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.value.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeSessionInputValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := RevokeSessionInput{
|
||||
DeviceSessionID: common.DeviceSessionID("device-session-1"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(10, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
},
|
||||
}
|
||||
|
||||
if err := input.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeSessionResultValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := RevokeSessionResult{
|
||||
Outcome: RevokeSessionOutcomeRevoked,
|
||||
Session: revokedSessionFixture(),
|
||||
}
|
||||
|
||||
if err := result.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeUserSessionsResultValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := RevokeUserSessionsResult{
|
||||
Outcome: RevokeUserSessionsOutcomeRevoked,
|
||||
UserID: common.UserID("user-1"),
|
||||
Sessions: []devicesession.Session{
|
||||
revokedSessionFixture(),
|
||||
},
|
||||
}
|
||||
|
||||
if err := result.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureUserResultValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value EnsureUserResult
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "existing",
|
||||
value: EnsureUserResult{
|
||||
Outcome: EnsureUserOutcomeExisting,
|
||||
UserID: common.UserID("user-1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "created",
|
||||
value: EnsureUserResult{
|
||||
Outcome: EnsureUserOutcomeCreated,
|
||||
UserID: common.UserID("user-2"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocked",
|
||||
value: EnsureUserResult{
|
||||
Outcome: EnsureUserOutcomeBlocked,
|
||||
BlockReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocked with user id",
|
||||
value: EnsureUserResult{
|
||||
Outcome: EnsureUserOutcomeBlocked,
|
||||
UserID: common.UserID("user-1"),
|
||||
BlockReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.value.Validate()
|
||||
if tt.wantErr && err == nil {
|
||||
require.FailNow(t, "Validate() returned nil error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlockUserInputsAndResultValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
byID := BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
}
|
||||
if err := byID.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "BlockUserByIDInput.Validate() returned error: %v", err)
|
||||
}
|
||||
|
||||
byEmail := BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
}
|
||||
if err := byEmail.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "BlockUserByEmailInput.Validate() returned error: %v", err)
|
||||
}
|
||||
|
||||
result := BlockUserResult{
|
||||
Outcome: BlockUserOutcomeBlocked,
|
||||
UserID: common.UserID("user-1"),
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "BlockUserResult.Validate() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoginCodeInputAndResultValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := SendLoginCodeInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Code: "654321",
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "SendLoginCodeInput.Validate() returned error: %v", err)
|
||||
}
|
||||
|
||||
result := SendLoginCodeResult{Outcome: SendLoginCodeOutcomeSent}
|
||||
if err := result.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "SendLoginCodeResult.Validate() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateComparableChallenges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
previous := challengeFixture()
|
||||
next := challengeFixture()
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
if err := ValidateComparableChallenges(previous, next); err != nil {
|
||||
require.Failf(t, "test failed", "ValidateComparableChallenges() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func challengeFixture() challenge.Challenge {
|
||||
timestamp := time.Unix(10, 0).UTC()
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-1"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hash"),
|
||||
Status: challenge.StatusPendingSend,
|
||||
DeliveryState: challenge.DeliveryPending,
|
||||
CreatedAt: timestamp,
|
||||
ExpiresAt: timestamp.Add(5 * time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
func revokedSessionFixture() devicesession.Session {
|
||||
timestamp := time.Unix(10, 0).UTC()
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID("device-session-1"),
|
||||
UserID: common.UserID("user-1"),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusRevoked,
|
||||
CreatedAt: timestamp.Add(-time.Minute),
|
||||
Revocation: &devicesession.Revocation{
|
||||
At: timestamp,
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
)
|
||||
|
||||
// GatewaySessionProjectionPublisher publishes gateway-facing session snapshots
|
||||
// after source-of-truth session changes.
|
||||
type GatewaySessionProjectionPublisher interface {
|
||||
// PublishSession writes or propagates snapshot in the gateway-facing
|
||||
// projection model.
|
||||
PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
)
|
||||
|
||||
// SendEmailCodeAbuseProtector decides whether one public send-email-code
|
||||
// attempt may proceed immediately or must be throttled by the auth-side resend
|
||||
// cooldown.
|
||||
type SendEmailCodeAbuseProtector interface {
|
||||
// CheckAndReserve validates input, checks the current resend cooldown
|
||||
// decision for input.Email, and reserves a new cooldown window immediately
|
||||
// when the outcome is allowed.
|
||||
CheckAndReserve(ctx context.Context, input SendEmailCodeAbuseInput) (SendEmailCodeAbuseResult, error)
|
||||
}
|
||||
|
||||
// SendEmailCodeAbuseInput describes one resend-throttle decision request for
|
||||
// a normalized public send-email-code attempt.
|
||||
type SendEmailCodeAbuseInput struct {
|
||||
// Email identifies the normalized e-mail address addressed by the public
|
||||
// request.
|
||||
Email common.Email
|
||||
|
||||
// Now records when the send attempt is being evaluated.
|
||||
Now time.Time
|
||||
}
|
||||
|
||||
// Validate reports whether SendEmailCodeAbuseInput contains a complete resend
|
||||
// cooldown decision request.
|
||||
func (i SendEmailCodeAbuseInput) Validate() error {
|
||||
if err := i.Email.Validate(); err != nil {
|
||||
return fmt.Errorf("send email code abuse input email: %w", err)
|
||||
}
|
||||
if i.Now.IsZero() {
|
||||
return fmt.Errorf("send email code abuse input now must not be zero")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendEmailCodeAbuseOutcome identifies the coarse resend-throttle decision for
|
||||
// one public send-email-code attempt.
|
||||
type SendEmailCodeAbuseOutcome string
|
||||
|
||||
const (
|
||||
// SendEmailCodeAbuseOutcomeAllowed reports that the attempt may proceed and
|
||||
// that the cooldown window has been reserved immediately.
|
||||
SendEmailCodeAbuseOutcomeAllowed SendEmailCodeAbuseOutcome = "allowed"
|
||||
|
||||
// SendEmailCodeAbuseOutcomeThrottled reports that the cooldown window is
|
||||
// still active and that the caller must not extend it.
|
||||
SendEmailCodeAbuseOutcomeThrottled SendEmailCodeAbuseOutcome = "throttled"
|
||||
)
|
||||
|
||||
// IsKnown reports whether SendEmailCodeAbuseOutcome belongs to the stable
|
||||
// Stage-17 resend-throttle contract.
|
||||
func (o SendEmailCodeAbuseOutcome) IsKnown() bool {
|
||||
switch o {
|
||||
case SendEmailCodeAbuseOutcomeAllowed, SendEmailCodeAbuseOutcomeThrottled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SendEmailCodeAbuseResult describes one resend-throttle decision returned by
|
||||
// SendEmailCodeAbuseProtector.
|
||||
type SendEmailCodeAbuseResult struct {
|
||||
// Outcome reports whether the current send attempt may proceed or must be
|
||||
// throttled.
|
||||
Outcome SendEmailCodeAbuseOutcome
|
||||
}
|
||||
|
||||
// Validate reports whether SendEmailCodeAbuseResult satisfies the resend
|
||||
// cooldown contract.
|
||||
func (r SendEmailCodeAbuseResult) Validate() error {
|
||||
if !r.Outcome.IsKnown() {
|
||||
return fmt.Errorf("send email code abuse result outcome %q is unsupported", r.Outcome)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendEmailCodeThrottleStatusToChallengeStatus maps one resend-throttle
|
||||
// outcome to the challenge lifecycle state used by sendemailcode.
|
||||
func SendEmailCodeThrottleStatusToChallengeStatus(outcome SendEmailCodeAbuseOutcome) (challenge.Status, challenge.DeliveryState, error) {
|
||||
switch outcome {
|
||||
case SendEmailCodeAbuseOutcomeAllowed:
|
||||
return challenge.StatusPendingSend, challenge.DeliveryPending, nil
|
||||
case SendEmailCodeAbuseOutcomeThrottled:
|
||||
return challenge.StatusDeliveryThrottled, challenge.DeliveryThrottled, nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("map send email code abuse outcome %q: unsupported outcome", outcome)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSendEmailCodeAbuseOutcomeIsKnown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.True(t, SendEmailCodeAbuseOutcomeAllowed.IsKnown())
|
||||
assert.True(t, SendEmailCodeAbuseOutcomeThrottled.IsKnown())
|
||||
assert.False(t, SendEmailCodeAbuseOutcome("unknown").IsKnown())
|
||||
}
|
||||
|
||||
func TestSendEmailCodeAbuseInputAndResultValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := SendEmailCodeAbuseInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Now: time.Unix(10, 0).UTC(),
|
||||
}
|
||||
require.NoError(t, input.Validate())
|
||||
|
||||
result := SendEmailCodeAbuseResult{Outcome: SendEmailCodeAbuseOutcomeThrottled}
|
||||
require.NoError(t, result.Validate())
|
||||
}
|
||||
|
||||
func TestSendEmailCodeThrottleStatusToChallengeStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
status, deliveryState, err := SendEmailCodeThrottleStatusToChallengeStatus(SendEmailCodeAbuseOutcomeAllowed)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, challenge.StatusPendingSend, status)
|
||||
assert.Equal(t, challenge.DeliveryPending, deliveryState)
|
||||
|
||||
status, deliveryState, err = SendEmailCodeThrottleStatusToChallengeStatus(SendEmailCodeAbuseOutcomeThrottled)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, challenge.StatusDeliveryThrottled, status)
|
||||
assert.Equal(t, challenge.DeliveryThrottled, deliveryState)
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
)
|
||||
|
||||
// SessionStore provides source-of-truth persistence for device sessions
|
||||
// without exposing storage-specific encoding or transaction primitives.
|
||||
type SessionStore interface {
|
||||
// Get returns the stored session for deviceSessionID. Implementations must
|
||||
// wrap ErrNotFound when deviceSessionID does not exist.
|
||||
Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error)
|
||||
|
||||
// ListByUserID returns every stored session for userID in newest-first
|
||||
// order. Implementations must return an empty slice, not ErrNotFound, when
|
||||
// userID has no stored sessions.
|
||||
ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error)
|
||||
|
||||
// CountActiveByUserID returns the number of active sessions currently stored
|
||||
// for userID.
|
||||
CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error)
|
||||
|
||||
// Create persists record as a new device session. Implementations must wrap
|
||||
// ErrConflict when record.ID already exists.
|
||||
Create(ctx context.Context, record devicesession.Session) error
|
||||
|
||||
// Revoke stores a revoked view of one target session. Implementations must
|
||||
// wrap ErrNotFound when input.DeviceSessionID does not exist.
|
||||
Revoke(ctx context.Context, input RevokeSessionInput) (RevokeSessionResult, error)
|
||||
|
||||
// RevokeAllByUserID stores revoked views for all currently active sessions
|
||||
// owned by input.UserID.
|
||||
RevokeAllByUserID(ctx context.Context, input RevokeUserSessionsInput) (RevokeUserSessionsResult, error)
|
||||
}
|
||||
|
||||
// RevokeSessionInput describes one single-session revoke mutation requested
|
||||
// from SessionStore.
|
||||
type RevokeSessionInput struct {
|
||||
// DeviceSessionID identifies the session that should be revoked.
|
||||
DeviceSessionID common.DeviceSessionID
|
||||
|
||||
// Revocation stores the audit metadata that must be attached to the revoked
|
||||
// session.
|
||||
Revocation devicesession.Revocation
|
||||
}
|
||||
|
||||
// Validate reports whether RevokeSessionInput contains a complete revoke
|
||||
// request.
|
||||
func (i RevokeSessionInput) Validate() error {
|
||||
if err := i.DeviceSessionID.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke session input device session id: %w", err)
|
||||
}
|
||||
if err := i.Revocation.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke session input revocation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeSessionOutcome identifies the coarse outcome of revoking one device
|
||||
// session.
|
||||
type RevokeSessionOutcome string
|
||||
|
||||
const (
|
||||
// RevokeSessionOutcomeRevoked reports that an active session was moved to
|
||||
// the revoked state by the current mutation.
|
||||
RevokeSessionOutcomeRevoked RevokeSessionOutcome = "revoked"
|
||||
|
||||
// RevokeSessionOutcomeAlreadyRevoked reports that the requested session had
|
||||
// already been revoked before the current mutation.
|
||||
RevokeSessionOutcomeAlreadyRevoked RevokeSessionOutcome = "already_revoked"
|
||||
)
|
||||
|
||||
// IsKnown reports whether RevokeSessionOutcome is supported by the current
|
||||
// session-store contract.
|
||||
func (o RevokeSessionOutcome) IsKnown() bool {
|
||||
switch o {
|
||||
case RevokeSessionOutcomeRevoked, RevokeSessionOutcomeAlreadyRevoked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeSessionResult describes the stable outcome returned by SessionStore
|
||||
// after a single-session revoke attempt.
|
||||
type RevokeSessionResult struct {
|
||||
// Outcome reports whether the session was revoked just now or had already
|
||||
// been revoked.
|
||||
Outcome RevokeSessionOutcome
|
||||
|
||||
// Session stores the current source-of-truth session state after the revoke
|
||||
// attempt.
|
||||
Session devicesession.Session
|
||||
}
|
||||
|
||||
// Validate reports whether RevokeSessionResult satisfies the session-store
|
||||
// contract invariants.
|
||||
func (r RevokeSessionResult) Validate() error {
|
||||
if !r.Outcome.IsKnown() {
|
||||
return fmt.Errorf("revoke session result outcome %q is unsupported", r.Outcome)
|
||||
}
|
||||
if err := r.Session.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke session result session: %w", err)
|
||||
}
|
||||
if r.Session.Status != devicesession.StatusRevoked {
|
||||
return errors.New("revoke session result session must be revoked")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUserSessionsInput describes one bulk user-session revoke mutation
|
||||
// requested from SessionStore.
|
||||
type RevokeUserSessionsInput struct {
|
||||
// UserID identifies the owner whose active sessions should be revoked.
|
||||
UserID common.UserID
|
||||
|
||||
// Revocation stores the audit metadata that must be attached to every
|
||||
// revoked session.
|
||||
Revocation devicesession.Revocation
|
||||
}
|
||||
|
||||
// Validate reports whether RevokeUserSessionsInput contains a complete bulk
|
||||
// revoke request.
|
||||
func (i RevokeUserSessionsInput) Validate() error {
|
||||
if err := i.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke user sessions input user id: %w", err)
|
||||
}
|
||||
if err := i.Revocation.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke user sessions input revocation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUserSessionsOutcome identifies the coarse outcome of revoking all
|
||||
// active sessions of one user.
|
||||
type RevokeUserSessionsOutcome string
|
||||
|
||||
const (
|
||||
// RevokeUserSessionsOutcomeRevoked reports that one or more active sessions
|
||||
// were revoked by the current mutation.
|
||||
RevokeUserSessionsOutcomeRevoked RevokeUserSessionsOutcome = "revoked"
|
||||
|
||||
// RevokeUserSessionsOutcomeNoActiveSessions reports that the target user did
|
||||
// not currently own any active sessions.
|
||||
RevokeUserSessionsOutcomeNoActiveSessions RevokeUserSessionsOutcome = "no_active_sessions"
|
||||
)
|
||||
|
||||
// IsKnown reports whether RevokeUserSessionsOutcome is supported by the
|
||||
// current session-store contract.
|
||||
func (o RevokeUserSessionsOutcome) IsKnown() bool {
|
||||
switch o {
|
||||
case RevokeUserSessionsOutcomeRevoked, RevokeUserSessionsOutcomeNoActiveSessions:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeUserSessionsResult describes the stable outcome returned by
|
||||
// SessionStore after one bulk revoke attempt.
|
||||
type RevokeUserSessionsResult struct {
|
||||
// Outcome reports whether at least one active session was revoked.
|
||||
Outcome RevokeUserSessionsOutcome
|
||||
|
||||
// UserID identifies the owner whose sessions were evaluated.
|
||||
UserID common.UserID
|
||||
|
||||
// Sessions stores the current source-of-truth session states for every
|
||||
// session affected by the bulk revoke operation.
|
||||
Sessions []devicesession.Session
|
||||
}
|
||||
|
||||
// Validate reports whether RevokeUserSessionsResult satisfies the bulk
|
||||
// session-store contract invariants.
|
||||
func (r RevokeUserSessionsResult) Validate() error {
|
||||
if !r.Outcome.IsKnown() {
|
||||
return fmt.Errorf("revoke user sessions result outcome %q is unsupported", r.Outcome)
|
||||
}
|
||||
if err := r.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke user sessions result user id: %w", err)
|
||||
}
|
||||
for index, session := range r.Sessions {
|
||||
if err := session.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke user sessions result session %d: %w", index, err)
|
||||
}
|
||||
if session.Status != devicesession.StatusRevoked {
|
||||
return fmt.Errorf("revoke user sessions result session %d must be revoked", index)
|
||||
}
|
||||
if session.UserID != r.UserID {
|
||||
return fmt.Errorf("revoke user sessions result session %d belongs to %q, want %q", index, session.UserID, r.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
switch r.Outcome {
|
||||
case RevokeUserSessionsOutcomeRevoked:
|
||||
if len(r.Sessions) == 0 {
|
||||
return errors.New("revoke user sessions result must include sessions when outcome is revoked")
|
||||
}
|
||||
case RevokeUserSessionsOutcomeNoActiveSessions:
|
||||
if len(r.Sessions) != 0 {
|
||||
return errors.New("revoke user sessions result must not include sessions when outcome is no_active_sessions")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
)
|
||||
|
||||
// UserDirectory provides the auth/session boundary to user ownership,
|
||||
// registration, and block-policy decisions.
|
||||
type UserDirectory interface {
|
||||
// ResolveByEmail returns the current resolution state for email without
|
||||
// creating any new user record.
|
||||
ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error)
|
||||
|
||||
// ExistsByUserID reports whether userID currently identifies a stored user
|
||||
// record.
|
||||
ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error)
|
||||
|
||||
// EnsureUserByEmail returns an existing user for email, creates a new user
|
||||
// when registration is allowed, or reports a blocked outcome when the
|
||||
// address may not continue through confirm flow.
|
||||
EnsureUserByEmail(ctx context.Context, email common.Email) (EnsureUserResult, error)
|
||||
|
||||
// BlockByUserID applies a block state to the user identified by
|
||||
// input.UserID. Implementations must wrap ErrNotFound when input.UserID does
|
||||
// not exist.
|
||||
BlockByUserID(ctx context.Context, input BlockUserByIDInput) (BlockUserResult, error)
|
||||
|
||||
// BlockByEmail applies a block state to input.Email, even when no user
|
||||
// record currently exists for that e-mail address.
|
||||
BlockByEmail(ctx context.Context, input BlockUserByEmailInput) (BlockUserResult, error)
|
||||
}
|
||||
|
||||
// EnsureUserOutcome identifies the coarse outcome of ensuring a user record
|
||||
// for one normalized e-mail address.
|
||||
type EnsureUserOutcome string
|
||||
|
||||
const (
|
||||
// EnsureUserOutcomeExisting reports that the e-mail already belonged to a
|
||||
// stored user.
|
||||
EnsureUserOutcomeExisting EnsureUserOutcome = "existing"
|
||||
|
||||
// EnsureUserOutcomeCreated reports that a new user was created for the
|
||||
// e-mail address.
|
||||
EnsureUserOutcomeCreated EnsureUserOutcome = "created"
|
||||
|
||||
// EnsureUserOutcomeBlocked reports that the e-mail cannot be used for login
|
||||
// or registration.
|
||||
EnsureUserOutcomeBlocked EnsureUserOutcome = "blocked"
|
||||
)
|
||||
|
||||
// IsKnown reports whether EnsureUserOutcome is supported by the current
|
||||
// user-directory contract.
|
||||
func (o EnsureUserOutcome) IsKnown() bool {
|
||||
switch o {
|
||||
case EnsureUserOutcomeExisting, EnsureUserOutcomeCreated, EnsureUserOutcomeBlocked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureUserResult describes the stable outcome returned by UserDirectory
|
||||
// after one ensure-user attempt.
|
||||
type EnsureUserResult struct {
|
||||
// Outcome reports whether the user already existed, was created, or is
|
||||
// blocked by policy.
|
||||
Outcome EnsureUserOutcome
|
||||
|
||||
// UserID is present when Outcome is EnsureUserOutcomeExisting or
|
||||
// EnsureUserOutcomeCreated.
|
||||
UserID common.UserID
|
||||
|
||||
// BlockReasonCode is present only when Outcome is EnsureUserOutcomeBlocked.
|
||||
BlockReasonCode userresolution.BlockReasonCode
|
||||
}
|
||||
|
||||
// Validate reports whether EnsureUserResult satisfies the user-directory
|
||||
// contract invariants.
|
||||
func (r EnsureUserResult) Validate() error {
|
||||
if !r.Outcome.IsKnown() {
|
||||
return fmt.Errorf("ensure user result outcome %q is unsupported", r.Outcome)
|
||||
}
|
||||
|
||||
switch r.Outcome {
|
||||
case EnsureUserOutcomeExisting, EnsureUserOutcomeCreated:
|
||||
if err := r.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("ensure user result user id: %w", err)
|
||||
}
|
||||
if !r.BlockReasonCode.IsZero() {
|
||||
return errors.New("ensure user result must not contain block reason code for existing or created outcomes")
|
||||
}
|
||||
case EnsureUserOutcomeBlocked:
|
||||
if !r.UserID.IsZero() {
|
||||
return errors.New("ensure user result must not contain user id for blocked outcome")
|
||||
}
|
||||
if err := r.BlockReasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("ensure user result block reason code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockUserByIDInput describes one block mutation targeted by stable user id.
|
||||
type BlockUserByIDInput struct {
|
||||
// UserID identifies the user that should be blocked.
|
||||
UserID common.UserID
|
||||
|
||||
// ReasonCode stores the machine-readable block reason to apply.
|
||||
ReasonCode userresolution.BlockReasonCode
|
||||
}
|
||||
|
||||
// Validate reports whether BlockUserByIDInput contains a complete block
|
||||
// request.
|
||||
func (i BlockUserByIDInput) Validate() error {
|
||||
if err := i.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("block user by id input user id: %w", err)
|
||||
}
|
||||
if err := i.ReasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("block user by id input reason code: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockUserByEmailInput describes one block mutation targeted by normalized
|
||||
// e-mail address.
|
||||
type BlockUserByEmailInput struct {
|
||||
// Email identifies the e-mail address that should be blocked.
|
||||
Email common.Email
|
||||
|
||||
// ReasonCode stores the machine-readable block reason to apply.
|
||||
ReasonCode userresolution.BlockReasonCode
|
||||
}
|
||||
|
||||
// Validate reports whether BlockUserByEmailInput contains a complete block
|
||||
// request.
|
||||
func (i BlockUserByEmailInput) Validate() error {
|
||||
if err := i.Email.Validate(); err != nil {
|
||||
return fmt.Errorf("block user by email input email: %w", err)
|
||||
}
|
||||
if err := i.ReasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("block user by email input reason code: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockUserOutcome identifies the coarse outcome of blocking one user or
|
||||
// e-mail subject.
|
||||
type BlockUserOutcome string
|
||||
|
||||
const (
|
||||
// BlockUserOutcomeBlocked reports that the current mutation applied a new
|
||||
// block state.
|
||||
BlockUserOutcomeBlocked BlockUserOutcome = "blocked"
|
||||
|
||||
// BlockUserOutcomeAlreadyBlocked reports that the target subject had already
|
||||
// been blocked before the current mutation.
|
||||
BlockUserOutcomeAlreadyBlocked BlockUserOutcome = "already_blocked"
|
||||
)
|
||||
|
||||
// IsKnown reports whether BlockUserOutcome is supported by the current
|
||||
// user-directory contract.
|
||||
func (o BlockUserOutcome) IsKnown() bool {
|
||||
switch o {
|
||||
case BlockUserOutcomeBlocked, BlockUserOutcomeAlreadyBlocked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// BlockUserResult describes the stable outcome returned by UserDirectory after
|
||||
// one block attempt.
|
||||
type BlockUserResult struct {
|
||||
// Outcome reports whether the current mutation applied a new block state.
|
||||
Outcome BlockUserOutcome
|
||||
|
||||
// UserID optionally stores the stable user identifier resolved for the
|
||||
// blocked subject when one exists.
|
||||
UserID common.UserID
|
||||
}
|
||||
|
||||
// Validate reports whether BlockUserResult satisfies the user-directory
|
||||
// contract invariants.
|
||||
func (r BlockUserResult) Validate() error {
|
||||
if !r.Outcome.IsKnown() {
|
||||
return fmt.Errorf("block user result outcome %q is unsupported", r.Outcome)
|
||||
}
|
||||
if !r.UserID.IsZero() {
|
||||
if err := r.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("block user result user id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteRetriesProjectionPublishesForBlockFlow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{
|
||||
Errors: []error{errors.New("publish failed"), nil},
|
||||
}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
require.Len(t, publisher.PublishedSnapshots(), 2)
|
||||
}
|
||||
|
||||
func TestExecuteRepairsProjectionOnRepeatedAlreadyBlockedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
|
||||
|
||||
sessionRecord, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, sessionRecord.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, sessionRecord.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, sessionRecord.Revocation.ReasonCode)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
|
||||
publisher.Err = nil
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "already_blocked", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
require.NotNil(t, result.AffectedDeviceSessionIDs)
|
||||
assert.Empty(t, result.AffectedDeviceSessionIDs)
|
||||
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const blockFlowPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
|
||||
|
||||
func TestBlockUserAffectsLaterSendAndConfirmFlows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
sessionStore := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
idGenerator := &testkit.SequenceIDGenerator{
|
||||
ChallengeIDs: []common.ChallengeID{"challenge-1"},
|
||||
DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"},
|
||||
}
|
||||
hasher := testkit.DeterministicCodeHasher{}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
now := time.Unix(20, 0).UTC()
|
||||
clock := testkit.FixedClock{Time: now}
|
||||
|
||||
blockService, err := New(userDirectory, sessionStore, publisher, clock)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = blockService.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sendService, err := sendemailcode.New(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
hasher,
|
||||
mailSender,
|
||||
clock,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
sendResult, err := sendService.Execute(context.Background(), sendemailcode.Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", sendResult.ChallengeID)
|
||||
assert.Empty(t, mailSender.RecordedInputs())
|
||||
|
||||
challengeRecord, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, challenge.StatusDeliverySuppressed, challengeRecord.Status)
|
||||
assert.Equal(t, challenge.DeliverySuppressed, challengeRecord.DeliveryState)
|
||||
|
||||
confirmService, err := confirmemailcode.New(
|
||||
challengeStore,
|
||||
sessionStore,
|
||||
userDirectory,
|
||||
testkit.StaticConfigProvider{},
|
||||
publisher,
|
||||
idGenerator,
|
||||
hasher,
|
||||
clock,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = confirmService.Execute(context.Background(), confirmemailcode.Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: blockFlowPublicKey,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err))
|
||||
|
||||
updatedChallenge, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusFailed, updatedChallenge.Status)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestExecuteLogsSafeOutcomeFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
sessionStore := &testkit.InMemorySessionStore{}
|
||||
require.NoError(t, sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
logger, buffer := newObservedServiceLogger()
|
||||
service, err := NewWithObservability(
|
||||
userDirectory,
|
||||
sessionStore,
|
||||
&testkit.RecordingProjectionPublisher{},
|
||||
testkit.FixedClock{Time: time.Unix(20, 0).UTC()},
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
logOutput := buffer.String()
|
||||
assert.Contains(t, logOutput, "block_user")
|
||||
assert.Contains(t, logOutput, "\"user_id\":\"user-1\"")
|
||||
assert.Contains(t, logOutput, "\"reason_code\":\"policy_block\"")
|
||||
assert.NotContains(t, logOutput, "pilot@example.com")
|
||||
}
|
||||
|
||||
func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) {
|
||||
buffer := &bytes.Buffer{}
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = ""
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
zapcore.AddSync(buffer),
|
||||
zap.DebugLevel,
|
||||
)
|
||||
|
||||
return zap.New(core), buffer
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
// Package blockuser implements the trusted internal block-user use case.
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// SubjectKindUserID identifies a block request addressed by stable user id.
|
||||
SubjectKindUserID = "user_id"
|
||||
|
||||
// SubjectKindEmail identifies a block request addressed by normalized e-mail
|
||||
// address.
|
||||
SubjectKindEmail = "email"
|
||||
)
|
||||
|
||||
// Input describes one trusted internal block-user request.
|
||||
type Input struct {
|
||||
// UserID identifies the subject to block when the request is user-id based.
|
||||
UserID string
|
||||
|
||||
// Email identifies the subject to block when the request is e-mail based.
|
||||
Email string
|
||||
|
||||
// ReasonCode stores the machine-readable block reason code applied to the
|
||||
// user directory.
|
||||
ReasonCode string
|
||||
|
||||
// ActorType stores the machine-readable actor type for any derived session
|
||||
// revocation.
|
||||
ActorType string
|
||||
|
||||
// ActorID stores the optional stable actor identifier for any derived
|
||||
// session revocation.
|
||||
ActorID string
|
||||
}
|
||||
|
||||
// Result describes the frozen internal block-user acknowledgement.
|
||||
type Result struct {
|
||||
// Outcome reports whether the block state was newly applied or already
|
||||
// existed.
|
||||
Outcome string
|
||||
|
||||
// SubjectKind reports whether the request targeted `user_id` or `email`.
|
||||
SubjectKind string
|
||||
|
||||
// SubjectValue stores the normalized subject value addressed by the
|
||||
// operation.
|
||||
SubjectValue string
|
||||
|
||||
// AffectedSessionCount reports how many sessions changed state during the
|
||||
// current call.
|
||||
AffectedSessionCount int64
|
||||
|
||||
// AffectedDeviceSessionIDs lists every session identifier affected during
|
||||
// the current call.
|
||||
AffectedDeviceSessionIDs []string
|
||||
}
|
||||
|
||||
// Service executes the trusted internal block-user use case.
|
||||
type Service struct {
|
||||
userDirectory ports.UserDirectory
|
||||
sessionStore ports.SessionStore
|
||||
publisher ports.GatewaySessionProjectionPublisher
|
||||
clock ports.Clock
|
||||
logger *zap.Logger
|
||||
telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// New returns a block-user service wired to the required ports.
|
||||
func New(userDirectory ports.UserDirectory, sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
|
||||
return NewWithObservability(userDirectory, sessionStore, publisher, clock, nil, nil)
|
||||
}
|
||||
|
||||
// NewWithObservability returns a block-user service wired to the required
|
||||
// ports plus optional structured logging and telemetry dependencies.
|
||||
func NewWithObservability(
|
||||
userDirectory ports.UserDirectory,
|
||||
sessionStore ports.SessionStore,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
clock ports.Clock,
|
||||
logger *zap.Logger,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
switch {
|
||||
case userDirectory == nil:
|
||||
return nil, fmt.Errorf("blockuser: user directory must not be nil")
|
||||
case sessionStore == nil:
|
||||
return nil, fmt.Errorf("blockuser: session store must not be nil")
|
||||
case publisher == nil:
|
||||
return nil, fmt.Errorf("blockuser: projection publisher must not be nil")
|
||||
case clock == nil:
|
||||
return nil, fmt.Errorf("blockuser: clock must not be nil")
|
||||
default:
|
||||
return &Service{
|
||||
userDirectory: userDirectory,
|
||||
sessionStore: sessionStore,
|
||||
publisher: publisher,
|
||||
clock: clock,
|
||||
logger: namedLogger(logger, "block_user"),
|
||||
telemetry: telemetryRuntime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute applies the requested block state and revokes any active sessions of
|
||||
// the resolved user when one exists.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
|
||||
logFields := []zap.Field{
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "block_user"),
|
||||
}
|
||||
defer func() {
|
||||
if result.Outcome != "" {
|
||||
logFields = append(logFields, zap.String("outcome", result.Outcome))
|
||||
}
|
||||
if result.SubjectKind != "" {
|
||||
logFields = append(logFields, zap.String("subject_kind", result.SubjectKind))
|
||||
}
|
||||
if result.AffectedSessionCount > 0 {
|
||||
logFields = append(logFields, zap.Int64("affected_session_count", result.AffectedSessionCount))
|
||||
}
|
||||
shared.LogServiceOutcome(s.logger, ctx, "block user completed", err, logFields...)
|
||||
}()
|
||||
|
||||
subjectKind, subjectValue, storeResult, err := s.blockSubject(ctx, input)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
logFields = append(logFields, zap.String("reason_code", shared.NormalizeString(input.ReasonCode)))
|
||||
if !storeResult.UserID.IsZero() {
|
||||
logFields = append(logFields, zap.String("user_id", storeResult.UserID.String()))
|
||||
}
|
||||
|
||||
affectedDeviceSessionIDs := []string{}
|
||||
affectedSessionCount := int64(0)
|
||||
if !storeResult.UserID.IsZero() {
|
||||
revocation, err := shared.BuildRevocation(
|
||||
devicesession.RevokeReasonUserBlocked.String(),
|
||||
input.ActorType,
|
||||
input.ActorID,
|
||||
s.clock.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
revokeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{
|
||||
UserID: storeResult.UserID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := revokeResult.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
for _, record := range revokeResult.Sessions {
|
||||
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String())
|
||||
}
|
||||
if revokeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions {
|
||||
if err := s.republishCurrentRevokedSessions(ctx, storeResult.UserID); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
affectedSessionCount = int64(len(revokeResult.Sessions))
|
||||
if affectedSessionCount > 0 {
|
||||
s.telemetry.RecordSessionRevocations(ctx, "block_user", devicesession.RevokeReasonUserBlocked.String(), affectedSessionCount)
|
||||
}
|
||||
}
|
||||
|
||||
result = Result{
|
||||
Outcome: string(storeResult.Outcome),
|
||||
SubjectKind: subjectKind,
|
||||
SubjectValue: subjectValue,
|
||||
AffectedSessionCount: affectedSessionCount,
|
||||
AffectedDeviceSessionIDs: affectedDeviceSessionIDs,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) blockSubject(ctx context.Context, input Input) (string, string, ports.BlockUserResult, error) {
|
||||
userID := shared.NormalizeString(input.UserID)
|
||||
email := shared.NormalizeString(input.Email)
|
||||
|
||||
switch {
|
||||
case userID == "" && email == "":
|
||||
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
|
||||
case userID != "" && email != "":
|
||||
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
|
||||
case userID != "":
|
||||
parsedUserID, err := shared.ParseUserID(userID)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
|
||||
result, err := s.userDirectory.BlockByUserID(ctx, ports.BlockUserByIDInput{
|
||||
UserID: parsedUserID,
|
||||
ReasonCode: reasonCode,
|
||||
})
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return "", "", ports.BlockUserResult{}, shared.SubjectNotFound()
|
||||
default:
|
||||
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
|
||||
}
|
||||
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_user_id", string(result.Outcome))
|
||||
|
||||
return SubjectKindUserID, parsedUserID.String(), result, nil
|
||||
default:
|
||||
parsedEmail, err := shared.ParseEmail(email)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
|
||||
result, err := s.userDirectory.BlockByEmail(ctx, ports.BlockUserByEmailInput{
|
||||
Email: parsedEmail,
|
||||
ReasonCode: reasonCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
|
||||
}
|
||||
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_email", string(result.Outcome))
|
||||
|
||||
return SubjectKindEmail, parsedEmail.String(), result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseBlockReasonCode(value string) (userresolution.BlockReasonCode, error) {
|
||||
reasonCode := userresolution.BlockReasonCode(shared.NormalizeString(value))
|
||||
if err := reasonCode.Validate(); err != nil {
|
||||
return "", shared.InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return reasonCode, nil
|
||||
}
|
||||
|
||||
func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error {
|
||||
records, err := s.sessionStore.ListByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
if record.Status != devicesession.StatusRevoked {
|
||||
continue
|
||||
}
|
||||
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user_repair"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return logger.Named(name)
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteBlocksByUserIDAndRevokesSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
|
||||
assert.Equal(t, "user-1", result.SubjectValue)
|
||||
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), published[0].RevokeActorType)
|
||||
}
|
||||
|
||||
func TestExecuteBlocksByEmailWithoutExistingUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
|
||||
assert.Equal(t, "pilot@example.com", result.SubjectValue)
|
||||
require.NotNil(t, result.AffectedDeviceSessionIDs)
|
||||
assert.Empty(t, result.AffectedDeviceSessionIDs)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
assert.Empty(t, publisher.PublishedSnapshots())
|
||||
}
|
||||
|
||||
func TestExecuteBlocksByEmailWithExistingUserAndRevokesSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
|
||||
}
|
||||
|
||||
func TestExecuteReturnsSubjectNotFoundForUnknownUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(&testkit.InMemoryUserDirectory{}, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "missing",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
|
||||
}
|
||||
|
||||
func TestExecuteAlreadyBlockedStillRevokesLingeringSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedBlockedUser(common.Email("pilot@example.com"), common.UserID("user-1"), userresolution.BlockReasonCode("policy_block")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedBlockedUser() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "already_blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
|
||||
}
|
||||
|
||||
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
stubuserservice "galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("blocks by email through runtime stub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(userDirectory, store, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
})
|
||||
|
||||
t.Run("blocks by user id through runtime stub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteReturnsInvalidCodeForThrottledChallengeWithoutConsumingAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
record.Status = challenge.StatusDeliveryThrottled
|
||||
record.DeliveryState = challenge.DeliveryThrottled
|
||||
require.NoError(t, record.Validate())
|
||||
require.NoError(t, deps.challengeStore.Create(context.Background(), record))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeInvalidCode, shared.CodeOf(err))
|
||||
|
||||
updated, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, 0, updated.Attempts.Confirm)
|
||||
assert.Equal(t, challenge.StatusDeliveryThrottled, updated.Status)
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteConfirmsChallengeAfterTransientProjectionPublishFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
|
||||
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, deps.challengeStore.Create(
|
||||
context.Background(),
|
||||
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
|
||||
))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
|
||||
}
|
||||
|
||||
func TestExecuteConfirmedRetryRepublishesAfterTransientProjectionPublishFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
|
||||
key := mustClientPublicKey(t, publicKeyString())
|
||||
require.NoError(t, deps.challengeStore.Create(
|
||||
context.Background(),
|
||||
confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
|
||||
))
|
||||
require.NoError(t, deps.sessionStore.Create(
|
||||
context.Background(),
|
||||
activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute)),
|
||||
))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
|
||||
}
|
||||
|
||||
func TestExecuteRepairsProjectionOnIdenticalRetryAfterExhaustedPublishRetries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
deps.publisher.Err = errors.New("publish failed")
|
||||
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, deps.challengeStore.Create(
|
||||
context.Background(),
|
||||
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
|
||||
))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
|
||||
|
||||
sessionRecord, getErr := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, devicesession.StatusActive, sessionRecord.Status)
|
||||
|
||||
challengeRecord, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusConfirmedPendingExpire, challengeRecord.Status)
|
||||
require.NotNil(t, challengeRecord.Confirmation)
|
||||
|
||||
deps.publisher.Err = nil
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user