diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index ab1c3e1..bf99cf6 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -37,7 +37,7 @@ flowchart LR ## Main Components -### 1. Edge Gateway +### 1. [Edge Gateway](./gateway/README.md) The gateway is the only public entry point for client traffic. @@ -58,7 +58,7 @@ Responsibilities: The gateway must not implement domain-specific business logic. -### 2. Auth / Session Service +### 2. [Auth / Session Service](./authsession/README.md) This service owns authentication and device session lifecycle. diff --git a/authsession/PLAN.md b/authsession/PLAN.md new file mode 100644 index 0000000..645731b --- /dev/null +++ b/authsession/PLAN.md @@ -0,0 +1,1152 @@ +# Auth / Session Service Implementation Plan + +This plan has been already implemented and stays here for historical reasons. + +It should NOT be threated as source of truth for service functionality. + +## Purpose + +This plan describes a detailed, incremental implementation path for +[`Auth / Session Service`](README.md) that integrates with the existing +`Edge Gateway`. + +The plan is intentionally atomic. +Each stage should be small enough to implement, review, and test without +overloading development context. + +## Global Rules for the Entire Plan + +- keep domain logic independent from concrete storage backends; +- keep gateway projection separate from source-of-truth records; +- preserve the existing public auth contract expected by gateway: + - `send-email-code` -> `challenge_id` + - `confirm-email-code` -> `device_session_id` +- keep `confirm-email-code` synchronous; +- do not introduce a pending async session-provisioning model; +- use synchronous internal REST where immediate answer is required; +- use Redis Streams / pub-sub only for session lifecycle propagation and other + event-style side effects; +- keep implementation idempotent where retries are expected; +- design Redis-backed stores behind interfaces so SQL migration remains possible. + +## Milestone Structure + +Suggested milestones: + +1. Domain skeleton and ports +2. In-memory service behavior and tests +3. Redis-backed source-of-truth stores +4. Gateway projection publisher +5. Public HTTP API +6. Internal trusted API +7. Integration with user-service and config-provider ports +8. Revoke/block flows +9. Observability and hardening +10. End-to-end integration with gateway + +## Suggested Module Structure + +The structure described below is allowed to be changed +during the Plan steps implementation. + +```text +authsession/ +├── cmd/ +│ └── authsession/ +│ └── main.go +│ +├── internal/ +│ ├── app/ +│ │ ├── app.go +│ │ ├── bootstrap.go +│ │ └── wiring.go +│ │ +│ ├── config/ +│ │ ├── config.go +│ │ ├── env.go +│ │ └── validation.go +│ │ +│ ├── domain/ +│ │ ├── challenge/ +│ │ │ ├── model.go +│ │ │ ├── state.go +│ │ │ ├── policy.go +│ │ │ └── errors.go +│ │ │ +│ │ ├── devicesession/ +│ │ │ ├── model.go +│ │ │ ├── state.go +│ │ │ ├── revoke.go +│ │ │ └── errors.go +│ │ │ +│ │ ├── userresolution/ +│ │ │ ├── model.go +│ │ │ └── policy.go +│ │ │ +│ │ ├── sessionlimit/ +│ │ │ ├── model.go +│ │ │ └── policy.go +│ │ │ +│ │ └── common/ +│ │ ├── email.go +│ │ ├── time.go +│ │ ├── ids.go +│ │ └── types.go +│ │ +│ ├── ports/ +│ │ ├── challengestore.go +│ │ ├── sessionstore.go +│ │ ├── userdirectory.go +│ │ ├── configprovider.go +│ │ ├── mailsender.go +│ │ ├── projectionpublisher.go +│ │ ├── clock.go +│ │ ├── idgenerator.go +│ │ ├── codegenerator.go +│ │ └── codehasher.go +│ │ +│ ├── service/ +│ │ ├── sendemailcode/ +│ │ │ └── service.go +│ │ ├── confirmemailcode/ +│ │ │ └── service.go +│ │ ├── getsession/ +│ │ │ └── service.go +│ │ ├── listusersessions/ +│ │ │ └── service.go +│ │ ├── revokedevicesession/ +│ │ │ └── service.go +│ │ ├── revokeallusersessions/ +│ │ │ └── service.go +│ │ ├── blockuser/ +│ │ │ └── service.go +│ │ └── shared/ +│ │ ├── normalize.go +│ │ ├── projection.go +│ │ └── publicerrors.go +│ │ +│ ├── api/ +│ │ ├── publichttp/ +│ │ │ ├── handler_send_email_code.go +│ │ │ ├── handler_confirm_email_code.go +│ │ │ ├── dto.go +│ │ │ └── errors.go +│ │ │ +│ │ └── internalhttp/ +│ │ ├── handler_get_session.go +│ │ ├── handler_list_user_sessions.go +│ │ ├── handler_revoke_device_session.go +│ │ ├── handler_revoke_all_user_sessions.go +│ │ ├── handler_block_user.go +│ │ ├── dto.go +│ │ └── errors.go +│ │ +│ ├── adapters/ +│ │ ├── redis/ +│ │ │ ├── challengestore/ +│ │ │ │ └── store.go +│ │ │ ├── sessionstore/ +│ │ │ │ └── store.go +│ │ │ ├── configprovider/ +│ │ │ │ └── provider.go +│ │ │ └── gatewayprojection/ +│ │ │ ├── publisher.go +│ │ │ ├── snapshot.go +│ │ │ └── stream.go +│ │ │ +│ │ ├── userservice/ +│ │ │ ├── client.go +│ │ │ ├── mapper.go +│ │ │ └── stub.go +│ │ │ +│ │ ├── mail/ +│ │ │ ├── stub.go +│ │ │ └── rest_client.go +│ │ │ +│ │ ├── crypto/ +│ │ │ ├── codehasher.go +│ │ │ └── publickey.go +│ │ │ +│ │ ├── clock/ +│ │ │ └── system.go +│ │ │ +│ │ └── id/ +│ │ ├── challengeid.go +│ │ └── devicesessionid.go +│ │ +│ ├── observability/ +│ │ ├── logging.go +│ │ ├── metrics.go +│ │ └── tracing.go +│ │ +│ └── testkit/ +│ ├── fixtures.go +│ ├── fake_clock.go +│ ├── fake_idgen.go +│ ├── fake_mail.go +│ ├── fake_userdir.go +│ └── fake_projection.go +│ +├── api/ +│ ├── public-openapi.yaml +│ └── internal-openapi.yaml +│ +└── README.md +``` + +### Description + +- `cmd/authsession` — service entry point: process startup, configuration loading, application assembly, and HTTP server startup. + +- `internal/app` — top-level application orchestration layer: dependency initialization, runtime bootstrap, and component wiring. + +- `internal/config` — service configuration loading, normalization, and validation from environment and other sources. + +- `internal/domain/challenge` — domain model for the `send_email_code` / `confirm_email_code` challenge flow: states, transitions, TTL/retry policies, and domain errors. + +- `internal/domain/devicesession` — domain model for `device_session`: session state, revocation, revoke reasons, and related domain errors. + +- `internal/domain/userresolution` — domain model for user resolution by email through user-service: existing user, allowed registration, or blocked user. + +- `internal/domain/sessionlimit` — domain model and policy rules for active `device_session` limits. + +- `internal/domain/common` — shared domain value objects and helper types: email, time, identifiers, and common primitive types. + +- `internal/ports` — interfaces for all external dependencies: source-of-truth stores, user-service, mail delivery, config, projection publisher, clock, generators, and hashing. + +- `internal/service/sendemailcode` — use case for sending a code: email normalization, challenge lifecycle, suppression/send decision, and success-shaped public response. + +- `internal/service/confirmemailcode` — use case for confirming a code: challenge validation, public-key validation, resolve/create user flow, session-limit enforcement, `device_session` creation, and projection publication. + +- `internal/service/getsession` — use case for reading a single `device_session` for the trusted internal API. + +- `internal/service/listusersessions` — use case for listing user sessions for the trusted internal API. + +- `internal/service/revokedevicesession` — use case for revoking a single device session and publishing the updated gateway projection. + +- `internal/service/revokeallusersessions` — use case for revoking all active sessions of a user and publishing the resulting updates. + +- `internal/service/blockuser` — use case for blocking a user/email and revoking active sessions according to policy. + +- `internal/service/shared` — shared application-layer code: normalization helpers, gateway projection builders, and public error mapping. + +- `internal/api/publichttp` — public HTTP API for gateway integration: handlers, DTOs, and error mapping for `send_email_code` and `confirm_email_code`. + +- `internal/api/internalhttp` — trusted internal HTTP API: revoke/read/list/block endpoints, DTOs, and separate internal error policy. + +- `internal/adapters/redis/challengestore` — Redis adapter for source-of-truth challenge storage. + +- `internal/adapters/redis/sessionstore` — Redis adapter for source-of-truth `device_session` storage. + +- `internal/adapters/redis/configprovider` — Redis adapter for dynamic configuration, such as active-session limits. + +- `internal/adapters/redis/gatewayprojection` — Redis adapter for the `Edge Gateway` integration projection: KV snapshots and lifecycle updates in streams. + +- `internal/adapters/userservice` — user-service integration adapter: REST client, response-to-domain mapping, and stub implementation for early stages. + +- `internal/adapters/mail` — mail-delivery adapter: development stub and future REST mail-service client. + +- `internal/adapters/crypto` — cryptographic adapters: confirmation-code hashing and `client_public_key` validation/parsing. + +- `internal/adapters/clock` — system clock implementation. + +- `internal/adapters/id` — generation of stable domain identifiers such as `challenge_id` and `device_session_id`. + +- `internal/observability` — service logging, metrics, and tracing. + +- `internal/testkit` — test fixtures, fake/mock dependencies, and shared helpers for unit and integration tests. + +- `api/public-openapi.yaml` — formal specification of the public HTTP API. + +- `api/internal-openapi.yaml` — formal specification of the trusted internal HTTP API. + +- `README.md` — architectural service description covering its role in the system, contracts, domain rules, and integrations. + +--- + +## ~~Stage 1.~~ Freeze the Service Contract + +Status: implemented. + +### Goal + +Write down the exact service-level contracts before implementation starts. + +### Tasks + +- freeze public auth use cases: + - `send_email_code` + - `confirm_email_code` +- freeze internal trusted use cases: + - `GetSession` + - `ListUserSessions` + - `RevokeDeviceSession` + - `RevokeAllUserSessions` + - `BlockUser` +- define canonical request/response DTOs for the service boundary; +- define client-safe error classes for the public auth API; +- define richer internal error classes for logs and internal API. + +### Deliverables + +- service contract notes in repo docs; +- initial error catalog; +- agreement on public vs internal API boundaries. + +### Exit Criteria + +- no unresolved ambiguity around public auth input/output shapes; +- no unresolved ambiguity around internal revoke/read operations. + +--- + +## ~~Stage 2.~~ Define Core Domain Types + +Status: implemented. + +### Goal + +Create the minimal domain model without any transport or storage code. + +### Tasks + +- define challenge aggregate concept; +- define device-session aggregate concept; +- define revoke reason model; +- define user resolution result model: + - existing user + - creatable user + - blocked user +- define session-limit decision model; +- define mail-delivery result model; +- define projection snapshot model for gateway integration; +- define domain statuses and allowed transitions. + +### Important Constraints + +- challenge and session models must not depend on Redis-specific encoding; +- gateway projection model must be separate from domain entities. + +### Deliverables + +- domain package with types only; +- transition invariants documented in code comments and tests. + +### Exit Criteria + +- domain package compiles without storage adapters; +- status transitions are covered by unit tests. + +--- + +## ~~Stage 3.~~ Define Service Ports + +Status: implemented. + +### Goal + +Create clean interfaces around every external dependency. + +### Tasks + +Define interfaces conceptually equivalent to: + +- `ChallengeStore` +- `SessionStore` +- `UserDirectory` / `UserResolver` +- `ConfigProvider` +- `MailSender` +- `GatewaySessionProjectionPublisher` +- `Clock` +- `IDGenerator` +- `CodeGenerator` +- `CodeHasher` + +### Notes + +- `ChallengeStore` and `SessionStore` are source-of-truth ports; +- `GatewaySessionProjectionPublisher` is an integration port, not a domain + store; +- `UserDirectory` must support existing / creatable / blocked decisions and + user creation when allowed; +- `ConfigProvider` must support "limit absent" as a first-class case. + +### Deliverables + +- interface package or packages; +- port-level test doubles. + +### Exit Criteria + +- service layer can be implemented against interfaces only. + +--- + +## ~~Stage 4.~~ Implement Pure Domain Services In Memory + +Status: implemented. + +### Goal + +Implement the auth logic once, against in-memory stores and adapters. + +### Tasks + +Implement core use cases: + +- `SendEmailCode` +- `ConfirmEmailCode` +- `GetSession` +- `ListUserSessions` +- `RevokeDeviceSession` +- `RevokeAllUserSessions` +- `BlockUser` + +### Required Behaviors + +#### SendEmailCode + +- normalize email; +- consult `UserDirectory` policy if needed; +- create challenge; +- generate secure code; +- store only hashed code; +- attempt delivery or suppress it; +- always return a success-shaped result with `challenge_id`. + +#### ConfirmEmailCode + +- load challenge; +- validate expiration and status; +- validate code hash; +- validate `client_public_key` format; +- handle idempotent repeat confirm for same successful challenge and same key; +- resolve/create user through `UserDirectory`; +- reject blocked user; +- load session-limit config; +- count active sessions; +- reject if limit exceeded; +- create session; +- store session; +- move challenge into short-window confirmed state; +- publish session projection; +- return `device_session_id`. + +#### Revoke Flows + +- update source of truth; +- publish revoked projection for every affected session. + +### Deliverables + +- service layer with in-memory dependencies; +- unit tests for every public behavior. + +### Exit Criteria + +- full service logic is testable without Redis or HTTP; +- edge cases are covered by unit tests. + +--- + +## ~~Stage 5.~~ Design Challenge Rules in Detail + +Status: implemented. + +### Goal + +Remove ambiguity from challenge handling before persistent adapters are written. + +### Tasks + +- define challenge TTL; +- define max confirm attempts; +- define resend behavior policy, if any; +- define short idempotency window after successful confirm; +- define state machine for: + - new challenge + - sent/suppressed + - confirmed + - expired + - failed +- define exact behavior for repeated confirms: + - same code + same key -> same session id + - same code + different key -> fail + - expired challenge -> fail + - too many attempts -> fail + +### Deliverables + +- explicit challenge policy spec in code comments/tests. + +### Exit Criteria + +- no hidden challenge behavior remains undecided. + +--- + +## ~~Stage 6.~~ Define Public Error Policy + +Status: implemented. + +### Goal + +Make public auth failures predictable and safe. + +### Tasks + +Decide exact client-safe categories for: + +- malformed e-mail; +- malformed `client_public_key`; +- unknown challenge; +- expired challenge; +- invalid code; +- blocked by policy at confirm stage; +- session limit exceeded; +- temporarily unavailable. + +### Additional Rules + +- `send_email_code` must not reveal whether the e-mail exists or is blocked; +- public errors should be normalized for gateway passthrough; +- internal logs and traces may keep richer reasons. + +### Deliverables + +- public error mapping table; +- internal error hierarchy. + +### Exit Criteria + +- gateway adapter behavior can be implemented without guesswork. + +--- + +## ~~Stage 7.~~ Implement Redis ChallengeStore + +Status: implemented. + +### Goal + +Add the first persistent backend for challenges. + +### Tasks + +- implement challenge read/write/update operations in Redis KV; +- define Redis key scheme for challenges; +- store hashed codes only; +- store challenge status and timestamps; +- support atomic compare-and-set style updates where required; +- support expiration cleanup through TTL and/or explicit status. + +### Important Design Rule + +The interface must not expose Redis primitives directly. + +### Deliverables + +- Redis-backed challenge store adapter; +- adapter integration tests against Redis. + +### Exit Criteria + +- challenge lifecycle works against Redis under concurrent access assumptions. + +--- + +## ~~Stage 8.~~ Implement Redis SessionStore + +Status: implemented. + +### Goal + +Add the first persistent backend for sessions. + +### Tasks + +- implement create/read/list/revoke operations; +- define Redis key scheme for sessions; +- support listing all sessions for one user; +- support revoking one session; +- support revoking all sessions for one user; +- support block-related session revocation; +- support active-session counting for limit enforcement; +- store revoke reason and actor metadata. + +### Important Design Rule + +The session source-of-truth record must remain distinct from gateway projection +encoding. + +### Deliverables + +- Redis-backed session store adapter; +- adapter integration tests. + +### Exit Criteria + +- all session lifecycle operations are persistent and testable. + +--- + +## ~~Stage 9.~~ Implement Redis ConfigProvider + +Status: implemented. + +### Goal + +Support dynamic session-limit configuration. + +### Tasks + +- implement config lookup from Redis KV; +- define config key scheme for auth-service settings; +- support: + - limit present with integer value + - limit absent + - invalid config value +- define fallback behavior for invalid config read. + +### Required Behavior + +- missing config -> no session-count limit; +- invalid config -> fail closed or fail safe according to explicit decision; +- document the chosen policy. + +### Deliverables + +- Redis-backed config adapter; +- tests for absent, valid, and invalid values. + +### Exit Criteria + +- session-limit logic no longer depends on hard-coded constants. + +--- + +## ~~Stage 10.~~ Implement Gateway Session Projection Publisher + +Status: implemented. + +### Goal + +Bridge auth source-of-truth state into gateway-facing cache/projection state. + +### Tasks + +- define exact projection snapshot structure consumed by gateway; +- define Redis KV key scheme for gateway session lookup; +- define Redis Stream schema for session lifecycle updates; +- implement projection write on session create; +- implement projection update on session revoke; +- implement projection update for bulk revoke/all; +- make publication idempotent and retry-safe. + +### Important Constraints + +- projection publisher should accept domain session data and transform it; +- it must not force domain logic to know Redis snapshot shape. + +### Deliverables + +- Redis-backed projection publisher; +- integration tests that emulate gateway expectations. + +### Exit Criteria + +- created sessions appear in gateway-readable projection; +- revoked sessions produce gateway-readable invalidation/update records. + +--- + +## ~~Stage 11.~~ Implement Stub MailSender + +Status: implemented. + +### Goal + +Introduce the mail-delivery port without coupling auth logic to one concrete delivery transport. + +### Tasks + +- create a stub adapter with deterministic success/failure modes; +- record delivery attempts for tests; +- support explicit suppression mode for blocked/hidden flows; +- ensure service logic can distinguish: + - sent + - suppressed + - failed + +### Deliverables + +- stub mail adapter; +- tests around challenge delivery state transitions. + +### Exit Criteria + +- auth logic is fully testable without real mail infrastructure. + +--- + +## ~~Stage 12.~~ Implement Stub UserDirectory + +Status: implemented. + +### Goal + +Introduce the user-service dependency before its real service exists. + +### Tasks + +- create an in-memory or stub REST-like adapter that can return: + - existing user + - creatable user + - blocked user +- support create-on-confirm behavior; +- support lookups by normalized email; +- support user block state. + +### Deliverables + +- stub user-service adapter; +- integration tests for auth flows. + +### Exit Criteria + +- auth-service no longer needs to fake user decisions internally. + +--- + +## ~~Stage 13.~~ Implement Public HTTP API + +Status: implemented. + +### Goal + +Expose the synchronous public auth flow expected by gateway. + +### Tasks + +- create HTTP handlers for: + - `send_email_code` + - `confirm_email_code` +- define JSON DTOs matching gateway expectations; +- implement request validation; +- implement response normalization; +- implement mapping from internal errors to public client-safe errors; +- add request timeout handling and structured logging. + +### Important Constraints + +- keep semantics aligned with gateway adapter expectations; +- do not expose internal admin/session methods on the public listener. + +### Deliverables + +- public HTTP server; +- handler tests; +- end-to-end tests through HTTP. + +### Exit Criteria + +- gateway can call the service through a real HTTP adapter. + +--- + +## ~~Stage 14.~~ Implement Internal Trusted API + +Status: implemented. + +### Goal + +Expose lifecycle and read operations for trusted internal callers. + +### Tasks + +Implement internal endpoints for: + +- `GetSession` +- `ListUserSessions` +- `RevokeDeviceSession` +- `RevokeAllUserSessions` +- `BlockUser` + +Optional additions later: + +- unblock flow; +- challenge inspection. + +### Notes + +- this may use REST for simplicity; +- authentication/authorization of internal callers can be stubbed initially if + there is not yet a platform-wide internal auth mechanism. + +### Deliverables + +- internal HTTP API; +- handler tests. + +### Exit Criteria + +- session lifecycle can be driven without touching Redis manually. + +--- + +## ~~Stage 15.~~ Implement Revoke Logic Thoroughly + +Status: implemented. + +### Goal + +Make revoke behavior explicit and reliable. + +### Tasks + +For `RevokeDeviceSession`: + +- load target session; +- no-op or explicit result if already revoked; +- persist revoke metadata; +- publish revoked projection. + +For `RevokeAllUserSessions`: + +- list active sessions for user; +- revoke each relevant session; +- publish projection for each affected session; +- preserve reason metadata. + +For `BlockUser`: + +- mark user blocked through `UserDirectory` or trusted policy adapter; +- revoke all active sessions; +- ensure future auth flow is denied at confirm stage and mail can be suppressed + at send stage. + +### Deliverables + +- complete revoke implementation; +- tests for single, bulk, and block flows. + +### Exit Criteria + +- gateway-facing revoke propagation is available for all revoke models. + +--- + +## ~~Stage 16.~~ Add Consistency Safeguards + +Status: implemented. + +### Goal + +Reduce create/revoke drift between source of truth and gateway projection. + +### Tasks + +- identify all places where source-of-truth write and projection publish happen; +- add retry strategy for projection writes; +- make projection publication idempotent; +- define recovery behavior if projection publish fails after source-of-truth + success; +- add dead-letter or repair strategy placeholder if needed later; +- document the consistency model. + +### Preferred Short-Term Outcome + +- source-of-truth success is never reported as auth success unless projection + write/publish reached the required success threshold, or the failure handling + policy is explicit and tested. + +### Deliverables + +- consistency policy document; +- tests for partial failure scenarios. + +### Exit Criteria + +- known failure windows are explicit and bounded. + +--- + +## ~~Stage 17.~~ Add Public Anti-Abuse Hooks + +Status: implemented. + +### Goal + +Prepare the auth service for safe interaction behind gateway public routing. + +### Tasks + +- add service-level hooks for challenge resend throttling; +- add max-attempt handling per challenge; +- add metrics for suppressed/blocked/sent flows; +- preserve soft anti-enumeration outward behavior. + +### Notes + +Gateway already applies public-edge rate limits. +This stage is about auth-specific flow protection, not replacing gateway limits. + +### Deliverables + +- abuse-control policy inside auth domain; +- tests for throttling and attempt exhaustion. + +### Exit Criteria + +- auth flow cannot be trivially abused through repeated confirm attempts. + +--- + +## ~~Stage 18.~~ Add Observability + +Status: implemented. + +### Goal + +Make the service operable from the beginning. + +### Tasks + +- structured logs for all major state transitions; +- metrics for all major operations; +- tracing spans for public auth flow and internal API; +- redact secrets and codes from logs; +- include stable identifiers such as challenge id, device session id, user id, + and reason codes where safe. + +### Minimum Metrics + +- challenges created; +- deliveries sent/suppressed/failed; +- confirm attempts; +- confirm successes/failures; +- sessions created; +- session limit rejections; +- sessions revoked by reason; +- projection publish failures; +- user-resolution outcomes. + +### Deliverables + +- metrics endpoint wiring if needed; +- logging/tracing middleware; +- observability tests where practical. + +### Exit Criteria + +- production debugging is possible without adding ad hoc logs later. + +--- + +## ~~Stage 19.~~ Add Gateway-Compatibility Tests + +Status: implemented. + +### Goal + +Test auth-service not just in isolation, but against gateway expectations. + +### Tasks + +- verify public auth HTTP DTO compatibility; +- verify `confirm-email-code` returns ready `device_session_id`; +- verify created session projection is readable by a gateway-compatible reader; +- verify revoked projection invalidates session; +- verify repeated confirm returns same session id in idempotency window; +- verify blocked e-mail still keeps `send_email_code` outwardly success-shaped; +- verify session limit exceeded returns stable client-visible error; +- verify malformed `client_public_key` is rejected. + +### Deliverables + +- integration test suite focused on gateway contract. + +### Exit Criteria + +- no ambiguity remains about integration with existing gateway behavior. + +--- + +## ~~Stage 20.~~ Add Real REST Adapter to User Service Contract + +Status: implemented. + +### Goal + +Prepare for future extraction of `User Service`. + +### Tasks + +- define internal REST client for user resolution/create/block operations; +- keep stub implementation for tests; +- add timeout, retry, and error mapping policy; +- define normalized email rules at the boundary. + +### Deliverables + +- REST client adapter for future user-service; +- compatibility tests using stub server. + +### Exit Criteria + +- auth-service can later switch from stub to real user-service with no domain + rewrite. + +--- + +## ~~Stage 21.~~ Add Real Mail Adapter Contract + +Status: implemented. + +### Goal + +Prepare for later internal mail-service-backed delivery. + +### Tasks + +- define mail adapter request/response contract; +- preserve current stub for tests; +- define delivery timeout and error mapping; +- define how suppression vs explicit failure is represented. + +### Deliverables + +- mail adapter interface finalized; +- optional HTTP client adapter skeleton. + +### Exit Criteria + +- auth flow is decoupled from the future mail implementation. + +--- + +## ~~Stage 22.~~ Production Hardening Pass + +Status: implemented. + +### Goal + +Review edge cases before calling the service implementation complete. + +### Tasks + +- test Redis reconnect behavior; +- test duplicate publish behavior; +- test crash/restart around confirm and revoke flows; +- test large numbers of active sessions per user; +- test concurrent confirms against the same challenge; +- test concurrent revoke and confirm races; +- test block-user during active auth flow; +- test expired challenge cleanup strategy. + +### Deliverables + +- hardening checklist; +- race-condition tests; +- operational notes. + +### Exit Criteria + +- no major known race remains undocumented. + +--- + +## ~~Stage 23.~~ Optional Cleanup and Migration Readiness + +Status: implemented. + +### Goal + +Make future SQL migration realistic. + +### Tasks + +- review whether domain services leak Redis assumptions; +- ensure all store interfaces are storage-agnostic; +- isolate key naming, stream naming, and projection serialization; +- add adapter contract tests reusable by future SQL backends. + +### Deliverables + +- backend-agnostic adapter tests; +- migration readiness notes. + +### Exit Criteria + +- a future SQL backend can be added without reworking service-layer logic. + +--- + +## Recommended First Working Slice + +If implementation needs an aggressively small first milestone, do this subset +first: + +1. domain types +2. service ports +3. in-memory service logic +4. stub `UserDirectory` +5. stub `MailSender` +6. public HTTP API +7. Redis `SessionStore` +8. Redis `ChallengeStore` +9. Redis projection publisher +10. gateway-compatibility tests for: + - send-email-code + - confirm-email-code + - session projection after confirm + +This gives an end-to-end happy path quickly, without waiting for revoke/admin +and full hardening. + +## Recommended Second Slice + +1. internal trusted API +2. session-limit config provider +3. revoke-device +4. revoke-all +5. block-user +6. observability +7. consistency safeguards +8. hardening tests + +## Final Acceptance Criteria + +The service can be considered implementation-ready when all of the following +are true: + +- gateway can call public auth routes synchronously; +- `confirm-email-code` returns a ready `device_session_id`; +- the created session appears in gateway-compatible projection storage; +- revoked sessions publish gateway-compatible revoke updates; +- repeated successful confirm returns the same session id during the short + idempotency window; +- session creation respects dynamic limit config; +- user block prevents future auth flow and can revoke active sessions; +- all storage is hidden behind interfaces; +- auth-service is not required on the authenticated command hot path; +- logs, metrics, and tests cover the full lifecycle. + +## Implementation Order Summary + +```mermaid +flowchart TD + A["Freeze contracts"] + B["Domain model"] + C["Ports"] + D["In-memory service logic"] + E["Redis stores"] + F["Projection publisher"] + G["Public HTTP API"] + H["Internal trusted API"] + I["Revoke and block flows"] + J["Observability and hardening"] + K["Gateway compatibility tests"] + + A --> B --> C --> D --> E --> F --> G --> H --> I --> J --> K +``` diff --git a/authsession/README.md b/authsession/README.md new file mode 100644 index 0000000..b36c41a --- /dev/null +++ b/authsession/README.md @@ -0,0 +1,459 @@ +# Auth / Session Service + +## Run and Dependencies + +`cmd/authsession` starts two HTTP listeners: + +- public REST on `AUTHSESSION_PUBLIC_HTTP_ADDR` with default `:8080` +- trusted internal REST on `AUTHSESSION_INTERNAL_HTTP_ADDR` with default `:8081` + +Startup requires: + +- one reachable Redis deployment configured by `AUTHSESSION_REDIS_ADDR` + +That Redis deployment is used for: + +- source-of-truth challenges +- source-of-truth device sessions +- dynamic active-session limit config +- gateway session projection cache and stream updates +- send-email-code resend throttling + +Optional integrations: + +- `AUTHSESSION_USER_SERVICE_MODE=stub|rest` +- `AUTHSESSION_MAIL_SERVICE_MODE=stub|rest` +- OTLP telemetry through standard `OTEL_*` variables +- stdout telemetry through + `AUTHSESSION_OTEL_STDOUT_TRACES_ENABLED` and + `AUTHSESSION_OTEL_STDOUT_METRICS_ENABLED` + +Operational caveats: + +- the service exposes no `/healthz`, `/readyz`, or `/metrics` endpoints +- user-service and mail-service default to in-process stub adapters until + `rest` mode is configured +- startup performs bounded Redis `PING` checks for every Redis-backed adapter + and fails fast if Redis or runtime config is invalid + +Additional module docs: + +- [Public REST contract](api/public-openapi.yaml) +- [Internal REST contract](api/internal-openapi.yaml) +- [Documentation index](docs/README.md) +- [Edge Gateway README](../gateway/README.md) + +## Purpose + +`Auth / Session Service` owns e-mail-code authentication and the lifecycle of +device sessions. + +It is the source of truth for: + +- authentication challenges +- device sessions +- revoke and block state +- publication of session lifecycle updates consumed by + [`Edge Gateway`](../gateway/README.md) + +The service is intentionally not on the hot path for every authenticated +request. Gateway authenticates the steady-state request path from its own cache +and session-lifecycle updates rather than by synchronous round-trips back to +auth for each command. + +## Responsibilities + +The service is responsible for: + +- public auth commands: + - `send-email-code` + - `confirm-email-code` +- creating device sessions after successful confirmation +- registering the client public key for a newly created session +- revoking one device session +- revoking all sessions of one user +- blocking a user or e-mail subject for future auth flows +- persisting source-of-truth session state +- projecting session state into gateway-consumable Redis data +- exposing a trusted internal REST API for read, revoke, and block operations + +The service is not responsible for: + +- verifying authenticated transport signatures on every business request +- gateway anti-replay for authenticated command traffic +- downstream business authorization +- direct push delivery to clients +- long-lived hot-path session caching inside gateway +- mail-service implementation details beyond the mail-delivery contract + +## Position in the System + +```mermaid +flowchart LR + Client["Client"] + Gateway["Edge Gateway"] + Auth["Auth / Session Service"] + User["User Service"] + Mail["Mail Service"] + Redis["Redis"] + Business["Business Services"] + + Client --> Gateway + Gateway --> Auth + Gateway --> Business + Auth --> User + Auth --> Mail + Auth --> Redis + Redis --> Gateway +``` + +## Main Principles + +- public auth stays synchronous +- `send-email-code` returns `challenge_id` +- `confirm-email-code` returns a ready `device_session_id` +- no pending async session-provisioning stage exists +- source-of-truth session state and gateway-facing projection remain separate +- Redis is the initial backend, but the domain and service layers stay storage + agnostic behind ports +- `send-email-code` stays success-shaped for existing, new, blocked, and + throttled e-mail flows +- `confirm-email-code` supports short-window idempotent retry for the same + confirmed challenge and the same `client_public_key` +- active-session limits are configuration driven: + - absent limit means disabled + - limit overflow rejects new session creation explicitly + - the service does not evict existing sessions to make room + +## Gateway-Facing Public Contract + +Gateway already exposes the public REST auth surface and delegates it to this +service: + +- `POST /api/v1/public/auth/send-email-code` +- `POST /api/v1/public/auth/confirm-email-code` + +The effective DTO contract is: + +| Operation | Request | Success response | +| --- | --- | --- | +| `POST /api/v1/public/auth/send-email-code` | `{ "email": string }` | `{ "challenge_id": string }` | +| `POST /api/v1/public/auth/confirm-email-code` | `{ "challenge_id": string, "code": string, "client_public_key": string }` | `{ "device_session_id": string }` | + +`client_public_key` is the standard base64-encoded raw 32-byte Ed25519 public +key registered for the created device session. + +Public boundary rules: + +- requests and responses are JSON only +- request DTOs reject unknown fields +- empty bodies, malformed JSON, trailing JSON input, and unknown fields return + `400 invalid_request` +- surrounding ASCII and Unicode whitespace is trimmed from input string fields + before validation +- `send-email-code` remains success-shaped for existing, new, blocked, and + throttled e-mail paths +- `confirm-email-code` returns a ready `device_session_id` synchronously on + success + +Stable public business-error contract: + +| HTTP status | `error.code` | Stable `error.message` | +| --- | --- | --- | +| `400` | `invalid_request` | field-specific validation detail | +| `400` | `invalid_code` | `confirmation code is invalid` | +| `400` | `invalid_client_public_key` | `client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key` | +| `403` | `blocked_by_policy` | `authentication is blocked by policy` | +| `404` | `challenge_not_found` | `challenge not found` | +| `409` | `session_limit_exceeded` | `active session limit would be exceeded` | +| `410` | `challenge_expired` | `challenge expired` | +| `503` | `service_unavailable` | `service is unavailable` | + +The public error envelope is always: + +```json +{ + "error": { + "code": "string", + "message": "string" + } +} +``` + +## Trusted Internal API + +The trusted internal REST surface lives under `/api/v1/internal` and is +documented in [`api/internal-openapi.yaml`](api/internal-openapi.yaml). + +Implemented endpoints: + +- `GET /api/v1/internal/sessions/{device_session_id}` +- `GET /api/v1/internal/users/{user_id}/sessions` +- `POST /api/v1/internal/sessions/{device_session_id}/revoke` +- `POST /api/v1/internal/users/{user_id}/sessions/revoke-all` +- `POST /api/v1/internal/user-blocks` + +Key internal API properties: + +- all bodies are JSON only +- `ListUserSessions` is newest-first and unpaginated in v1 +- revoke and block mutations require audit metadata as `reason_code` and + `actor` +- `BlockUser` accepts exactly one of `user_id` or `email` +- mutating operations are idempotent and return explicit acknowledgement + payloads rather than empty `204` responses + +Stable internal error surface: + +| HTTP status | `error.code` | Stable `error.message` | +| --- | --- | --- | +| `400` | `invalid_request` | field-specific validation detail | +| `404` | `session_not_found` | `session not found` | +| `404` | `subject_not_found` | `subject not found` | +| `500` | `internal_error` | `internal server error` | +| `503` | `service_unavailable` | `service is unavailable` | + +## Challenge Model + +A challenge represents one short-lived public e-mail-code flow. + +Core fields: + +- `challenge_id` +- normalized e-mail +- hashed confirmation code +- `status` +- `delivery_state` +- creation and expiration timestamps +- send and confirm attempt counters +- minimal abuse metadata +- optional confirmation metadata used for idempotent retry + +### Challenge States + +Supported `challenge.Status` values: + +- `pending_send` +- `sent` +- `delivery_suppressed` +- `delivery_throttled` +- `confirmed_pending_expire` +- `expired` +- `failed` +- `cancelled` + +Supported `challenge.DeliveryState` values: + +- `pending` +- `sent` +- `suppressed` +- `throttled` +- `failed` + +Policy rules: + +- initial challenge TTL is `5m` +- confirmed-challenge retention for idempotent retry is `5m` +- max invalid confirm attempts is `5` +- every `send-email-code` call creates a fresh challenge +- resend throttling is e-mail scoped with a fixed `1m` cooldown +- a throttled send still creates a fresh challenge in + `status=delivery_throttled` and `delivery_state=throttled` +- throttled sends do not call `UserDirectory` and do not call `MailSender` +- blocked sends outside the throttle path become `delivery_suppressed` + +Fresh confirm semantics: + +- only `sent` and `delivery_suppressed` accept a first successful confirm +- `pending_send`, `delivery_throttled`, `failed`, and `cancelled` return + `invalid_code` +- expired challenges return `challenge_expired` while the Redis grace window + keeps the record present, then `challenge_not_found` after cleanup removes + the key + +Idempotent retry semantics: + +- a repeated confirm with the same `challenge_id`, valid `code`, and identical + `client_public_key` on `confirmed_pending_expire` returns the same + `device_session_id` +- the same confirmed challenge with a different `client_public_key` fails as + `invalid_code` +- idempotent retry republishes the stored gateway session view + +## Device Session And Revoke Model + +A device session is created only after successful confirmation. + +Core fields: + +- `device_session_id` +- `user_id` +- parsed client public key +- `status` +- `created_at` +- optional revocation metadata + +Supported session states: + +- `active` +- `revoked` + +Built-in revoke reason codes: + +- `device_logout` +- `logout_all` +- `admin_revoke` +- `user_blocked` +- `confirm_race_repair` for best-effort cleanup of superseded sessions created + during a confirm race + +Revoke behavior is intentionally separated by use case: + +- revoke one device session +- revoke all sessions of one user +- block a subject and revoke active sessions implied by that subject + +Internal mutation responses report only sessions changed by the current call, +so repeated idempotent operations may return: + +- `already_revoked` with `affected_session_count=0` +- `no_active_sessions` with `affected_session_count=0` +- `already_blocked` with `affected_session_count=0` + +## User Resolution And Session Limits + +`Auth / Session Service` does not own durable user records. It delegates to +`UserDirectory` for: + +- resolve-by-email without mutation +- ensure existing-or-created user during confirm +- existence checks for stable `user_id` +- block-by-user-id and block-by-email operations + +Supported user-resolution outcomes: + +- `existing` +- `creatable` +- `blocked` + +Supported ensure-user outcomes: + +- `existing` +- `created` +- `blocked` + +Session-limit rules: + +- the value is loaded from a shared config provider +- absent value means the limit is disabled +- active sessions are counted before creating a new one +- limit overflow returns `session_limit_exceeded` +- the service never silently revokes an existing session to satisfy the limit + +## Gateway Projection Model + +Gateway-facing session projection is separate from source-of-truth +`devicesession.Session`. + +Each successful projection publish writes: + +- one Redis KV snapshot under + `` +- one full-snapshot Redis Stream event under the session-events stream + +The default gateway-facing namespaces are: + +- cache key prefix: `gateway:session:` +- session-events stream: `gateway:session_events` + +Projected fields are intentionally limited to what gateway consumes: + +- `device_session_id` +- `user_id` +- `client_public_key` +- `status` +- optional `revoked_at_ms` + +Revoke reason and actor metadata stay in authsession source of truth and are +not projected to gateway. + +## Consistency Model + +Source of truth is written first. Gateway projection is published only after +the source-of-truth write succeeds. + +Caller-visible rules: + +- if projection publication does not reach its required success threshold, the + public or internal call returns `service_unavailable` +- already-written source-of-truth state is intentionally preserved +- the documented repair path is to repeat the same confirm or revoke command + +Projection publish rules: + +- request-path projection publish uses a bounded retry loop with `3` total + attempts +- repeated publishes are safe because the cache snapshot is overwritten and + duplicate full-snapshot stream events remain valid under gateway's + later-event-wins model +- `confirm-email-code` rereads the stored session after the challenge CAS + succeeds and republishes that current view so a concurrent revoke or block + cannot overwrite source of truth with a stale active projection +- idempotent confirm retry also republishes the stored session view +- best-effort cleanup of superseded confirm-race sessions uses the same + publish helper but is not part of the caller-visible success contract + +## Runtime Summary + +Runtime wiring is implemented in [`internal/app`](internal/app) and +[`cmd/authsession`](cmd/authsession/main.go). + +Process-local collaborators: + +- system UTC clock +- crypto-random `challenge_id` and `device_session_id` generators +- crypto-random 6-digit confirmation-code generator +- bcrypt-backed code hashing +- structured logging through `zap` +- process telemetry through OpenTelemetry + +Redis-backed adapters: + +- challenge store +- session store +- session-limit config provider +- gateway projection publisher +- send-email-code abuse protector + +External service adapters: + +- user-service: + - default `stub` + - optional REST adapter with one retry for read-style methods on transport + errors and HTTP `502`, `503`, or `504` + - mutation methods do not auto-retry +- mail-service: + - default `stub` + - optional REST adapter with no automatic retry on transport or upstream + failure, to avoid duplicate deliveries + +Listener defaults: + +- public HTTP: `:8080` +- internal HTTP: `:8081` +- read-header timeout: `2s` +- read timeout: `10s` +- idle timeout: `1m` +- per-request use-case timeout: `3s` + +For detailed runtime behavior, configuration groups, operational notes, and +examples, see [`docs/README.md`](docs/README.md). + +## Non-Goals + +- making authsession a hot synchronous dependency for every authenticated + gateway command +- moving business authorization into authsession +- exposing revoke or read operations as public unauthenticated routes +- introducing short-lived access-token or refresh-token flows +- adding pending async session provisioning after confirm diff --git a/authsession/api/internal-openapi.yaml b/authsession/api/internal-openapi.yaml new file mode 100644 index 0000000..7e6b1a1 --- /dev/null +++ b/authsession/api/internal-openapi.yaml @@ -0,0 +1,456 @@ +openapi: 3.0.3 +info: + title: Galaxy Auth / Session Service Internal API + version: v1 + description: | + This specification documents the implemented `galaxy/authsession` v1 + trusted internal REST contract. + + Contract rules: + - the internal surface lives under `/api/v1/internal`; + - all request and response bodies are JSON only; + - read operations return canonical session DTO wrappers; + - mutating operations return explicit `200` JSON acknowledgements; + - mutation requests carry audit metadata as `reason_code` and `actor`; + - `BlockUser` accepts exactly one of `user_id` or `email`; + - `ListUserSessions` is newest-first and unpaginated in v1. +tags: + - name: InternalAuthSession + description: Trusted internal session read, revoke, and block operations. +paths: + /api/v1/internal/sessions/{device_session_id}: + get: + tags: + - InternalAuthSession + operationId: getSession + summary: Read one device session + parameters: + - $ref: "#/components/parameters/DeviceSessionID" + responses: + "200": + description: The requested device session. + content: + application/json: + schema: + $ref: "#/components/schemas/GetSessionResponse" + "404": + $ref: "#/components/responses/SessionNotFoundError" + "500": + $ref: "#/components/responses/InternalError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" + /api/v1/internal/users/{user_id}/sessions: + get: + tags: + - InternalAuthSession + operationId: listUserSessions + summary: List all active and revoked sessions of one user + description: | + Returns the full v1 session list for one user. Results are ordered from + newest to oldest and are intentionally unpaginated in v1. + parameters: + - $ref: "#/components/parameters/UserID" + responses: + "200": + description: | + Sessions belonging to the requested user. Returns an empty array + when the user has no stored sessions, including unknown `user_id` + values. + content: + application/json: + schema: + $ref: "#/components/schemas/ListUserSessionsResponse" + "500": + $ref: "#/components/responses/InternalError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" + /api/v1/internal/sessions/{device_session_id}/revoke: + post: + tags: + - InternalAuthSession + operationId: revokeDeviceSession + summary: Revoke one device session + parameters: + - $ref: "#/components/parameters/DeviceSessionID" + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RevokeDeviceSessionRequest" + responses: + "200": + description: Explicit idempotent acknowledgement of the revoke result. + content: + application/json: + schema: + $ref: "#/components/schemas/RevokeDeviceSessionResponse" + "400": + $ref: "#/components/responses/InvalidRequestError" + "404": + $ref: "#/components/responses/SessionNotFoundError" + "500": + $ref: "#/components/responses/InternalError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" + /api/v1/internal/users/{user_id}/sessions/revoke-all: + post: + tags: + - InternalAuthSession + operationId: revokeAllUserSessions + summary: Revoke all sessions of one user + parameters: + - $ref: "#/components/parameters/UserID" + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RevokeAllUserSessionsRequest" + responses: + "200": + description: Explicit idempotent acknowledgement of the bulk revoke result. + content: + application/json: + schema: + $ref: "#/components/schemas/RevokeAllUserSessionsResponse" + "400": + $ref: "#/components/responses/InvalidRequestError" + "404": + $ref: "#/components/responses/SubjectNotFoundError" + "500": + $ref: "#/components/responses/InternalError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" + /api/v1/internal/user-blocks: + post: + tags: + - InternalAuthSession + operationId: blockUser + summary: Block future auth flow for one subject and revoke active sessions + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/BlockUserRequest" + responses: + "200": + description: Explicit idempotent acknowledgement of the block result. + content: + application/json: + schema: + $ref: "#/components/schemas/BlockUserResponse" + "400": + $ref: "#/components/responses/InvalidRequestError" + "404": + $ref: "#/components/responses/SubjectNotFoundError" + "500": + $ref: "#/components/responses/InternalError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" +components: + parameters: + DeviceSessionID: + name: device_session_id + in: path + required: true + description: Stable identifier of one device session. + schema: + type: string + UserID: + name: user_id + in: path + required: true + description: Stable identifier of one user. + schema: + type: string + schemas: + Actor: + type: object + additionalProperties: false + required: + - type + properties: + type: + type: string + description: Machine-readable actor type such as `system`, `service`, or `admin`. + id: + type: string + description: Optional stable identifier of the initiating actor. + ErrorResponse: + type: object + additionalProperties: false + required: + - error + properties: + error: + $ref: "#/components/schemas/ErrorBody" + ErrorBody: + type: object + additionalProperties: false + required: + - code + - message + properties: + code: + type: string + description: Stable internal API error code. + message: + type: string + description: Human-readable error description safe for trusted internal callers. + Session: + type: object + additionalProperties: false + required: + - device_session_id + - user_id + - client_public_key + - status + - created_at + properties: + device_session_id: + type: string + user_id: + type: string + client_public_key: + type: string + description: Standard base64-encoded raw 32-byte Ed25519 public key of the device session. + status: + type: string + enum: + - active + - revoked + created_at: + type: string + format: date-time + description: RFC3339 UTC timestamp when the session was created. + revoked_at: + type: string + format: date-time + nullable: true + description: RFC3339 UTC timestamp when the session was revoked. + revoke_reason_code: + type: string + nullable: true + description: Machine-readable revoke reason code when the session is revoked. + revoke_actor_type: + type: string + nullable: true + description: Actor type that initiated the revoke. + revoke_actor_id: + type: string + nullable: true + description: Optional stable actor identifier that initiated the revoke. + GetSessionResponse: + type: object + additionalProperties: false + required: + - session + properties: + session: + $ref: "#/components/schemas/Session" + ListUserSessionsResponse: + type: object + additionalProperties: false + required: + - sessions + properties: + sessions: + type: array + description: Full newest-first session list for the requested user. + items: + $ref: "#/components/schemas/Session" + RevokeDeviceSessionRequest: + type: object + additionalProperties: false + required: + - reason_code + - actor + properties: + reason_code: + type: string + description: Machine-readable revoke reason code. + actor: + $ref: "#/components/schemas/Actor" + RevokeDeviceSessionResponse: + type: object + additionalProperties: false + required: + - outcome + - device_session_id + - affected_session_count + properties: + outcome: + type: string + enum: + - revoked + - already_revoked + device_session_id: + type: string + affected_session_count: + type: integer + format: int64 + minimum: 0 + RevokeAllUserSessionsRequest: + type: object + additionalProperties: false + required: + - reason_code + - actor + properties: + reason_code: + type: string + description: Machine-readable bulk revoke reason code. + actor: + $ref: "#/components/schemas/Actor" + RevokeAllUserSessionsResponse: + type: object + additionalProperties: false + required: + - outcome + - user_id + - affected_session_count + - affected_device_session_ids + properties: + outcome: + type: string + enum: + - revoked + - no_active_sessions + user_id: + type: string + affected_session_count: + type: integer + format: int64 + minimum: 0 + affected_device_session_ids: + type: array + items: + type: string + BlockUserRequest: + oneOf: + - $ref: "#/components/schemas/BlockUserByUserIDRequest" + - $ref: "#/components/schemas/BlockUserByEmailRequest" + BlockUserByUserIDRequest: + type: object + additionalProperties: false + required: + - user_id + - reason_code + - actor + properties: + user_id: + type: string + reason_code: + type: string + description: Machine-readable block reason code. + actor: + $ref: "#/components/schemas/Actor" + BlockUserByEmailRequest: + type: object + additionalProperties: false + required: + - email + - reason_code + - actor + properties: + email: + type: string + format: email + reason_code: + type: string + description: Machine-readable block reason code. + actor: + $ref: "#/components/schemas/Actor" + BlockUserResponse: + type: object + additionalProperties: false + required: + - outcome + - subject_kind + - subject_value + - affected_session_count + - affected_device_session_ids + properties: + outcome: + type: string + enum: + - blocked + - already_blocked + subject_kind: + type: string + enum: + - user_id + - email + subject_value: + type: string + affected_session_count: + type: integer + format: int64 + minimum: 0 + affected_device_session_ids: + type: array + items: + type: string + responses: + InvalidRequestError: + description: Request path, parameters, or body fields are invalid. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + invalidRequest: + value: + error: + code: invalid_request + message: reason_code must not be empty + SessionNotFoundError: + description: The referenced device session does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + sessionNotFound: + value: + error: + code: session_not_found + message: session not found + SubjectNotFoundError: + description: The referenced internal block or bulk-revoke subject does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + subjectNotFound: + value: + error: + code: subject_not_found + message: subject not found + ServiceUnavailableError: + description: A required dependency is temporarily unavailable. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + unavailable: + value: + error: + code: service_unavailable + message: service is unavailable + InternalError: + description: Unexpected internal failure while processing the request. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + internal: + value: + error: + code: internal_error + message: internal server error diff --git a/authsession/api/public-openapi.yaml b/authsession/api/public-openapi.yaml new file mode 100644 index 0000000..c3b027b --- /dev/null +++ b/authsession/api/public-openapi.yaml @@ -0,0 +1,284 @@ +openapi: 3.0.3 +info: + title: Galaxy Auth / Session Service Public API + version: v1 + description: | + This specification documents the implemented `galaxy/authsession` v1 + public REST contract for the e-mail-code flow consumed by + `galaxy/gateway`. + + Implemented public operations: + - `POST /api/v1/public/auth/send-email-code` + - `POST /api/v1/public/auth/confirm-email-code` + + Contract rules: + - requests and responses are JSON only; + - request schemas reject unknown fields via `additionalProperties: false`; + - empty bodies, malformed JSON, multiple JSON objects, and unknown fields + are rejected as `400 invalid_request`; + - surrounding ASCII/Unicode whitespace is trimmed from input string fields + before validation; + - `send-email-code` remains success-shaped for existing, new, and blocked + e-mail addresses; + - `confirm-email-code` returns a ready `device_session_id` synchronously on + success. +tags: + - name: PublicAuth + description: Public unauthenticated e-mail-code authentication endpoints. +paths: + /api/v1/public/auth/send-email-code: + post: + tags: + - PublicAuth + operationId: sendEmailCode + summary: Start a public e-mail login challenge + description: | + Accepts one client e-mail address and starts the public challenge flow. + The outward result remains success-shaped even when the underlying + policy suppresses mail delivery for anti-enumeration purposes. + security: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/SendEmailCodeRequest" + examples: + default: + value: + email: pilot@example.com + responses: + "200": + description: The login challenge was accepted. + content: + application/json: + schema: + $ref: "#/components/schemas/SendEmailCodeResponse" + examples: + accepted: + value: + challenge_id: challenge-123 + "400": + $ref: "#/components/responses/SendEmailCodeBadRequestError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" + /api/v1/public/auth/confirm-email-code: + post: + tags: + - PublicAuth + operationId: confirmEmailCode + summary: Confirm a public e-mail login challenge + description: | + Completes a previously issued `challenge_id`, validates the submitted + verification code, registers the standard base64-encoded raw 32-byte + Ed25519 `client_public_key`, and returns the created + `device_session_id`. + security: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ConfirmEmailCodeRequest" + examples: + default: + value: + challenge_id: challenge-123 + code: "123456" + client_public_key: 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= + responses: + "200": + description: The device session was created and is ready for use. + content: + application/json: + schema: + $ref: "#/components/schemas/ConfirmEmailCodeResponse" + examples: + accepted: + value: + device_session_id: device-session-123 + "400": + $ref: "#/components/responses/ConfirmEmailCodeBadRequestError" + "403": + $ref: "#/components/responses/BlockedByPolicyError" + "404": + $ref: "#/components/responses/ChallengeNotFoundError" + "409": + $ref: "#/components/responses/SessionLimitExceededError" + "410": + $ref: "#/components/responses/ChallengeExpiredError" + "503": + $ref: "#/components/responses/ServiceUnavailableError" +components: + schemas: + SendEmailCodeRequest: + type: object + additionalProperties: false + required: + - email + properties: + email: + type: string + description: Single client e-mail address that should receive the login code. + format: email + SendEmailCodeResponse: + type: object + additionalProperties: false + required: + - challenge_id + properties: + challenge_id: + type: string + description: Opaque challenge identifier returned by the Auth / Session Service. + ConfirmEmailCodeRequest: + type: object + additionalProperties: false + required: + - challenge_id + - code + - client_public_key + properties: + challenge_id: + type: string + description: Opaque challenge identifier previously returned by send-email-code. + code: + type: string + description: Verification code delivered to the client. + client_public_key: + type: string + description: Standard base64-encoded raw 32-byte Ed25519 public key registered for the new device session. + ConfirmEmailCodeResponse: + type: object + additionalProperties: false + required: + - device_session_id + properties: + device_session_id: + type: string + description: Stable identifier of the created device session. + ErrorResponse: + type: object + additionalProperties: false + required: + - error + properties: + error: + $ref: "#/components/schemas/ErrorBody" + ErrorBody: + type: object + additionalProperties: false + required: + - code + - message + properties: + code: + type: string + description: | + Stable gateway-generated or client-safe auth-adapter-projected + error code. Gateway-generated values include `invalid_request`, + `not_found`, `method_not_allowed`, `request_too_large`, + `rate_limited`, `internal_error`, and `service_unavailable`. + message: + type: string + description: Human-readable client-safe error description. + responses: + SendEmailCodeBadRequestError: + description: | + Request body or field values are invalid. This includes empty bodies, + malformed JSON, multiple JSON objects, unknown fields, and invalid + `email`. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + invalidRequest: + value: + error: + code: invalid_request + message: email must be a single valid email address + ConfirmEmailCodeBadRequestError: + description: | + Request body or field values are invalid. This includes malformed + request payloads, invalid confirmation codes, and malformed + `client_public_key` values. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + invalidRequest: + value: + error: + code: invalid_request + message: challenge_id must not be empty + invalidCode: + value: + error: + code: invalid_code + message: confirmation code is invalid + invalidClientPublicKey: + value: + error: + code: invalid_client_public_key + message: client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key + ChallengeNotFoundError: + description: The referenced challenge does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + notFound: + value: + error: + code: challenge_not_found + message: challenge not found + ChallengeExpiredError: + description: The referenced challenge has expired and can no longer be confirmed. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + expired: + value: + error: + code: challenge_expired + message: challenge expired + BlockedByPolicyError: + description: The auth flow is denied by account or registration policy. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + blocked: + value: + error: + code: blocked_by_policy + message: authentication is blocked by policy + SessionLimitExceededError: + description: Creating another active device session would exceed the configured limit. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + limitExceeded: + value: + error: + code: session_limit_exceeded + message: active session limit would be exceeded + ServiceUnavailableError: + description: The service is temporarily unable to serve the request safely. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + examples: + unavailable: + value: + error: + code: service_unavailable + message: service is unavailable diff --git a/authsession/cmd/authsession/main.go b/authsession/cmd/authsession/main.go new file mode 100644 index 0000000..c728324 --- /dev/null +++ b/authsession/cmd/authsession/main.go @@ -0,0 +1,72 @@ +package main + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + + "galaxy/authsession/internal/app" + "galaxy/authsession/internal/config" + "galaxy/authsession/internal/logging" + "galaxy/authsession/internal/telemetry" +) + +func main() { + if err := run(); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "authsession: %v\n", err) + os.Exit(1) + } +} + +func run() error { + cfg, err := config.LoadFromEnv() + if err != nil { + return err + } + + logger, err := logging.New(cfg.Logging.Level) + if err != nil { + return fmt.Errorf("build logger: %w", err) + } + defer func() { + _ = logging.Sync(logger) + }() + + rootCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + telemetryRuntime, err := telemetry.NewProcess(rootCtx, telemetry.ProcessConfig{ + ServiceName: cfg.Telemetry.ServiceName, + TracesExporter: cfg.Telemetry.TracesExporter, + MetricsExporter: cfg.Telemetry.MetricsExporter, + TracesProtocol: cfg.Telemetry.TracesProtocol, + MetricsProtocol: cfg.Telemetry.MetricsProtocol, + StdoutTracesEnabled: cfg.Telemetry.StdoutTracesEnabled, + StdoutMetricsEnabled: cfg.Telemetry.StdoutMetricsEnabled, + }, logger) + if err != nil { + return fmt.Errorf("build telemetry runtime: %w", err) + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout) + defer cancel() + _ = telemetryRuntime.Shutdown(shutdownCtx) + }() + + runtime, err := app.NewRuntime(rootCtx, cfg, logger, telemetryRuntime) + if err != nil { + return err + } + defer func() { + _ = runtime.Close() + }() + + if err := runtime.App.Run(rootCtx); err != nil && !errors.Is(err, context.Canceled) { + return err + } + + return nil +} diff --git a/authsession/contract_openapi_test.go b/authsession/contract_openapi_test.go new file mode 100644 index 0000000..2c4296c --- /dev/null +++ b/authsession/contract_openapi_test.go @@ -0,0 +1,418 @@ +package authsession + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "path/filepath" + "runtime" + "slices" + "testing" + + "galaxy/authsession/internal/service/shared" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +func TestPublicOpenAPISpecValidates(t *testing.T) { + t.Parallel() + + loadSpec(t, "api", "public-openapi.yaml") +} + +func TestInternalOpenAPISpecValidates(t *testing.T) { + t.Parallel() + + loadSpec(t, "api", "internal-openapi.yaml") +} + +func TestPublicOpenAPISpecMatchesGatewayPublicAuthContract(t *testing.T) { + t.Parallel() + + authDoc := loadSpec(t, "api", "public-openapi.yaml") + gatewayDoc := loadSpec(t, "..", "gateway", "openapi.yaml") + authErrorEnvelope := componentSchemaRef(t, authDoc, "ErrorResponse") + gatewayProjectedEnvelope := defaultResponseSchemaRef(t, getOperation(t, gatewayDoc, "/api/v1/public/auth/send-email-code", http.MethodPost)) + const errorResponseRef = "#/components/schemas/ErrorResponse" + + paths := []string{ + "/api/v1/public/auth/send-email-code", + "/api/v1/public/auth/confirm-email-code", + } + + for _, path := range paths { + authOperation := getOperation(t, authDoc, path, http.MethodPost) + gatewayOperation := getOperation(t, gatewayDoc, path, http.MethodPost) + + if authOperation.OperationID != gatewayOperation.OperationID { + require.Failf(t, "test failed", "operation %s: got operationId %q, want %q", path, authOperation.OperationID, gatewayOperation.OperationID) + } + + compareSchemaRefs( + t, + requestSchemaRef(t, authOperation), + requestSchemaRef(t, gatewayOperation), + "path "+path+" request schema", + ) + compareSchemaRefs( + t, + responseSchemaRef(t, authOperation, http.StatusOK), + responseSchemaRef(t, gatewayOperation, http.StatusOK), + "path "+path+" success response schema", + ) + + for _, status := range publicErrorStatuses(path) { + assertSchemaRef(t, responseSchemaRef(t, authOperation, status), errorResponseRef, "path "+path+" error response "+http.StatusText(status)+" envelope") + } + } + + compareSchemaRefs( + t, + authErrorEnvelope, + componentSchemaRef(t, gatewayDoc, "ErrorResponse"), + "ErrorResponse schema", + ) + compareSchemaRefs( + t, + componentSchemaRef(t, authDoc, "ErrorBody"), + componentSchemaRef(t, gatewayDoc, "ErrorBody"), + "ErrorBody schema", + ) + assertSchemaRef(t, gatewayProjectedEnvelope, errorResponseRef, "projected gateway auth error envelope") +} + +func TestPublicOpenAPISpecErrorExamplesMatchStablePublicErrors(t *testing.T) { + t.Parallel() + + doc := loadSpec(t, "api", "public-openapi.yaml") + + tests := []struct { + name string + responseName string + exampleName string + projection shared.PublicErrorProjection + }{ + { + name: "send invalid request", + responseName: "SendEmailCodeBadRequestError", + exampleName: "invalidRequest", + projection: shared.ProjectPublicError(shared.InvalidRequest("email must be a single valid email address")), + }, + { + name: "confirm invalid request", + responseName: "ConfirmEmailCodeBadRequestError", + exampleName: "invalidRequest", + projection: shared.ProjectPublicError(shared.InvalidRequest("challenge_id must not be empty")), + }, + { + name: "confirm invalid code", + responseName: "ConfirmEmailCodeBadRequestError", + exampleName: "invalidCode", + projection: shared.ProjectPublicError(shared.InvalidCode()), + }, + { + name: "confirm invalid client public key", + responseName: "ConfirmEmailCodeBadRequestError", + exampleName: "invalidClientPublicKey", + projection: shared.ProjectPublicError(shared.InvalidClientPublicKey()), + }, + { + name: "challenge not found", + responseName: "ChallengeNotFoundError", + exampleName: "notFound", + projection: shared.ProjectPublicError(shared.ChallengeNotFound()), + }, + { + name: "challenge expired", + responseName: "ChallengeExpiredError", + exampleName: "expired", + projection: shared.ProjectPublicError(shared.ChallengeExpired()), + }, + { + name: "blocked by policy", + responseName: "BlockedByPolicyError", + exampleName: "blocked", + projection: shared.ProjectPublicError(shared.BlockedByPolicy()), + }, + { + name: "session limit exceeded", + responseName: "SessionLimitExceededError", + exampleName: "limitExceeded", + projection: shared.ProjectPublicError(shared.SessionLimitExceeded()), + }, + { + name: "service unavailable", + responseName: "ServiceUnavailableError", + exampleName: "unavailable", + projection: shared.ProjectPublicError(shared.ServiceUnavailable(nil)), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := responseExampleValue(t, doc, tt.responseName, tt.exampleName) + want := map[string]any{ + "error": map[string]any{ + "code": tt.projection.Code, + "message": tt.projection.Message, + }, + } + + require.JSONEq(t, string(mustJSON(t, want)), string(mustJSON(t, got))) + }) + } +} + +func TestInternalOpenAPISpecFreezesMutationContracts(t *testing.T) { + t.Parallel() + + doc := loadSpec(t, "api", "internal-openapi.yaml") + + blockUser := componentSchemaRef(t, doc, "BlockUserRequest") + if got := len(blockUser.Value.OneOf); got != 2 { + require.Failf(t, "test failed", "BlockUserRequest oneOf length = %d, want 2", got) + } + + refs := []string{ + blockUser.Value.OneOf[0].Ref, + blockUser.Value.OneOf[1].Ref, + } + slices.Sort(refs) + wantRefs := []string{ + "#/components/schemas/BlockUserByEmailRequest", + "#/components/schemas/BlockUserByUserIDRequest", + } + if !slices.Equal(refs, wantRefs) { + require.Failf(t, "test failed", "BlockUserRequest oneOf refs = %v, want %v", refs, wantRefs) + } + + assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByUserIDRequest"), "reason_code", "actor", "user_id") + assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByEmailRequest"), "reason_code", "actor", "email") + assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "device_session_id", "affected_session_count") + assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "user_id", "affected_session_count", "affected_device_session_ids") + assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "subject_kind", "subject_value", "affected_session_count", "affected_device_session_ids") + + assertStringEnum(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "revoked", "already_revoked") + assertStringEnum(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "revoked", "no_active_sessions") + assertStringEnum(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "blocked", "already_blocked") +} + +func loadSpec(t *testing.T, pathElems ...string) *openapi3.T { + t.Helper() + + _, thisFile, _, ok := runtime.Caller(0) + if !ok { + require.FailNow(t, "runtime.Caller failed") + } + + specPath := filepath.Join(append([]string{filepath.Dir(thisFile)}, pathElems...)...) + loader := openapi3.NewLoader() + doc, err := loader.LoadFromFile(specPath) + if err != nil { + require.Failf(t, "test failed", "load spec %s: %v", specPath, err) + } + if doc == nil { + require.Failf(t, "test failed", "load spec %s: returned nil document", specPath) + } + if doc.Info == nil { + require.Failf(t, "test failed", "load spec %s: missing info section", specPath) + } + if doc.Info.Version != "v1" { + require.Failf(t, "test failed", "spec %s version = %q, want v1", specPath, doc.Info.Version) + } + if err := doc.Validate(context.Background()); err != nil { + require.Failf(t, "test failed", "validate spec %s: %v", specPath, err) + } + + return doc +} + +func getOperation(t *testing.T, doc *openapi3.T, path string, method string) *openapi3.Operation { + t.Helper() + + if doc.Paths == nil { + require.Failf(t, "test failed", "spec is missing paths while looking up %s %s", method, path) + } + pathItem := doc.Paths.Value(path) + if pathItem == nil { + require.Failf(t, "test failed", "spec is missing path %s", path) + } + operation := pathItem.GetOperation(method) + if operation == nil { + require.Failf(t, "test failed", "spec is missing %s operation for path %s", method, path) + } + + return operation +} + +func requestSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef { + t.Helper() + + if operation.RequestBody == nil || operation.RequestBody.Value == nil { + require.FailNow(t, "operation is missing request body") + } + mediaType := operation.RequestBody.Value.Content.Get("application/json") + if mediaType == nil || mediaType.Schema == nil { + require.FailNow(t, "operation is missing application/json request schema") + } + + return mediaType.Schema +} + +func responseSchemaRef(t *testing.T, operation *openapi3.Operation, status int) *openapi3.SchemaRef { + t.Helper() + + if operation.Responses == nil { + require.Failf(t, "test failed", "operation is missing responses for status %d", status) + } + response := operation.Responses.Status(status) + if response == nil || response.Value == nil { + require.Failf(t, "test failed", "operation is missing response for status %d", status) + } + mediaType := response.Value.Content.Get("application/json") + if mediaType == nil || mediaType.Schema == nil { + require.Failf(t, "test failed", "operation response %d is missing application/json schema", status) + } + + return mediaType.Schema +} + +func defaultResponseSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef { + t.Helper() + + if operation.Responses == nil { + require.FailNow(t, "operation is missing default responses") + } + response := operation.Responses.Default() + if response == nil || response.Value == nil { + require.FailNow(t, "operation is missing default response") + } + mediaType := response.Value.Content.Get("application/json") + if mediaType == nil || mediaType.Schema == nil { + require.FailNow(t, "operation default response is missing application/json schema") + } + + return mediaType.Schema +} + +func componentSchemaRef(t *testing.T, doc *openapi3.T, name string) *openapi3.SchemaRef { + t.Helper() + + if doc.Components == nil { + require.Failf(t, "test failed", "spec is missing components while looking up schema %s", name) + } + schema := doc.Components.Schemas[name] + if schema == nil || schema.Value == nil { + require.Failf(t, "test failed", "spec is missing schema %s", name) + } + + return schema +} + +func responseExampleValue(t *testing.T, doc *openapi3.T, responseName string, exampleName string) any { + t.Helper() + + if doc.Components == nil { + require.Failf(t, "test failed", "spec is missing components while looking up response %s", responseName) + } + response := doc.Components.Responses[responseName] + if response == nil || response.Value == nil { + require.Failf(t, "test failed", "spec is missing response %s", responseName) + } + mediaType := response.Value.Content.Get("application/json") + if mediaType == nil { + require.Failf(t, "test failed", "response %s is missing application/json content", responseName) + } + example := mediaType.Examples[exampleName] + if example == nil || example.Value == nil { + require.Failf(t, "test failed", "response %s is missing example %s", responseName, exampleName) + } + + return example.Value.Value +} + +func compareSchemaRefs(t *testing.T, got *openapi3.SchemaRef, want *openapi3.SchemaRef, name string) { + t.Helper() + + gotJSON := mustJSON(t, got) + wantJSON := mustJSON(t, want) + if !bytes.Equal(gotJSON, wantJSON) { + require.Failf(t, "test failed", "%s mismatch:\n got: %s\nwant: %s", name, gotJSON, wantJSON) + } +} + +func assertSchemaRef(t *testing.T, schemaRef *openapi3.SchemaRef, want string, name string) { + t.Helper() + + if schemaRef.Ref != want { + require.Failf(t, "test failed", "%s ref = %q, want %q", name, schemaRef.Ref, want) + } +} + +func assertRequiredFields(t *testing.T, schemaRef *openapi3.SchemaRef, fields ...string) { + t.Helper() + + required := append([]string(nil), schemaRef.Value.Required...) + slices.Sort(required) + want := append([]string(nil), fields...) + slices.Sort(want) + if !slices.Equal(required, want) { + require.Failf(t, "test failed", "schema required fields = %v, want %v", required, want) + } +} + +func assertStringEnum(t *testing.T, schemaRef *openapi3.SchemaRef, property string, values ...string) { + t.Helper() + + prop := schemaRef.Value.Properties[property] + if prop == nil || prop.Value == nil { + require.Failf(t, "test failed", "schema is missing property %s", property) + } + + got := make([]string, 0, len(prop.Value.Enum)) + for _, raw := range prop.Value.Enum { + value, ok := raw.(string) + if !ok { + require.Failf(t, "test failed", "property %s enum contains non-string value %T", property, raw) + } + got = append(got, value) + } + + if !slices.Equal(got, values) { + require.Failf(t, "test failed", "property %s enum = %v, want %v", property, got, values) + } +} + +func mustJSON(t *testing.T, value any) []byte { + t.Helper() + + data, err := json.Marshal(value) + if err != nil { + require.Failf(t, "test failed", "marshal JSON: %v", err) + } + + return data +} + +func publicErrorStatuses(path string) []int { + switch path { + case "/api/v1/public/auth/send-email-code": + return []int{http.StatusBadRequest, http.StatusServiceUnavailable} + case "/api/v1/public/auth/confirm-email-code": + return []int{ + http.StatusBadRequest, + http.StatusForbidden, + http.StatusNotFound, + http.StatusConflict, + http.StatusGone, + http.StatusServiceUnavailable, + } + default: + panic("unexpected public auth path: " + path) + } +} diff --git a/authsession/docs/README.md b/authsession/docs/README.md new file mode 100644 index 0000000..957a122 --- /dev/null +++ b/authsession/docs/README.md @@ -0,0 +1,22 @@ +# Auth / Session Service Docs + +This directory keeps service-local documentation that is too detailed for the +root architecture document and too operational for the OpenAPI specs. + +Sections: + +- [Runtime and components](runtime.md) +- [Auth, revoke, and repair flows](flows.md) +- [Operator runbook](runbook.md) +- [Configuration and contract examples](examples.md) + +Primary references: + +- [`../README.md`](../README.md) for service scope, contracts, and core domain + rules +- [`../api/public-openapi.yaml`](../api/public-openapi.yaml) for the public + REST contract +- [`../api/internal-openapi.yaml`](../api/internal-openapi.yaml) for the + trusted internal REST contract +- [`../../gateway/README.md`](../../gateway/README.md) for the downstream + consumer of authsession's public DTOs and Redis session projection diff --git a/authsession/docs/examples.md b/authsession/docs/examples.md new file mode 100644 index 0000000..dbaef02 --- /dev/null +++ b/authsession/docs/examples.md @@ -0,0 +1,194 @@ +# Configuration And Contract Examples + +The examples below are illustrative. Values such as keys, codes, and IDs are +placeholders unless explicitly stated otherwise. + +## Example Environment + +Minimal local-development shape: + +```dotenv +AUTHSESSION_REDIS_ADDR=127.0.0.1:6379 +AUTHSESSION_PUBLIC_HTTP_ADDR=:8080 +AUTHSESSION_INTERNAL_HTTP_ADDR=:8081 + +AUTHSESSION_USER_SERVICE_MODE=stub +AUTHSESSION_MAIL_SERVICE_MODE=stub + +OTEL_SERVICE_NAME=galaxy-authsession +OTEL_TRACES_EXPORTER=none +OTEL_METRICS_EXPORTER=none +``` + +Example REST-backed integration shape: + +```dotenv +AUTHSESSION_REDIS_ADDR=127.0.0.1:6379 + +AUTHSESSION_USER_SERVICE_MODE=rest +AUTHSESSION_USER_SERVICE_BASE_URL=http://127.0.0.1:8091 +AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT=1s + +AUTHSESSION_MAIL_SERVICE_MODE=rest +AUTHSESSION_MAIL_SERVICE_BASE_URL=http://127.0.0.1:8092 +AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT=1s +``` + +## Public Auth HTTP Examples + +Start an e-mail challenge: + +```bash +curl -X POST http://127.0.0.1:8080/api/v1/public/auth/send-email-code \ + -H 'Content-Type: application/json' \ + -d '{"email":"pilot@example.com"}' +``` + +Example response: + +```json +{ + "challenge_id": "challenge-123" +} +``` + +Confirm the challenge and register the device public key: + +```bash +curl -X POST http://127.0.0.1:8080/api/v1/public/auth/confirm-email-code \ + -H 'Content-Type: application/json' \ + -d '{ + "challenge_id": "challenge-123", + "code": "123456", + "client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=" + }' +``` + +Example response: + +```json +{ + "device_session_id": "device-session-123" +} +``` + +Stable public error example: + +```json +{ + "error": { + "code": "challenge_expired", + "message": "challenge expired" + } +} +``` + +## Trusted Internal HTTP Examples + +Read one session: + +```bash +curl http://127.0.0.1:8081/api/v1/internal/sessions/device-session-123 +``` + +Example response: + +```json +{ + "session": { + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=", + "status": "active", + "created_at": "2026-04-05T12:00:00Z" + } +} +``` + +Revoke one session: + +```bash +curl -X POST http://127.0.0.1:8081/api/v1/internal/sessions/device-session-123/revoke \ + -H 'Content-Type: application/json' \ + -d '{"reason_code":"admin_revoke","actor":{"type":"system"}}' +``` + +Example response: + +```json +{ + "outcome": "revoked", + "device_session_id": "device-session-123", + "affected_session_count": 1 +} +``` + +Block by e-mail: + +```bash +curl -X POST http://127.0.0.1:8081/api/v1/internal/user-blocks \ + -H 'Content-Type: application/json' \ + -d '{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin","id":"admin-1"}}' +``` + +Example response: + +```json +{ + "outcome": "blocked", + "subject_kind": "email", + "subject_value": "pilot@example.com", + "affected_session_count": 0, + "affected_device_session_ids": [] +} +``` + +## Redis Projection Examples + +### Gateway Session Cache Record + +Example Redis key and JSON value written by authsession for gateway: + +```text +gateway:session:device-session-123 +``` + +```json +{ + "device_session_id": "device-session-123", + "user_id": "user-123", + "client_public_key": "11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no=", + "status": "active" +} +``` + +### Gateway Session-Event Stream Entry + +Active snapshot: + +```bash +redis-cli XADD gateway:session_events '*' \ + device_session_id device-session-123 \ + user_id user-123 \ + client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \ + status active +``` + +Revoked snapshot: + +```bash +redis-cli XADD gateway:session_events '*' \ + device_session_id device-session-123 \ + user_id user-123 \ + client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \ + status revoked \ + revoked_at_ms 1775121700000 +``` + +Notes: + +- projected field values are strings in the Redis Stream payload +- `revoked_at_ms` is written only for revoked snapshots +- duplicate full-snapshot stream events are acceptable +- the cache snapshot and stream event intentionally omit revoke reason and + actor metadata because gateway does not consume them diff --git a/authsession/docs/flows.md b/authsession/docs/flows.md new file mode 100644 index 0000000..8d0a016 --- /dev/null +++ b/authsession/docs/flows.md @@ -0,0 +1,119 @@ +# Auth, Revoke, and Repair Flows + +## Public Auth Flow + +```mermaid +sequenceDiagram + participant Client + participant Gateway + participant Auth + participant Abuse as Resend throttle + participant User as UserDirectory + participant Mail as MailSender + participant Challenge as ChallengeStore + participant Session as SessionStore + participant Config as ConfigProvider + participant Projection as Gateway projection publisher + + Client->>Gateway: POST /api/v1/public/auth/send-email-code + Gateway->>Auth: POST /api/v1/public/auth/send-email-code + Auth->>Abuse: check and reserve cooldown + alt throttled + Abuse-->>Auth: throttled + Auth->>Challenge: create delivery_throttled challenge + Auth-->>Gateway: 200 {challenge_id} + else allowed + Abuse-->>Auth: allowed + Auth->>User: ResolveByEmail(email) + User-->>Auth: existing / creatable / blocked + Auth->>Challenge: create pending challenge + alt blocked + Auth->>Challenge: mark delivery_suppressed + else not blocked + Auth->>Mail: SendLoginCode(email, code) + Mail-->>Auth: sent / suppressed / failure + Auth->>Challenge: persist final delivery outcome + end + Auth-->>Gateway: 200 {challenge_id} + end + + Client->>Gateway: POST /api/v1/public/auth/confirm-email-code + Gateway->>Auth: POST /api/v1/public/auth/confirm-email-code + Auth->>Challenge: load and validate challenge + Auth->>User: EnsureUserByEmail(email) + User-->>Auth: existing / created / blocked + Auth->>Config: LoadSessionLimit() + Auth->>Session: CountActiveByUserID(user_id) + Auth->>Session: create device session + Auth->>Challenge: CAS to confirmed_pending_expire + Auth->>Session: reread current stored session view + Auth->>Projection: publish gateway snapshot + Auth-->>Gateway: 200 {device_session_id} +``` + +## Revoke and Block Flow + +```mermaid +sequenceDiagram + participant Caller as Trusted internal caller + participant Auth + participant User as UserDirectory + participant Session as SessionStore + participant Projection as Gateway projection publisher + participant Gateway + + Caller->>Auth: revoke or block request + alt block by user or email + Auth->>User: apply block mutation + User-->>Auth: blocked / already_blocked + end + Auth->>Session: revoke one or many sessions + Session-->>Auth: updated source-of-truth sessions + loop each affected session + Auth->>Projection: publish revoked snapshot + end + Auth-->>Caller: 200 acknowledgement + Projection-->>Gateway: revoked session snapshot +``` + +## Projection Repair On Retry + +Projection writes happen after source-of-truth updates. If projection publish +fails after state is already stored, the caller sees `service_unavailable`, and +the repair path is to repeat the same request. + +```mermaid +sequenceDiagram + participant Client + participant Auth + participant Challenge as ChallengeStore + participant Session as SessionStore + participant Projection as Gateway projection publisher + + Client->>Auth: confirm-email-code + Auth->>Challenge: validate challenge + Auth->>Session: create session + Auth->>Challenge: persist confirmed_pending_expire + Auth->>Projection: publish snapshot + Projection-->>Auth: failure + Auth-->>Client: 503 service_unavailable + + Client->>Auth: repeat same confirm-email-code + Auth->>Challenge: load confirmed_pending_expire challenge + Auth->>Session: load stored session from confirmation metadata + Auth->>Projection: republish current stored session view + Projection-->>Auth: success + Auth-->>Client: 200 {device_session_id} +``` + +## Confirm-Race Cleanup + +Concurrent identical confirms are allowed to race at the store level, but the +service converges them back to one surviving active session. + +- the winning CAS stores challenge confirmation metadata and publishes the + surviving session snapshot +- a superseded session created by a losing racing request is revoked + best-effort with `reason_code=confirm_race_repair` +- cleanup uses the same projection helper, but cleanup failure is not part of + the caller-visible success contract diff --git a/authsession/docs/runbook.md b/authsession/docs/runbook.md new file mode 100644 index 0000000..50082e4 --- /dev/null +++ b/authsession/docs/runbook.md @@ -0,0 +1,157 @@ +# Operator Runbook + +This runbook covers the checks that matter most during startup, steady-state +verification, shutdown, and common authsession incidents. + +## Startup Checks + +Before starting the process, confirm: + +- `AUTHSESSION_REDIS_ADDR` points to the Redis deployment used for authsession + source-of-truth data, resend throttling, and gateway projection +- the configured Redis ACL, DB, TLS, and key-prefix settings match the target + environment +- if `AUTHSESSION_USER_SERVICE_MODE=rest`, both + `AUTHSESSION_USER_SERVICE_BASE_URL` and + `AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT` are configured +- if `AUTHSESSION_MAIL_SERVICE_MODE=rest`, both + `AUTHSESSION_MAIL_SERVICE_BASE_URL` and + `AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT` are configured +- gateway and authsession agree on: + - `gateway:session:` cache key prefix + - `gateway:session_events` stream name + +At startup the process performs bounded `PING` checks for: + +- challenge store +- session store +- config provider +- gateway projection publisher +- resend-throttle protector + +Startup fails fast if any of those checks fail. + +Expected listener state after a healthy start: + +- public HTTP on `AUTHSESSION_PUBLIC_HTTP_ADDR` or default `:8080` +- internal HTTP on `AUTHSESSION_INTERNAL_HTTP_ADDR` or default `:8081` + +Known startup caveats: + +- there is no health, readiness, or metrics endpoint to probe directly +- stub user-service and stub mail-service are valid production start modes + only for development and isolated testing, not for real environments + +## Steady-State Verification + +Because the service intentionally exposes no `/healthz` or `/readyz`, practical +verification is: + +1. confirm the process emitted startup logs for both listeners +2. open a TCP connection to the configured public and internal listener + addresses +3. send one smoke request to the public auth surface and one to the trusted + internal surface when a non-destructive path is available +4. confirm Redis connectivity and namespace configuration out of band + +Recommended smoke requests: + +- public: malformed `send-email-code` request and expect `400 invalid_request` +- internal: `GET /api/v1/internal/users/{unknown}/sessions` and expect `200` + with an empty list + +## Shutdown + +The process handles `SIGINT` and `SIGTERM`. + +Shutdown behavior: + +- the per-component shutdown budget is controlled by + `AUTHSESSION_SHUTDOWN_TIMEOUT` +- both HTTP listeners are stopped through the coordinated app shutdown +- Redis and HTTP-client resources are closed after the app stops +- telemetry providers are flushed and shut down after the process begins + exiting + +During planned restarts: + +1. send `SIGTERM` +2. wait for the listener shutdown logs +3. restart the process with the same Redis configuration +4. re-run the steady-state verification steps above + +## Incident Triage + +### Confirm Returns `503` But A Later Retry Succeeds + +Interpret this as a projection-publication failure after source-of-truth state +was already written. + +Check: + +1. whether the challenge moved to `confirmed_pending_expire` +2. whether the created session exists in source of truth +3. whether Redis was reachable for gateway projection writes at the time of + failure +4. whether a repeated identical confirm repaired the gateway projection + +Expected behavior: + +- the first request returns `503 service_unavailable` +- the same confirm retried during the idempotency window returns the same + `device_session_id` + +### Revocation Does Not Reach Gateway + +If a revoked session still authenticates through gateway: + +1. verify the authsession source-of-truth record is revoked +2. verify a gateway projection snapshot was written under + `gateway:session:` +3. verify a matching snapshot event was appended to `gateway:session_events` +4. verify gateway is pointed at the same Redis address, DB, and stream name +5. check whether a later active snapshot overwrote the revoked view + +### Send Flow Is Unexpectedly Throttled + +If repeated `send-email-code` calls return challenge ids but no mail is sent: + +1. check the resend-throttle key namespace +2. confirm the same normalized e-mail address is being reused +3. verify the requests are inside the fixed `1m` cooldown window +4. confirm authsession is creating `delivery_throttled` challenges rather than + `delivery_suppressed` ones + +Expected throttled behavior: + +- a fresh `challenge_id` is still returned +- `UserDirectory` is not called +- `MailSender` is not called + +### User-Service Or Mail-Service REST Failures + +If `rest` mode is enabled and calls begin failing: + +1. verify the configured base URL +2. verify outbound connectivity from the authsession process +3. confirm request timeouts are large enough for the environment +4. for user-service reads, remember the client retries only once on transport + errors and `502`/`503`/`504` +5. for mail-service sends, remember the client never auto-retries + +Observed behavior: + +- public auth flows usually surface these failures as `503 service_unavailable` +- internal revoke and block flows surface them as `503 service_unavailable` + +### Expired Challenge Questions + +When callers report mixed `challenge_expired` and `challenge_not_found` +responses: + +- `challenge_expired` means the record still exists and has crossed the + expiration boundary +- `challenge_not_found` means the record is absent, including after Redis TTL + cleanup removes it + +That difference is expected and should not be treated as a contract drift. diff --git a/authsession/docs/runtime.md b/authsession/docs/runtime.md new file mode 100644 index 0000000..05b321b --- /dev/null +++ b/authsession/docs/runtime.md @@ -0,0 +1,176 @@ +# Runtime and Components + +The diagram below focuses on the deployed `galaxy/authsession` process and its +runtime dependencies. + +```mermaid +flowchart LR + subgraph Clients + Gateway["Edge Gateway"] + Internal["Trusted internal callers"] + end + + subgraph Authsession["Auth / Session Service process"] + PublicHTTP["Public HTTP listener\n/api/v1/public/auth/*"] + InternalHTTP["Trusted internal listener\n/api/v1/internal/*"] + Services["Application services"] + Runtime["Clock, IDs, code generation, hashing"] + Telemetry["Logs, traces, metrics"] + end + + Redis["Redis\nchallenges + sessions + config + projection + throttle"] + User["User Service\nstub or REST"] + Mail["Mail Service\nstub or REST"] + GatewayCache["Gateway session cache\nand session-events stream"] + + Gateway --> PublicHTTP + Internal --> InternalHTTP + PublicHTTP --> Services + InternalHTTP --> Services + Services --> Runtime + Services --> Redis + Services --> User + Services --> Mail + Services --> GatewayCache + PublicHTTP --> Telemetry + InternalHTTP --> Telemetry +``` + +## Listeners + +`authsession` exposes exactly two HTTP listeners: + +| Listener | Default addr | Purpose | +| --- | --- | --- | +| Public HTTP | `:8080` | Unauthenticated public auth routes consumed directly or through gateway | +| Internal HTTP | `:8081` | Trusted read, revoke, and block operations | + +Shared listener defaults: + +- read-header timeout: `2s` +- read timeout: `10s` +- idle timeout: `1m` +- per-request application timeout: `3s` + +Intentional omissions: + +- no `/healthz` +- no `/readyz` +- no `/metrics` +- no separate admin listener + +## Startup Wiring + +`cmd/authsession` loads process config, builds the logger and telemetry +runtime, then assembles the application through `internal/app.NewRuntime`. + +`NewRuntime` wires: + +- Redis-backed `ChallengeStore` +- Redis-backed `SessionStore` +- Redis-backed `ConfigProvider` +- Redis-backed gateway `ProjectionPublisher` +- Redis-backed resend-throttle `SendEmailCodeAbuseProtector` +- local runtime helpers for clock, ID generation, code generation, and code + hashing +- user-service adapter selected by `AUTHSESSION_USER_SERVICE_MODE` +- mail-service adapter selected by `AUTHSESSION_MAIL_SERVICE_MODE` +- public and internal HTTP servers + +Before startup completes, the process performs bounded `PING` checks for every +Redis-backed adapter listed above. Startup fails fast if any Redis-backed +dependency is unavailable or misconfigured. + +## Redis Namespaces + +Default Redis naming: + +- challenges: `authsession:challenge:` +- sessions: `authsession:session:` +- user-to-session index: `authsession:user-sessions:` +- user-to-active-session index: `authsession:user-active-sessions:` +- session limit key: `authsession:config:active-session-limit` +- send-email-code throttle keys: `authsession:send-email-code-throttle:` +- gateway session cache keys: `gateway:session:` +- gateway session-events stream: `gateway:session_events` + +The authsession process owns the source-of-truth namespaces and writes the +gateway-facing projection namespaces as a derived integration view. + +## Configuration Groups + +Required for all process starts: + +- `AUTHSESSION_REDIS_ADDR` + +Core process config: + +- `AUTHSESSION_SHUTDOWN_TIMEOUT` +- `AUTHSESSION_LOG_LEVEL` + +Public HTTP config: + +- `AUTHSESSION_PUBLIC_HTTP_ADDR` +- `AUTHSESSION_PUBLIC_HTTP_READ_HEADER_TIMEOUT` +- `AUTHSESSION_PUBLIC_HTTP_READ_TIMEOUT` +- `AUTHSESSION_PUBLIC_HTTP_IDLE_TIMEOUT` +- `AUTHSESSION_PUBLIC_HTTP_REQUEST_TIMEOUT` + +Internal HTTP config: + +- `AUTHSESSION_INTERNAL_HTTP_ADDR` +- `AUTHSESSION_INTERNAL_HTTP_READ_HEADER_TIMEOUT` +- `AUTHSESSION_INTERNAL_HTTP_READ_TIMEOUT` +- `AUTHSESSION_INTERNAL_HTTP_IDLE_TIMEOUT` +- `AUTHSESSION_INTERNAL_HTTP_REQUEST_TIMEOUT` + +Redis connectivity and namespace config: + +- `AUTHSESSION_REDIS_USERNAME` +- `AUTHSESSION_REDIS_PASSWORD` +- `AUTHSESSION_REDIS_DB` +- `AUTHSESSION_REDIS_TLS_ENABLED` +- `AUTHSESSION_REDIS_OPERATION_TIMEOUT` +- `AUTHSESSION_REDIS_CHALLENGE_KEY_PREFIX` +- `AUTHSESSION_REDIS_SESSION_KEY_PREFIX` +- `AUTHSESSION_REDIS_USER_SESSIONS_KEY_PREFIX` +- `AUTHSESSION_REDIS_USER_ACTIVE_SESSIONS_KEY_PREFIX` +- `AUTHSESSION_REDIS_SESSION_LIMIT_KEY` +- `AUTHSESSION_REDIS_GATEWAY_SESSION_CACHE_KEY_PREFIX` +- `AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM` +- `AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM_MAX_LEN` +- `AUTHSESSION_REDIS_SEND_EMAIL_CODE_THROTTLE_KEY_PREFIX` + +User-service integration: + +- `AUTHSESSION_USER_SERVICE_MODE=stub|rest` +- `AUTHSESSION_USER_SERVICE_BASE_URL` +- `AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT` + +Mail-service integration: + +- `AUTHSESSION_MAIL_SERVICE_MODE=stub|rest` +- `AUTHSESSION_MAIL_SERVICE_BASE_URL` +- `AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT` + +Telemetry: + +- `OTEL_SERVICE_NAME` +- `OTEL_TRACES_EXPORTER` +- `OTEL_METRICS_EXPORTER` +- `OTEL_EXPORTER_OTLP_PROTOCOL` +- `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` +- `OTEL_EXPORTER_OTLP_METRICS_PROTOCOL` +- `AUTHSESSION_OTEL_STDOUT_TRACES_ENABLED` +- `AUTHSESSION_OTEL_STDOUT_METRICS_ENABLED` + +## Runtime Notes + +- user-service and mail-service default to `stub`, which keeps local startup + backward-compatible and does not require external URLs +- read-style user-service REST methods retry once on transport errors and HTTP + `502`, `503`, or `504` +- user-service mutation methods do not auto-retry +- mail-service REST requests do not auto-retry, to avoid duplicate delivery +- authsession exports telemetry through OTel providers only; it does not serve + Prometheus text exposition directly diff --git a/authsession/gateway_compatibility_test.go b/authsession/gateway_compatibility_test.go new file mode 100644 index 0000000..e40de8b --- /dev/null +++ b/authsession/gateway_compatibility_test.go @@ -0,0 +1,720 @@ +package authsession + +import ( + "bytes" + "context" + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "galaxy/authsession/internal/adapters/mail" + "galaxy/authsession/internal/adapters/redis/challengestore" + "galaxy/authsession/internal/adapters/redis/configprovider" + "galaxy/authsession/internal/adapters/redis/projectionpublisher" + "galaxy/authsession/internal/adapters/redis/sessionstore" + "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/api/internalhttp" + "galaxy/authsession/internal/api/publichttp" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/testkit" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +const ( + gatewayCompatibilityChallengeKeyPrefix = "authsession:challenge:" + gatewayCompatibilitySessionKeyPrefix = "authsession:session:" + gatewayCompatibilityUserSessionsKeyPrefix = "authsession:user-sessions:" + gatewayCompatibilityUserActiveKeyPrefix = "authsession:user-active-sessions:" + gatewayCompatibilitySessionLimitKey = "authsession:config:active-session-limit" + gatewayCompatibilitySessionCacheKeyPrefix = "gateway:session:" + gatewayCompatibilitySessionEventsStream = "gateway:session_events" + gatewayCompatibilityStreamMaxLen int64 = 128 + + gatewayCompatibilityEmail = "pilot@example.com" + gatewayCompatibilityCode = "123456" +) + +var gatewayCompatibilityClientPublicKey = mustGatewayCompatibilityClientPublicKeyBase64() + +func TestGatewayCompatibilityConfirmReturnsGatewayReadableSessionProjection(t *testing.T) { + t.Parallel() + + app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{}) + + sendResponse := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, sendResponse.StatusCode) + + var sendBody struct { + ChallengeID string `json:"challenge_id"` + } + require.NoError(t, json.Unmarshal([]byte(sendResponse.Body), &sendBody)) + assert.Equal(t, "challenge-1", sendBody.ChallengeID) + + attempts := app.mailSender.RecordedAttempts() + require.Len(t, attempts, 1) + + confirmResponse := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": sendBody.ChallengeID, + "code": attempts[0].Input.Code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusOK, confirmResponse.StatusCode) + + var confirmBody struct { + DeviceSessionID string `json:"device_session_id"` + } + require.NoError(t, json.Unmarshal([]byte(confirmResponse.Body), &confirmBody)) + assert.Equal(t, "device-session-1", confirmBody.DeviceSessionID) + + record := app.mustReadGatewayCacheRecord(t, confirmBody.DeviceSessionID) + assert.Equal(t, gatewayCacheRecord{ + DeviceSessionID: "device-session-1", + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "active", + }, record) + + events := app.mustReadGatewaySessionEvents(t, confirmBody.DeviceSessionID) + require.NotEmpty(t, events) + assert.Equal(t, gatewaySessionEventRecord{ + DeviceSessionID: "device-session-1", + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "active", + }, events[len(events)-1]) +} + +func TestGatewayCompatibilityRevokePublishesRevokedGatewayProjection(t *testing.T) { + t.Parallel() + + app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{}) + + sessionID := app.createSessionThroughPublicFlow(t) + + revokeResponse := gatewayCompatibilityPostJSON( + t, + app.internalBaseURL+"/api/v1/internal/sessions/"+sessionID+"/revoke", + `{"reason_code":"admin_revoke","actor":{"type":"system"}}`, + ) + assert.Equal(t, http.StatusOK, revokeResponse.StatusCode) + assert.JSONEq(t, `{"outcome":"revoked","device_session_id":"`+sessionID+`","affected_session_count":1}`, revokeResponse.Body) + + record := app.mustReadGatewayCacheRecord(t, sessionID) + require.NotNil(t, record.RevokedAtMS) + assert.Equal(t, gatewayCacheRecord{ + DeviceSessionID: sessionID, + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "revoked", + RevokedAtMS: int64Pointer(app.now.UnixMilli()), + }, record) + + events := app.mustReadGatewaySessionEvents(t, sessionID) + require.NotEmpty(t, events) + last := events[len(events)-1] + require.NotNil(t, last.RevokedAtMS) + assert.Equal(t, gatewaySessionEventRecord{ + DeviceSessionID: sessionID, + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "revoked", + RevokedAtMS: int64Pointer(app.now.UnixMilli()), + }, last) +} + +func TestGatewayCompatibilityRepeatedConfirmReturnsSameSessionID(t *testing.T) { + t.Parallel() + + app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{}) + + sendResponse := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, sendResponse.StatusCode) + + var sendBody struct { + ChallengeID string `json:"challenge_id"` + } + require.NoError(t, json.Unmarshal([]byte(sendResponse.Body), &sendBody)) + + attempts := app.mailSender.RecordedAttempts() + require.Len(t, attempts, 1) + + requestBody := map[string]string{ + "challenge_id": sendBody.ChallengeID, + "code": attempts[0].Input.Code, + "client_public_key": gatewayCompatibilityClientPublicKey, + } + + first := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", requestBody) + second := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", requestBody) + assert.Equal(t, http.StatusOK, first.StatusCode) + assert.Equal(t, http.StatusOK, second.StatusCode) + + var firstBody struct { + DeviceSessionID string `json:"device_session_id"` + } + var secondBody struct { + DeviceSessionID string `json:"device_session_id"` + } + require.NoError(t, json.Unmarshal([]byte(first.Body), &firstBody)) + require.NoError(t, json.Unmarshal([]byte(second.Body), &secondBody)) + assert.Equal(t, firstBody.DeviceSessionID, secondBody.DeviceSessionID) + + record := app.mustReadGatewayCacheRecord(t, firstBody.DeviceSessionID) + assert.Equal(t, gatewayCacheRecord{ + DeviceSessionID: firstBody.DeviceSessionID, + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "active", + }, record) +} + +func TestGatewayCompatibilityBlockedEmailSendRemainsSuccessShaped(t *testing.T) { + t.Parallel() + + app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{ + SeedBlockedEmail: true, + }) + + response := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, response.StatusCode) + + var body map[string]string + require.NoError(t, json.Unmarshal([]byte(response.Body), &body)) + assert.Equal(t, map[string]string{"challenge_id": "challenge-1"}, body) +} + +func TestGatewayCompatibilitySessionLimitExceededReturnsStableClientError(t *testing.T) { + t.Parallel() + + limit := 1 + app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{ + SeedExistingUser: true, + SessionLimit: &limit, + SeedActiveSessions: []devicesession.Session{ + gatewayCompatibilityActiveSession( + t, + "device-session-existing", + "user-1", + gatewayCompatibilityClientPublicKey, + time.Date(2026, 4, 5, 11, 58, 0, 0, time.UTC), + ), + }, + }) + + sendResponse := gatewayCompatibilityPostJSON(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, sendResponse.StatusCode) + + attempts := app.mailSender.RecordedAttempts() + require.Len(t, attempts, 1) + + confirmResponse := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "challenge-1", + "code": attempts[0].Input.Code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusConflict, confirmResponse.StatusCode) + assert.JSONEq(t, `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`, confirmResponse.Body) +} + +func TestGatewayCompatibilityMalformedClientPublicKeyReturnsStableError(t *testing.T) { + t.Parallel() + + app := newGatewayCompatibilityHarness(t, gatewayCompatibilityOptions{}) + + response := gatewayCompatibilityPostJSON( + t, + app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", + `{"challenge_id":"challenge-123","code":"123456","client_public_key":"invalid"}`, + ) + assert.Equal(t, http.StatusBadRequest, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`, response.Body) +} + +type gatewayCompatibilityOptions struct { + SeedBlockedEmail bool + SeedExistingUser bool + SessionLimit *int + SeedActiveSessions []devicesession.Session +} + +// gatewayCompatibilityHarness owns one gateway-focused integration test setup +// with real HTTP servers and real Redis-backed authsession adapters. +type gatewayCompatibilityHarness struct { + publicBaseURL string + internalBaseURL string + mailSender *mail.StubSender + redisClient *redis.Client + now time.Time +} + +func newGatewayCompatibilityHarness(t *testing.T, options gatewayCompatibilityOptions) gatewayCompatibilityHarness { + t.Helper() + + now := time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC) + redisServer := miniredis.RunT(t) + redisClient := redis.NewClient(&redis.Options{ + Addr: redisServer.Addr(), + Protocol: 2, + DisableIdentity: true, + }) + t.Cleanup(func() { + assert.NoError(t, redisClient.Close()) + }) + + if options.SessionLimit != nil { + redisServer.Set(gatewayCompatibilitySessionLimitKey, strconv.Itoa(*options.SessionLimit)) + } + + challengeStore, err := challengestore.New(challengestore.Config{ + Addr: redisServer.Addr(), + DB: 0, + KeyPrefix: gatewayCompatibilityChallengeKeyPrefix, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, challengeStore.Close()) + }) + + sessionStore, err := sessionstore.New(sessionstore.Config{ + Addr: redisServer.Addr(), + DB: 0, + SessionKeyPrefix: gatewayCompatibilitySessionKeyPrefix, + UserSessionsKeyPrefix: gatewayCompatibilityUserSessionsKeyPrefix, + UserActiveSessionsKeyPrefix: gatewayCompatibilityUserActiveKeyPrefix, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, sessionStore.Close()) + }) + + configStore, err := configprovider.New(configprovider.Config{ + Addr: redisServer.Addr(), + DB: 0, + SessionLimitKey: gatewayCompatibilitySessionLimitKey, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, configStore.Close()) + }) + + publisher, err := projectionpublisher.New(projectionpublisher.Config{ + Addr: redisServer.Addr(), + DB: 0, + SessionCacheKeyPrefix: gatewayCompatibilitySessionCacheKeyPrefix, + SessionEventsStream: gatewayCompatibilitySessionEventsStream, + StreamMaxLen: gatewayCompatibilityStreamMaxLen, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, publisher.Close()) + }) + + userDirectory := &userservice.StubDirectory{} + if options.SeedBlockedEmail { + require.NoError(t, userDirectory.SeedBlockedEmail(common.Email(gatewayCompatibilityEmail), "policy_blocked")) + } + if options.SeedExistingUser { + require.NoError(t, userDirectory.SeedExisting(common.Email(gatewayCompatibilityEmail), common.UserID("user-1"))) + } + for _, session := range options.SeedActiveSessions { + require.NoError(t, sessionStore.Create(context.Background(), session)) + } + + mailSender := &mail.StubSender{} + idGenerator := &testkit.SequenceIDGenerator{} + codeGenerator := testkit.FixedCodeGenerator{Code: gatewayCompatibilityCode} + codeHasher := testkit.DeterministicCodeHasher{} + clock := testkit.FixedClock{Time: now} + + sendEmailCodeService, err := sendemailcode.NewWithObservability( + challengeStore, + userDirectory, + idGenerator, + codeGenerator, + codeHasher, + mailSender, + nil, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + confirmEmailCodeService, err := confirmemailcode.NewWithObservability( + challengeStore, + sessionStore, + userDirectory, + configStore, + publisher, + idGenerator, + codeHasher, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + getSessionService, err := getsession.New(sessionStore) + require.NoError(t, err) + listUserSessionsService, err := listusersessions.New(sessionStore) + require.NoError(t, err) + revokeDeviceSessionService, err := revokedevicesession.NewWithObservability(sessionStore, publisher, clock, zap.NewNop(), nil) + require.NoError(t, err) + revokeAllUserSessionsService, err := revokeallusersessions.NewWithObservability(sessionStore, userDirectory, publisher, clock, zap.NewNop(), nil) + require.NoError(t, err) + blockUserService, err := blockuser.NewWithObservability(userDirectory, sessionStore, publisher, clock, zap.NewNop(), nil) + require.NoError(t, err) + + publicCfg := publichttp.DefaultConfig() + publicCfg.Addr = gatewayCompatibilityFreeAddr(t) + publicServer, err := publichttp.NewServer(publicCfg, publichttp.Dependencies{ + SendEmailCode: sendEmailCodeService, + ConfirmEmailCode: confirmEmailCodeService, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + internalCfg := internalhttp.DefaultConfig() + internalCfg.Addr = gatewayCompatibilityFreeAddr(t) + internalServer, err := internalhttp.NewServer(internalCfg, internalhttp.Dependencies{ + GetSession: getSessionService, + ListUserSessions: listUserSessionsService, + RevokeDeviceSession: revokeDeviceSessionService, + RevokeAllUserSessions: revokeAllUserSessionsService, + BlockUser: blockUserService, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + gatewayCompatibilityRunServer(t, publicServer.Run, publicServer.Shutdown, publicCfg.Addr) + gatewayCompatibilityRunServer(t, internalServer.Run, internalServer.Shutdown, internalCfg.Addr) + + return gatewayCompatibilityHarness{ + publicBaseURL: "http://" + publicCfg.Addr, + internalBaseURL: "http://" + internalCfg.Addr, + mailSender: mailSender, + redisClient: redisClient, + now: now, + } +} + +func (h gatewayCompatibilityHarness) createSessionThroughPublicFlow(t *testing.T) string { + t.Helper() + + sendResponse := gatewayCompatibilityPostJSON(t, h.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, sendResponse.StatusCode) + + var sendBody struct { + ChallengeID string `json:"challenge_id"` + } + require.NoError(t, json.Unmarshal([]byte(sendResponse.Body), &sendBody)) + + attempts := h.mailSender.RecordedAttempts() + require.Len(t, attempts, 1) + + confirmResponse := gatewayCompatibilityPostJSONValue(t, h.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": sendBody.ChallengeID, + "code": attempts[0].Input.Code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusOK, confirmResponse.StatusCode) + + var confirmBody struct { + DeviceSessionID string `json:"device_session_id"` + } + require.NoError(t, json.Unmarshal([]byte(confirmResponse.Body), &confirmBody)) + + return confirmBody.DeviceSessionID +} + +// gatewayCacheRecord mirrors the strict gateway Redis session-cache wire +// contract used on the authenticated hot path. +type gatewayCacheRecord struct { + DeviceSessionID string `json:"device_session_id"` + UserID string `json:"user_id"` + ClientPublicKey string `json:"client_public_key"` + Status string `json:"status"` + RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"` +} + +func (h gatewayCompatibilityHarness) mustReadGatewayCacheRecord(t *testing.T, deviceSessionID string) gatewayCacheRecord { + t.Helper() + + payload, err := h.redisClient.Get(context.Background(), gatewayCompatibilitySessionCacheKeyPrefix+deviceSessionID).Bytes() + require.NoError(t, err) + + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.DisallowUnknownFields() + + var record gatewayCacheRecord + require.NoError(t, decoder.Decode(&record)) + + err = decoder.Decode(&struct{}{}) + require.ErrorIs(t, err, io.EOF) + + require.NotEmpty(t, record.DeviceSessionID) + require.Equal(t, deviceSessionID, record.DeviceSessionID) + require.NotEmpty(t, record.UserID) + require.NotEmpty(t, record.ClientPublicKey) + require.Contains(t, []string{"active", "revoked"}, record.Status) + + return record +} + +// gatewaySessionEventRecord mirrors the strict gateway Redis Stream event +// contract for full session snapshots. +type gatewaySessionEventRecord struct { + DeviceSessionID string + UserID string + ClientPublicKey string + Status string + RevokedAtMS *int64 +} + +func (h gatewayCompatibilityHarness) mustReadGatewaySessionEvents(t *testing.T, deviceSessionID string) []gatewaySessionEventRecord { + t.Helper() + + entries, err := h.redisClient.XRange(context.Background(), gatewayCompatibilitySessionEventsStream, "-", "+").Result() + require.NoError(t, err) + + records := make([]gatewaySessionEventRecord, 0, len(entries)) + for _, entry := range entries { + record := decodeGatewaySessionEvent(t, entry.Values) + if record.DeviceSessionID == deviceSessionID { + records = append(records, record) + } + } + require.NotEmpty(t, records) + + return records +} + +func decodeGatewaySessionEvent(t *testing.T, values map[string]any) gatewaySessionEventRecord { + t.Helper() + + requiredKeys := map[string]struct{}{ + "device_session_id": {}, + "user_id": {}, + "client_public_key": {}, + "status": {}, + } + optionalKeys := map[string]struct{}{ + "revoked_at_ms": {}, + } + + for key := range values { + if _, ok := requiredKeys[key]; ok { + continue + } + if _, ok := optionalKeys[key]; ok { + continue + } + + require.Failf(t, "test failed", "decode gateway session event: unsupported field %q", key) + } + + record := gatewaySessionEventRecord{ + DeviceSessionID: gatewayCompatibilityRequiredStringField(t, values, "device_session_id"), + UserID: gatewayCompatibilityRequiredStringField(t, values, "user_id"), + ClientPublicKey: gatewayCompatibilityRequiredStringField(t, values, "client_public_key"), + Status: gatewayCompatibilityRequiredStringField(t, values, "status"), + } + require.Contains(t, []string{"active", "revoked"}, record.Status) + + if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok { + parsed := gatewayCompatibilityParseInt64Field(t, rawRevokedAtMS, "revoked_at_ms") + record.RevokedAtMS = &parsed + } + + return record +} + +func gatewayCompatibilityRequiredStringField(t *testing.T, values map[string]any, field string) string { + t.Helper() + + value, ok := values[field] + require.Truef(t, ok, "decode gateway session event: missing %s", field) + + stringValue := gatewayCompatibilityCoerceString(t, value, field) + require.NotEmptyf(t, strings.TrimSpace(stringValue), "decode gateway session event: %s must not be empty", field) + + return stringValue +} + +func gatewayCompatibilityParseInt64Field(t *testing.T, value any, field string) int64 { + t.Helper() + + stringValue := gatewayCompatibilityCoerceString(t, value, field) + parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64) + require.NoErrorf(t, err, "decode gateway session event: %s", field) + + return parsed +} + +func gatewayCompatibilityCoerceString(t *testing.T, value any, field string) string { + t.Helper() + + switch typed := value.(type) { + case string: + return typed + case []byte: + return string(typed) + case fmt.Stringer: + return typed.String() + case int: + return strconv.Itoa(typed) + case int64: + return strconv.FormatInt(typed, 10) + case uint64: + return strconv.FormatUint(typed, 10) + default: + require.Failf(t, "test failed", "decode gateway session event: %s: unsupported value type %T", field, value) + return "" + } +} + +func gatewayCompatibilityRunServer( + t *testing.T, + run func(context.Context) error, + shutdown func(context.Context) error, + addr string, +) { + t.Helper() + + errCh := make(chan error, 1) + go func() { + errCh <- run(context.Background()) + }() + + gatewayCompatibilityWaitForTCP(t, addr) + t.Cleanup(func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + assert.NoError(t, shutdown(shutdownCtx)) + assert.NoError(t, <-errCh) + }) +} + +func gatewayCompatibilityWaitForTCP(t *testing.T, addr string) { + t.Helper() + + require.Eventually(t, func() bool { + conn, err := net.DialTimeout("tcp", addr, 50*time.Millisecond) + if err != nil { + return false + } + _ = conn.Close() + return true + }, 5*time.Second, 25*time.Millisecond) +} + +func gatewayCompatibilityFreeAddr(t *testing.T) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + assert.NoError(t, listener.Close()) + }() + + return listener.Addr().String() +} + +type gatewayCompatibilityHTTPResponse struct { + StatusCode int + Body string +} + +func gatewayCompatibilityPostJSON(t *testing.T, url string, body string) gatewayCompatibilityHTTPResponse { + t.Helper() + + request, err := http.NewRequest(http.MethodPost, url, bytes.NewBufferString(body)) + require.NoError(t, err) + request.Header.Set("Content-Type", "application/json") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + require.NoError(t, err) + + return gatewayCompatibilityHTTPResponse{ + StatusCode: response.StatusCode, + Body: string(payload), + } +} + +func gatewayCompatibilityPostJSONValue(t *testing.T, url string, value any) gatewayCompatibilityHTTPResponse { + t.Helper() + + payload, err := json.Marshal(value) + require.NoError(t, err) + + return gatewayCompatibilityPostJSON(t, url, string(payload)) +} + +func gatewayCompatibilityActiveSession( + t *testing.T, + deviceSessionID string, + userID string, + clientPublicKeyBase64 string, + createdAt time.Time, +) devicesession.Session { + t.Helper() + + keyBytes, err := base64.StdEncoding.DecodeString(clientPublicKeyBase64) + require.NoError(t, err) + + clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey(keyBytes)) + require.NoError(t, err) + + session := devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: clientPublicKey, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } + require.NoError(t, session.Validate()) + + return session +} + +func mustGatewayCompatibilityClientPublicKeyBase64() string { + key := make([]byte, ed25519.PublicKeySize) + for index := range key { + key[index] = byte(index + 1) + } + + return base64.StdEncoding.EncodeToString(key) +} + +func int64Pointer(value int64) *int64 { + return &value +} diff --git a/authsession/go.mod b/authsession/go.mod new file mode 100644 index 0000000..7a50199 --- /dev/null +++ b/authsession/go.mod @@ -0,0 +1,84 @@ +module galaxy/authsession + +go 1.26.0 + +require ( + github.com/alicebob/miniredis/v2 v2.37.0 + github.com/getkin/kin-openapi v0.134.0 + github.com/gin-gonic/gin v1.12.0 + github.com/redis/go-redis/v9 v9.18.0 + github.com/stretchr/testify v1.11.1 + go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0 + go.opentelemetry.io/otel v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 + go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0 + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0 + go.opentelemetry.io/otel/metric v1.42.0 + go.opentelemetry.io/otel/sdk v1.42.0 + go.opentelemetry.io/otel/sdk/metric v1.42.0 + go.opentelemetry.io/otel/trace v1.42.0 + go.uber.org/zap v1.27.1 + golang.org/x/crypto v0.49.0 +) + +require ( + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/gabriel-vasile/mimetype v1.4.13 // indirect + github.com/gin-contrib/sse v1.1.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.30.1 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/goccy/go-yaml v1.19.2 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c // indirect + github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/quic-go/qpack v0.6.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.3.1 // indirect + github.com/woodsbury/decimal128 v1.3.0 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect + go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 // indirect + go.opentelemetry.io/proto/otlp v1.10.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/arch v0.24.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect + google.golang.org/grpc v1.80.0 // indirect + google.golang.org/protobuf v1.36.11 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/authsession/go.sum b/authsession/go.sum new file mode 100644 index 0000000..fbeb233 --- /dev/null +++ b/authsession/go.sum @@ -0,0 +1,196 @@ +github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68= +github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= +github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/getkin/kin-openapi v0.134.0 h1:/L5+1+kfe6dXh8Ot/wqiTgUkjOIEJiC0bbYVziHB8rU= +github.com/getkin/kin-openapi v0.134.0/go.mod h1:wK6ZLG/VgoETO9pcLJ/VmAtIcl/DNlMayNTb716EUxE= +github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= +github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= +github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= +github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= +github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= +github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c h1:7ACFcSaQsrWtrH4WHHfUqE1C+f8r2uv8KGaW0jTNjus= +github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c/go.mod h1:JKox4Gszkxt57kj27u7rvi7IFoIULvCZHUsBTUmQM/s= +github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b h1:vivRhVUAa9t1q0Db4ZmezBP8pWQWnXHFokZj0AOea2g= +github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= +github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= +github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +github.com/woodsbury/decimal128 v1.3.0 h1:8pffMNWIlC0O5vbyHWFZAt5yWvWcrHA+3ovIIjVWss0= +github.com/woodsbury/decimal128 v1.3.0/go.mod h1:C5UTmyTjW3JftjUFzOVhC20BEQa2a4ZKOB5I6Zjb+ds= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= +go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0 h1:E7DmskpIO7ZR6QI6zKSEKIDNUYoKw9oHXP23gzbCdU0= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.67.0/go.mod h1:WB2cS9y+AwqqKhoo9gw6/ZxlSjFBUQGZ8BQOaD3FVXM= +go.opentelemetry.io/contrib/propagators/b3 v1.42.0 h1:B2Pew5ufEtgkjLF+tSkXjgYZXQr9m7aCm1wLKB0URbU= +go.opentelemetry.io/contrib/propagators/b3 v1.42.0/go.mod h1:iPgUcSEF5DORW6+yNbdw/YevUy+QqJ508ncjhrRSCjc= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 h1:MdKucPl/HbzckWWEisiNqMPhRrAOQX8r4jTuGr636gk= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0/go.mod h1:RolT8tWtfHcjajEH5wFIZ4Dgh5jpPdFXYV9pTAk/qjc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 h1:H7O6RlGOMTizyl3R08Kn5pdM06bnH8oscSj7o11tmLA= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0/go.mod h1:mBFWu/WOVDkWWsR7Tx7h6EpQB8wsv7P0Yrh0Pb7othc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 h1:THuZiwpQZuHPul65w4WcwEnkX2QIuMT+UFoOrygtoJw= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0/go.mod h1:J2pvYM5NGHofZ2/Ru6zw/TNWnEQp5crgyDeSrYpXkAw= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 h1:zWWrB1U6nqhS/k6zYB74CjRpuiitRtLLi68VcgmOEto= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0/go.mod h1:2qXPNBX1OVRC0IwOnfo1ljoid+RD0QK3443EaqVlsOU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 h1:uLXP+3mghfMf7XmV4PkGfFhFKuNWoCvvx5wP/wOXo0o= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0/go.mod h1:v0Tj04armyT59mnURNUJf7RCKcKzq+lgJs6QSjHjaTc= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0 h1:lSZHgNHfbmQTPfuTmWVkEu8J8qXaQwuV30pjCcAUvP8= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0/go.mod h1:so9ounLcuoRDu033MW/E0AD4hhUjVqswrMF5FoZlBcw= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0 h1:s/1iRkCKDfhlh1JF26knRneorus8aOwVIDhvYx9WoDw= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0/go.mod h1:UI3wi0FXg1Pofb8ZBiBLhtMzgoTm1TYkMvn71fAqDzs= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= +go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= +golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/authsession/internal/adapters/antiabuse/send_email_code_protector.go b/authsession/internal/adapters/antiabuse/send_email_code_protector.go new file mode 100644 index 0000000..1e5a3db --- /dev/null +++ b/authsession/internal/adapters/antiabuse/send_email_code_protector.go @@ -0,0 +1,56 @@ +// Package antiabuse provides runtime in-process adapters for auth-specific +// public abuse controls. +package antiabuse + +import ( + "context" + "fmt" + "sync" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" +) + +// SendEmailCodeProtector is a concurrency-safe in-process resend-throttle +// adapter for public send-email-code attempts. +type SendEmailCodeProtector struct { + mu sync.Mutex + reservedUntil map[common.Email]time.Time +} + +// CheckAndReserve applies the fixed Stage-17 resend cooldown using input.Now +// as the authoritative decision timestamp. +func (p *SendEmailCodeProtector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) { + if ctx == nil { + return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: nil context") + } + if err := ctx.Err(); err != nil { + return ports.SendEmailCodeAbuseResult{}, err + } + if err := input.Validate(); err != nil { + return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err) + } + + p.mu.Lock() + defer p.mu.Unlock() + + if p.reservedUntil == nil { + p.reservedUntil = make(map[common.Email]time.Time) + } + + reservedUntil, exists := p.reservedUntil[input.Email] + if exists && input.Now.Before(reservedUntil) { + return ports.SendEmailCodeAbuseResult{ + Outcome: ports.SendEmailCodeAbuseOutcomeThrottled, + }, nil + } + + p.reservedUntil[input.Email] = input.Now.UTC().Add(challenge.ResendThrottleCooldown) + return ports.SendEmailCodeAbuseResult{ + Outcome: ports.SendEmailCodeAbuseOutcomeAllowed, + }, nil +} + +var _ ports.SendEmailCodeAbuseProtector = (*SendEmailCodeProtector)(nil) diff --git a/authsession/internal/adapters/antiabuse/send_email_code_protector_test.go b/authsession/internal/adapters/antiabuse/send_email_code_protector_test.go new file mode 100644 index 0000000..ad1a2af --- /dev/null +++ b/authsession/internal/adapters/antiabuse/send_email_code_protector_test.go @@ -0,0 +1,64 @@ +package antiabuse + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSendEmailCodeProtectorCheckAndReserve(t *testing.T) { + t.Parallel() + + protector := &SendEmailCodeProtector{} + email := common.Email("pilot@example.com") + now := time.Unix(10, 0).UTC() + + result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now, + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome) + + result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now.Add(30 * time.Second), + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome) + + result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now.Add(time.Minute), + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome) +} + +func TestSendEmailCodeProtectorNilOrCanceledContext(t *testing.T) { + t.Parallel() + + protector := &SendEmailCodeProtector{} + _, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{ + Email: common.Email("pilot@example.com"), + Now: time.Unix(10, 0).UTC(), + }) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = protector.CheckAndReserve(ctx, ports.SendEmailCodeAbuseInput{ + Email: common.Email("pilot@example.com"), + Now: time.Unix(10, 0).UTC(), + }) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} diff --git a/authsession/internal/adapters/contracttest/challenge_store.go b/authsession/internal/adapters/contracttest/challenge_store.go new file mode 100644 index 0000000..855715e --- /dev/null +++ b/authsession/internal/adapters/contracttest/challenge_store.go @@ -0,0 +1,206 @@ +// Package contracttest provides reusable adapter conformance suites that +// exercise storage-agnostic port contracts without depending on one concrete +// backend implementation. +package contracttest + +import ( + "context" + "crypto/ed25519" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ChallengeStoreFactory constructs a fresh ChallengeStore instance suitable +// for one isolated contract subtest. +type ChallengeStoreFactory func(t *testing.T) ports.ChallengeStore + +// RunChallengeStoreContractTests executes the backend-agnostic ChallengeStore +// contract suite against newStore. +func RunChallengeStoreContractTests(t *testing.T, newStore ChallengeStoreFactory) { + t.Helper() + + t.Run("create and get", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractConfirmedChallenge(t, time.Unix(1_775_130_000, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record, got) + }) + + t.Run("get not found", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + + _, err := store.Get(context.Background(), common.ChallengeID("missing-challenge")) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) + }) + + t.Run("create conflict", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractPendingChallenge(time.Unix(1_775_130_100, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + err := store.Create(context.Background(), record) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrConflict) + }) + + t.Run("compare and swap success", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + now := time.Unix(1_775_130_200, 0).UTC() + previous := contractPendingChallenge(now) + next := previous + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + next.Attempts.Send = 1 + next.Abuse.LastAttemptAt = contractTimePointer(now.Add(time.Minute)) + require.NoError(t, next.Validate()) + + require.NoError(t, store.Create(context.Background(), previous)) + require.NoError(t, store.CompareAndSwap(context.Background(), previous, next)) + + got, err := store.Get(context.Background(), previous.ID) + require.NoError(t, err) + assert.Equal(t, next, got) + }) + + t.Run("compare and swap conflict", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + now := time.Unix(1_775_130_300, 0).UTC() + stored := contractPendingChallenge(now) + previous := stored + previous.Attempts.Send = 99 + require.NoError(t, previous.Validate()) + next := stored + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + require.NoError(t, next.Validate()) + + require.NoError(t, store.Create(context.Background(), stored)) + + err := store.CompareAndSwap(context.Background(), previous, next) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrConflict) + }) + + t.Run("compare and swap not found", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + now := time.Unix(1_775_130_400, 0).UTC() + previous := contractPendingChallenge(now) + next := previous + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + require.NoError(t, next.Validate()) + + err := store.CompareAndSwap(context.Background(), previous, next) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) + }) + + t.Run("get returns defensive copies", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractConfirmedChallenge(t, time.Unix(1_775_130_500, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + require.NotEmpty(t, got.CodeHash) + got.CodeHash[0] = 0xFF + if got.Confirmation != nil { + keyBytes := got.Confirmation.ClientPublicKey.PublicKey() + if len(keyBytes) > 0 { + keyBytes[0] = 0xFE + } + } + + again, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record.CodeHash, again.CodeHash) + require.NotNil(t, again.Confirmation) + assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String()) + }) +} + +func contractPendingChallenge(now time.Time) challenge.Challenge { + record := challenge.Challenge{ + ID: common.ChallengeID("challenge-pending"), + Email: common.Email("pilot@example.com"), + CodeHash: []byte("hashed-pending-code"), + Status: challenge.StatusPendingSend, + DeliveryState: challenge.DeliveryPending, + CreatedAt: now, + ExpiresAt: now.Add(challenge.InitialTTL), + } + if err := record.Validate(); err != nil { + panic(err) + } + + return record +} + +func contractConfirmedChallenge(t *testing.T, now time.Time) challenge.Challenge { + t.Helper() + + clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{ + 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + }) + require.NoError(t, err) + + record := challenge.Challenge{ + ID: common.ChallengeID("challenge-confirmed"), + Email: common.Email("pilot@example.com"), + CodeHash: []byte("hashed-code"), + Status: challenge.StatusConfirmedPendingExpire, + DeliveryState: challenge.DeliverySent, + CreatedAt: now, + ExpiresAt: now.Add(challenge.ConfirmedRetention), + Attempts: challenge.AttemptCounters{ + Send: 1, + Confirm: 2, + }, + Abuse: challenge.AbuseMetadata{ + LastAttemptAt: contractTimePointer(now.Add(30 * time.Second)), + }, + Confirmation: &challenge.Confirmation{ + SessionID: common.DeviceSessionID("device-session-1"), + ClientPublicKey: clientPublicKey, + ConfirmedAt: now.Add(time.Minute), + }, + } + require.NoError(t, record.Validate()) + + return record +} + +func contractTimePointer(value time.Time) *time.Time { + return &value +} diff --git a/authsession/internal/adapters/contracttest/config_provider.go b/authsession/internal/adapters/contracttest/config_provider.go new file mode 100644 index 0000000..7541290 --- /dev/null +++ b/authsession/internal/adapters/contracttest/config_provider.go @@ -0,0 +1,65 @@ +package contracttest + +import ( + "context" + "testing" + + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ConfigProviderHarnessFactory constructs a fresh semantic ConfigProvider +// harness suitable for one isolated contract subtest. +type ConfigProviderHarnessFactory func(t *testing.T) ConfigProviderHarness + +// ConfigProviderHarness bundles one semantic ConfigProvider instance with the +// seed hooks needed by the backend-agnostic contract suite. +type ConfigProviderHarness struct { + // Provider is the semantic ConfigProvider under test. + Provider ports.ConfigProvider + + // SeedDisabled prepares storage so LoadSessionLimit observes “limit absent”. + SeedDisabled func(t *testing.T) + + // SeedLimit prepares storage so LoadSessionLimit observes a valid positive + // configured limit. + SeedLimit func(t *testing.T, limit int) +} + +// RunConfigProviderContractTests executes the backend-agnostic ConfigProvider +// semantic contract suite against newHarness. +func RunConfigProviderContractTests(t *testing.T, newHarness ConfigProviderHarnessFactory) { + t.Helper() + + t.Run("limit absent means disabled", func(t *testing.T) { + t.Parallel() + + harness := newHarness(t) + require.NotNil(t, harness.Provider) + require.NotNil(t, harness.SeedDisabled) + + harness.SeedDisabled(t) + + got, err := harness.Provider.LoadSessionLimit(context.Background()) + require.NoError(t, err) + assert.Equal(t, ports.SessionLimitConfig{}, got) + }) + + t.Run("valid positive limit means configured", func(t *testing.T) { + t.Parallel() + + harness := newHarness(t) + require.NotNil(t, harness.Provider) + require.NotNil(t, harness.SeedLimit) + + want := 5 + harness.SeedLimit(t, want) + + got, err := harness.Provider.LoadSessionLimit(context.Background()) + require.NoError(t, err) + require.NotNil(t, got.ActiveSessionLimit) + assert.Equal(t, want, *got.ActiveSessionLimit) + }) +} diff --git a/authsession/internal/adapters/contracttest/session_store.go b/authsession/internal/adapters/contracttest/session_store.go new file mode 100644 index 0000000..b43e1df --- /dev/null +++ b/authsession/internal/adapters/contracttest/session_store.go @@ -0,0 +1,283 @@ +package contracttest + +import ( + "context" + "crypto/ed25519" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// SessionStoreFactory constructs a fresh SessionStore instance suitable for +// one isolated contract subtest. +type SessionStoreFactory func(t *testing.T) ports.SessionStore + +// RunSessionStoreContractTests executes the backend-agnostic SessionStore +// contract suite against newStore. +func RunSessionStoreContractTests(t *testing.T, newStore SessionStoreFactory) { + t.Helper() + + t.Run("create and get", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record, got) + }) + + t.Run("create conflict", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_050, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + err := store.Create(context.Background(), record) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrConflict) + }) + + t.Run("get not found", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + + _, err := store.Get(context.Background(), common.DeviceSessionID("missing-session")) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) + }) + + t.Run("list by user id returns newest first", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + older := contractActiveSession(t, "device-session-old", "user-1", time.Unix(10, 0).UTC()) + newer := contractActiveSession(t, "device-session-new", "user-1", time.Unix(20, 0).UTC()) + revoked := contractRevokedSession(t, "device-session-revoked", "user-1", time.Unix(15, 0).UTC()) + otherUser := contractActiveSession(t, "device-session-other", "user-2", time.Unix(30, 0).UTC()) + + for _, record := range []devicesession.Session{older, newer, revoked, otherUser} { + require.NoError(t, store.Create(context.Background(), record)) + } + + got, err := store.ListByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + require.Len(t, got, 3) + assert.Equal( + t, + []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, + []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID}, + ) + }) + + t.Run("list by user id returns empty slice for unknown user", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + + got, err := store.ListByUserID(context.Background(), common.UserID("unknown-user")) + require.NoError(t, err) + require.NotNil(t, got) + assert.Empty(t, got) + }) + + t.Run("count active by user id", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + activeOne := contractActiveSession(t, "device-session-1", "user-1", time.Unix(40, 0).UTC()) + activeTwo := contractActiveSession(t, "device-session-2", "user-1", time.Unix(50, 0).UTC()) + revoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(60, 0).UTC()) + otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(70, 0).UTC()) + + for _, record := range []devicesession.Session{activeOne, activeTwo, revoked, otherUser} { + require.NoError(t, store.Create(context.Background(), record)) + } + + count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + assert.Equal(t, 2, count) + }) + + t.Run("revoke active session", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + revocation := contractRevocation(time.Unix(200, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", "") + result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ + DeviceSessionID: record.ID, + Revocation: revocation, + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome) + require.NotNil(t, result.Session.Revocation) + assert.Equal(t, revocation, *result.Session.Revocation) + + count, err := store.CountActiveByUserID(context.Background(), record.UserID) + require.NoError(t, err) + assert.Zero(t, count) + }) + + t.Run("revoke already revoked preserves stored revocation", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractRevokedSession(t, "device-session-2", "user-1", time.Unix(110, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ + DeviceSessionID: record.ID, + Revocation: contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1"), + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome) + require.NotNil(t, result.Session.Revocation) + assert.Equal(t, *record.Revocation, *result.Session.Revocation) + }) + + t.Run("revoke not found", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + + _, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ + DeviceSessionID: common.DeviceSessionID("missing-session"), + Revocation: contractRevocation(time.Unix(210, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", ""), + }) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) + }) + + t.Run("revoke all by user id revokes active sessions newest first", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + older := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC()) + newer := contractActiveSession(t, "device-session-2", "user-1", time.Unix(200, 0).UTC()) + alreadyRevoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(150, 0).UTC()) + otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(250, 0).UTC()) + + for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} { + require.NoError(t, store.Create(context.Background(), record)) + } + + revocation := contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1") + result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ + UserID: common.UserID("user-1"), + Revocation: revocation, + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome) + require.Len(t, result.Sessions, 2) + assert.Equal( + t, + []common.DeviceSessionID{newer.ID, older.ID}, + []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID}, + ) + assert.Equal(t, revocation, *result.Sessions[0].Revocation) + assert.Equal(t, revocation, *result.Sessions[1].Revocation) + + count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + assert.Zero(t, count) + }) + + t.Run("revoke all by user id reports no active sessions", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractRevokedSession(t, "device-session-5", "user-1", time.Unix(120, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ + UserID: common.UserID("user-1"), + Revocation: contractRevocation(time.Unix(400, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", ""), + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome) + require.NotNil(t, result.Sessions) + assert.Empty(t, result.Sessions) + }) + + t.Run("get returns defensive copies", func(t *testing.T) { + t.Parallel() + + store := newStore(t) + record := contractRevokedSession(t, "device-session-copy", "user-1", time.Unix(130, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + require.NotNil(t, got.Revocation) + got.Revocation.ActorID = "mutated" + + again, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + require.NotNil(t, again.Revocation) + assert.Equal(t, record, again) + }) +} + +func contractActiveSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + t.Helper() + + clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{ + 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + }) + require.NoError(t, err) + + record := devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: clientPublicKey, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } + require.NoError(t, record.Validate()) + + return record +} + +func contractRevokedSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + t.Helper() + + record := contractActiveSession(t, deviceSessionID, userID, createdAt) + revocation := contractRevocation(createdAt.Add(time.Minute), devicesession.RevokeReasonDeviceLogout, "user", "user-actor") + record.Status = devicesession.StatusRevoked + record.Revocation = &revocation + require.NoError(t, record.Validate()) + + return record +} + +func contractRevocation(at time.Time, reasonCode common.RevokeReasonCode, actorType string, actorID string) devicesession.Revocation { + record := devicesession.Revocation{ + At: at, + ReasonCode: reasonCode, + ActorType: common.RevokeActorType(actorType), + ActorID: actorID, + } + if err := record.Validate(); err != nil { + panic(err) + } + + return record +} diff --git a/authsession/internal/adapters/local/runtime.go b/authsession/internal/adapters/local/runtime.go new file mode 100644 index 0000000..824de6f --- /dev/null +++ b/authsession/internal/adapters/local/runtime.go @@ -0,0 +1,139 @@ +// Package local provides small in-process runtime implementations for +// authsession ports that do not require network dependencies. +package local + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "math/big" + "strings" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "golang.org/x/crypto/bcrypt" +) + +const ( + challengeIDPrefix = "challenge-" + deviceSessionIDPrefix = "device-session-" + codeDigits = 6 +) + +// Clock implements ports.Clock using the local system clock in UTC. +type Clock struct{} + +// Now returns the current system time normalized to UTC. +func (Clock) Now() time.Time { + return time.Now().UTC() +} + +// IDGenerator implements ports.IDGenerator with cryptographically random +// opaque identifiers. +type IDGenerator struct{} + +// NewChallengeID returns a fresh random challenge identifier. +func (IDGenerator) NewChallengeID() (common.ChallengeID, error) { + value, err := newOpaqueIDString(challengeIDPrefix) + if err != nil { + return "", err + } + + return common.ChallengeID(value), nil +} + +// NewDeviceSessionID returns a fresh random device-session identifier. +func (IDGenerator) NewDeviceSessionID() (common.DeviceSessionID, error) { + value, err := newOpaqueIDString(deviceSessionIDPrefix) + if err != nil { + return "", err + } + + return common.DeviceSessionID(value), nil +} + +// CodeGenerator implements ports.CodeGenerator with random 6-digit decimal +// confirmation codes. +type CodeGenerator struct{} + +// Generate returns one fresh random 6-digit decimal code. +func (CodeGenerator) Generate() (string, error) { + var builder strings.Builder + builder.Grow(codeDigits) + + for idx := 0; idx < codeDigits; idx++ { + digit, err := rand.Int(rand.Reader, big.NewInt(10)) + if err != nil { + return "", fmt.Errorf("generate confirmation code: %w", err) + } + builder.WriteByte(byte('0' + digit.Int64())) + } + + return builder.String(), nil +} + +// CodeHasher implements ports.CodeHasher with bcrypt-backed hashes. +type CodeHasher struct{} + +// Hash returns the bcrypt hash of code. +func (CodeHasher) Hash(code string) ([]byte, error) { + if err := validateCode(code); err != nil { + return nil, err + } + + hash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("hash confirmation code: %w", err) + } + + return hash, nil +} + +// Compare reports whether hash matches code. +func (CodeHasher) Compare(hash []byte, code string) (bool, error) { + if err := validateCode(code); err != nil { + return false, err + } + if len(hash) == 0 { + return false, nil + } + + err := bcrypt.CompareHashAndPassword(hash, []byte(code)) + switch err { + case nil: + return true, nil + case bcrypt.ErrMismatchedHashAndPassword: + return false, nil + default: + return false, fmt.Errorf("compare confirmation code hash: %w", err) + } +} + +func newOpaqueIDString(prefix string) (string, error) { + randomBytes := make([]byte, 16) + if _, err := rand.Read(randomBytes); err != nil { + return "", fmt.Errorf("generate opaque identifier: %w", err) + } + + return prefix + base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +func validateCode(code string) error { + switch { + case strings.TrimSpace(code) == "": + return fmt.Errorf("code must not be empty") + case strings.TrimSpace(code) != code: + return fmt.Errorf("code must not contain surrounding whitespace") + default: + return nil + } +} + +var ( + _ ports.Clock = Clock{} + _ ports.IDGenerator = IDGenerator{} + _ ports.CodeGenerator = CodeGenerator{} + _ ports.CodeHasher = CodeHasher{} +) diff --git a/authsession/internal/adapters/local/runtime_test.go b/authsession/internal/adapters/local/runtime_test.go new file mode 100644 index 0000000..7ee57f8 --- /dev/null +++ b/authsession/internal/adapters/local/runtime_test.go @@ -0,0 +1,60 @@ +package local + +import ( + "regexp" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClockNowReturnsUTC(t *testing.T) { + t.Parallel() + + now := Clock{}.Now() + + assert.Equal(t, time.UTC, now.Location()) +} + +func TestIDGeneratorProducesValidOpaqueIDs(t *testing.T) { + t.Parallel() + + generator := IDGenerator{} + + challengeID, err := generator.NewChallengeID() + require.NoError(t, err) + require.NoError(t, challengeID.Validate()) + assert.Regexp(t, regexp.MustCompile(`^challenge-[A-Za-z0-9_-]+$`), challengeID.String()) + + deviceSessionID, err := generator.NewDeviceSessionID() + require.NoError(t, err) + require.NoError(t, deviceSessionID.Validate()) + assert.Regexp(t, regexp.MustCompile(`^device-session-[A-Za-z0-9_-]+$`), deviceSessionID.String()) +} + +func TestCodeGeneratorProducesSixDigitNumericCodes(t *testing.T) { + t.Parallel() + + code, err := CodeGenerator{}.Generate() + require.NoError(t, err) + assert.Regexp(t, regexp.MustCompile(`^\d{6}$`), code) +} + +func TestCodeHasherHashesAndComparesCodes(t *testing.T) { + t.Parallel() + + hasher := CodeHasher{} + + hash, err := hasher.Hash("123456") + require.NoError(t, err) + require.NotEmpty(t, hash) + + match, err := hasher.Compare(hash, "123456") + require.NoError(t, err) + assert.True(t, match) + + match, err = hasher.Compare(hash, "000000") + require.NoError(t, err) + assert.False(t, match) +} diff --git a/authsession/internal/adapters/mail/rest_client.go b/authsession/internal/adapters/mail/rest_client.go new file mode 100644 index 0000000..1563926 --- /dev/null +++ b/authsession/internal/adapters/mail/rest_client.go @@ -0,0 +1,182 @@ +// Package mail provides runtime mail-delivery adapters for the auth/session +// service. +package mail + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "galaxy/authsession/internal/ports" +) + +const sendLoginCodePath = "/api/v1/internal/login-code-deliveries" + +// Config configures one HTTP-based mail-delivery client. +type Config struct { + // BaseURL is the absolute base URL of the internal mail-service HTTP API. + BaseURL string + + // RequestTimeout bounds each outbound mail-service request. + RequestTimeout time.Duration +} + +// RESTClient implements ports.MailSender over the frozen internal REST mail +// contract. +type RESTClient struct { + baseURL string + requestTimeout time.Duration + httpClient *http.Client +} + +// NewRESTClient constructs a REST-backed MailSender adapter from cfg. +func NewRESTClient(cfg Config) (*RESTClient, error) { + transport := http.DefaultTransport.(*http.Transport).Clone() + + return newRESTClient(cfg, &http.Client{Transport: transport}) +} + +func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) { + switch { + case strings.TrimSpace(cfg.BaseURL) == "": + return nil, errors.New("new mail service REST client: base URL must not be empty") + case cfg.RequestTimeout <= 0: + return nil, errors.New("new mail service REST client: request timeout must be positive") + case httpClient == nil: + return nil, errors.New("new mail service REST client: http client must not be nil") + } + + parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")) + if err != nil { + return nil, fmt.Errorf("new mail service REST client: parse base URL: %w", err) + } + if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" { + return nil, errors.New("new mail service REST client: base URL must be absolute") + } + + return &RESTClient{ + baseURL: parsedBaseURL.String(), + requestTimeout: cfg.RequestTimeout, + httpClient: httpClient, + }, nil +} + +// Close releases idle HTTP connections owned by the client transport. +func (c *RESTClient) Close() error { + if c == nil || c.httpClient == nil { + return nil + } + + type idleCloser interface { + CloseIdleConnections() + } + + if transport, ok := c.httpClient.Transport.(idleCloser); ok { + transport.CloseIdleConnections() + } + + return nil +} + +// SendLoginCode submits one delivery request to the internal mail service +// without retrying transport or upstream failures. +func (c *RESTClient) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) { + if err := validateRESTContext(ctx, "send login code"); err != nil { + return ports.SendLoginCodeResult{}, err + } + if err := input.Validate(); err != nil { + return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err) + } + + payload, statusCode, err := c.doRequest(ctx, "send login code", map[string]string{ + "email": input.Email.String(), + "code": input.Code, + }) + if err != nil { + return ports.SendLoginCodeResult{}, err + } + if statusCode != http.StatusOK { + return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: unexpected HTTP status %d", statusCode) + } + + var response struct { + Outcome ports.SendLoginCodeOutcome `json:"outcome"` + } + if err := decodeJSONPayload(payload, &response); err != nil { + return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err) + } + + result := ports.SendLoginCodeResult{Outcome: response.Outcome} + if err := result.Validate(); err != nil { + return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err) + } + + return result, nil +} + +func (c *RESTClient) doRequest(ctx context.Context, operation string, requestBody any) ([]byte, int, error) { + bodyBytes, err := json.Marshal(requestBody) + if err != nil { + return nil, 0, fmt.Errorf("%s: marshal request body: %w", operation, err) + } + + attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout) + defer cancel() + + request, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, c.baseURL+sendLoginCodePath, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, 0, fmt.Errorf("%s: build request: %w", operation, err) + } + request.Header.Set("Content-Type", "application/json") + + response, err := c.httpClient.Do(request) + if err != nil { + return nil, 0, fmt.Errorf("%s: %w", operation, err) + } + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + if err != nil { + return nil, 0, fmt.Errorf("%s: read response body: %w", operation, err) + } + + return payload, response.StatusCode, nil +} + +func decodeJSONPayload(payload []byte, target any) error { + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.DisallowUnknownFields() + + if err := decoder.Decode(target); err != nil { + return fmt.Errorf("decode response body: %w", err) + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + if err == nil { + return errors.New("decode response body: unexpected trailing JSON input") + } + + return fmt.Errorf("decode response body: %w", err) + } + + return nil +} + +func validateRESTContext(ctx context.Context, operation string) error { + if ctx == nil { + return fmt.Errorf("%s: nil context", operation) + } + if err := ctx.Err(); err != nil { + return fmt.Errorf("%s: %w", operation, err) + } + + return nil +} + +var _ ports.MailSender = (*RESTClient)(nil) diff --git a/authsession/internal/adapters/mail/rest_client_test.go b/authsession/internal/adapters/mail/rest_client_test.go new file mode 100644 index 0000000..c43d3bf --- /dev/null +++ b/authsession/internal/adapters/mail/rest_client_test.go @@ -0,0 +1,394 @@ +package mail + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRESTClient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg Config + wantErr string + }{ + { + name: "valid config", + cfg: Config{ + BaseURL: "http://127.0.0.1:8080", + RequestTimeout: time.Second, + }, + }, + { + name: "empty base url", + cfg: Config{ + RequestTimeout: time.Second, + }, + wantErr: "base URL must not be empty", + }, + { + name: "relative base url", + cfg: Config{ + BaseURL: "/relative", + RequestTimeout: time.Second, + }, + wantErr: "base URL must be absolute", + }, + { + name: "non positive timeout", + cfg: Config{ + BaseURL: "http://127.0.0.1:8080", + }, + wantErr: "request timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client, err := NewRESTClient(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.NoError(t, client.Close()) + }) + } +} + +func TestRESTClientSendLoginCodeSuccessCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + response string + wantOutcome ports.SendLoginCodeOutcome + }{ + { + name: "sent", + response: `{"outcome":"sent"}`, + wantOutcome: ports.SendLoginCodeOutcomeSent, + }, + { + name: "suppressed", + response: `{"outcome":"suppressed"}`, + wantOutcome: ports.SendLoginCodeOutcomeSuppressed, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var requestsMu sync.Mutex + var requests []capturedRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestsMu.Lock() + requests = append(requests, captureRequest(t, r)) + requestsMu.Unlock() + + writeJSON(t, w, http.StatusOK, json.RawMessage(tt.response)) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + result, err := client.SendLoginCode(context.Background(), validInput()) + require.NoError(t, err) + assert.Equal(t, tt.wantOutcome, result.Outcome) + + requestsMu.Lock() + defer requestsMu.Unlock() + + require.Len(t, requests, 1) + assert.Equal(t, http.MethodPost, requests[0].Method) + assert.Equal(t, sendLoginCodePath, requests[0].Path) + assert.Equal(t, "application/json", requests[0].ContentType) + assert.JSONEq(t, `{"email":"pilot@example.com","code":"654321"}`, requests[0].Body) + }) + } +} + +func TestRESTClientPreservesNormalizedEmailAndCodeExactly(t *testing.T) { + t.Parallel() + + var captured string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = captureRequest(t, r).Body + writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"}) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + result, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ + Email: common.Email("Pilot+Alias@Example.com"), + Code: "123456", + }) + require.NoError(t, err) + assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome) + assert.JSONEq(t, `{"email":"Pilot+Alias@Example.com","code":"123456"}`, captured) +} + +func TestRESTClientSendLoginCodeDoesNotRetry(t *testing.T) { + t.Parallel() + + t.Run("no retry on 503", func(t *testing.T) { + t.Parallel() + + var calls atomic.Int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + http.Error(w, "temporary", http.StatusServiceUnavailable) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + _, err := client.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorContains(t, err, "unexpected HTTP status 503") + assert.EqualValues(t, 1, calls.Load()) + }) + + t.Run("no retry on transport failure", func(t *testing.T) { + t.Parallel() + + var calls atomic.Int64 + client, err := newRESTClient(Config{ + BaseURL: "http://127.0.0.1:8080", + RequestTimeout: 250 * time.Millisecond, + }, &http.Client{ + Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { + calls.Add(1) + return nil, errors.New("temporary transport failure") + }), + }) + require.NoError(t, err) + + _, err = client.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorContains(t, err, "temporary transport failure") + assert.EqualValues(t, 1, calls.Load()) + }) +} + +func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + body string + wantErrText string + }{ + { + name: "rejects unknown field", + statusCode: http.StatusOK, + body: `{"outcome":"sent","extra":true}`, + wantErrText: "decode response body", + }, + { + name: "rejects unsupported outcome", + statusCode: http.StatusOK, + body: `{"outcome":"queued"}`, + wantErrText: "unsupported", + }, + { + name: "rejects missing outcome", + statusCode: http.StatusOK, + body: `{}`, + wantErrText: "unsupported", + }, + { + name: "rejects trailing json", + statusCode: http.StatusOK, + body: `{"outcome":"sent"}{}`, + wantErrText: "unexpected trailing JSON input", + }, + { + name: "rejects unexpected status", + statusCode: http.StatusBadGateway, + body: `{"error":"temporary"}`, + wantErrText: "unexpected HTTP status 502", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tt.statusCode) + _, err := io.WriteString(w, tt.body) + require.NoError(t, err) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + _, err := client.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErrText) + }) + } +} + +func TestRESTClientRequestTimeout(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(40 * time.Millisecond) + writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"}) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 10*time.Millisecond) + + _, err := client.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorContains(t, err, "context deadline exceeded") +} + +func TestRESTClientContextAndValidation(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("unexpected upstream call") + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + tests := []struct { + name string + run func() error + }{ + { + name: "nil context", + run: func() error { + _, err := client.SendLoginCode(nil, validInput()) + return err + }, + }, + { + name: "cancelled context", + run: func() error { + _, err := client.SendLoginCode(cancelledCtx, validInput()) + return err + }, + }, + { + name: "invalid email", + run: func() error { + _, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ + Email: common.Email(" bad@example.com "), + Code: "123456", + }) + return err + }, + }, + { + name: "invalid code", + run: func() error { + _, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ + Email: common.Email("pilot@example.com"), + Code: " 123456 ", + }) + return err + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + err := tt.run() + require.Error(t, err) + }) + } +} + +type capturedRequest struct { + Method string + Path string + ContentType string + Body string +} + +func captureRequest(t *testing.T, request *http.Request) capturedRequest { + t.Helper() + + body, err := io.ReadAll(request.Body) + require.NoError(t, err) + + return capturedRequest{ + Method: request.Method, + Path: request.URL.Path, + ContentType: request.Header.Get("Content-Type"), + Body: strings.TrimSpace(string(body)), + } +} + +func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) { + t.Helper() + + payload, err := json.Marshal(value) + require.NoError(t, err) + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(statusCode) + _, err = writer.Write(payload) + require.NoError(t, err) +} + +func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient { + t.Helper() + + client, err := NewRESTClient(Config{ + BaseURL: baseURL, + RequestTimeout: timeout, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, client.Close()) + }) + + return client +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) { + return fn(request) +} diff --git a/authsession/internal/adapters/mail/stub_sender.go b/authsession/internal/adapters/mail/stub_sender.go new file mode 100644 index 0000000..530be15 --- /dev/null +++ b/authsession/internal/adapters/mail/stub_sender.go @@ -0,0 +1,179 @@ +// Package mail provides runtime mail-delivery adapters for the auth/session +// service. +package mail + +import ( + "context" + "errors" + "fmt" + "sync" + + "galaxy/authsession/internal/ports" +) + +var errForcedFailure = errors.New("stub mail sender: forced failure") + +// StubMode identifies the deterministic outcome used by StubSender for one +// delivery attempt. +type StubMode string + +const ( + // StubModeSent reports that the adapter accepts delivery and returns the + // stable sent outcome expected by the auth flow. + StubModeSent StubMode = "sent" + + // StubModeSuppressed reports that the adapter intentionally suppresses + // outward delivery while still returning a successful suppressed outcome. + StubModeSuppressed StubMode = "suppressed" + + // StubModeFailed reports that the adapter returns an explicit delivery + // failure instead of a successful outcome. + StubModeFailed StubMode = "failed" +) + +// IsKnown reports whether mode is one of the supported stub delivery modes. +func (mode StubMode) IsKnown() bool { + switch mode { + case StubModeSent, StubModeSuppressed, StubModeFailed: + return true + default: + return false + } +} + +// StubStep overrides the default stub behavior for one queued delivery +// attempt. +type StubStep struct { + // Mode selects the delivery behavior for this queued step. + Mode StubMode + + // Err optionally overrides the failure returned when Mode is StubModeFailed. + Err error +} + +// Validate reports whether step contains one supported queued behavior. +func (step StubStep) Validate() error { + if !step.Mode.IsKnown() { + return fmt.Errorf("stub mail step mode %q is unsupported", step.Mode) + } + + return nil +} + +// Attempt records one validated delivery request handled by StubSender. +type Attempt struct { + // Input stores the validated cleartext mail-delivery request exactly as it + // was passed into SendLoginCode. + Input ports.SendLoginCodeInput + + // Mode stores the resolved stub mode after queued overrides were applied. + Mode StubMode +} + +// StubSender is a deterministic runtime MailSender implementation intended +// for development, local integration, and explicit stub-based tests. +// +// The zero value is ready to use and defaults to StubModeSent. +type StubSender struct { + // DefaultMode controls the fallback behavior when Script is empty. The zero + // value is treated as StubModeSent so the zero-value sender is usable + // without extra configuration. + DefaultMode StubMode + + // DefaultError optionally overrides the failure returned when DefaultMode + // resolves to StubModeFailed. + DefaultError error + + // Script stores queued one-shot overrides consumed in FIFO order before the + // default behavior is used. + Script []StubStep + + mu sync.Mutex + attempts []Attempt +} + +// SendLoginCode records one validated delivery request and returns the +// deterministic stub outcome selected by the queued script or the default +// mode. +func (s *StubSender) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) { + if ctx == nil { + return ports.SendLoginCodeResult{}, errors.New("stub mail sender: nil context") + } + if err := ctx.Err(); err != nil { + return ports.SendLoginCodeResult{}, err + } + if err := input.Validate(); err != nil { + return ports.SendLoginCodeResult{}, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + mode, errOverride, err := s.resolveNextStepLocked() + if err != nil { + return ports.SendLoginCodeResult{}, err + } + + s.attempts = append(s.attempts, Attempt{ + Input: input, + Mode: mode, + }) + + switch mode { + case StubModeSent: + return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSent}, nil + case StubModeSuppressed: + return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSuppressed}, nil + case StubModeFailed: + if errOverride != nil { + return ports.SendLoginCodeResult{}, errOverride + } + return ports.SendLoginCodeResult{}, errForcedFailure + default: + return ports.SendLoginCodeResult{}, fmt.Errorf("stub mail sender: unsupported resolved mode %q", mode) + } +} + +// RecordedAttempts returns a stable defensive copy of every validated delivery +// attempt handled by the stub. +func (s *StubSender) RecordedAttempts() []Attempt { + s.mu.Lock() + defer s.mu.Unlock() + + return append([]Attempt(nil), s.attempts...) +} + +func (s *StubSender) resolveNextStepLocked() (StubMode, error, error) { + if len(s.Script) > 0 { + step := s.Script[0] + s.Script = append([]StubStep(nil), s.Script[1:]...) + if err := step.Validate(); err != nil { + return "", nil, fmt.Errorf("stub mail sender: %w", err) + } + if step.Mode == StubModeFailed { + if step.Err != nil { + return step.Mode, step.Err, nil + } + return step.Mode, errForcedFailure, nil + } + return step.Mode, nil, nil + } + + mode := s.DefaultMode + if mode == "" { + mode = StubModeSent + } + if !mode.IsKnown() { + return "", nil, fmt.Errorf("stub mail sender: default mode %q is unsupported", mode) + } + if mode == StubModeFailed { + if s.DefaultError != nil { + return mode, s.DefaultError, nil + } + return mode, errForcedFailure, nil + } + + return mode, nil, nil +} + +var _ ports.MailSender = (*StubSender)(nil) diff --git a/authsession/internal/adapters/mail/stub_sender_test.go b/authsession/internal/adapters/mail/stub_sender_test.go new file mode 100644 index 0000000..1c0a502 --- /dev/null +++ b/authsession/internal/adapters/mail/stub_sender_test.go @@ -0,0 +1,198 @@ +package mail + +import ( + "context" + "errors" + "testing" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStubSenderSendLoginCode(t *testing.T) { + t.Parallel() + + t.Run("zero value defaults to sent", func(t *testing.T) { + t.Parallel() + + sender := &StubSender{} + + result, err := sender.SendLoginCode(context.Background(), validInput()) + require.NoError(t, err) + assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome) + + attempts := sender.RecordedAttempts() + require.Len(t, attempts, 1) + assert.Equal(t, StubModeSent, attempts[0].Mode) + assert.Equal(t, validInput(), attempts[0].Input) + }) + + t.Run("default suppressed", func(t *testing.T) { + t.Parallel() + + sender := &StubSender{DefaultMode: StubModeSuppressed} + + result, err := sender.SendLoginCode(context.Background(), validInput()) + require.NoError(t, err) + assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, result.Outcome) + + attempts := sender.RecordedAttempts() + require.Len(t, attempts, 1) + assert.Equal(t, StubModeSuppressed, attempts[0].Mode) + }) + + t.Run("default failed uses configured error", func(t *testing.T) { + t.Parallel() + + wantErr := errors.New("delivery refused") + sender := &StubSender{ + DefaultMode: StubModeFailed, + DefaultError: wantErr, + } + + result, err := sender.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorIs(t, err, wantErr) + assert.Equal(t, ports.SendLoginCodeResult{}, result) + + attempts := sender.RecordedAttempts() + require.Len(t, attempts, 1) + assert.Equal(t, StubModeFailed, attempts[0].Mode) + }) + + t.Run("default failed uses stable fallback error", func(t *testing.T) { + t.Parallel() + + sender := &StubSender{DefaultMode: StubModeFailed} + + _, err := sender.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.EqualError(t, err, "stub mail sender: forced failure") + }) + + t.Run("script overrides default and is consumed fifo", func(t *testing.T) { + t.Parallel() + + wantErr := errors.New("step failed") + sender := &StubSender{ + DefaultMode: StubModeSent, + Script: []StubStep{ + {Mode: StubModeSuppressed}, + {Mode: StubModeFailed, Err: wantErr}, + }, + } + + first, err := sender.SendLoginCode(context.Background(), validInput()) + require.NoError(t, err) + assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, first.Outcome) + + second, err := sender.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorIs(t, err, wantErr) + assert.Equal(t, ports.SendLoginCodeResult{}, second) + + third, err := sender.SendLoginCode(context.Background(), validInput()) + require.NoError(t, err) + assert.Equal(t, ports.SendLoginCodeOutcomeSent, third.Outcome) + + attempts := sender.RecordedAttempts() + require.Len(t, attempts, 3) + assert.Equal(t, []StubMode{StubModeSuppressed, StubModeFailed, StubModeSent}, []StubMode{ + attempts[0].Mode, + attempts[1].Mode, + attempts[2].Mode, + }) + assert.Empty(t, sender.Script) + }) + + t.Run("invalid default mode returns adapter error", func(t *testing.T) { + t.Parallel() + + sender := &StubSender{DefaultMode: StubMode("queued")} + + _, err := sender.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorContains(t, err, `default mode "queued" is unsupported`) + assert.Empty(t, sender.RecordedAttempts()) + }) + + t.Run("invalid scripted mode returns adapter error", func(t *testing.T) { + t.Parallel() + + sender := &StubSender{ + Script: []StubStep{ + {Mode: StubMode("queued")}, + }, + } + + _, err := sender.SendLoginCode(context.Background(), validInput()) + require.Error(t, err) + assert.ErrorContains(t, err, `mode "queued" is unsupported`) + assert.Empty(t, sender.RecordedAttempts()) + assert.Empty(t, sender.Script) + }) +} + +func TestStubSenderRecordedAttemptsAreDefensive(t *testing.T) { + t.Parallel() + + sender := &StubSender{} + + _, err := sender.SendLoginCode(context.Background(), validInput()) + require.NoError(t, err) + + attempts := sender.RecordedAttempts() + require.Len(t, attempts, 1) + attempts[0].Mode = StubModeFailed + attempts[0].Input.Code = "000000" + + again := sender.RecordedAttempts() + require.Len(t, again, 1) + assert.Equal(t, StubModeSent, again[0].Mode) + assert.Equal(t, "654321", again[0].Input.Code) +} + +func TestStubSenderSendLoginCodeNilContext(t *testing.T) { + t.Parallel() + + sender := &StubSender{} + + _, err := sender.SendLoginCode(nil, validInput()) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") + assert.Empty(t, sender.RecordedAttempts()) +} + +func TestStubSenderSendLoginCodeCancelledContext(t *testing.T) { + t.Parallel() + + sender := &StubSender{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := sender.SendLoginCode(ctx, validInput()) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + assert.Empty(t, sender.RecordedAttempts()) +} + +func TestStubSenderSendLoginCodeInvalidInput(t *testing.T) { + t.Parallel() + + sender := &StubSender{} + + _, err := sender.SendLoginCode(context.Background(), ports.SendLoginCodeInput{}) + require.Error(t, err) + assert.ErrorContains(t, err, "send login code input email") + assert.Empty(t, sender.RecordedAttempts()) +} + +func validInput() ports.SendLoginCodeInput { + return ports.SendLoginCodeInput{ + Email: common.Email("pilot@example.com"), + Code: "654321", + } +} diff --git a/authsession/internal/adapters/redis/challengestore/store.go b/authsession/internal/adapters/redis/challengestore/store.go new file mode 100644 index 0000000..3476487 --- /dev/null +++ b/authsession/internal/adapters/redis/challengestore/store.go @@ -0,0 +1,484 @@ +// Package challengestore implements ports.ChallengeStore with Redis-backed +// strict JSON challenge records. +package challengestore + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "strings" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/redis/go-redis/v9" +) + +const expirationGracePeriod = 5 * time.Minute + +// Config configures one Redis-backed challenge store instance. +type Config struct { + // Addr is the Redis network address in host:port form. + Addr string + + // Username is the optional Redis ACL username. + Username string + + // Password is the optional Redis ACL password. + Password string + + // DB is the Redis logical database index. + DB int + + // TLSEnabled enables TLS with a conservative minimum protocol version. + TLSEnabled bool + + // KeyPrefix is the namespace prefix applied to every challenge key. + KeyPrefix string + + // OperationTimeout bounds each Redis round trip performed by the adapter. + OperationTimeout time.Duration +} + +// Store persists challenges as one strict JSON value per Redis key. +type Store struct { + client *redis.Client + keyPrefix string + operationTimeout time.Duration +} + +type redisRecord struct { + ChallengeID string `json:"challenge_id"` + Email string `json:"email"` + CodeHashBase64 string `json:"code_hash_base64"` + Status challenge.Status `json:"status"` + DeliveryState challenge.DeliveryState `json:"delivery_state"` + CreatedAt string `json:"created_at"` + ExpiresAt string `json:"expires_at"` + SendAttemptCount int `json:"send_attempt_count"` + ConfirmAttemptCount int `json:"confirm_attempt_count"` + LastAttemptAt *string `json:"last_attempt_at,omitempty"` + ConfirmedSessionID string `json:"confirmed_session_id,omitempty"` + ConfirmedClientPublicKey string `json:"confirmed_client_public_key,omitempty"` + ConfirmedAt *string `json:"confirmed_at,omitempty"` +} + +// New constructs a Redis-backed challenge store from cfg. +func New(cfg Config) (*Store, error) { + if strings.TrimSpace(cfg.Addr) == "" { + return nil, errors.New("new redis challenge store: redis addr must not be empty") + } + if cfg.DB < 0 { + return nil, errors.New("new redis challenge store: redis db must not be negative") + } + if strings.TrimSpace(cfg.KeyPrefix) == "" { + return nil, errors.New("new redis challenge store: redis key prefix must not be empty") + } + if cfg.OperationTimeout <= 0 { + return nil, errors.New("new redis challenge store: operation timeout must be positive") + } + + options := &redis.Options{ + Addr: cfg.Addr, + Username: cfg.Username, + Password: cfg.Password, + DB: cfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if cfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + + return &Store{ + client: redis.NewClient(options), + keyPrefix: cfg.KeyPrefix, + operationTimeout: cfg.OperationTimeout, + }, nil +} + +// Close releases the underlying Redis client resources. +func (s *Store) Close() error { + if s == nil || s.client == nil { + return nil + } + + return s.client.Close() +} + +// Ping verifies that the configured Redis backend is reachable within the +// adapter operation timeout budget. +func (s *Store) Ping(ctx context.Context) error { + operationCtx, cancel, err := s.operationContext(ctx, "ping redis challenge store") + if err != nil { + return err + } + defer cancel() + + if err := s.client.Ping(operationCtx).Err(); err != nil { + return fmt.Errorf("ping redis challenge store: %w", err) + } + + return nil +} + +// Get returns the stored challenge for challengeID. +func (s *Store) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) { + if err := challengeID.Validate(); err != nil { + return challenge.Challenge{}, fmt.Errorf("get challenge from redis: %w", err) + } + + operationCtx, cancel, err := s.operationContext(ctx, "get challenge from redis") + if err != nil { + return challenge.Challenge{}, err + } + defer cancel() + + payload, err := s.client.Get(operationCtx, s.lookupKey(challengeID)).Bytes() + switch { + case errors.Is(err, redis.Nil): + return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, ports.ErrNotFound) + case err != nil: + return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err) + } + + record, err := decodeChallengeRecord(challengeID, payload) + if err != nil { + return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err) + } + + return record, nil +} + +// Create persists record as a new challenge. +func (s *Store) Create(ctx context.Context, record challenge.Challenge) error { + if err := record.Validate(); err != nil { + return fmt.Errorf("create challenge in redis: %w", err) + } + + payload, err := marshalChallengeRecord(record) + if err != nil { + return fmt.Errorf("create challenge in redis: %w", err) + } + + operationCtx, cancel, err := s.operationContext(ctx, "create challenge in redis") + if err != nil { + return err + } + defer cancel() + + created, err := s.client.SetNX(operationCtx, s.lookupKey(record.ID), payload, redisTTL(record.ExpiresAt)).Result() + if err != nil { + return fmt.Errorf("create challenge %q in redis: %w", record.ID, err) + } + if !created { + return fmt.Errorf("create challenge %q in redis: %w", record.ID, ports.ErrConflict) + } + + return nil +} + +// CompareAndSwap replaces previous with next when the currently stored +// challenge matches previous exactly in canonical Redis representation. +func (s *Store) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error { + if err := ports.ValidateComparableChallenges(previous, next); err != nil { + return fmt.Errorf("compare and swap challenge in redis: %w", err) + } + + nextPayload, err := marshalChallengeRecord(next) + if err != nil { + return fmt.Errorf("compare and swap challenge in redis: %w", err) + } + + operationCtx, cancel, err := s.operationContext(ctx, "compare and swap challenge in redis") + if err != nil { + return err + } + defer cancel() + + key := s.lookupKey(previous.ID) + watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { + payload, err := tx.Get(operationCtx, key).Bytes() + switch { + case errors.Is(err, redis.Nil): + return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrNotFound) + case err != nil: + return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) + } + + current, err := decodeChallengeRecord(previous.ID, payload) + if err != nil { + return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) + } + + matches, err := equalStoredChallenges(current, previous) + if err != nil { + return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) + } + if !matches { + return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict) + } + + _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { + pipe.Set(operationCtx, key, nextPayload, redisTTL(next.ExpiresAt)) + return nil + }) + if err != nil { + return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) + } + + return nil + }, key) + + switch { + case errors.Is(watchErr, redis.TxFailedErr): + return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict) + case watchErr != nil: + return watchErr + default: + return nil + } +} + +func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) { + if s == nil || s.client == nil { + return nil, nil, fmt.Errorf("%s: nil store", operation) + } + if ctx == nil { + return nil, nil, fmt.Errorf("%s: nil context", operation) + } + + operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout) + return operationCtx, cancel, nil +} + +func (s *Store) lookupKey(challengeID common.ChallengeID) string { + return s.keyPrefix + encodeKeyComponent(challengeID.String()) +} + +func encodeKeyComponent(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func marshalChallengeRecord(record challenge.Challenge) ([]byte, error) { + stored, err := redisRecordFromChallenge(record) + if err != nil { + return nil, err + } + + payload, err := json.Marshal(stored) + if err != nil { + return nil, fmt.Errorf("encode redis challenge record: %w", err) + } + + return payload, nil +} + +func redisRecordFromChallenge(record challenge.Challenge) (redisRecord, error) { + if err := record.Validate(); err != nil { + return redisRecord{}, fmt.Errorf("encode redis challenge record: %w", err) + } + + stored := redisRecord{ + ChallengeID: record.ID.String(), + Email: record.Email.String(), + CodeHashBase64: base64.StdEncoding.EncodeToString(record.CodeHash), + Status: record.Status, + DeliveryState: record.DeliveryState, + CreatedAt: formatTimestamp(record.CreatedAt), + ExpiresAt: formatTimestamp(record.ExpiresAt), + SendAttemptCount: record.Attempts.Send, + ConfirmAttemptCount: record.Attempts.Confirm, + LastAttemptAt: formatOptionalTimestamp(record.Abuse.LastAttemptAt), + } + if record.Confirmation != nil { + stored.ConfirmedSessionID = record.Confirmation.SessionID.String() + stored.ConfirmedClientPublicKey = record.Confirmation.ClientPublicKey.String() + stored.ConfirmedAt = formatOptionalTimestamp(&record.Confirmation.ConfirmedAt) + } + + return stored, nil +} + +func decodeChallengeRecord(expectedChallengeID common.ChallengeID, payload []byte) (challenge.Challenge, error) { + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.DisallowUnknownFields() + + var stored redisRecord + if err := decoder.Decode(&stored); err != nil { + return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err) + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + if err == nil { + return challenge.Challenge{}, errors.New("decode redis challenge record: unexpected trailing JSON input") + } + return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err) + } + + record, err := challengeFromRedisRecord(stored) + if err != nil { + return challenge.Challenge{}, err + } + if record.ID != expectedChallengeID { + return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: challenge_id %q does not match requested %q", record.ID, expectedChallengeID) + } + + return record, nil +} + +func challengeFromRedisRecord(stored redisRecord) (challenge.Challenge, error) { + createdAt, err := parseTimestamp("created_at", stored.CreatedAt) + if err != nil { + return challenge.Challenge{}, err + } + expiresAt, err := parseTimestamp("expires_at", stored.ExpiresAt) + if err != nil { + return challenge.Challenge{}, err + } + lastAttemptAt, err := parseOptionalTimestamp("last_attempt_at", stored.LastAttemptAt) + if err != nil { + return challenge.Challenge{}, err + } + + codeHash, err := base64.StdEncoding.Strict().DecodeString(stored.CodeHashBase64) + if err != nil { + return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: code_hash_base64: %w", err) + } + + record := challenge.Challenge{ + ID: common.ChallengeID(stored.ChallengeID), + Email: common.Email(stored.Email), + CodeHash: codeHash, + Status: stored.Status, + DeliveryState: stored.DeliveryState, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + Attempts: challenge.AttemptCounters{ + Send: stored.SendAttemptCount, + Confirm: stored.ConfirmAttemptCount, + }, + Abuse: challenge.AbuseMetadata{ + LastAttemptAt: lastAttemptAt, + }, + } + + confirmation, err := parseConfirmation(stored) + if err != nil { + return challenge.Challenge{}, err + } + record.Confirmation = confirmation + + if err := record.Validate(); err != nil { + return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err) + } + + return record, nil +} + +func parseConfirmation(stored redisRecord) (*challenge.Confirmation, error) { + hasSessionID := strings.TrimSpace(stored.ConfirmedSessionID) != "" + hasClientPublicKey := strings.TrimSpace(stored.ConfirmedClientPublicKey) != "" + hasConfirmedAt := stored.ConfirmedAt != nil + + if !hasSessionID && !hasClientPublicKey && !hasConfirmedAt { + return nil, nil + } + if !hasSessionID || !hasClientPublicKey || !hasConfirmedAt { + return nil, errors.New("decode redis challenge record: confirmation metadata must be either fully present or fully absent") + } + + confirmedAt, err := parseTimestamp("confirmed_at", *stored.ConfirmedAt) + if err != nil { + return nil, err + } + rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ConfirmedClientPublicKey) + if err != nil { + return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err) + } + clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey) + if err != nil { + return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err) + } + + return &challenge.Confirmation{ + SessionID: common.DeviceSessionID(stored.ConfirmedSessionID), + ClientPublicKey: clientPublicKey, + ConfirmedAt: confirmedAt, + }, nil +} + +func parseOptionalTimestamp(fieldName string, value *string) (*time.Time, error) { + if value == nil { + return nil, nil + } + + parsed, err := parseTimestamp(fieldName, *value) + if err != nil { + return nil, err + } + + return &parsed, nil +} + +func parseTimestamp(fieldName string, value string) (time.Time, error) { + if strings.TrimSpace(value) == "" { + return time.Time{}, fmt.Errorf("decode redis challenge record: %s must not be empty", fieldName) + } + + parsed, err := time.Parse(time.RFC3339Nano, value) + if err != nil { + return time.Time{}, fmt.Errorf("decode redis challenge record: %s: %w", fieldName, err) + } + + canonical := parsed.UTC().Format(time.RFC3339Nano) + if value != canonical { + return time.Time{}, fmt.Errorf("decode redis challenge record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName) + } + + return parsed.UTC(), nil +} + +func formatTimestamp(value time.Time) string { + return value.UTC().Format(time.RFC3339Nano) +} + +func formatOptionalTimestamp(value *time.Time) *string { + if value == nil { + return nil + } + + formatted := formatTimestamp(*value) + return &formatted +} + +func redisTTL(expiresAt time.Time) time.Duration { + ttl := time.Until(expiresAt.UTC()) + if ttl < 0 { + ttl = 0 + } + + return ttl + expirationGracePeriod +} + +func equalStoredChallenges(left challenge.Challenge, right challenge.Challenge) (bool, error) { + leftRecord, err := redisRecordFromChallenge(left) + if err != nil { + return false, err + } + rightRecord, err := redisRecordFromChallenge(right) + if err != nil { + return false, err + } + + return reflect.DeepEqual(leftRecord, rightRecord), nil +} + +var _ ports.ChallengeStore = (*Store)(nil) diff --git a/authsession/internal/adapters/redis/challengestore/store_test.go b/authsession/internal/adapters/redis/challengestore/store_test.go new file mode 100644 index 0000000..d0b1510 --- /dev/null +++ b/authsession/internal/adapters/redis/challengestore/store_test.go @@ -0,0 +1,531 @@ +package challengestore + +import ( + "context" + "crypto/ed25519" + "encoding/json" + "testing" + "time" + + "galaxy/authsession/internal/adapters/contracttest" + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStoreContract(t *testing.T) { + t.Parallel() + + contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore { + t.Helper() + + server := miniredis.RunT(t) + return newTestStore(t, server, Config{}) + }) +} + +func TestNew(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + tests := []struct { + name string + cfg Config + wantErr string + }{ + { + name: "valid config", + cfg: Config{ + Addr: server.Addr(), + DB: 2, + KeyPrefix: "authsession:challenge:", + OperationTimeout: 250 * time.Millisecond, + }, + }, + { + name: "empty addr", + cfg: Config{ + KeyPrefix: "authsession:challenge:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "negative db", + cfg: Config{ + Addr: server.Addr(), + DB: -1, + KeyPrefix: "authsession:challenge:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis db must not be negative", + }, + { + name: "empty key prefix", + cfg: Config{ + Addr: server.Addr(), + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis key prefix must not be empty", + }, + { + name: "non-positive operation timeout", + cfg: Config{ + Addr: server.Addr(), + KeyPrefix: "authsession:challenge:", + }, + wantErr: "operation timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store, err := New(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + }) + } +} + +func TestStorePing(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + require.NoError(t, store.Ping(context.Background())) +} + +func TestStoreCreateAndGet(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + now := time.Unix(1_775_130_000, 0).UTC() + + record := testChallenge(now) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record, got) + + got.CodeHash[0] = 0xFF + keyBytes := got.Confirmation.ClientPublicKey.PublicKey() + keyBytes[0] = 0xFE + + again, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record.CodeHash, again.CodeHash) + require.NotNil(t, again.Confirmation) + assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String()) +} + +func TestStoreCreateAndGetPendingChallenge(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + now := time.Unix(1_775_130_100, 0).UTC() + + record := testPendingChallenge(now) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record, got) +} + +func TestStoreCreateAndGetThrottledChallenge(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + now := time.Unix(1_775_130_150, 0).UTC() + + record := testPendingChallenge(now) + record.Status = challenge.StatusDeliveryThrottled + record.DeliveryState = challenge.DeliveryThrottled + record.Attempts.Send = 1 + record.Abuse.LastAttemptAt = timePointer(now) + require.NoError(t, record.Validate()) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record, got) +} + +func TestStoreGetStrictDecode(t *testing.T) { + t.Parallel() + + now := time.Unix(1_775_130_200, 0).UTC() + baseRecord := testChallenge(now) + baseStored, err := redisRecordFromChallenge(baseRecord) + require.NoError(t, err) + + tests := []struct { + name string + mutate func(redisRecord) string + wantErrText string + }{ + { + name: "malformed json", + mutate: func(_ redisRecord) string { + return "{" + }, + wantErrText: "decode redis challenge record", + }, + { + name: "trailing json input", + mutate: func(record redisRecord) string { + return mustMarshalJSON(t, record) + "{}" + }, + wantErrText: "unexpected trailing JSON input", + }, + { + name: "unknown field", + mutate: func(record redisRecord) string { + payload := map[string]any{ + "challenge_id": record.ChallengeID, + "email": record.Email, + "code_hash_base64": record.CodeHashBase64, + "status": record.Status, + "delivery_state": record.DeliveryState, + "created_at": record.CreatedAt, + "expires_at": record.ExpiresAt, + "send_attempt_count": record.SendAttemptCount, + "confirm_attempt_count": record.ConfirmAttemptCount, + "last_attempt_at": record.LastAttemptAt, + "confirmed_session_id": record.ConfirmedSessionID, + "confirmed_client_public_key": record.ConfirmedClientPublicKey, + "confirmed_at": record.ConfirmedAt, + "unexpected": true, + } + return mustMarshalJSON(t, payload) + }, + wantErrText: "unknown field", + }, + { + name: "unsupported status", + mutate: func(record redisRecord) string { + record.Status = challenge.Status("paused") + return mustMarshalJSON(t, record) + }, + wantErrText: `status "paused" is unsupported`, + }, + { + name: "unsupported delivery state", + mutate: func(record redisRecord) string { + record.DeliveryState = challenge.DeliveryState("queued") + return mustMarshalJSON(t, record) + }, + wantErrText: `delivery state "queued" is unsupported`, + }, + { + name: "missing required email", + mutate: func(record redisRecord) string { + record.Email = "" + return mustMarshalJSON(t, record) + }, + wantErrText: "challenge email", + }, + { + name: "challenge id mismatch", + mutate: func(record redisRecord) string { + record.ChallengeID = "other-challenge" + return mustMarshalJSON(t, record) + }, + wantErrText: `does not match requested`, + }, + { + name: "non canonical utc timestamp", + mutate: func(record redisRecord) string { + record.CreatedAt = "2026-04-04T12:00:00+03:00" + return mustMarshalJSON(t, record) + }, + wantErrText: "canonical UTC RFC3339Nano timestamp", + }, + { + name: "partial confirmation metadata", + mutate: func(record redisRecord) string { + record.ConfirmedAt = nil + return mustMarshalJSON(t, record) + }, + wantErrText: "confirmation metadata must be either fully present or fully absent", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored)) + + _, err := store.Get(context.Background(), baseRecord.ID) + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErrText) + }) + } +} + +func TestStoreKeySchemeAndTTL(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"}) + now := time.Now().UTC() + + prefixed := testPendingChallenge(now) + prefixed.ID = common.ChallengeID("challenge:opaque/id?value") + require.NoError(t, store.Create(context.Background(), prefixed)) + + key := store.lookupKey(prefixed.ID) + assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key) + assert.True(t, server.Exists(key)) + + freshTTL := server.TTL(key) + assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod) + assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second) + + expired := testPendingChallenge(now.Add(-10 * time.Minute)) + expired.ID = common.ChallengeID("expired-challenge") + expired.CreatedAt = now.Add(-20 * time.Minute) + expired.ExpiresAt = now.Add(-1 * time.Minute) + require.NoError(t, store.Create(context.Background(), expired)) + + expiredTTL := server.TTL(store.lookupKey(expired.ID)) + assert.LessOrEqual(t, expiredTTL, expirationGracePeriod) + assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second) +} + +func TestStoreCreateConflict(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + err := store.Create(context.Background(), record) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrConflict) +} + +func TestStoreGetNotFound(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + _, err := store.Get(context.Background(), common.ChallengeID("missing-challenge")) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) +} + +func TestStoreCompareAndSwap(t *testing.T) { + t.Parallel() + + now := time.Unix(1_775_130_400, 0).UTC() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + previous := testPendingChallenge(now) + next := previous + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + next.Attempts.Send = 1 + next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute)) + + require.NoError(t, store.Create(context.Background(), previous)) + require.NoError(t, store.CompareAndSwap(context.Background(), previous, next)) + + got, err := store.Get(context.Background(), previous.ID) + require.NoError(t, err) + assert.Equal(t, next, got) + }) + + t.Run("conflict when stored record differs", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + stored := testPendingChallenge(now) + previous := stored + previous.Attempts.Send = 99 + next := stored + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + + require.NoError(t, store.Create(context.Background(), stored)) + + err := store.CompareAndSwap(context.Background(), previous, next) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrConflict) + }) + + t.Run("not found", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + previous := testPendingChallenge(now) + next := previous + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + + err := store.CompareAndSwap(context.Background(), previous, next) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) + }) + + t.Run("corrupt stored record returns adapter error", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + previous := testPendingChallenge(now) + next := previous + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + + server.Set(store.lookupKey(previous.ID), "{") + + err := store.CompareAndSwap(context.Background(), previous, next) + require.Error(t, err) + assert.NotErrorIs(t, err, ports.ErrConflict) + assert.ErrorContains(t, err, "decode redis challenge record") + }) +} + +func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store { + t.Helper() + + if cfg.Addr == "" { + cfg.Addr = server.Addr() + } + if cfg.KeyPrefix == "" { + cfg.KeyPrefix = "authsession:challenge:" + } + if cfg.OperationTimeout == 0 { + cfg.OperationTimeout = 250 * time.Millisecond + } + + store, err := New(cfg) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + + return store +} + +func testPendingChallenge(now time.Time) challenge.Challenge { + return challenge.Challenge{ + ID: common.ChallengeID("challenge-pending"), + Email: common.Email("pilot@example.com"), + CodeHash: []byte("hashed-pending-code"), + Status: challenge.StatusPendingSend, + DeliveryState: challenge.DeliveryPending, + CreatedAt: now, + ExpiresAt: now.Add(challenge.InitialTTL), + } +} + +func testChallenge(now time.Time) challenge.Challenge { + clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{ + 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + }) + if err != nil { + panic(err) + } + + return challenge.Challenge{ + ID: common.ChallengeID("challenge-confirmed"), + Email: common.Email("pilot@example.com"), + CodeHash: []byte("hashed-code"), + Status: challenge.StatusConfirmedPendingExpire, + DeliveryState: challenge.DeliverySent, + CreatedAt: now, + ExpiresAt: now.Add(challenge.ConfirmedRetention), + Attempts: challenge.AttemptCounters{ + Send: 1, + Confirm: 2, + }, + Abuse: challenge.AbuseMetadata{ + LastAttemptAt: timePointer(now.Add(30 * time.Second)), + }, + Confirmation: &challenge.Confirmation{ + SessionID: common.DeviceSessionID("device-session-1"), + ClientPublicKey: clientPublicKey, + ConfirmedAt: now.Add(1 * time.Minute), + }, + } +} + +func timePointer(value time.Time) *time.Time { + return &value +} + +func mustMarshalJSON(t *testing.T, value any) string { + t.Helper() + + payload, err := json.Marshal(value) + require.NoError(t, err) + + return string(payload) +} + +func TestStorePingNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + err := store.Ping(nil) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") +} + +func TestStoreGetNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + _, err := store.Get(nil, common.ChallengeID("challenge")) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") +} diff --git a/authsession/internal/adapters/redis/configprovider/store.go b/authsession/internal/adapters/redis/configprovider/store.go new file mode 100644 index 0000000..7f0c096 --- /dev/null +++ b/authsession/internal/adapters/redis/configprovider/store.go @@ -0,0 +1,169 @@ +// Package configprovider implements ports.ConfigProvider with Redis-backed +// dynamic auth/session configuration. +package configprovider + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "galaxy/authsession/internal/ports" + + "github.com/redis/go-redis/v9" +) + +// Config configures one Redis-backed config provider instance. +type Config struct { + // Addr is the Redis network address in host:port form. + Addr string + + // Username is the optional Redis ACL username. + Username string + + // Password is the optional Redis ACL password. + Password string + + // DB is the Redis logical database index. + DB int + + // TLSEnabled enables TLS with a conservative minimum protocol version. + TLSEnabled bool + + // SessionLimitKey identifies the single Redis string key that stores the + // active-session-limit configuration value. + SessionLimitKey string + + // OperationTimeout bounds each Redis round trip performed by the adapter. + OperationTimeout time.Duration +} + +// Store reads dynamic auth/session configuration from Redis. +type Store struct { + client *redis.Client + sessionLimitKey string + operationTimeout time.Duration +} + +// New constructs a Redis-backed config provider from cfg. +func New(cfg Config) (*Store, error) { + switch { + case strings.TrimSpace(cfg.Addr) == "": + return nil, errors.New("new redis config provider: redis addr must not be empty") + case cfg.DB < 0: + return nil, errors.New("new redis config provider: redis db must not be negative") + case strings.TrimSpace(cfg.SessionLimitKey) == "": + return nil, errors.New("new redis config provider: session limit key must not be empty") + case cfg.OperationTimeout <= 0: + return nil, errors.New("new redis config provider: operation timeout must be positive") + } + + options := &redis.Options{ + Addr: cfg.Addr, + Username: cfg.Username, + Password: cfg.Password, + DB: cfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if cfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + + return &Store{ + client: redis.NewClient(options), + sessionLimitKey: cfg.SessionLimitKey, + operationTimeout: cfg.OperationTimeout, + }, nil +} + +// Close releases the underlying Redis client resources. +func (s *Store) Close() error { + if s == nil || s.client == nil { + return nil + } + + return s.client.Close() +} + +// Ping verifies that the configured Redis backend is reachable within the +// adapter operation timeout budget. +func (s *Store) Ping(ctx context.Context) error { + operationCtx, cancel, err := s.operationContext(ctx, "ping redis config provider") + if err != nil { + return err + } + defer cancel() + + if err := s.client.Ping(operationCtx).Err(); err != nil { + return fmt.Errorf("ping redis config provider: %w", err) + } + + return nil +} + +// LoadSessionLimit returns the current active-session-limit configuration. +// Missing or invalid Redis values are treated as “limit absent” by policy. +func (s *Store) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) { + operationCtx, cancel, err := s.operationContext(ctx, "load session limit from redis") + if err != nil { + return ports.SessionLimitConfig{}, err + } + defer cancel() + + value, err := s.client.Get(operationCtx, s.sessionLimitKey).Result() + switch { + case errors.Is(err, redis.Nil): + return ports.SessionLimitConfig{}, nil + case err != nil: + return ports.SessionLimitConfig{}, fmt.Errorf("load session limit from redis: %w", err) + } + + config, valid := parseSessionLimitConfig(value) + if !valid { + return ports.SessionLimitConfig{}, nil + } + if err := config.Validate(); err != nil { + return ports.SessionLimitConfig{}, nil + } + + return config, nil +} + +func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) { + if s == nil || s.client == nil { + return nil, nil, fmt.Errorf("%s: nil store", operation) + } + if ctx == nil { + return nil, nil, fmt.Errorf("%s: nil context", operation) + } + + operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout) + return operationCtx, cancel, nil +} + +func parseSessionLimitConfig(raw string) (ports.SessionLimitConfig, bool) { + if strings.TrimSpace(raw) == "" || strings.TrimSpace(raw) != raw { + return ports.SessionLimitConfig{}, false + } + for _, symbol := range raw { + if symbol < '0' || symbol > '9' { + return ports.SessionLimitConfig{}, false + } + } + + parsed, err := strconv.ParseInt(raw, 10, strconv.IntSize) + if err != nil || parsed <= 0 { + return ports.SessionLimitConfig{}, false + } + + limit := int(parsed) + return ports.SessionLimitConfig{ + ActiveSessionLimit: &limit, + }, true +} + +var _ ports.ConfigProvider = (*Store)(nil) diff --git a/authsession/internal/adapters/redis/configprovider/store_test.go b/authsession/internal/adapters/redis/configprovider/store_test.go new file mode 100644 index 0000000..fe88ac6 --- /dev/null +++ b/authsession/internal/adapters/redis/configprovider/store_test.go @@ -0,0 +1,283 @@ +package configprovider + +import ( + "context" + "strconv" + "testing" + "time" + + "galaxy/authsession/internal/adapters/contracttest" + "galaxy/authsession/internal/ports" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStoreContract(t *testing.T) { + t.Parallel() + + contracttest.RunConfigProviderContractTests(t, func(t *testing.T) contracttest.ConfigProviderHarness { + t.Helper() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + return contracttest.ConfigProviderHarness{ + Provider: store, + SeedDisabled: func(t *testing.T) { + t.Helper() + server.Del(store.sessionLimitKey) + }, + SeedLimit: func(t *testing.T, limit int) { + t.Helper() + server.Set(store.sessionLimitKey, strconv.Itoa(limit)) + }, + } + }) +} + +func TestNew(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + tests := []struct { + name string + cfg Config + wantErr string + }{ + { + name: "valid config", + cfg: Config{ + Addr: server.Addr(), + DB: 2, + SessionLimitKey: "authsession:config:active-session-limit", + OperationTimeout: 250 * time.Millisecond, + }, + }, + { + name: "empty addr", + cfg: Config{ + SessionLimitKey: "authsession:config:active-session-limit", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "negative db", + cfg: Config{ + Addr: server.Addr(), + DB: -1, + SessionLimitKey: "authsession:config:active-session-limit", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis db must not be negative", + }, + { + name: "empty session limit key", + cfg: Config{ + Addr: server.Addr(), + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "session limit key must not be empty", + }, + { + name: "non positive timeout", + cfg: Config{ + Addr: server.Addr(), + SessionLimitKey: "authsession:config:active-session-limit", + }, + wantErr: "operation timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store, err := New(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + }) + } +} + +func TestStorePing(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + require.NoError(t, store.Ping(context.Background())) +} + +func TestStoreLoadSessionLimit(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*testing.T, *miniredis.Miniredis, *Store) + wantConfig ports.SessionLimitConfig + }{ + { + name: "missing key means disabled", + wantConfig: ports.SessionLimitConfig{}, + }, + { + name: "valid positive integer", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, "5") + }, + wantConfig: configWithLimit(5), + }, + { + name: "empty string is invalid and disabled", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, "") + }, + wantConfig: ports.SessionLimitConfig{}, + }, + { + name: "whitespace only is invalid and disabled", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, " ") + }, + wantConfig: ports.SessionLimitConfig{}, + }, + { + name: "whitespace padded integer is invalid and disabled", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, " 5 ") + }, + wantConfig: ports.SessionLimitConfig{}, + }, + { + name: "non integer text is invalid and disabled", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, "five") + }, + wantConfig: ports.SessionLimitConfig{}, + }, + { + name: "zero is invalid and disabled", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, "0") + }, + wantConfig: ports.SessionLimitConfig{}, + }, + { + name: "negative integer is invalid and disabled", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, "-3") + }, + wantConfig: ports.SessionLimitConfig{}, + }, + { + name: "overflow is invalid and disabled", + seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) { + t.Helper() + server.Set(store.sessionLimitKey, "999999999999999999999999999999") + }, + wantConfig: ports.SessionLimitConfig{}, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + if tt.seed != nil { + tt.seed(t, server, store) + } + + got, err := store.LoadSessionLimit(context.Background()) + require.NoError(t, err) + assert.Equal(t, tt.wantConfig, got) + }) + } +} + +func TestStoreLoadSessionLimitBackendFailure(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + server.Close() + + _, err := store.LoadSessionLimit(context.Background()) + require.Error(t, err) + assert.ErrorContains(t, err, "load session limit from redis") +} + +func TestStoreLoadSessionLimitNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + _, err := store.LoadSessionLimit(nil) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") +} + +func TestStorePingNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + err := store.Ping(nil) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") +} + +func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store { + t.Helper() + + if cfg.Addr == "" { + cfg.Addr = server.Addr() + } + if cfg.SessionLimitKey == "" { + cfg.SessionLimitKey = "authsession:config:active-session-limit" + } + if cfg.OperationTimeout == 0 { + cfg.OperationTimeout = 250 * time.Millisecond + } + + store, err := New(cfg) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + + return store +} + +func configWithLimit(limit int) ports.SessionLimitConfig { + return ports.SessionLimitConfig{ + ActiveSessionLimit: &limit, + } +} diff --git a/authsession/internal/adapters/redis/projectionpublisher/publisher.go b/authsession/internal/adapters/redis/projectionpublisher/publisher.go new file mode 100644 index 0000000..7897de0 --- /dev/null +++ b/authsession/internal/adapters/redis/projectionpublisher/publisher.go @@ -0,0 +1,223 @@ +// Package projectionpublisher implements +// ports.GatewaySessionProjectionPublisher with Redis-backed gateway-compatible +// cache snapshots and session lifecycle events. +package projectionpublisher + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/ports" + + "github.com/redis/go-redis/v9" +) + +// Config configures one Redis-backed gateway session projection publisher. +type Config struct { + // Addr is the Redis network address in host:port form. + Addr string + + // Username is the optional Redis ACL username. + Username string + + // Password is the optional Redis ACL password. + Password string + + // DB is the Redis logical database index. + DB int + + // TLSEnabled enables TLS with a conservative minimum protocol version. + TLSEnabled bool + + // SessionCacheKeyPrefix is the namespace prefix applied to gateway session + // cache keys. The raw device session identifier is appended directly. + SessionCacheKeyPrefix string + + // SessionEventsStream identifies the gateway session lifecycle Redis Stream. + SessionEventsStream string + + // StreamMaxLen bounds the session lifecycle stream with approximate + // trimming via XADD MAXLEN ~. + StreamMaxLen int64 + + // OperationTimeout bounds each Redis round trip performed by the adapter. + OperationTimeout time.Duration +} + +// Publisher publishes gateway-compatible session projections into Redis cache +// and stream namespaces. +type Publisher struct { + client *redis.Client + sessionCacheKeyPrefix string + sessionEventsStream string + streamMaxLen int64 + operationTimeout time.Duration +} + +type cacheRecord struct { + DeviceSessionID string `json:"device_session_id"` + UserID string `json:"user_id"` + ClientPublicKey string `json:"client_public_key"` + Status gatewayprojection.Status `json:"status"` + RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"` +} + +// New constructs a Redis-backed gateway session projection publisher from +// cfg. +func New(cfg Config) (*Publisher, error) { + switch { + case strings.TrimSpace(cfg.Addr) == "": + return nil, errors.New("new redis projection publisher: redis addr must not be empty") + case cfg.DB < 0: + return nil, errors.New("new redis projection publisher: redis db must not be negative") + case strings.TrimSpace(cfg.SessionCacheKeyPrefix) == "": + return nil, errors.New("new redis projection publisher: session cache key prefix must not be empty") + case strings.TrimSpace(cfg.SessionEventsStream) == "": + return nil, errors.New("new redis projection publisher: session events stream must not be empty") + case cfg.StreamMaxLen <= 0: + return nil, errors.New("new redis projection publisher: stream max len must be positive") + case cfg.OperationTimeout <= 0: + return nil, errors.New("new redis projection publisher: operation timeout must be positive") + } + + options := &redis.Options{ + Addr: cfg.Addr, + Username: cfg.Username, + Password: cfg.Password, + DB: cfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if cfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + + return &Publisher{ + client: redis.NewClient(options), + sessionCacheKeyPrefix: cfg.SessionCacheKeyPrefix, + sessionEventsStream: cfg.SessionEventsStream, + streamMaxLen: cfg.StreamMaxLen, + operationTimeout: cfg.OperationTimeout, + }, nil +} + +// Close releases the underlying Redis client resources. +func (p *Publisher) Close() error { + if p == nil || p.client == nil { + return nil + } + + return p.client.Close() +} + +// Ping verifies that the configured Redis backend is reachable within the +// adapter operation timeout budget. +func (p *Publisher) Ping(ctx context.Context) error { + operationCtx, cancel, err := p.operationContext(ctx, "ping redis projection publisher") + if err != nil { + return err + } + defer cancel() + + if err := p.client.Ping(operationCtx).Err(); err != nil { + return fmt.Errorf("ping redis projection publisher: %w", err) + } + + return nil +} + +// PublishSession writes one gateway-compatible session snapshot into the +// gateway cache namespace and appends the same snapshot to the gateway session +// event stream within one Redis transaction. +func (p *Publisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error { + if err := snapshot.Validate(); err != nil { + return fmt.Errorf("publish session projection to redis: %w", err) + } + + payload, err := marshalCacheRecord(snapshot) + if err != nil { + return fmt.Errorf("publish session projection to redis: %w", err) + } + values := buildStreamValues(snapshot) + + operationCtx, cancel, err := p.operationContext(ctx, "publish session projection to redis") + if err != nil { + return err + } + defer cancel() + + key := p.sessionCacheKey(snapshot.DeviceSessionID) + _, err = p.client.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { + pipe.Set(operationCtx, key, payload, 0) + pipe.XAdd(operationCtx, &redis.XAddArgs{ + Stream: p.sessionEventsStream, + MaxLen: p.streamMaxLen, + Approx: true, + Values: values, + }) + return nil + }) + if err != nil { + return fmt.Errorf("publish session projection %q to redis: %w", snapshot.DeviceSessionID, err) + } + + return nil +} + +func (p *Publisher) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) { + if p == nil || p.client == nil { + return nil, nil, fmt.Errorf("%s: nil publisher", operation) + } + if ctx == nil { + return nil, nil, fmt.Errorf("%s: nil context", operation) + } + + operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout) + return operationCtx, cancel, nil +} + +func (p *Publisher) sessionCacheKey(deviceSessionID interface{ String() string }) string { + return p.sessionCacheKeyPrefix + deviceSessionID.String() +} + +func marshalCacheRecord(snapshot gatewayprojection.Snapshot) ([]byte, error) { + record := cacheRecord{ + DeviceSessionID: snapshot.DeviceSessionID.String(), + UserID: snapshot.UserID.String(), + ClientPublicKey: snapshot.ClientPublicKey, + Status: snapshot.Status, + } + if snapshot.RevokedAt != nil { + revokedAtMS := snapshot.RevokedAt.UTC().UnixMilli() + record.RevokedAtMS = &revokedAtMS + } + + payload, err := json.Marshal(record) + if err != nil { + return nil, fmt.Errorf("marshal gateway session cache record: %w", err) + } + + return payload, nil +} + +func buildStreamValues(snapshot gatewayprojection.Snapshot) map[string]any { + values := map[string]any{ + "device_session_id": snapshot.DeviceSessionID.String(), + "user_id": snapshot.UserID.String(), + "client_public_key": snapshot.ClientPublicKey, + "status": string(snapshot.Status), + } + if snapshot.RevokedAt != nil { + values["revoked_at_ms"] = fmt.Sprint(snapshot.RevokedAt.UTC().UnixMilli()) + } + + return values +} + +var _ ports.GatewaySessionProjectionPublisher = (*Publisher)(nil) diff --git a/authsession/internal/adapters/redis/projectionpublisher/publisher_test.go b/authsession/internal/adapters/redis/projectionpublisher/publisher_test.go new file mode 100644 index 0000000..5b3db20 --- /dev/null +++ b/authsession/internal/adapters/redis/projectionpublisher/publisher_test.go @@ -0,0 +1,442 @@ +package projectionpublisher + +import ( + "bytes" + "context" + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/gatewayprojection" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + tests := []struct { + name string + cfg Config + wantErr string + }{ + { + name: "valid config", + cfg: Config{ + Addr: server.Addr(), + DB: 3, + SessionCacheKeyPrefix: "gateway:session:", + SessionEventsStream: "gateway:session_events", + StreamMaxLen: 1024, + OperationTimeout: 250 * time.Millisecond, + }, + }, + { + name: "empty addr", + cfg: Config{ + SessionCacheKeyPrefix: "gateway:session:", + SessionEventsStream: "gateway:session_events", + StreamMaxLen: 1024, + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "negative db", + cfg: Config{ + Addr: server.Addr(), + DB: -1, + SessionCacheKeyPrefix: "gateway:session:", + SessionEventsStream: "gateway:session_events", + StreamMaxLen: 1024, + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis db must not be negative", + }, + { + name: "empty session cache key prefix", + cfg: Config{ + Addr: server.Addr(), + SessionEventsStream: "gateway:session_events", + StreamMaxLen: 1024, + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "session cache key prefix must not be empty", + }, + { + name: "empty session events stream", + cfg: Config{ + Addr: server.Addr(), + SessionCacheKeyPrefix: "gateway:session:", + StreamMaxLen: 1024, + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "session events stream must not be empty", + }, + { + name: "non positive stream max len", + cfg: Config{ + Addr: server.Addr(), + SessionCacheKeyPrefix: "gateway:session:", + SessionEventsStream: "gateway:session_events", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "stream max len must be positive", + }, + { + name: "non positive timeout", + cfg: Config{ + Addr: server.Addr(), + SessionCacheKeyPrefix: "gateway:session:", + SessionEventsStream: "gateway:session_events", + StreamMaxLen: 1024, + }, + wantErr: "operation timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + publisher, err := New(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, publisher.Close()) + }) + }) + } +} + +func TestPublisherPing(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{}) + + require.NoError(t, publisher.Ping(context.Background())) +} + +func TestPublisherPublishSessionActive(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{}) + snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil) + + require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) + + key := publisher.sessionCacheKey(snapshot.DeviceSessionID) + assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key) + assert.True(t, server.Exists(key)) + assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String()))) + + payload, err := server.Get(key) + require.NoError(t, err) + record := decodeCachePayload(t, payload) + assert.Equal(t, cacheRecord{ + DeviceSessionID: snapshot.DeviceSessionID.String(), + UserID: snapshot.UserID.String(), + ClientPublicKey: snapshot.ClientPublicKey, + Status: gatewayprojection.StatusActive, + }, record) + assert.Zero(t, server.TTL(key)) + + entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, map[string]string{ + "device_session_id": snapshot.DeviceSessionID.String(), + "user_id": snapshot.UserID.String(), + "client_public_key": snapshot.ClientPublicKey, + "status": string(gatewayprojection.StatusActive), + }, stringifyValues(entries[0].Values)) +} + +func TestPublisherPublishSessionRevoked(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{}) + revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC() + snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt) + + require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) + + key := publisher.sessionCacheKey(snapshot.DeviceSessionID) + payload, err := server.Get(key) + require.NoError(t, err) + record := decodeCachePayload(t, payload) + require.NotNil(t, record.RevokedAtMS) + assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS) + assert.Equal(t, cacheRecord{ + DeviceSessionID: snapshot.DeviceSessionID.String(), + UserID: snapshot.UserID.String(), + ClientPublicKey: snapshot.ClientPublicKey, + Status: gatewayprojection.StatusRevoked, + RevokedAtMS: int64Pointer(revokedAt.UnixMilli()), + }, record) + + entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, map[string]string{ + "device_session_id": snapshot.DeviceSessionID.String(), + "user_id": snapshot.UserID.String(), + "client_public_key": snapshot.ClientPublicKey, + "status": string(gatewayprojection.StatusRevoked), + "revoked_at_ms": "1776000123456", + }, stringifyValues(entries[0].Values)) +} + +func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8}) + deviceSessionID := "device-session-456" + + active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil) + revokedAt := time.Unix(1_776_010_000, 0).UTC() + revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt) + + require.NoError(t, publisher.PublishSession(context.Background(), active)) + require.NoError(t, publisher.PublishSession(context.Background(), revoked)) + + payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID)) + require.NoError(t, err) + record := decodeCachePayload(t, payload) + require.NotNil(t, record.RevokedAtMS) + assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS) + assert.Equal(t, gatewayprojection.StatusRevoked, record.Status) + + entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() + require.NoError(t, err) + require.Len(t, entries, 2) + assert.Equal(t, map[string]string{ + "device_session_id": active.DeviceSessionID.String(), + "user_id": active.UserID.String(), + "client_public_key": active.ClientPublicKey, + "status": string(gatewayprojection.StatusActive), + }, stringifyValues(entries[0].Values)) + assert.Equal(t, map[string]string{ + "device_session_id": revoked.DeviceSessionID.String(), + "user_id": revoked.UserID.String(), + "client_public_key": revoked.ClientPublicKey, + "status": string(gatewayprojection.StatusRevoked), + "revoked_at_ms": "1776010000000", + }, stringifyValues(entries[1].Values)) +} + +func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8}) + snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil) + + require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) + require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) + + payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID)) + require.NoError(t, err) + record := decodeCachePayload(t, payload) + assert.Equal(t, cacheRecord{ + DeviceSessionID: snapshot.DeviceSessionID.String(), + UserID: snapshot.UserID.String(), + ClientPublicKey: snapshot.ClientPublicKey, + Status: gatewayprojection.StatusActive, + }, record) + + entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() + require.NoError(t, err) + require.Len(t, entries, 2) + assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values)) +} + +func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2}) + + for index := range 6 { + snapshot := testSnapshot( + common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(), + gatewayprojection.StatusActive, + nil, + ) + require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) + } + + streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result() + require.NoError(t, err) + assert.LessOrEqual(t, streamLength, int64(2)) +} + +func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{}) + snapshot := gatewayprojection.Snapshot{ + DeviceSessionID: common.DeviceSessionID("device-session-123"), + UserID: common.UserID("user-123"), + Status: gatewayprojection.StatusActive, + } + + err := publisher.PublishSession(context.Background(), snapshot) + require.Error(t, err) + assert.ErrorContains(t, err, "gateway projection client public key") + assert.Empty(t, server.Keys()) +} + +func TestPublisherPublishSessionNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{}) + + err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil)) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") +} + +func TestPublisherPublishSessionBackendFailure(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{}) + server.Close() + + err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil)) + require.Error(t, err) + assert.ErrorContains(t, err, "publish session projection") +} + +func TestPublisherPingNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + publisher := newTestPublisher(t, server, Config{}) + + err := publisher.Ping(nil) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") +} + +func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher { + t.Helper() + + if cfg.Addr == "" { + cfg.Addr = server.Addr() + } + if cfg.SessionCacheKeyPrefix == "" { + cfg.SessionCacheKeyPrefix = "gateway:session:" + } + if cfg.SessionEventsStream == "" { + cfg.SessionEventsStream = "gateway:session_events" + } + if cfg.StreamMaxLen == 0 { + cfg.StreamMaxLen = 1024 + } + if cfg.OperationTimeout == 0 { + cfg.OperationTimeout = 250 * time.Millisecond + } + + publisher, err := New(cfg) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, publisher.Close()) + }) + + return publisher +} + +func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot { + raw := make(ed25519.PublicKey, ed25519.PublicKeySize) + for index := range raw { + raw[index] = byte(index + 1) + } + + snapshot := gatewayprojection.Snapshot{ + DeviceSessionID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID("user-123"), + ClientPublicKey: base64.StdEncoding.EncodeToString(raw), + Status: status, + RevokedAt: revokedAt, + } + if status == gatewayprojection.StatusRevoked { + snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked") + snapshot.RevokeActorType = common.RevokeActorType("system") + } + + return snapshot +} + +func decodeCachePayload(t *testing.T, payload string) cacheRecord { + t.Helper() + + decoder := json.NewDecoder(bytes.NewReader([]byte(payload))) + decoder.DisallowUnknownFields() + + var record cacheRecord + require.NoError(t, decoder.Decode(&record)) + err := decoder.Decode(&struct{}{}) + if err == nil { + require.FailNow(t, "expected cache payload EOF after first JSON value") + } + require.ErrorIs(t, err, io.EOF) + + var fieldSet map[string]json.RawMessage + require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet)) + expectedFields := map[string]struct{}{ + "device_session_id": {}, + "user_id": {}, + "client_public_key": {}, + "status": {}, + } + if record.RevokedAtMS != nil { + expectedFields["revoked_at_ms"] = struct{}{} + } + assert.Equal(t, len(expectedFields), len(fieldSet)) + for field := range fieldSet { + _, ok := expectedFields[field] + assert.Truef(t, ok, "unexpected cache payload field %q", field) + } + + return record +} + +func stringifyValues(values map[string]any) map[string]string { + stringified := make(map[string]string, len(values)) + for key, value := range values { + stringified[key] = fmt.Sprint(value) + } + return stringified +} + +func encodeBase64URL(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func int64Pointer(value int64) *int64 { + return &value +} diff --git a/authsession/internal/adapters/redis/sendemailcodeabuse/protector.go b/authsession/internal/adapters/redis/sendemailcodeabuse/protector.go new file mode 100644 index 0000000..ff52de0 --- /dev/null +++ b/authsession/internal/adapters/redis/sendemailcodeabuse/protector.go @@ -0,0 +1,152 @@ +// Package sendemailcodeabuse implements ports.SendEmailCodeAbuseProtector with +// one Redis TTL key per normalized e-mail address. +package sendemailcodeabuse + +import ( + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "strings" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/redis/go-redis/v9" +) + +// Config configures one Redis-backed send-email-code abuse protector. +type Config struct { + // Addr is the Redis network address in host:port form. + Addr string + + // Username is the optional Redis ACL username. + Username string + + // Password is the optional Redis ACL password. + Password string + + // DB is the Redis logical database index. + DB int + + // TLSEnabled enables TLS with a conservative minimum protocol version. + TLSEnabled bool + + // KeyPrefix is the namespace prefix applied to every resend-throttle key. + KeyPrefix string + + // OperationTimeout bounds each Redis round trip performed by the adapter. + OperationTimeout time.Duration +} + +// Protector applies the fixed resend cooldown with one Redis key per +// normalized e-mail address. +type Protector struct { + client *redis.Client + keyPrefix string + operationTimeout time.Duration +} + +// New constructs a Redis-backed resend-throttle protector from cfg. +func New(cfg Config) (*Protector, error) { + switch { + case strings.TrimSpace(cfg.Addr) == "": + return nil, errors.New("new redis send email code abuse protector: redis addr must not be empty") + case cfg.DB < 0: + return nil, errors.New("new redis send email code abuse protector: redis db must not be negative") + case strings.TrimSpace(cfg.KeyPrefix) == "": + return nil, errors.New("new redis send email code abuse protector: redis key prefix must not be empty") + case cfg.OperationTimeout <= 0: + return nil, errors.New("new redis send email code abuse protector: operation timeout must be positive") + } + + options := &redis.Options{ + Addr: cfg.Addr, + Username: cfg.Username, + Password: cfg.Password, + DB: cfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if cfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + + return &Protector{ + client: redis.NewClient(options), + keyPrefix: cfg.KeyPrefix, + operationTimeout: cfg.OperationTimeout, + }, nil +} + +// Close releases the underlying Redis client resources. +func (p *Protector) Close() error { + if p == nil || p.client == nil { + return nil + } + + return p.client.Close() +} + +// Ping verifies that the configured Redis backend is reachable within the +// adapter operation timeout budget. +func (p *Protector) Ping(ctx context.Context) error { + operationCtx, cancel, err := p.operationContext(ctx, "ping redis send email code abuse protector") + if err != nil { + return err + } + defer cancel() + + if err := p.client.Ping(operationCtx).Err(); err != nil { + return fmt.Errorf("ping redis send email code abuse protector: %w", err) + } + + return nil +} + +// CheckAndReserve applies the fixed resend cooldown using one TTL key per +// normalized e-mail address. +func (p *Protector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) { + if err := input.Validate(); err != nil { + return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err) + } + + operationCtx, cancel, err := p.operationContext(ctx, "check and reserve send email code abuse") + if err != nil { + return ports.SendEmailCodeAbuseResult{}, err + } + defer cancel() + + key := p.lookupKey(input.Email) + value := input.Now.UTC().Add(challenge.ResendThrottleCooldown).Format(time.RFC3339Nano) + created, err := p.client.SetNX(operationCtx, key, value, challenge.ResendThrottleCooldown).Result() + if err != nil { + return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse for %q: %w", input.Email, err) + } + if created { + return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeAllowed}, nil + } + + return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeThrottled}, nil +} + +func (p *Protector) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) { + if p == nil || p.client == nil { + return nil, nil, fmt.Errorf("%s: nil protector", operation) + } + if ctx == nil { + return nil, nil, fmt.Errorf("%s: nil context", operation) + } + + operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout) + return operationCtx, cancel, nil +} + +func (p *Protector) lookupKey(email common.Email) string { + return p.keyPrefix + base64.RawURLEncoding.EncodeToString([]byte(email.String())) +} + +var _ ports.SendEmailCodeAbuseProtector = (*Protector)(nil) diff --git a/authsession/internal/adapters/redis/sendemailcodeabuse/protector_test.go b/authsession/internal/adapters/redis/sendemailcodeabuse/protector_test.go new file mode 100644 index 0000000..c791dca --- /dev/null +++ b/authsession/internal/adapters/redis/sendemailcodeabuse/protector_test.go @@ -0,0 +1,176 @@ +package sendemailcodeabuse + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + tests := []struct { + name string + cfg Config + wantErr string + }{ + { + name: "valid config", + cfg: Config{ + Addr: server.Addr(), + DB: 1, + KeyPrefix: "authsession:send-email-code-throttle:", + OperationTimeout: 250 * time.Millisecond, + }, + }, + { + name: "empty addr", + cfg: Config{ + KeyPrefix: "authsession:send-email-code-throttle:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "negative db", + cfg: Config{ + Addr: server.Addr(), + DB: -1, + KeyPrefix: "authsession:send-email-code-throttle:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis db must not be negative", + }, + { + name: "empty key prefix", + cfg: Config{ + Addr: server.Addr(), + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis key prefix must not be empty", + }, + { + name: "non-positive timeout", + cfg: Config{ + Addr: server.Addr(), + KeyPrefix: "authsession:send-email-code-throttle:", + }, + wantErr: "operation timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + protector, err := New(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, protector.Close()) + }) + }) + } +} + +func TestProtectorPing(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + protector := newTestProtector(t, server, Config{}) + + require.NoError(t, protector.Ping(context.Background())) +} + +func TestProtectorCheckAndReserve(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + protector := newTestProtector(t, server, Config{}) + email := common.Email("pilot@example.com") + now := time.Unix(10, 0).UTC() + + result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now, + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome) + + key := protector.lookupKey(email) + assert.True(t, server.Exists(key)) + ttl := server.TTL(key) + assert.LessOrEqual(t, ttl, challenge.ResendThrottleCooldown) + assert.GreaterOrEqual(t, ttl, challenge.ResendThrottleCooldown-2*time.Second) + + result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now.Add(30 * time.Second), + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome) + ttlAfterThrottle := server.TTL(key) + assert.LessOrEqual(t, ttlAfterThrottle, ttl) + + server.FastForward(challenge.ResendThrottleCooldown) + + result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now.Add(challenge.ResendThrottleCooldown), + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome) +} + +func TestProtectorNilContext(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + protector := newTestProtector(t, server, Config{}) + + _, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{ + Email: common.Email("pilot@example.com"), + Now: time.Unix(10, 0).UTC(), + }) + require.Error(t, err) + assert.ErrorContains(t, err, "nil context") +} + +func newTestProtector(t *testing.T, server *miniredis.Miniredis, cfg Config) *Protector { + t.Helper() + + if cfg.Addr == "" { + cfg.Addr = server.Addr() + } + if cfg.KeyPrefix == "" { + cfg.KeyPrefix = "authsession:send-email-code-throttle:" + } + if cfg.OperationTimeout == 0 { + cfg.OperationTimeout = 250 * time.Millisecond + } + + protector, err := New(cfg) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, protector.Close()) + }) + + return protector +} diff --git a/authsession/internal/adapters/redis/sessionstore/store.go b/authsession/internal/adapters/redis/sessionstore/store.go new file mode 100644 index 0000000..7827e42 --- /dev/null +++ b/authsession/internal/adapters/redis/sessionstore/store.go @@ -0,0 +1,723 @@ +// Package sessionstore implements ports.SessionStore with Redis-backed strict +// JSON source-of-truth session records and per-user indexes. +package sessionstore + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "slices" + "strings" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/ports" + + "github.com/redis/go-redis/v9" +) + +const mutationRetryLimit = 3 + +// Config configures one Redis-backed session store instance. +type Config struct { + // Addr is the Redis network address in host:port form. + Addr string + + // Username is the optional Redis ACL username. + Username string + + // Password is the optional Redis ACL password. + Password string + + // DB is the Redis logical database index. + DB int + + // TLSEnabled enables TLS with a conservative minimum protocol version. + TLSEnabled bool + + // SessionKeyPrefix is the namespace prefix applied to primary session keys. + SessionKeyPrefix string + + // UserSessionsKeyPrefix is the namespace prefix applied to all-session user + // indexes. + UserSessionsKeyPrefix string + + // UserActiveSessionsKeyPrefix is the namespace prefix applied to active + // session user indexes. + UserActiveSessionsKeyPrefix string + + // OperationTimeout bounds each Redis round trip performed by the adapter. + OperationTimeout time.Duration +} + +// Store persists source-of-truth sessions in Redis and maintains user-scoped +// indexes for list and count operations. +type Store struct { + client *redis.Client + sessionKeyPrefix string + userSessionsKeyPrefix string + userActiveSessionsKeyPrefix string + operationTimeout time.Duration +} + +type redisRecord struct { + DeviceSessionID string `json:"device_session_id"` + UserID string `json:"user_id"` + ClientPublicKeyBase64 string `json:"client_public_key_base64"` + Status devicesession.Status `json:"status"` + CreatedAt string `json:"created_at"` + RevokedAt *string `json:"revoked_at,omitempty"` + RevokeReasonCode string `json:"revoke_reason_code,omitempty"` + RevokeActorType string `json:"revoke_actor_type,omitempty"` + RevokeActorID string `json:"revoke_actor_id,omitempty"` +} + +// New constructs a Redis-backed session store from cfg. +func New(cfg Config) (*Store, error) { + switch { + case strings.TrimSpace(cfg.Addr) == "": + return nil, errors.New("new redis session store: redis addr must not be empty") + case cfg.DB < 0: + return nil, errors.New("new redis session store: redis db must not be negative") + case strings.TrimSpace(cfg.SessionKeyPrefix) == "": + return nil, errors.New("new redis session store: session key prefix must not be empty") + case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "": + return nil, errors.New("new redis session store: user sessions key prefix must not be empty") + case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "": + return nil, errors.New("new redis session store: user active sessions key prefix must not be empty") + case cfg.OperationTimeout <= 0: + return nil, errors.New("new redis session store: operation timeout must be positive") + } + + options := &redis.Options{ + Addr: cfg.Addr, + Username: cfg.Username, + Password: cfg.Password, + DB: cfg.DB, + Protocol: 2, + DisableIdentity: true, + } + if cfg.TLSEnabled { + options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + + return &Store{ + client: redis.NewClient(options), + sessionKeyPrefix: cfg.SessionKeyPrefix, + userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix, + userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix, + operationTimeout: cfg.OperationTimeout, + }, nil +} + +// Close releases the underlying Redis client resources. +func (s *Store) Close() error { + if s == nil || s.client == nil { + return nil + } + + return s.client.Close() +} + +// Ping verifies that the configured Redis backend is reachable within the +// adapter operation timeout budget. +func (s *Store) Ping(ctx context.Context) error { + operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store") + if err != nil { + return err + } + defer cancel() + + if err := s.client.Ping(operationCtx).Err(); err != nil { + return fmt.Errorf("ping redis session store: %w", err) + } + + return nil +} + +// Get returns the stored session for deviceSessionID. +func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { + if err := deviceSessionID.Validate(); err != nil { + return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err) + } + + operationCtx, cancel, err := s.operationContext(ctx, "get session from redis") + if err != nil { + return devicesession.Session{}, err + } + defer cancel() + + record, err := s.loadSession(operationCtx, deviceSessionID) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound) + default: + return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err) + } + } + + return record, nil +} + +// ListByUserID returns every stored session for userID in newest-first order. +func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) { + if err := userID.Validate(); err != nil { + return nil, fmt.Errorf("list sessions by user id from redis: %w", err) + } + + operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis") + if err != nil { + return nil, err + } + defer cancel() + + deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result() + if err != nil { + return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err) + } + if len(deviceSessionIDs) == 0 { + return []devicesession.Session{}, nil + } + + records := make([]devicesession.Session, 0, len(deviceSessionIDs)) + for _, rawDeviceSessionID := range deviceSessionIDs { + deviceSessionID := common.DeviceSessionID(rawDeviceSessionID) + record, err := s.loadSession(operationCtx, deviceSessionID) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID) + default: + return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err) + } + } + if record.UserID != userID { + return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID) + } + records = append(records, record) + } + + sortSessionsNewestFirst(records) + return records, nil +} + +// CountActiveByUserID returns the number of active sessions currently stored +// for userID. +func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) { + if err := userID.Validate(); err != nil { + return 0, fmt.Errorf("count active sessions by user id from redis: %w", err) + } + + operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis") + if err != nil { + return 0, err + } + defer cancel() + + count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result() + if err != nil { + return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err) + } + + return int(count), nil +} + +// Create persists record as a new device session. +func (s *Store) Create(ctx context.Context, record devicesession.Session) error { + if err := record.Validate(); err != nil { + return fmt.Errorf("create session in redis: %w", err) + } + + payload, err := marshalSessionRecord(record) + if err != nil { + return fmt.Errorf("create session in redis: %w", err) + } + + deviceSessionKey := s.sessionKey(record.ID) + allSessionsKey := s.userSessionsKey(record.UserID) + activeSessionsKey := s.userActiveSessionsKey(record.UserID) + + operationCtx, cancel, err := s.operationContext(ctx, "create session in redis") + if err != nil { + return err + } + defer cancel() + + watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { + _, err := tx.Get(operationCtx, deviceSessionKey).Bytes() + switch { + case errors.Is(err, redis.Nil): + case err != nil: + return fmt.Errorf("create session %q in redis: %w", record.ID, err) + default: + return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict) + } + + _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { + pipe.Set(operationCtx, deviceSessionKey, payload, 0) + pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{ + Score: createdAtScore(record.CreatedAt), + Member: record.ID.String(), + }) + if record.Status == devicesession.StatusActive { + pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{ + Score: createdAtScore(record.CreatedAt), + Member: record.ID.String(), + }) + } + return nil + }) + if err != nil { + return fmt.Errorf("create session %q in redis: %w", record.ID, err) + } + + return nil + }, deviceSessionKey) + + switch { + case errors.Is(watchErr, redis.TxFailedErr): + return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict) + case watchErr != nil: + return watchErr + default: + return nil + } +} + +// Revoke stores a revoked view of one target session. +func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) { + if err := input.Validate(); err != nil { + return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err) + } + + var result ports.RevokeSessionResult + err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error { + deviceSessionKey := s.sessionKey(input.DeviceSessionID) + + watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { + current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound) + default: + return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) + } + } + + if current.Status == devicesession.StatusRevoked { + result = ports.RevokeSessionResult{ + Outcome: ports.RevokeSessionOutcomeAlreadyRevoked, + Session: current, + } + return result.Validate() + } + + next := current + next.Status = devicesession.StatusRevoked + revocation := input.Revocation + next.Revocation = &revocation + if err := next.Validate(); err != nil { + return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) + } + + payload, err := marshalSessionRecord(next) + if err != nil { + return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) + } + + _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { + pipe.Set(operationCtx, deviceSessionKey, payload, 0) + pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String()) + return nil + }) + if err != nil { + return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) + } + + result = ports.RevokeSessionResult{ + Outcome: ports.RevokeSessionOutcomeRevoked, + Session: next, + } + return result.Validate() + }, deviceSessionKey) + + switch { + case errors.Is(watchErr, redis.TxFailedErr): + return errRetryMutation + case watchErr != nil: + return watchErr + default: + return nil + } + }) + if err != nil { + return ports.RevokeSessionResult{}, err + } + + return result, nil +} + +// RevokeAllByUserID stores revoked views for all currently active sessions +// owned by input.UserID. +func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) { + if err := input.Validate(); err != nil { + return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err) + } + + var result ports.RevokeUserSessionsResult + err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error { + activeSessionsKey := s.userActiveSessionsKey(input.UserID) + + watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { + deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result() + if err != nil { + return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err) + } + if len(deviceSessionIDs) == 0 { + // Force EXEC so WATCH observes concurrent active-index changes even + // for the no-op path. + _, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { + pipe.ZCard(operationCtx, activeSessionsKey) + return nil + }) + if err != nil { + return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err) + } + + result = ports.RevokeUserSessionsResult{ + Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions, + UserID: input.UserID, + Sessions: []devicesession.Session{}, + } + return result.Validate() + } + + records := make([]devicesession.Session, 0, len(deviceSessionIDs)) + for _, rawDeviceSessionID := range deviceSessionIDs { + deviceSessionID := common.DeviceSessionID(rawDeviceSessionID) + record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID) + default: + return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err) + } + } + if record.UserID != input.UserID { + return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID) + } + if record.Status != devicesession.StatusActive { + return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status) + } + + next := record + next.Status = devicesession.StatusRevoked + revocation := input.Revocation + next.Revocation = &revocation + if err := next.Validate(); err != nil { + return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err) + } + records = append(records, next) + } + + _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { + for _, record := range records { + payload, err := marshalSessionRecord(record) + if err != nil { + return fmt.Errorf("session %q: %w", record.ID, err) + } + pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0) + pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String()) + } + return nil + }) + if err != nil { + return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err) + } + + sortSessionsNewestFirst(records) + result = ports.RevokeUserSessionsResult{ + Outcome: ports.RevokeUserSessionsOutcomeRevoked, + UserID: input.UserID, + Sessions: records, + } + return result.Validate() + }, activeSessionsKey) + + switch { + case errors.Is(watchErr, redis.TxFailedErr): + return errRetryMutation + case watchErr != nil: + return watchErr + default: + return nil + } + }) + if err != nil { + return ports.RevokeUserSessionsResult{}, err + } + + return result, nil +} + +var errRetryMutation = errors.New("redis session store: retry mutation") + +func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error { + for attempt := 0; attempt < mutationRetryLimit; attempt++ { + operationCtx, cancel, err := s.operationContext(ctx, operation) + if err != nil { + return err + } + + err = execute(operationCtx) + cancel() + + switch { + case errors.Is(err, errRetryMutation): + if attempt == mutationRetryLimit-1 { + return fmt.Errorf("%s: mutation retry limit exceeded", operation) + } + continue + default: + return err + } + } + + return fmt.Errorf("%s: mutation retry limit exceeded", operation) +} + +func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) { + if s == nil || s.client == nil { + return nil, nil, fmt.Errorf("%s: nil store", operation) + } + if ctx == nil { + return nil, nil, fmt.Errorf("%s: nil context", operation) + } + + operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout) + return operationCtx, cancel, nil +} + +func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { + return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get) +} + +func (s *Store) loadSessionWithGetter( + ctx context.Context, + deviceSessionID common.DeviceSessionID, + getter func(context.Context, string) *redis.StringCmd, +) (devicesession.Session, error) { + payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes() + switch { + case errors.Is(err, redis.Nil): + return devicesession.Session{}, ports.ErrNotFound + case err != nil: + return devicesession.Session{}, err + } + + record, err := decodeSessionRecord(deviceSessionID, payload) + if err != nil { + return devicesession.Session{}, err + } + + return record, nil +} + +func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string { + return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String()) +} + +func (s *Store) userSessionsKey(userID common.UserID) string { + return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String()) +} + +func (s *Store) userActiveSessionsKey(userID common.UserID) string { + return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String()) +} + +func encodeKeyComponent(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func marshalSessionRecord(record devicesession.Session) ([]byte, error) { + stored, err := redisRecordFromSession(record) + if err != nil { + return nil, err + } + + payload, err := json.Marshal(stored) + if err != nil { + return nil, fmt.Errorf("encode redis session record: %w", err) + } + + return payload, nil +} + +func redisRecordFromSession(record devicesession.Session) (redisRecord, error) { + if err := record.Validate(); err != nil { + return redisRecord{}, fmt.Errorf("encode redis session record: %w", err) + } + + stored := redisRecord{ + DeviceSessionID: record.ID.String(), + UserID: record.UserID.String(), + ClientPublicKeyBase64: record.ClientPublicKey.String(), + Status: record.Status, + CreatedAt: formatTimestamp(record.CreatedAt), + } + if record.Revocation != nil { + stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At) + stored.RevokeReasonCode = record.Revocation.ReasonCode.String() + stored.RevokeActorType = record.Revocation.ActorType.String() + stored.RevokeActorID = record.Revocation.ActorID + } + + return stored, nil +} + +func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) { + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.DisallowUnknownFields() + + var stored redisRecord + if err := decoder.Decode(&stored); err != nil { + return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err) + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + if err == nil { + return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input") + } + return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err) + } + + record, err := sessionFromRedisRecord(stored) + if err != nil { + return devicesession.Session{}, err + } + if record.ID != expectedDeviceSessionID { + return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID) + } + + return record, nil +} + +func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) { + createdAt, err := parseTimestamp("created_at", stored.CreatedAt) + if err != nil { + return devicesession.Session{}, err + } + + rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64) + if err != nil { + return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err) + } + clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey) + if err != nil { + return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err) + } + + record := devicesession.Session{ + ID: common.DeviceSessionID(stored.DeviceSessionID), + UserID: common.UserID(stored.UserID), + ClientPublicKey: clientPublicKey, + Status: stored.Status, + CreatedAt: createdAt, + } + + revocation, err := parseRevocation(stored) + if err != nil { + return devicesession.Session{}, err + } + record.Revocation = revocation + + if err := record.Validate(); err != nil { + return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err) + } + + return record, nil +} + +func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) { + hasRevokedAt := stored.RevokedAt != nil + hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != "" + hasActorType := strings.TrimSpace(stored.RevokeActorType) != "" + hasActorID := strings.TrimSpace(stored.RevokeActorID) != "" + + if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID { + return nil, nil + } + if !hasRevokedAt || !hasReasonCode || !hasActorType { + return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent") + } + + revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt) + if err != nil { + return nil, err + } + + return &devicesession.Revocation{ + At: revokedAt, + ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode), + ActorType: common.RevokeActorType(stored.RevokeActorType), + ActorID: stored.RevokeActorID, + }, nil +} + +func parseTimestamp(fieldName string, value string) (time.Time, error) { + if strings.TrimSpace(value) == "" { + return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName) + } + + parsed, err := time.Parse(time.RFC3339Nano, value) + if err != nil { + return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err) + } + + canonical := parsed.UTC().Format(time.RFC3339Nano) + if value != canonical { + return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName) + } + + return parsed.UTC(), nil +} + +func formatTimestamp(value time.Time) string { + return value.UTC().Format(time.RFC3339Nano) +} + +func formatOptionalTimestamp(value *time.Time) *string { + if value == nil { + return nil + } + + formatted := formatTimestamp(*value) + return &formatted +} + +func createdAtScore(createdAt time.Time) float64 { + return float64(createdAt.UTC().UnixMicro()) +} + +func sortSessionsNewestFirst(records []devicesession.Session) { + slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int { + switch { + case left.CreatedAt.Equal(right.CreatedAt): + return strings.Compare(left.ID.String(), right.ID.String()) + case left.CreatedAt.After(right.CreatedAt): + return -1 + default: + return 1 + } + }) +} + +var _ ports.SessionStore = (*Store)(nil) diff --git a/authsession/internal/adapters/redis/sessionstore/store_test.go b/authsession/internal/adapters/redis/sessionstore/store_test.go new file mode 100644 index 0000000..a7c2661 --- /dev/null +++ b/authsession/internal/adapters/redis/sessionstore/store_test.go @@ -0,0 +1,635 @@ +package sessionstore + +import ( + "context" + "crypto/ed25519" + "encoding/json" + "testing" + "time" + + "galaxy/authsession/internal/adapters/contracttest" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/ports" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStoreContract(t *testing.T) { + t.Parallel() + + contracttest.RunSessionStoreContractTests(t, func(t *testing.T) ports.SessionStore { + t.Helper() + + server := miniredis.RunT(t) + return newTestStore(t, server, Config{}) + }) +} + +func TestNew(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + + tests := []struct { + name string + cfg Config + wantErr string + }{ + { + name: "valid config", + cfg: Config{ + Addr: server.Addr(), + DB: 1, + SessionKeyPrefix: "authsession:session:", + UserSessionsKeyPrefix: "authsession:user-sessions:", + UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", + OperationTimeout: 250 * time.Millisecond, + }, + }, + { + name: "empty addr", + cfg: Config{ + SessionKeyPrefix: "authsession:session:", + UserSessionsKeyPrefix: "authsession:user-sessions:", + UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis addr must not be empty", + }, + { + name: "negative db", + cfg: Config{ + Addr: server.Addr(), + DB: -1, + SessionKeyPrefix: "authsession:session:", + UserSessionsKeyPrefix: "authsession:user-sessions:", + UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "redis db must not be negative", + }, + { + name: "empty session prefix", + cfg: Config{ + Addr: server.Addr(), + UserSessionsKeyPrefix: "authsession:user-sessions:", + UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "session key prefix must not be empty", + }, + { + name: "empty all sessions prefix", + cfg: Config{ + Addr: server.Addr(), + SessionKeyPrefix: "authsession:session:", + UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "user sessions key prefix must not be empty", + }, + { + name: "empty active sessions prefix", + cfg: Config{ + Addr: server.Addr(), + SessionKeyPrefix: "authsession:session:", + UserSessionsKeyPrefix: "authsession:user-sessions:", + OperationTimeout: 250 * time.Millisecond, + }, + wantErr: "user active sessions key prefix must not be empty", + }, + { + name: "non positive timeout", + cfg: Config{ + Addr: server.Addr(), + SessionKeyPrefix: "authsession:session:", + UserSessionsKeyPrefix: "authsession:user-sessions:", + UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", + }, + wantErr: "operation timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store, err := New(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + }) + } +} + +func TestStorePing(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + require.NoError(t, store.Ping(context.Background())) +} + +func TestStoreCreateAndGetActive(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record, got) + + got.Revocation = &devicesession.Revocation{ + At: got.CreatedAt.Add(time.Minute), + ReasonCode: devicesession.RevokeReasonAdminRevoke, + ActorType: common.RevokeActorType("admin"), + } + + again, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Nil(t, again.Revocation) + assert.Equal(t, record, again) +} + +func TestStoreCreateAndGetRevoked(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := revokedSessionFixture("device-session-2", "user-1", time.Unix(1_775_240_100, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + got, err := store.Get(context.Background(), record.ID) + require.NoError(t, err) + assert.Equal(t, record, got) + + count, err := store.CountActiveByUserID(context.Background(), record.UserID) + require.NoError(t, err) + assert.Zero(t, count) +} + +func TestStoreGetNotFound(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + _, err := store.Get(context.Background(), common.DeviceSessionID("missing-session")) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) +} + +func TestStoreCreateConflict(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_200, 0).UTC()) + + require.NoError(t, store.Create(context.Background(), record)) + + err := store.Create(context.Background(), record) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrConflict) +} + +func TestStoreIndexesAndOrdering(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + older := activeSessionFixture("device-session-old", "user-1", time.Unix(10, 0).UTC()) + newer := activeSessionFixture("device-session-new", "user-1", time.Unix(20, 0).UTC()) + revoked := revokedSessionFixture("device-session-revoked", "user-1", time.Unix(15, 0).UTC()) + otherUser := activeSessionFixture("device-session-other", "user-2", time.Unix(30, 0).UTC()) + + for _, record := range []devicesession.Session{older, newer, revoked, otherUser} { + require.NoError(t, store.Create(context.Background(), record)) + } + + got, err := store.ListByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + require.Len(t, got, 3) + assert.Equal(t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID}) + + count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + assert.Equal(t, 2, count) + + unknown, err := store.ListByUserID(context.Background(), common.UserID("unknown-user")) + require.NoError(t, err) + assert.Empty(t, unknown) +} + +func TestStoreKeyPrefixesAndEncodedPrimaryKey(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{ + SessionKeyPrefix: "custom:session:", + UserSessionsKeyPrefix: "custom:user-sessions:", + UserActiveSessionsKeyPrefix: "custom:user-active-sessions:", + }) + + record := activeSessionFixture("device/session:opaque?1", "user/opaque:1", time.Unix(40, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + primaryKey := store.sessionKey(record.ID) + assert.Equal(t, "custom:session:"+encodeKeyComponent(record.ID.String()), primaryKey) + assert.True(t, server.Exists(primaryKey)) + + allSessionsKey := store.userSessionsKey(record.UserID) + activeSessionsKey := store.userActiveSessionsKey(record.UserID) + assert.Equal(t, "custom:user-sessions:"+encodeKeyComponent(record.UserID.String()), allSessionsKey) + assert.Equal(t, "custom:user-active-sessions:"+encodeKeyComponent(record.UserID.String()), activeSessionsKey) + + allMembers, err := server.ZMembers(allSessionsKey) + require.NoError(t, err) + assert.Equal(t, []string{record.ID.String()}, allMembers) + + activeMembers, err := server.ZMembers(activeSessionsKey) + require.NoError(t, err) + assert.Equal(t, []string{record.ID.String()}, activeMembers) +} + +func TestStoreRevoke(t *testing.T) { + t.Parallel() + + t.Run("active session", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + revocation := devicesession.Revocation{ + At: time.Unix(200, 0).UTC(), + ReasonCode: devicesession.RevokeReasonLogoutAll, + ActorType: common.RevokeActorType("system"), + } + + result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ + DeviceSessionID: record.ID, + Revocation: revocation, + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome) + require.NotNil(t, result.Session.Revocation) + assert.Equal(t, revocation, *result.Session.Revocation) + + count, err := store.CountActiveByUserID(context.Background(), record.UserID) + require.NoError(t, err) + assert.Zero(t, count) + }) + + t.Run("already revoked keeps stored revocation", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := revokedSessionFixture("device-session-2", "user-1", time.Unix(100, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ + DeviceSessionID: record.ID, + Revocation: devicesession.Revocation{ + At: time.Unix(300, 0).UTC(), + ReasonCode: devicesession.RevokeReasonAdminRevoke, + ActorType: common.RevokeActorType("admin"), + ActorID: "admin-1", + }, + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome) + require.NotNil(t, result.Session.Revocation) + assert.Equal(t, *record.Revocation, *result.Session.Revocation) + }) + + t.Run("unknown session", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + _, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ + DeviceSessionID: common.DeviceSessionID("missing-session"), + Revocation: devicesession.Revocation{ + At: time.Unix(200, 0).UTC(), + ReasonCode: devicesession.RevokeReasonLogoutAll, + ActorType: common.RevokeActorType("system"), + }, + }) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) + }) +} + +func TestStoreRevokeAllByUserID(t *testing.T) { + t.Parallel() + + t.Run("revokes active sessions newest first and clears active index", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + + older := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC()) + newer := activeSessionFixture("device-session-2", "user-1", time.Unix(200, 0).UTC()) + alreadyRevoked := revokedSessionFixture("device-session-3", "user-1", time.Unix(150, 0).UTC()) + otherUser := activeSessionFixture("device-session-4", "user-2", time.Unix(250, 0).UTC()) + + for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} { + require.NoError(t, store.Create(context.Background(), record)) + } + + revocation := devicesession.Revocation{ + At: time.Unix(300, 0).UTC(), + ReasonCode: devicesession.RevokeReasonAdminRevoke, + ActorType: common.RevokeActorType("admin"), + ActorID: "admin-1", + } + + result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ + UserID: common.UserID("user-1"), + Revocation: revocation, + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome) + require.Len(t, result.Sessions, 2) + assert.Equal(t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID}) + assert.Equal(t, revocation, *result.Sessions[0].Revocation) + assert.Equal(t, revocation, *result.Sessions[1].Revocation) + + count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + assert.Zero(t, count) + + otherCount, err := store.CountActiveByUserID(context.Background(), common.UserID("user-2")) + require.NoError(t, err) + assert.Equal(t, 1, otherCount) + }) + + t.Run("no active sessions", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := revokedSessionFixture("device-session-5", "user-1", time.Unix(100, 0).UTC()) + require.NoError(t, store.Create(context.Background(), record)) + + result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ + UserID: common.UserID("user-1"), + Revocation: devicesession.Revocation{ + At: time.Unix(400, 0).UTC(), + ReasonCode: devicesession.RevokeReasonAdminRevoke, + ActorType: common.RevokeActorType("admin"), + }, + }) + require.NoError(t, err) + assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome) + assert.Empty(t, result.Sessions) + }) +} + +func TestStoreStrictDecodeCorruption(t *testing.T) { + t.Parallel() + + now := time.Unix(1_775_240_300, 0).UTC() + baseRecord := revokedSessionFixture("device-session-corrupt", "user-1", now) + stored, err := redisRecordFromSession(baseRecord) + require.NoError(t, err) + + tests := []struct { + name string + mutate func(redisRecord) string + wantErrText string + }{ + { + name: "malformed json", + mutate: func(_ redisRecord) string { + return "{" + }, + wantErrText: "decode redis session record", + }, + { + name: "trailing json input", + mutate: func(record redisRecord) string { + return mustMarshalJSON(t, record) + "{}" + }, + wantErrText: "unexpected trailing JSON input", + }, + { + name: "unknown field", + mutate: func(record redisRecord) string { + payload := map[string]any{ + "device_session_id": record.DeviceSessionID, + "user_id": record.UserID, + "client_public_key_base64": record.ClientPublicKeyBase64, + "status": record.Status, + "created_at": record.CreatedAt, + "revoked_at": record.RevokedAt, + "revoke_reason_code": record.RevokeReasonCode, + "revoke_actor_type": record.RevokeActorType, + "revoke_actor_id": record.RevokeActorID, + "unexpected": true, + } + return mustMarshalJSON(t, payload) + }, + wantErrText: "unknown field", + }, + { + name: "unsupported status", + mutate: func(record redisRecord) string { + record.Status = devicesession.Status("paused") + return mustMarshalJSON(t, record) + }, + wantErrText: `status "paused" is unsupported`, + }, + { + name: "non canonical timestamp", + mutate: func(record redisRecord) string { + record.CreatedAt = "2026-04-04T12:00:00+03:00" + return mustMarshalJSON(t, record) + }, + wantErrText: "canonical UTC RFC3339Nano timestamp", + }, + { + name: "incomplete revocation metadata", + mutate: func(record redisRecord) string { + record.RevokeActorType = "" + return mustMarshalJSON(t, record) + }, + wantErrText: "revocation metadata must be either fully present or fully absent", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + server.Set(store.sessionKey(baseRecord.ID), tt.mutate(stored)) + + _, err := store.Get(context.Background(), baseRecord.ID) + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErrText) + }) + } +} + +func TestStoreListByUserIDDetectsCorruptIndexes(t *testing.T) { + t.Parallel() + + t.Run("missing primary record", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + userID := common.UserID("user-1") + _, err := server.ZAdd(store.userSessionsKey(userID), 100, "missing-session") + require.NoError(t, err) + + _, err = store.ListByUserID(context.Background(), userID) + require.Error(t, err) + assert.ErrorContains(t, err, "references missing session") + }) + + t.Run("wrong user id in primary record", func(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := activeSessionFixture("device-session-1", "user-2", time.Unix(100, 0).UTC()) + require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record)) + _, err := server.ZAdd(store.userSessionsKey(common.UserID("user-1")), createdAtScore(record.CreatedAt), record.ID.String()) + require.NoError(t, err) + + _, err = store.ListByUserID(context.Background(), common.UserID("user-1")) + require.Error(t, err) + assert.ErrorContains(t, err, `belongs to "user-2"`) + }) +} + +func TestStoreRevokeAllByUserIDDetectsCorruptActiveIndex(t *testing.T) { + t.Parallel() + + server := miniredis.RunT(t) + store := newTestStore(t, server, Config{}) + record := revokedSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC()) + require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record)) + _, err := server.ZAdd(store.userActiveSessionsKey(record.UserID), createdAtScore(record.CreatedAt), record.ID.String()) + require.NoError(t, err) + + _, err = store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ + UserID: record.UserID, + Revocation: devicesession.Revocation{ + At: time.Unix(200, 0).UTC(), + ReasonCode: devicesession.RevokeReasonAdminRevoke, + ActorType: common.RevokeActorType("admin"), + }, + }) + require.Error(t, err) + assert.ErrorContains(t, err, `is "revoked"`) +} + +func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store { + t.Helper() + + if cfg.Addr == "" { + cfg.Addr = server.Addr() + } + if cfg.SessionKeyPrefix == "" { + cfg.SessionKeyPrefix = "authsession:session:" + } + if cfg.UserSessionsKeyPrefix == "" { + cfg.UserSessionsKeyPrefix = "authsession:user-sessions:" + } + if cfg.UserActiveSessionsKeyPrefix == "" { + cfg.UserActiveSessionsKeyPrefix = "authsession:user-active-sessions:" + } + if cfg.OperationTimeout == 0 { + cfg.OperationTimeout = 250 * time.Millisecond + } + + store, err := New(cfg) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, store.Close()) + }) + + return store +} + +func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{ + 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + }) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: clientPublicKey, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} + +func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + record := activeSessionFixture(deviceSessionID, userID, createdAt) + record.Status = devicesession.StatusRevoked + record.Revocation = &devicesession.Revocation{ + At: createdAt.Add(time.Minute), + ReasonCode: devicesession.RevokeReasonDeviceLogout, + ActorType: common.RevokeActorType("user"), + ActorID: "user-actor", + } + return record +} + +func seedSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record devicesession.Session) error { + t.Helper() + + stored, err := redisRecordFromSession(record) + require.NoError(t, err) + server.Set(key, mustMarshalJSON(t, stored)) + return nil +} + +func mustMarshalJSON(t *testing.T, value any) string { + t.Helper() + + payload, err := json.Marshal(value) + require.NoError(t, err) + + return string(payload) +} diff --git a/authsession/internal/adapters/userservice/rest_client.go b/authsession/internal/adapters/userservice/rest_client.go new file mode 100644 index 0000000..8777304 --- /dev/null +++ b/authsession/internal/adapters/userservice/rest_client.go @@ -0,0 +1,382 @@ +// Package userservice provides runtime user-directory adapters for the +// auth/session service. +package userservice + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" +) + +const ( + resolveByEmailPath = "/api/v1/internal/user-resolutions/by-email" + existsByUserIDPath = "/api/v1/internal/users/%s/exists" + ensureByEmailPath = "/api/v1/internal/users/ensure-by-email" + blockByUserIDPath = "/api/v1/internal/users/%s/block" + blockByEmailPath = "/api/v1/internal/user-blocks/by-email" +) + +// Config configures one HTTP-based UserDirectory client. +type Config struct { + // BaseURL is the absolute base URL of the future user-service internal + // HTTP API. + BaseURL string + + // RequestTimeout bounds each outbound user-service request. + RequestTimeout time.Duration +} + +// RESTClient implements ports.UserDirectory over a frozen internal REST +// contract. +type RESTClient struct { + baseURL string + requestTimeout time.Duration + httpClient *http.Client +} + +// NewRESTClient constructs a REST-backed UserDirectory adapter from cfg. +func NewRESTClient(cfg Config) (*RESTClient, error) { + transport := http.DefaultTransport.(*http.Transport).Clone() + + return newRESTClient(cfg, &http.Client{Transport: transport}) +} + +func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) { + switch { + case strings.TrimSpace(cfg.BaseURL) == "": + return nil, errors.New("new user service REST client: base URL must not be empty") + case cfg.RequestTimeout <= 0: + return nil, errors.New("new user service REST client: request timeout must be positive") + case httpClient == nil: + return nil, errors.New("new user service REST client: http client must not be nil") + } + + parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")) + if err != nil { + return nil, fmt.Errorf("new user service REST client: parse base URL: %w", err) + } + if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" { + return nil, errors.New("new user service REST client: base URL must be absolute") + } + + return &RESTClient{ + baseURL: parsedBaseURL.String(), + requestTimeout: cfg.RequestTimeout, + httpClient: httpClient, + }, nil +} + +// Close releases idle HTTP connections owned by the client transport. +func (c *RESTClient) Close() error { + if c == nil || c.httpClient == nil { + return nil + } + + type idleCloser interface { + CloseIdleConnections() + } + + if transport, ok := c.httpClient.Transport.(idleCloser); ok { + transport.CloseIdleConnections() + } + + return nil +} + +// ResolveByEmail returns the current coarse user-resolution state for email +// without creating any new user record. +func (c *RESTClient) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) { + if err := validateContext(ctx, "resolve by email"); err != nil { + return userresolution.Result{}, err + } + if err := email.Validate(); err != nil { + return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err) + } + + var response struct { + Kind userresolution.Kind `json:"kind"` + UserID string `json:"user_id,omitempty"` + BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"` + } + + if err := c.doJSON(ctx, "resolve by email", http.MethodPost, resolveByEmailPath, map[string]string{ + "email": email.String(), + }, &response, true); err != nil { + return userresolution.Result{}, err + } + + result := userresolution.Result{ + Kind: response.Kind, + UserID: common.UserID(response.UserID), + BlockReasonCode: response.BlockReasonCode, + } + if err := result.Validate(); err != nil { + return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err) + } + + return result, nil +} + +// ExistsByUserID reports whether userID currently identifies a stored user +// record. +func (c *RESTClient) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) { + if err := validateContext(ctx, "exists by user id"); err != nil { + return false, err + } + if err := userID.Validate(); err != nil { + return false, fmt.Errorf("exists by user id: %w", err) + } + + var response struct { + Exists bool `json:"exists"` + } + + if err := c.doJSON(ctx, "exists by user id", http.MethodGet, fmt.Sprintf(existsByUserIDPath, url.PathEscape(userID.String())), nil, &response, true); err != nil { + return false, err + } + + return response.Exists, nil +} + +// EnsureUserByEmail returns an existing user for email, creates a new user +// when registration is allowed, or reports a blocked outcome. +func (c *RESTClient) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) { + if err := validateContext(ctx, "ensure user by email"); err != nil { + return ports.EnsureUserResult{}, err + } + if err := email.Validate(); err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + + var response struct { + Outcome ports.EnsureUserOutcome `json:"outcome"` + UserID string `json:"user_id,omitempty"` + BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"` + } + + if err := c.doJSON(ctx, "ensure user by email", http.MethodPost, ensureByEmailPath, map[string]string{ + "email": email.String(), + }, &response, false); err != nil { + return ports.EnsureUserResult{}, err + } + + result := ports.EnsureUserResult{ + Outcome: response.Outcome, + UserID: common.UserID(response.UserID), + BlockReasonCode: response.BlockReasonCode, + } + if err := result.Validate(); err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + + return result, nil +} + +// BlockByUserID applies a block state to the user identified by input.UserID. +// Unknown user ids wrap ports.ErrNotFound. +func (c *RESTClient) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) { + if err := validateContext(ctx, "block by user id"); err != nil { + return ports.BlockUserResult{}, err + } + if err := input.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err) + } + + payload, statusCode, err := c.doRequest(ctx, "block by user id", http.MethodPost, fmt.Sprintf(blockByUserIDPath, url.PathEscape(input.UserID.String())), map[string]string{ + "reason_code": input.ReasonCode.String(), + }, false) + if err != nil { + return ports.BlockUserResult{}, err + } + if statusCode == http.StatusNotFound { + return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound) + } + if statusCode != http.StatusOK { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: unexpected HTTP status %d", statusCode) + } + + var response struct { + Outcome ports.BlockUserOutcome `json:"outcome"` + UserID string `json:"user_id,omitempty"` + } + if err := decodeJSONPayload(payload, &response); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err) + } + + result := ports.BlockUserResult{ + Outcome: response.Outcome, + UserID: common.UserID(response.UserID), + } + if err := result.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err) + } + + return result, nil +} + +// BlockByEmail applies a block state to input.Email even when no user record +// currently exists for that e-mail address. +func (c *RESTClient) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) { + if err := validateContext(ctx, "block by email"); err != nil { + return ports.BlockUserResult{}, err + } + if err := input.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err) + } + + var response struct { + Outcome ports.BlockUserOutcome `json:"outcome"` + UserID string `json:"user_id,omitempty"` + } + + if err := c.doJSON(ctx, "block by email", http.MethodPost, blockByEmailPath, map[string]string{ + "email": input.Email.String(), + "reason_code": input.ReasonCode.String(), + }, &response, false); err != nil { + return ports.BlockUserResult{}, err + } + + result := ports.BlockUserResult{ + Outcome: response.Outcome, + UserID: common.UserID(response.UserID), + } + if err := result.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err) + } + + return result, nil +} + +func (c *RESTClient) doJSON(ctx context.Context, operation string, method string, requestPath string, requestBody any, responseTarget any, retryRead bool) error { + payload, statusCode, err := c.doRequest(ctx, operation, method, requestPath, requestBody, retryRead) + if err != nil { + return err + } + if statusCode != http.StatusOK { + return fmt.Errorf("%s: unexpected HTTP status %d", operation, statusCode) + } + if err := decodeJSONPayload(payload, responseTarget); err != nil { + return fmt.Errorf("%s: %w", operation, err) + } + + return nil +} + +func (c *RESTClient) doRequest(ctx context.Context, operation string, method string, requestPath string, requestBody any, retryRead bool) ([]byte, int, error) { + bodyBytes, err := marshalOptionalRequestBody(requestBody) + if err != nil { + return nil, 0, fmt.Errorf("%s: %w", operation, err) + } + + attempts := 1 + if retryRead { + attempts = 2 + } + + var lastErr error + for attempt := 0; attempt < attempts; attempt++ { + attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout) + + request, err := http.NewRequestWithContext(attemptCtx, method, c.baseURL+requestPath, bytes.NewReader(bodyBytes)) + if err != nil { + cancel() + return nil, 0, fmt.Errorf("%s: build request: %w", operation, err) + } + if method == http.MethodPost { + request.Header.Set("Content-Type", "application/json") + } + + response, err := c.httpClient.Do(request) + if err != nil { + cancel() + lastErr = fmt.Errorf("%s: %w", operation, err) + if retryRead && attempt == 0 && ctx.Err() == nil { + continue + } + + return nil, 0, lastErr + } + + payload, readErr := io.ReadAll(response.Body) + closeErr := response.Body.Close() + cancel() + if readErr != nil { + lastErr = fmt.Errorf("%s: read response body: %w", operation, readErr) + if retryRead && attempt == 0 && ctx.Err() == nil { + continue + } + + return nil, 0, lastErr + } + if closeErr != nil { + lastErr = fmt.Errorf("%s: close response body: %w", operation, closeErr) + if retryRead && attempt == 0 && ctx.Err() == nil { + continue + } + + return nil, 0, lastErr + } + + if retryRead && attempt == 0 && isRetriableUserServiceStatus(response.StatusCode) { + lastErr = fmt.Errorf("%s: unexpected HTTP status %d", operation, response.StatusCode) + continue + } + + return payload, response.StatusCode, nil + } + + return nil, 0, lastErr +} + +func marshalOptionalRequestBody(value any) ([]byte, error) { + if value == nil { + return nil, nil + } + + payload, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("marshal request body: %w", err) + } + + return payload, nil +} + +func decodeJSONPayload(payload []byte, target any) error { + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.DisallowUnknownFields() + + if err := decoder.Decode(target); err != nil { + return fmt.Errorf("decode response body: %w", err) + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + if err == nil { + return errors.New("decode response body: unexpected trailing JSON input") + } + + return fmt.Errorf("decode response body: %w", err) + } + + return nil +} + +func isRetriableUserServiceStatus(statusCode int) bool { + switch statusCode { + case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return true + default: + return false + } +} + +var _ ports.UserDirectory = (*RESTClient)(nil) diff --git a/authsession/internal/adapters/userservice/rest_client_test.go b/authsession/internal/adapters/userservice/rest_client_test.go new file mode 100644 index 0000000..01a5c36 --- /dev/null +++ b/authsession/internal/adapters/userservice/rest_client_test.go @@ -0,0 +1,622 @@ +package userservice + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRESTClient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg Config + wantErr string + }{ + { + name: "valid config", + cfg: Config{ + BaseURL: "http://127.0.0.1:8080", + RequestTimeout: time.Second, + }, + }, + { + name: "empty base url", + cfg: Config{ + RequestTimeout: time.Second, + }, + wantErr: "base URL must not be empty", + }, + { + name: "relative base url", + cfg: Config{ + BaseURL: "/relative", + RequestTimeout: time.Second, + }, + wantErr: "base URL must be absolute", + }, + { + name: "non positive timeout", + cfg: Config{ + BaseURL: "http://127.0.0.1:8080", + }, + wantErr: "request timeout must be positive", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client, err := NewRESTClient(tt.cfg) + if tt.wantErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.NoError(t, client.Close()) + }) + } +} + +func TestRESTClientEndpointSuccessCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + run func(*testing.T, *RESTClient) + }{ + { + name: "resolve by email", + run: func(t *testing.T, client *RESTClient) { + result, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Case@example.com")) + require.NoError(t, err) + assert.Equal(t, userresolution.Result{ + Kind: userresolution.KindExisting, + UserID: common.UserID("user-123"), + }, result) + }, + }, + { + name: "exists by user id", + run: func(t *testing.T, client *RESTClient) { + exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123")) + require.NoError(t, err) + assert.True(t, exists) + }, + }, + { + name: "ensure user by email", + run: func(t *testing.T, client *RESTClient) { + result, err := client.EnsureUserByEmail(context.Background(), common.Email("created@example.com")) + require.NoError(t, err) + assert.Equal(t, ports.EnsureUserResult{ + Outcome: ports.EnsureUserOutcomeCreated, + UserID: common.UserID("user-234"), + }, result) + }, + }, + { + name: "block by user id", + run: func(t *testing.T, client *RESTClient) { + result, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("user-123"), + ReasonCode: userresolution.BlockReasonCode("policy_blocked"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeBlocked, + UserID: common.UserID("user-123"), + }, result) + }, + }, + { + name: "block by email", + run: func(t *testing.T, client *RESTClient) { + result, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("blocked@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_blocked"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeAlreadyBlocked, + UserID: common.UserID("user-345"), + }, result) + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var requestsMu sync.Mutex + var requests []capturedRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestsMu.Lock() + requests = append(requests, captureRequest(t, r)) + requestsMu.Unlock() + + switch { + case r.Method == http.MethodPost && r.URL.Path == resolveByEmailPath: + writeJSON(t, w, http.StatusOK, map[string]any{ + "kind": "existing", + "user_id": "user-123", + }) + case r.Method == http.MethodGet && r.URL.Path == "/api/v1/internal/users/user-123/exists": + writeJSON(t, w, http.StatusOK, map[string]any{"exists": true}) + case r.Method == http.MethodPost && r.URL.Path == ensureByEmailPath: + writeJSON(t, w, http.StatusOK, map[string]any{ + "outcome": "created", + "user_id": "user-234", + }) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/internal/users/user-123/block": + writeJSON(t, w, http.StatusOK, map[string]any{ + "outcome": "blocked", + "user_id": "user-123", + }) + case r.Method == http.MethodPost && r.URL.Path == blockByEmailPath: + writeJSON(t, w, http.StatusOK, map[string]any{ + "outcome": "already_blocked", + "user_id": "user-345", + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + tt.run(t, client) + + requestsMu.Lock() + defer requestsMu.Unlock() + + require.Len(t, requests, 1) + switch tt.name { + case "resolve by email": + assert.Equal(t, capturedRequest{ + Method: http.MethodPost, + Path: resolveByEmailPath, + ContentType: "application/json", + Body: `{"email":"Pilot+Case@example.com"}`, + }, requests[0]) + case "exists by user id": + assert.Equal(t, capturedRequest{ + Method: http.MethodGet, + Path: "/api/v1/internal/users/user-123/exists", + }, requests[0]) + case "ensure user by email": + assert.Equal(t, capturedRequest{ + Method: http.MethodPost, + Path: ensureByEmailPath, + ContentType: "application/json", + Body: `{"email":"created@example.com"}`, + }, requests[0]) + case "block by user id": + assert.Equal(t, capturedRequest{ + Method: http.MethodPost, + Path: "/api/v1/internal/users/user-123/block", + ContentType: "application/json", + Body: `{"reason_code":"policy_blocked"}`, + }, requests[0]) + case "block by email": + assert.Equal(t, capturedRequest{ + Method: http.MethodPost, + Path: blockByEmailPath, + ContentType: "application/json", + Body: `{"email":"blocked@example.com","reason_code":"policy_blocked"}`, + }, requests[0]) + } + }) + } +} + +func TestRESTClientPreservesNormalizedEmailExactly(t *testing.T) { + t.Parallel() + + var captured string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + request := captureRequest(t, r) + captured = request.Body + writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"}) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + _, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Alias@Example.com")) + require.NoError(t, err) + assert.Equal(t, `{"email":"Pilot+Alias@Example.com"}`, captured) +} + +func TestRESTClientBlockByUserIDNotFound(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + _, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("missing-user"), + ReasonCode: userresolution.BlockReasonCode("policy_blocked"), + }) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) +} + +func TestRESTClientReadMethodsRetryOnce(t *testing.T) { + t.Parallel() + + t.Run("resolve by email retries on 503", func(t *testing.T) { + t.Parallel() + + var calls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + if calls == 1 { + http.Error(w, "temporary", http.StatusServiceUnavailable) + return + } + + writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"}) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + result, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + require.NoError(t, err) + assert.Equal(t, userresolution.KindCreatable, result.Kind) + assert.Equal(t, 2, calls) + }) + + t.Run("exists by user id retries on transport failure", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSON(t, w, http.StatusOK, map[string]any{"exists": true}) + })) + defer server.Close() + + baseTransport := server.Client().Transport + client, err := newRESTClient(Config{ + BaseURL: server.URL, + RequestTimeout: 250 * time.Millisecond, + }, &http.Client{ + Transport: &failOnceRoundTripper{ + next: baseTransport, + err: errors.New("temporary transport failure"), + }, + }) + require.NoError(t, err) + + exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123")) + require.NoError(t, err) + assert.True(t, exists) + }) +} + +func TestRESTClientMutationMethodsDoNotRetry(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + run func(*RESTClient) error + }{ + { + name: "ensure user by email", + run: func(client *RESTClient) error { + _, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com")) + return err + }, + }, + { + name: "block by user id", + run: func(client *RESTClient) error { + _, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("user-123"), + ReasonCode: userresolution.BlockReasonCode("policy_blocked"), + }) + return err + }, + }, + { + name: "block by email", + run: func(client *RESTClient) error { + _, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("pilot@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_blocked"), + }) + return err + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var calls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + http.Error(w, "temporary", http.StatusServiceUnavailable) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + err := tt.run(client) + require.Error(t, err) + assert.Equal(t, 1, calls) + }) + } +} + +func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + body string + wantErrText string + run func(*RESTClient) error + }{ + { + name: "resolve by email rejects unknown field", + statusCode: http.StatusOK, + body: `{"kind":"creatable","extra":true}`, + wantErrText: "decode response body", + run: func(client *RESTClient) error { + _, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + return err + }, + }, + { + name: "ensure user by email rejects malformed outcome", + statusCode: http.StatusOK, + body: `{"outcome":"mystery"}`, + wantErrText: "unsupported", + run: func(client *RESTClient) error { + _, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com")) + return err + }, + }, + { + name: "ensure user by email rejects missing user id for created outcome", + statusCode: http.StatusOK, + body: `{"outcome":"created"}`, + wantErrText: "user id", + run: func(client *RESTClient) error { + _, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com")) + return err + }, + }, + { + name: "exists by user id rejects trailing json", + statusCode: http.StatusOK, + body: `{"exists":true}{}`, + wantErrText: "unexpected trailing JSON input", + run: func(client *RESTClient) error { + _, err := client.ExistsByUserID(context.Background(), common.UserID("user-123")) + return err + }, + }, + { + name: "block by email rejects unexpected status", + statusCode: http.StatusBadGateway, + body: `{"error":"temporary"}`, + wantErrText: "unexpected HTTP status 502", + run: func(client *RESTClient) error { + _, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("pilot@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_blocked"), + }) + return err + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tt.statusCode) + _, err := io.WriteString(w, tt.body) + require.NoError(t, err) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + + err := tt.run(client) + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErrText) + }) + } +} + +func TestRESTClientRequestTimeout(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(40 * time.Millisecond) + writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"}) + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 10*time.Millisecond) + + _, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + require.Error(t, err) + assert.ErrorContains(t, err, "context deadline exceeded") +} + +func TestRESTClientContextAndValidation(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("unexpected upstream call") + })) + defer server.Close() + + client := newTestRESTClient(t, server.URL, 250*time.Millisecond) + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + tests := []struct { + name string + run func() error + }{ + { + name: "nil context", + run: func() error { + _, err := client.ResolveByEmail(nil, common.Email("pilot@example.com")) + return err + }, + }, + { + name: "cancelled context", + run: func() error { + _, err := client.ExistsByUserID(cancelledCtx, common.UserID("user-123")) + return err + }, + }, + { + name: "invalid email", + run: func() error { + _, err := client.EnsureUserByEmail(context.Background(), common.Email(" bad@example.com ")) + return err + }, + }, + { + name: "invalid user id", + run: func() error { + _, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID(" bad "), + ReasonCode: userresolution.BlockReasonCode("policy_blocked"), + }) + return err + }, + }, + { + name: "invalid reason code", + run: func() error { + _, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("pilot@example.com"), + ReasonCode: userresolution.BlockReasonCode(" bad "), + }) + return err + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + err := tt.run() + require.Error(t, err) + }) + } +} + +type capturedRequest struct { + Method string + Path string + ContentType string + Body string +} + +func captureRequest(t *testing.T, request *http.Request) capturedRequest { + t.Helper() + + body, err := io.ReadAll(request.Body) + require.NoError(t, err) + + return capturedRequest{ + Method: request.Method, + Path: request.URL.Path, + ContentType: request.Header.Get("Content-Type"), + Body: strings.TrimSpace(string(body)), + } +} + +func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) { + t.Helper() + + payload, err := json.Marshal(value) + require.NoError(t, err) + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(statusCode) + _, err = writer.Write(payload) + require.NoError(t, err) +} + +func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient { + t.Helper() + + client, err := NewRESTClient(Config{ + BaseURL: baseURL, + RequestTimeout: timeout, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, client.Close()) + }) + + return client +} + +type failOnceRoundTripper struct { + mu sync.Mutex + next http.RoundTripper + err error + done bool +} + +func (rt *failOnceRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + rt.mu.Lock() + if !rt.done { + rt.done = true + err := rt.err + rt.mu.Unlock() + return nil, err + } + next := rt.next + rt.mu.Unlock() + + return next.RoundTrip(request) +} diff --git a/authsession/internal/adapters/userservice/stub_directory.go b/authsession/internal/adapters/userservice/stub_directory.go new file mode 100644 index 0000000..58582ba --- /dev/null +++ b/authsession/internal/adapters/userservice/stub_directory.go @@ -0,0 +1,361 @@ +// Package userservice provides runtime user-directory adapters for the +// auth/session service. +package userservice + +import ( + "context" + "fmt" + "sync" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" +) + +type entry struct { + userID common.UserID + blockReasonCode userresolution.BlockReasonCode +} + +// StubDirectory is a concurrency-safe in-process UserDirectory stub intended +// for development, local integration, and explicit stub-based tests. +// +// The zero value is ready to use. Unknown e-mail addresses resolve as +// creatable, unknown user identifiers do not exist, and EnsureUserByEmail +// creates deterministic user ids such as "user-1", "user-2", and so on. +type StubDirectory struct { + mu sync.Mutex + byEmail map[common.Email]entry + emailByUserID map[common.UserID]common.Email + createdUserIDs []common.UserID + nextUserNumber int +} + +// ResolveByEmail returns the current coarse user-resolution state for email +// without creating any new user record. +func (d *StubDirectory) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) { + if err := validateContext(ctx, "resolve by email"); err != nil { + return userresolution.Result{}, err + } + if err := email.Validate(); err != nil { + return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + result, err := d.resolveLocked(email) + if err != nil { + return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err) + } + + return result, nil +} + +// ExistsByUserID reports whether userID currently identifies a stored user +// record. +func (d *StubDirectory) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) { + if err := validateContext(ctx, "exists by user id"); err != nil { + return false, err + } + if err := userID.Validate(); err != nil { + return false, fmt.Errorf("exists by user id: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + _, ok := d.emailByUserID[userID] + return ok, nil +} + +// EnsureUserByEmail returns an existing user for email, creates a new user +// when registration is allowed, or reports a blocked outcome. +func (d *StubDirectory) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) { + if err := validateContext(ctx, "ensure user by email"); err != nil { + return ports.EnsureUserResult{}, err + } + if err := email.Validate(); err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + d.ensureMapsLocked() + + stored, ok := d.byEmail[email] + if ok { + if !stored.blockReasonCode.IsZero() { + result := ports.EnsureUserResult{ + Outcome: ports.EnsureUserOutcomeBlocked, + BlockReasonCode: stored.blockReasonCode, + } + if err := result.Validate(); err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + + return result, nil + } + + result := ports.EnsureUserResult{ + Outcome: ports.EnsureUserOutcomeExisting, + UserID: stored.userID, + } + if err := result.Validate(); err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + + return result, nil + } + + userID, err := d.nextCreatedUserIDLocked() + if err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + d.byEmail[email] = entry{userID: userID} + d.emailByUserID[userID] = email + + result := ports.EnsureUserResult{ + Outcome: ports.EnsureUserOutcomeCreated, + UserID: userID, + } + if err := result.Validate(); err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + + return result, nil +} + +// BlockByUserID applies a block state to the user identified by input.UserID. +// Unknown user ids wrap ports.ErrNotFound. +func (d *StubDirectory) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) { + if err := validateContext(ctx, "block by user id"); err != nil { + return ports.BlockUserResult{}, err + } + if err := input.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + email, ok := d.emailByUserID[input.UserID] + if !ok { + return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound) + } + + stored := d.byEmail[email] + if !stored.blockReasonCode.IsZero() { + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeAlreadyBlocked, + UserID: input.UserID, + } + if err := result.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err) + } + + return result, nil + } + + stored.blockReasonCode = input.ReasonCode + d.byEmail[email] = stored + + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeBlocked, + UserID: input.UserID, + } + if err := result.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err) + } + + return result, nil +} + +// BlockByEmail applies a block state to input.Email even when no user record +// currently exists for that e-mail address. +func (d *StubDirectory) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) { + if err := validateContext(ctx, "block by email"); err != nil { + return ports.BlockUserResult{}, err + } + if err := input.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + d.ensureMapsLocked() + + stored := d.byEmail[input.Email] + if !stored.blockReasonCode.IsZero() { + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeAlreadyBlocked, + UserID: stored.userID, + } + if err := result.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err) + } + + return result, nil + } + + stored.blockReasonCode = input.ReasonCode + d.byEmail[input.Email] = stored + if !stored.userID.IsZero() { + d.emailByUserID[stored.userID] = input.Email + } + + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeBlocked, + UserID: stored.userID, + } + if err := result.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err) + } + + return result, nil +} + +// SeedExisting preloads one existing unblocked user record into the runtime +// stub. +func (d *StubDirectory) SeedExisting(email common.Email, userID common.UserID) error { + if err := email.Validate(); err != nil { + return fmt.Errorf("seed existing email: %w", err) + } + if err := userID.Validate(); err != nil { + return fmt.Errorf("seed existing user id: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + d.ensureMapsLocked() + d.byEmail[email] = entry{userID: userID} + d.emailByUserID[userID] = email + + return nil +} + +// SeedBlockedEmail preloads one blocked e-mail address that does not +// necessarily belong to an existing user record. +func (d *StubDirectory) SeedBlockedEmail(email common.Email, reasonCode userresolution.BlockReasonCode) error { + if err := email.Validate(); err != nil { + return fmt.Errorf("seed blocked email: %w", err) + } + if err := reasonCode.Validate(); err != nil { + return fmt.Errorf("seed blocked email reason code: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + d.ensureMapsLocked() + d.byEmail[email] = entry{blockReasonCode: reasonCode} + + return nil +} + +// SeedBlockedUser preloads one blocked existing user record into the runtime +// stub. +func (d *StubDirectory) SeedBlockedUser(email common.Email, userID common.UserID, reasonCode userresolution.BlockReasonCode) error { + if err := d.SeedExisting(email, userID); err != nil { + return err + } + + d.mu.Lock() + defer d.mu.Unlock() + + stored := d.byEmail[email] + stored.blockReasonCode = reasonCode + d.byEmail[email] = stored + + return nil +} + +// QueueCreatedUserIDs appends deterministic user identifiers that +// EnsureUserByEmail consumes before falling back to generated ids. +func (d *StubDirectory) QueueCreatedUserIDs(userIDs ...common.UserID) error { + for index, userID := range userIDs { + if err := userID.Validate(); err != nil { + return fmt.Errorf("queue created user id %d: %w", index, err) + } + } + + d.mu.Lock() + defer d.mu.Unlock() + + d.createdUserIDs = append(d.createdUserIDs, userIDs...) + return nil +} + +func (d *StubDirectory) ensureMapsLocked() { + if d.byEmail == nil { + d.byEmail = make(map[common.Email]entry) + } + if d.emailByUserID == nil { + d.emailByUserID = make(map[common.UserID]common.Email) + } +} + +func (d *StubDirectory) resolveLocked(email common.Email) (userresolution.Result, error) { + stored, ok := d.byEmail[email] + if !ok { + result := userresolution.Result{Kind: userresolution.KindCreatable} + if err := result.Validate(); err != nil { + return userresolution.Result{}, err + } + + return result, nil + } + if !stored.blockReasonCode.IsZero() { + result := userresolution.Result{ + Kind: userresolution.KindBlocked, + BlockReasonCode: stored.blockReasonCode, + } + if err := result.Validate(); err != nil { + return userresolution.Result{}, err + } + + return result, nil + } + + result := userresolution.Result{ + Kind: userresolution.KindExisting, + UserID: stored.userID, + } + if err := result.Validate(); err != nil { + return userresolution.Result{}, err + } + + return result, nil +} + +func (d *StubDirectory) nextCreatedUserIDLocked() (common.UserID, error) { + if len(d.createdUserIDs) > 0 { + userID := d.createdUserIDs[0] + d.createdUserIDs = d.createdUserIDs[1:] + return userID, nil + } + + d.nextUserNumber++ + userID := common.UserID(fmt.Sprintf("user-%d", d.nextUserNumber)) + if err := userID.Validate(); err != nil { + return "", err + } + + return userID, nil +} + +func validateContext(ctx context.Context, operation string) error { + if ctx == nil { + return fmt.Errorf("%s: nil context", operation) + } + if err := ctx.Err(); err != nil { + return fmt.Errorf("%s: %w", operation, err) + } + + return nil +} + +var _ ports.UserDirectory = (*StubDirectory)(nil) diff --git a/authsession/internal/adapters/userservice/stub_directory_test.go b/authsession/internal/adapters/userservice/stub_directory_test.go new file mode 100644 index 0000000..d6916cd --- /dev/null +++ b/authsession/internal/adapters/userservice/stub_directory_test.go @@ -0,0 +1,329 @@ +package userservice + +import ( + "context" + "errors" + "testing" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStubDirectoryResolveByEmail(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing"))) + require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block"))) + + tests := []struct { + name string + email common.Email + wantKind userresolution.Kind + wantUserID common.UserID + wantReasonCode userresolution.BlockReasonCode + }{ + { + name: "zero value unknown email is creatable", + email: common.Email("new@example.com"), + wantKind: userresolution.KindCreatable, + }, + { + name: "existing email", + email: common.Email("existing@example.com"), + wantKind: userresolution.KindExisting, + wantUserID: common.UserID("user-existing"), + }, + { + name: "blocked email", + email: common.Email("blocked@example.com"), + wantKind: userresolution.KindBlocked, + wantReasonCode: userresolution.BlockReasonCode("policy_block"), + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := directory.ResolveByEmail(context.Background(), tt.email) + require.NoError(t, err) + assert.Equal(t, tt.wantKind, result.Kind) + assert.Equal(t, tt.wantUserID, result.UserID) + assert.Equal(t, tt.wantReasonCode, result.BlockReasonCode) + }) + } +} + +func TestStubDirectoryEnsureUserByEmail(t *testing.T) { + t.Parallel() + + t.Run("existing", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing"))) + + result, err := directory.EnsureUserByEmail(context.Background(), common.Email("existing@example.com")) + require.NoError(t, err) + assert.Equal(t, ports.EnsureUserOutcomeExisting, result.Outcome) + assert.Equal(t, common.UserID("user-existing"), result.UserID) + }) + + t.Run("blocked", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block"))) + + result, err := directory.EnsureUserByEmail(context.Background(), common.Email("blocked@example.com")) + require.NoError(t, err) + assert.Equal(t, ports.EnsureUserOutcomeBlocked, result.Outcome) + assert.Equal(t, userresolution.BlockReasonCode("policy_block"), result.BlockReasonCode) + }) + + t.Run("created queued then existing", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.QueueCreatedUserIDs(common.UserID("user-created"))) + + first, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com")) + require.NoError(t, err) + assert.Equal(t, ports.EnsureUserOutcomeCreated, first.Outcome) + assert.Equal(t, common.UserID("user-created"), first.UserID) + + second, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com")) + require.NoError(t, err) + assert.Equal(t, ports.EnsureUserOutcomeExisting, second.Outcome) + assert.Equal(t, common.UserID("user-created"), second.UserID) + }) + + t.Run("created fallback id", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + + result, err := directory.EnsureUserByEmail(context.Background(), common.Email("fallback@example.com")) + require.NoError(t, err) + assert.Equal(t, ports.EnsureUserOutcomeCreated, result.Outcome) + assert.Equal(t, common.UserID("user-1"), result.UserID) + }) +} + +func TestStubDirectoryExistsByUserID(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing"))) + + exists, err := directory.ExistsByUserID(context.Background(), common.UserID("user-existing")) + require.NoError(t, err) + assert.True(t, exists) + + exists, err = directory.ExistsByUserID(context.Background(), common.UserID("missing")) + require.NoError(t, err) + assert.False(t, exists) +} + +func TestStubDirectoryBlockByEmail(t *testing.T) { + t.Parallel() + + t.Run("unknown email becomes blocked without user id", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + + result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("blocked@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserOutcomeBlocked, result.Outcome) + assert.True(t, result.UserID.IsZero()) + + resolution, err := directory.ResolveByEmail(context.Background(), common.Email("blocked@example.com")) + require.NoError(t, err) + assert.Equal(t, userresolution.KindBlocked, resolution.Kind) + }) + + t.Run("existing user preserves linked user id and repeat is already blocked", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + first, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("pilot@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome) + assert.Equal(t, common.UserID("user-1"), first.UserID) + + second, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("pilot@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome) + assert.Equal(t, common.UserID("user-1"), second.UserID) + }) +} + +func TestStubDirectoryBlockByUserID(t *testing.T) { + t.Parallel() + + t.Run("unknown user wraps ErrNotFound", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + + _, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("missing"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.Error(t, err) + assert.ErrorIs(t, err, ports.ErrNotFound) + }) + + t.Run("existing user blocks then returns already blocked", func(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + first, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("user-1"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome) + assert.Equal(t, common.UserID("user-1"), first.UserID) + + second, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("user-1"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome) + assert.Equal(t, common.UserID("user-1"), second.UserID) + }) +} + +func TestStubDirectoryContextAndValidation(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + tests := []struct { + name string + run func() error + want string + }{ + { + name: "resolve nil context", + run: func() error { + _, err := directory.ResolveByEmail(nil, common.Email("pilot@example.com")) + return err + }, + want: "nil context", + }, + { + name: "ensure cancelled context", + run: func() error { + _, err := directory.EnsureUserByEmail(cancelledCtx, common.Email("pilot@example.com")) + return err + }, + want: context.Canceled.Error(), + }, + { + name: "exists invalid user id", + run: func() error { + _, err := directory.ExistsByUserID(context.Background(), common.UserID(" bad ")) + return err + }, + want: "exists by user id", + }, + { + name: "block by email invalid email", + run: func() error { + _, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("bad"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + return err + }, + want: "block by email", + }, + { + name: "seed invalid user id", + run: func() error { + return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID(" bad ")) + }, + want: "seed existing user id", + }, + { + name: "queue invalid created user id", + run: func() error { + return directory.QueueCreatedUserIDs(common.UserID(" bad ")) + }, + want: "queue created user id 0", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.run() + require.Error(t, err) + assert.ErrorContains(t, err, tt.want) + }) + } +} + +func TestStubDirectorySeedBlockedUser(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + require.NoError(t, directory.SeedBlockedUser( + common.Email("pilot@example.com"), + common.UserID("user-1"), + userresolution.BlockReasonCode("policy_block"), + )) + + result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("pilot@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.NoError(t, err) + assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, result.Outcome) + assert.Equal(t, common.UserID("user-1"), result.UserID) +} + +func TestStubDirectoryCancelledContextWrapsContextError(t *testing.T) { + t.Parallel() + + directory := &StubDirectory{} + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := directory.BlockByUserID(cancelledCtx, ports.BlockUserByIDInput{ + UserID: common.UserID("user-1"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.ErrorContains(t, err, "block by user id") +} diff --git a/authsession/internal/api/internalhttp/doc.go b/authsession/internal/api/internalhttp/doc.go new file mode 100644 index 0000000..1e940b2 --- /dev/null +++ b/authsession/internal/api/internalhttp/doc.go @@ -0,0 +1,3 @@ +// Package internalhttp exposes the trusted internal HTTP API used for session +// read, revoke, and block operations. +package internalhttp diff --git a/authsession/internal/api/internalhttp/e2e_test.go b/authsession/internal/api/internalhttp/e2e_test.go new file mode 100644 index 0000000..b6a4e5b --- /dev/null +++ b/authsession/internal/api/internalhttp/e2e_test.go @@ -0,0 +1,286 @@ +package internalhttp + +import ( + "bytes" + "context" + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInternalHTTPEndToEndGetSession(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)))) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := getJSON(t, server.URL+"/api/v1/internal/sessions/device-session-1") + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"session":{"device_session_id":"device-session-1","user_id":"user-1","client_public_key":"`+validClientPublicKey+`","status":"active","created_at":"2026-04-05T12:00:00Z"}}`, response.Body) +} + +func TestInternalHTTPEndToEndListUserSessions(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + key := testClientPublicKey(t, validClientPublicKey) + require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", key, time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)))) + require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-2", "user-1", key, time.Date(2026, 4, 5, 12, 1, 0, 0, time.UTC)))) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := getJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions") + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.Contains(t, response.Body, `"device_session_id":"device-session-2"`) + assert.Contains(t, response.Body, `"device_session_id":"device-session-1"`) + assert.Less(t, bytes.Index([]byte(response.Body), []byte(`"device_session_id":"device-session-2"`)), bytes.Index([]byte(response.Body), []byte(`"device_session_id":"device-session-1"`))) +} + +func TestInternalHTTPEndToEndListUserSessionsUnknownUserReturnsEmptyArray(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := getJSON(t, server.URL+"/api/v1/internal/users/unknown-user/sessions") + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"sessions":[]}`, response.Body) +} + +func TestInternalHTTPEndToEndGetSessionNotFound(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := getJSON(t, server.URL+"/api/v1/internal/sessions/missing-session") + + assert.Equal(t, http.StatusNotFound, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"session_not_found","message":"session not found"}}`, response.Body) +} + +func TestInternalHTTPEndToEndRevokeDeviceSession(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)))) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/internal/sessions/device-session-1/revoke", `{"reason_code":"admin_revoke","actor":{"type":"system"}}`) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"outcome":"revoked","device_session_id":"device-session-1","affected_session_count":1}`, response.Body) +} + +func TestInternalHTTPEndToEndRevokeAllUserSessions(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + key := testClientPublicKey(t, validClientPublicKey) + require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", key, time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)))) + require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-2", "user-1", key, time.Date(2026, 4, 5, 12, 1, 0, 0, time.UTC)))) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"outcome":"revoked","user_id":"user-1","affected_session_count":2,"affected_device_session_ids":["device-session-2","device-session-1"]}`, response.Body) +} + +func TestInternalHTTPEndToEndRevokeAllUserSessionsNoActiveSessions(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-1","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body) +} + +func TestInternalHTTPEndToEndRevokeAllUserSessionsUnknownUser(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/internal/users/missing-user/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`) + + assert.Equal(t, http.StatusNotFound, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"subject_not_found","message":"subject not found"}}`, response.Body) +} + +func TestInternalHTTPEndToEndBlockUserByEmail(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body) +} + +func TestInternalHTTPEndToEndBlockUserByUserID(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)))) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"user_id":"user-1","reason_code":"policy_blocked","actor":{"type":"admin"}}`) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"user_id","subject_value":"user-1","affected_session_count":1,"affected_device_session_ids":["device-session-1"]}`, response.Body) +} + +func TestInternalHTTPEndToEndBlockUserUnknownUserID(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"user_id":"missing-user","reason_code":"policy_blocked","actor":{"type":"admin"}}`) + + assert.Equal(t, http.StatusNotFound, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"subject_not_found","message":"subject not found"}}`, response.Body) +} + +type endToEndApp struct { + handler http.Handler + sessionStore *testkit.InMemorySessionStore + userDirectory *userservice.StubDirectory +} + +func newEndToEndApp(t *testing.T) endToEndApp { + t.Helper() + + sessionStore := &testkit.InMemorySessionStore{} + userDirectory := &userservice.StubDirectory{} + publisher := &testkit.RecordingProjectionPublisher{} + clock := testkit.FixedClock{Time: time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)} + + getSessionService, err := getsession.New(sessionStore) + require.NoError(t, err) + listUserSessionsService, err := listusersessions.New(sessionStore) + require.NoError(t, err) + revokeDeviceSessionService, err := revokedevicesession.New(sessionStore, publisher, clock) + require.NoError(t, err) + revokeAllUserSessionsService, err := revokeallusersessions.New(sessionStore, userDirectory, publisher, clock) + require.NoError(t, err) + blockUserService, err := blockuser.New(userDirectory, sessionStore, publisher, clock) + require.NoError(t, err) + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionService, + ListUserSessions: listUserSessionsService, + RevokeDeviceSession: revokeDeviceSessionService, + RevokeAllUserSessions: revokeAllUserSessionsService, + BlockUser: blockUserService, + }) + + return endToEndApp{ + handler: handler, + sessionStore: sessionStore, + userDirectory: userDirectory, + } +} + +type httpResponse struct { + StatusCode int + Body string +} + +func getJSON(t *testing.T, url string) httpResponse { + t.Helper() + + response, err := http.Get(url) + require.NoError(t, err) + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + require.NoError(t, err) + + return httpResponse{StatusCode: response.StatusCode, Body: string(payload)} +} + +func postJSON(t *testing.T, url string, body string) httpResponse { + t.Helper() + + response, err := http.Post(url, "application/json", bytes.NewBufferString(body)) + require.NoError(t, err) + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + require.NoError(t, err) + + return httpResponse{StatusCode: response.StatusCode, Body: string(payload)} +} + +func postJSONValue(t *testing.T, url string, value any) httpResponse { + t.Helper() + + body, err := json.Marshal(value) + require.NoError(t, err) + return postJSON(t, url, string(body)) +} + +func activeSession(id string, userID string, key common.ClientPublicKey, createdAt time.Time) devicesession.Session { + return devicesession.Session{ + ID: common.DeviceSessionID(id), + UserID: common.UserID(userID), + ClientPublicKey: key, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} + +func testClientPublicKey(t *testing.T, encoded string) common.ClientPublicKey { + t.Helper() + + decoded, err := base64.StdEncoding.DecodeString(encoded) + require.NoError(t, err) + + key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded)) + require.NoError(t, err) + return key +} + +const validClientPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=" diff --git a/authsession/internal/api/internalhttp/handler.go b/authsession/internal/api/internalhttp/handler.go new file mode 100644 index 0000000..23c3919 --- /dev/null +++ b/authsession/internal/api/internalhttp/handler.go @@ -0,0 +1,513 @@ +package internalhttp + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/telemetry" + + "github.com/gin-gonic/gin" + "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" +) + +const jsonContentType = "application/json; charset=utf-8" + +const internalHTTPServiceName = "galaxy-authsession-internal" + +type errorResponse struct { + Error errorBody `json:"error"` +} + +type errorBody struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type actorRequest struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` +} + +type sessionResponseDTO struct { + DeviceSessionID string `json:"device_session_id"` + UserID string `json:"user_id"` + ClientPublicKey string `json:"client_public_key"` + Status string `json:"status"` + CreatedAt string `json:"created_at"` + RevokedAt *string `json:"revoked_at,omitempty"` + RevokeReasonCode *string `json:"revoke_reason_code,omitempty"` + RevokeActorType *string `json:"revoke_actor_type,omitempty"` + RevokeActorID *string `json:"revoke_actor_id,omitempty"` +} + +type getSessionResponse struct { + Session sessionResponseDTO `json:"session"` +} + +type listUserSessionsResponse struct { + Sessions []sessionResponseDTO `json:"sessions"` +} + +type revokeDeviceSessionRequest struct { + ReasonCode string `json:"reason_code"` + Actor actorRequest `json:"actor"` +} + +type revokeDeviceSessionResponse struct { + Outcome string `json:"outcome"` + DeviceSessionID string `json:"device_session_id"` + AffectedSessionCount int64 `json:"affected_session_count"` +} + +type revokeAllUserSessionsRequest struct { + ReasonCode string `json:"reason_code"` + Actor actorRequest `json:"actor"` +} + +type revokeAllUserSessionsResponse struct { + Outcome string `json:"outcome"` + UserID string `json:"user_id"` + AffectedSessionCount int64 `json:"affected_session_count"` + AffectedDeviceSessionIDs []string `json:"affected_device_session_ids"` +} + +type blockUserRequest struct { + UserID string `json:"user_id,omitempty"` + Email string `json:"email,omitempty"` + ReasonCode string `json:"reason_code"` + Actor actorRequest `json:"actor"` +} + +type blockUserResponse struct { + Outcome string `json:"outcome"` + SubjectKind string `json:"subject_kind"` + SubjectValue string `json:"subject_value"` + AffectedSessionCount int64 `json:"affected_session_count"` + AffectedDeviceSessionIDs []string `json:"affected_device_session_ids"` +} + +var configureGinModeOnce sync.Once + +func newHandlerWithConfig(cfg Config, deps Dependencies) (http.Handler, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + + normalizedDeps, err := normalizeDependencies(deps) + if err != nil { + return nil, err + } + + configureGinModeOnce.Do(func() { + gin.SetMode(gin.ReleaseMode) + }) + + engine := gin.New() + engine.Use(newOTelMiddleware(normalizedDeps.Telemetry)) + engine.Use(withInternalObservability(normalizedDeps.Logger, normalizedDeps.Telemetry)) + engine.GET("/api/v1/internal/sessions/:device_session_id", handleGetSession(normalizedDeps.GetSession, cfg.RequestTimeout)) + engine.GET("/api/v1/internal/users/:user_id/sessions", handleListUserSessions(normalizedDeps.ListUserSessions, cfg.RequestTimeout)) + engine.POST("/api/v1/internal/sessions/:device_session_id/revoke", handleRevokeDeviceSession(normalizedDeps.RevokeDeviceSession, cfg.RequestTimeout)) + engine.POST("/api/v1/internal/users/:user_id/sessions/revoke-all", handleRevokeAllUserSessions(normalizedDeps.RevokeAllUserSessions, cfg.RequestTimeout)) + engine.POST("/api/v1/internal/user-blocks", handleBlockUser(normalizedDeps.BlockUser, cfg.RequestTimeout)) + + return engine, nil +} + +func newOTelMiddleware(runtime *telemetry.Runtime) gin.HandlerFunc { + options := []otelgin.Option{} + if runtime != nil { + options = append( + options, + otelgin.WithTracerProvider(runtime.TracerProvider()), + otelgin.WithMeterProvider(runtime.MeterProvider()), + ) + } + + return otelgin.Middleware(internalHTTPServiceName, options...) +} + +func handleGetSession(useCase GetSessionUseCase, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := useCase.Execute(callCtx, getsession.Input{ + DeviceSessionID: c.Param("device_session_id"), + }) + if err != nil { + abortWithProjection(c, projectInternalError(err)) + return + } + if err := validateGetSessionResult(&result); err != nil { + abortWithProjection(c, internalErrorProjection(fmt.Errorf("get session response: %w", err))) + return + } + + c.JSON(http.StatusOK, getSessionResponse{Session: toSessionResponseDTO(result.Session)}) + } +} + +func handleListUserSessions(useCase ListUserSessionsUseCase, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := useCase.Execute(callCtx, listusersessions.Input{ + UserID: c.Param("user_id"), + }) + if err != nil { + abortWithProjection(c, projectInternalError(err)) + return + } + if err := validateListUserSessionsResult(&result); err != nil { + abortWithProjection(c, internalErrorProjection(fmt.Errorf("list user sessions response: %w", err))) + return + } + + c.JSON(http.StatusOK, listUserSessionsResponse{Sessions: toSessionResponseDTOs(result.Sessions)}) + } +} + +func handleRevokeDeviceSession(useCase RevokeDeviceSessionUseCase, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + var request revokeDeviceSessionRequest + if err := decodeJSONRequest(c.Request, &request); err != nil { + abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error()))) + return + } + if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil { + abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error()))) + return + } + + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := useCase.Execute(callCtx, revokedevicesession.Input{ + DeviceSessionID: c.Param("device_session_id"), + ReasonCode: request.ReasonCode, + ActorType: request.Actor.Type, + ActorID: request.Actor.ID, + }) + if err != nil { + abortWithProjection(c, projectInternalError(err)) + return + } + if err := validateRevokeDeviceSessionResult(&result); err != nil { + abortWithProjection(c, internalErrorProjection(fmt.Errorf("revoke device session response: %w", err))) + return + } + + c.JSON(http.StatusOK, revokeDeviceSessionResponse{ + Outcome: result.Outcome, + DeviceSessionID: result.DeviceSessionID, + AffectedSessionCount: result.AffectedSessionCount, + }) + } +} + +func handleRevokeAllUserSessions(useCase RevokeAllUserSessionsUseCase, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + var request revokeAllUserSessionsRequest + if err := decodeJSONRequest(c.Request, &request); err != nil { + abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error()))) + return + } + if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil { + abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error()))) + return + } + + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := useCase.Execute(callCtx, revokeallusersessions.Input{ + UserID: c.Param("user_id"), + ReasonCode: request.ReasonCode, + ActorType: request.Actor.Type, + ActorID: request.Actor.ID, + }) + if err != nil { + abortWithProjection(c, projectInternalError(err)) + return + } + if err := validateRevokeAllUserSessionsResult(&result); err != nil { + abortWithProjection(c, internalErrorProjection(fmt.Errorf("revoke all user sessions response: %w", err))) + return + } + + c.JSON(http.StatusOK, revokeAllUserSessionsResponse{ + Outcome: result.Outcome, + UserID: result.UserID, + AffectedSessionCount: result.AffectedSessionCount, + AffectedDeviceSessionIDs: cloneStrings(result.AffectedDeviceSessionIDs), + }) + } +} + +func handleBlockUser(useCase BlockUserUseCase, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + var request blockUserRequest + if err := decodeJSONRequest(c.Request, &request); err != nil { + abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error()))) + return + } + if err := validateBlockUserRequest(&request); err != nil { + abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error()))) + return + } + + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := useCase.Execute(callCtx, blockuser.Input{ + UserID: request.UserID, + Email: request.Email, + ReasonCode: request.ReasonCode, + ActorType: request.Actor.Type, + ActorID: request.Actor.ID, + }) + if err != nil { + abortWithProjection(c, projectInternalError(err)) + return + } + if err := validateBlockUserResult(&result); err != nil { + abortWithProjection(c, internalErrorProjection(fmt.Errorf("block user response: %w", err))) + return + } + + c.JSON(http.StatusOK, blockUserResponse{ + Outcome: result.Outcome, + SubjectKind: result.SubjectKind, + SubjectValue: result.SubjectValue, + AffectedSessionCount: result.AffectedSessionCount, + AffectedDeviceSessionIDs: cloneStrings(result.AffectedDeviceSessionIDs), + }) + } +} + +func toSessionResponseDTO(session shared.Session) sessionResponseDTO { + return sessionResponseDTO{ + DeviceSessionID: session.DeviceSessionID, + UserID: session.UserID, + ClientPublicKey: session.ClientPublicKey, + Status: session.Status, + CreatedAt: session.CreatedAt, + RevokedAt: cloneStringPointer(session.RevokedAt), + RevokeReasonCode: cloneStringPointer(session.RevokeReasonCode), + RevokeActorType: cloneStringPointer(session.RevokeActorType), + RevokeActorID: cloneStringPointer(session.RevokeActorID), + } +} + +func toSessionResponseDTOs(sessions []shared.Session) []sessionResponseDTO { + result := make([]sessionResponseDTO, 0, len(sessions)) + for _, session := range sessions { + result = append(result, toSessionResponseDTO(session)) + } + + return result +} + +func cloneStrings(values []string) []string { + result := make([]string, 0, len(values)) + return append(result, values...) +} + +func cloneStringPointer(value *string) *string { + if value == nil { + return nil + } + + cloned := *value + return &cloned +} + +func validateAuditRequest(reasonCode string, actor actorRequest) error { + if strings.TrimSpace(reasonCode) == "" { + return errors.New("reason_code must not be empty") + } + if strings.TrimSpace(actor.Type) == "" { + return errors.New("actor.type must not be empty") + } + + return nil +} + +func validateBlockUserRequest(request *blockUserRequest) error { + if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil { + return err + } + + hasUserID := strings.TrimSpace(request.UserID) != "" + hasEmail := strings.TrimSpace(request.Email) != "" + switch { + case hasUserID && hasEmail: + return errors.New("exactly one of user_id or email must be provided") + case !hasUserID && !hasEmail: + return errors.New("exactly one of user_id or email must be provided") + default: + return nil + } +} + +func validateSessionDTO(session *shared.Session) error { + switch { + case strings.TrimSpace(session.DeviceSessionID) == "": + return errors.New("session.device_session_id must not be empty") + case strings.TrimSpace(session.UserID) == "": + return errors.New("session.user_id must not be empty") + case strings.TrimSpace(session.ClientPublicKey) == "": + return errors.New("session.client_public_key must not be empty") + case strings.TrimSpace(session.CreatedAt) == "": + return errors.New("session.created_at must not be empty") + } + + if _, err := time.Parse(time.RFC3339, session.CreatedAt); err != nil { + return fmt.Errorf("session.created_at: %w", err) + } + + switch session.Status { + case "active": + if session.RevokedAt != nil || session.RevokeReasonCode != nil || session.RevokeActorType != nil || session.RevokeActorID != nil { + return errors.New("active session must not contain revoke metadata") + } + case "revoked": + switch { + case session.RevokedAt == nil || strings.TrimSpace(*session.RevokedAt) == "": + return errors.New("revoked session must contain revoked_at") + case session.RevokeReasonCode == nil || strings.TrimSpace(*session.RevokeReasonCode) == "": + return errors.New("revoked session must contain revoke_reason_code") + case session.RevokeActorType == nil || strings.TrimSpace(*session.RevokeActorType) == "": + return errors.New("revoked session must contain revoke_actor_type") + } + if _, err := time.Parse(time.RFC3339, *session.RevokedAt); err != nil { + return fmt.Errorf("session.revoked_at: %w", err) + } + default: + return fmt.Errorf("session.status %q is unsupported", session.Status) + } + + return nil +} + +func validateGetSessionResult(result *getsession.Result) error { + return validateSessionDTO(&result.Session) +} + +func validateListUserSessionsResult(result *listusersessions.Result) error { + if result.Sessions == nil { + return errors.New("sessions must not be null") + } + + for index := range result.Sessions { + if err := validateSessionDTO(&result.Sessions[index]); err != nil { + return fmt.Errorf("sessions[%d]: %w", index, err) + } + } + + return nil +} + +func validateRevokeDeviceSessionResult(result *revokedevicesession.Result) error { + switch result.Outcome { + case "revoked": + if result.AffectedSessionCount != 1 { + return errors.New("revoked outcome must affect exactly one session") + } + case "already_revoked": + if result.AffectedSessionCount != 0 { + return errors.New("already_revoked outcome must affect zero sessions") + } + default: + return fmt.Errorf("revoke device session outcome %q is unsupported", result.Outcome) + } + if strings.TrimSpace(result.DeviceSessionID) == "" { + return errors.New("device_session_id must not be empty") + } + + return nil +} + +func validateRevokeAllUserSessionsResult(result *revokeallusersessions.Result) error { + switch result.Outcome { + case "revoked", "no_active_sessions": + default: + return fmt.Errorf("revoke all user sessions outcome %q is unsupported", result.Outcome) + } + if strings.TrimSpace(result.UserID) == "" { + return errors.New("user_id must not be empty") + } + if result.AffectedSessionCount < 0 { + return errors.New("affected_session_count must not be negative") + } + if result.AffectedDeviceSessionIDs == nil { + return errors.New("affected_device_session_ids must not be null") + } + if int64(len(result.AffectedDeviceSessionIDs)) != result.AffectedSessionCount { + return errors.New("affected_device_session_ids length must match affected_session_count") + } + for index, deviceSessionID := range result.AffectedDeviceSessionIDs { + if strings.TrimSpace(deviceSessionID) == "" { + return fmt.Errorf("affected_device_session_ids[%d] must not be empty", index) + } + } + + return nil +} + +func validateBlockUserResult(result *blockuser.Result) error { + switch result.Outcome { + case "blocked", "already_blocked": + default: + return fmt.Errorf("block user outcome %q is unsupported", result.Outcome) + } + switch result.SubjectKind { + case blockuser.SubjectKindUserID, blockuser.SubjectKindEmail: + default: + return fmt.Errorf("subject_kind %q is unsupported", result.SubjectKind) + } + if strings.TrimSpace(result.SubjectValue) == "" { + return errors.New("subject_value must not be empty") + } + if result.AffectedSessionCount < 0 { + return errors.New("affected_session_count must not be negative") + } + if result.AffectedDeviceSessionIDs == nil { + return errors.New("affected_device_session_ids must not be null") + } + if int64(len(result.AffectedDeviceSessionIDs)) != result.AffectedSessionCount { + return errors.New("affected_device_session_ids length must match affected_session_count") + } + for index, deviceSessionID := range result.AffectedDeviceSessionIDs { + if strings.TrimSpace(deviceSessionID) == "" { + return fmt.Errorf("affected_device_session_ids[%d] must not be empty", index) + } + } + + return nil +} + +func projectInternalError(err error) shared.InternalErrorProjection { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return shared.ProjectInternalError(shared.ServiceUnavailable(err)) + } + + return shared.ProjectInternalError(err) +} + +func internalErrorProjection(err error) shared.InternalErrorProjection { + return shared.ProjectInternalError(shared.InternalError(err)) +} diff --git a/authsession/internal/api/internalhttp/handler_test.go b/authsession/internal/api/internalhttp/handler_test.go new file mode 100644 index 0000000..e38de6d --- /dev/null +++ b/authsession/internal/api/internalhttp/handler_test.go @@ -0,0 +1,784 @@ +package internalhttp + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/shared" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestGetSessionHandlerSuccess(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(func(_ context.Context, input getsession.Input) (getsession.Result, error) { + assert.Equal(t, getsession.Input{DeviceSessionID: "device-session-123"}, input) + return getsession.Result{ + Session: validSessionDTO(), + }, nil + }), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil) + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"session":{"device_session_id":"device-session-123","user_id":"user-123","client_public_key":"public-key-material","status":"active","created_at":"2026-04-05T12:00:00Z"}}`, recorder.Body.String()) +} + +func TestListUserSessionsHandlerSuccess(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(func(_ context.Context, input listusersessions.Input) (listusersessions.Result, error) { + assert.Equal(t, listusersessions.Input{UserID: "user-123"}, input) + first := validSessionDTO() + second := validRevokedSessionDTO() + second.DeviceSessionID = "device-session-122" + return listusersessions.Result{Sessions: []shared.Session{first, second}}, nil + }), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/users/user-123/sessions", nil) + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.Contains(t, recorder.Body.String(), `"sessions":[`) + assert.Contains(t, recorder.Body.String(), `"device_session_id":"device-session-123"`) + assert.Contains(t, recorder.Body.String(), `"device_session_id":"device-session-122"`) +} + +func TestListUserSessionsHandlerUnknownUserReturnsEmptyArray(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(func(_ context.Context, input listusersessions.Input) (listusersessions.Result, error) { + assert.Equal(t, listusersessions.Input{UserID: "unknown-user"}, input) + return listusersessions.Result{Sessions: []shared.Session{}}, nil + }), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/users/unknown-user/sessions", nil) + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"sessions":[]}`, recorder.Body.String()) +} + +func TestRevokeDeviceSessionHandlerAlreadyRevoked(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(func(_ context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) { + assert.Equal(t, revokedevicesession.Input{ + DeviceSessionID: "device-session-123", + ReasonCode: "admin_revoke", + ActorType: "system", + }, input) + return revokedevicesession.Result{ + Outcome: "already_revoked", + DeviceSessionID: "device-session-123", + AffectedSessionCount: 0, + }, nil + }), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/internal/sessions/device-session-123/revoke", + bytes.NewBufferString(`{"reason_code":"admin_revoke","actor":{"type":"system"}}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"outcome":"already_revoked","device_session_id":"device-session-123","affected_session_count":0}`, recorder.Body.String()) +} + +func TestRevokeAllUserSessionsHandlerNoActiveSessions(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(func(_ context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) { + assert.Equal(t, revokeallusersessions.Input{ + UserID: "user-123", + ReasonCode: "logout_all", + ActorType: "system", + }, input) + return revokeallusersessions.Result{ + Outcome: "no_active_sessions", + UserID: "user-123", + AffectedSessionCount: 0, + AffectedDeviceSessionIDs: []string{}, + }, nil + }), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/internal/users/user-123/sessions/revoke-all", + bytes.NewBufferString(`{"reason_code":"logout_all","actor":{"type":"system"}}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-123","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String()) +} + +func TestBlockUserHandlerSuccessByEmail(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(func(_ context.Context, input blockuser.Input) (blockuser.Result, error) { + assert.Equal(t, blockuser.Input{ + Email: "pilot@example.com", + ReasonCode: "policy_blocked", + ActorType: "admin", + }, input) + return blockuser.Result{ + Outcome: "blocked", + SubjectKind: blockuser.SubjectKindEmail, + SubjectValue: "pilot@example.com", + AffectedSessionCount: 0, + AffectedDeviceSessionIDs: []string{}, + }, nil + }), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/internal/user-blocks", + bytes.NewBufferString(`{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String()) +} + +func TestBlockUserHandlerSuccessByUserID(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(func(_ context.Context, input blockuser.Input) (blockuser.Result, error) { + assert.Equal(t, blockuser.Input{ + UserID: "user-123", + ReasonCode: "policy_blocked", + ActorType: "admin", + }, input) + return blockuser.Result{ + Outcome: "already_blocked", + SubjectKind: blockuser.SubjectKindUserID, + SubjectValue: "user-123", + AffectedSessionCount: 0, + AffectedDeviceSessionIDs: []string{}, + }, nil + }), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/internal/user-blocks", + bytes.NewBufferString(`{"user_id":"user-123","reason_code":"policy_blocked","actor":{"type":"admin"}}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"outcome":"already_blocked","subject_kind":"user_id","subject_value":"user-123","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String()) +} + +func TestInternalHandlersRejectInvalidPathParams(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + target string + body string + wantStatus int + wantBody string + }{ + { + name: "get session empty device session id", + method: http.MethodGet, + target: "/api/v1/internal/sessions/%20", + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"device session id must not be empty"}}`, + }, + { + name: "list sessions empty user id", + method: http.MethodGet, + target: "/api/v1/internal/users/%20/sessions", + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"user id must not be empty"}}`, + }, + { + name: "revoke all empty user id", + method: http.MethodPost, + target: "/api/v1/internal/users/%20/sessions/revoke-all", + body: `{"reason_code":"logout_all","actor":{"type":"system"}}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"user id must not be empty"}}`, + }, + } + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{}, shared.InvalidRequest("device session id must not be empty") + }), + ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) { + return listusersessions.Result{}, shared.InvalidRequest("user id must not be empty") + }), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) { + return revokeallusersessions.Result{}, shared.InvalidRequest("user id must not be empty") + }), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body)) + if tt.body != "" { + request.Header.Set("Content-Type", "application/json") + } + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestInternalMutationHandlersRejectInvalidRequests(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + target string + body string + wantStatus int + wantBody string + }{ + { + name: "revoke device session empty body", + method: http.MethodPost, + target: "/api/v1/internal/sessions/device-session-123/revoke", + body: ``, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body must not be empty"}}`, + }, + { + name: "revoke device session malformed json", + method: http.MethodPost, + target: "/api/v1/internal/sessions/device-session-123/revoke", + body: `{"reason_code":`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`, + }, + { + name: "revoke device session multiple objects", + method: http.MethodPost, + target: "/api/v1/internal/sessions/device-session-123/revoke", + body: `{"reason_code":"admin_revoke","actor":{"type":"system"}}{"reason_code":"admin_revoke","actor":{"type":"system"}}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body must contain a single JSON object"}}`, + }, + { + name: "revoke device session unknown field", + method: http.MethodPost, + target: "/api/v1/internal/sessions/device-session-123/revoke", + body: `{"reason_code":"admin_revoke","actor":{"type":"system"},"extra":true}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body contains unknown field \"extra\""}}`, + }, + { + name: "revoke device session invalid json type", + method: http.MethodPost, + target: "/api/v1/internal/sessions/device-session-123/revoke", + body: `{"reason_code":123,"actor":{"type":"system"}}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body contains an invalid value for \"reason_code\""}}`, + }, + { + name: "revoke all missing reason code", + method: http.MethodPost, + target: "/api/v1/internal/users/user-123/sessions/revoke-all", + body: `{"reason_code":" ","actor":{"type":"system"}}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"reason_code must not be empty"}}`, + }, + { + name: "block user missing actor type", + method: http.MethodPost, + target: "/api/v1/internal/user-blocks", + body: `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":" "}}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"actor.type must not be empty"}}`, + }, + { + name: "block user missing subject", + method: http.MethodPost, + target: "/api/v1/internal/user-blocks", + body: `{"reason_code":"policy_blocked","actor":{"type":"system"}}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"exactly one of user_id or email must be provided"}}`, + }, + { + name: "block user conflicting subjects", + method: http.MethodPost, + target: "/api/v1/internal/user-blocks", + body: `{"user_id":"user-123","email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"system"}}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"exactly one of user_id or email must be provided"}}`, + }, + } + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body)) + if tt.body != "" { + request.Header.Set("Content-Type", "application/json") + } + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestInternalHandlersMapServiceErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + target string + body string + deps Dependencies + wantStatus int + wantBody string + }{ + { + name: "get session not found", + method: http.MethodGet, + target: "/api/v1/internal/sessions/missing", + deps: Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{}, shared.SessionNotFound() + }), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }, + wantStatus: http.StatusNotFound, + wantBody: `{"error":{"code":"session_not_found","message":"session not found"}}`, + }, + { + name: "revoke all subject not found", + method: http.MethodPost, + target: "/api/v1/internal/users/missing/sessions/revoke-all", + body: `{"reason_code":"logout_all","actor":{"type":"system"}}`, + deps: Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) { + return revokeallusersessions.Result{}, shared.SubjectNotFound() + }), + BlockUser: blockUserFunc(unexpectedBlockUser), + }, + wantStatus: http.StatusNotFound, + wantBody: `{"error":{"code":"subject_not_found","message":"subject not found"}}`, + }, + { + name: "service unavailable", + method: http.MethodGet, + target: "/api/v1/internal/sessions/device-session-123", + deps: Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{}, shared.ServiceUnavailable(errors.New("redis timeout")) + }), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }, + wantStatus: http.StatusServiceUnavailable, + wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, + }, + { + name: "internal error", + method: http.MethodGet, + target: "/api/v1/internal/sessions/device-session-123", + deps: Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{}, shared.InternalError(errors.New("broken invariant")) + }), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }, + wantStatus: http.StatusInternalServerError, + wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`, + }, + { + name: "unexpected error hidden", + method: http.MethodGet, + target: "/api/v1/internal/sessions/device-session-123", + deps: Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{}, errors.New("boom") + }), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }, + wantStatus: http.StatusInternalServerError, + wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), tt.deps) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body)) + if tt.body != "" { + request.Header.Set("Content-Type", "application/json") + } + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestInternalHandlerTimeoutMapsToServiceUnavailable(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.RequestTimeout = 5 * time.Millisecond + + handler := mustNewHandler(t, cfg, Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{}, context.DeadlineExceeded + }), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil) + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) + assert.JSONEq(t, `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, recorder.Body.String()) +} + +func TestInternalHandlersRejectInvalidSuccessPayloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + target string + body string + deps Dependencies + }{ + { + name: "get session malformed response", + method: http.MethodGet, + target: "/api/v1/internal/sessions/device-session-123", + deps: Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + dto := validSessionDTO() + dto.DeviceSessionID = "" + return getsession.Result{Session: dto}, nil + }), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(unexpectedBlockUser), + }, + }, + { + name: "revoke all malformed response", + method: http.MethodPost, + target: "/api/v1/internal/users/user-123/sessions/revoke-all", + body: `{"reason_code":"logout_all","actor":{"type":"system"}}`, + deps: Dependencies{ + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) { + return revokeallusersessions.Result{ + Outcome: "revoked", + UserID: "user-123", + AffectedSessionCount: 2, + AffectedDeviceSessionIDs: []string{"device-session-1"}, + }, nil + }), + BlockUser: blockUserFunc(unexpectedBlockUser), + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), tt.deps) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body)) + if tt.body != "" { + request.Header.Set("Content-Type", "application/json") + } + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + assert.JSONEq(t, `{"error":{"code":"internal_error","message":"internal server error"}}`, recorder.Body.String()) + }) + } +} + +func TestInternalHandlerLogsDoNotContainSensitiveFields(t *testing.T) { + t.Parallel() + + logger, buffer := newObservedLogger() + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + Logger: logger, + GetSession: getSessionFunc(unexpectedGetSession), + ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions), + RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession), + RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions), + BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) { + return blockuser.Result{ + Outcome: "blocked", + SubjectKind: blockuser.SubjectKindEmail, + SubjectValue: "pilot@example.com", + AffectedSessionCount: 0, + AffectedDeviceSessionIDs: []string{}, + }, nil + }), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/internal/user-blocks", + bytes.NewBufferString(`{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin","id":"admin-1"}}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + require.Equal(t, http.StatusOK, recorder.Code) + logOutput := buffer.String() + assert.NotContains(t, logOutput, "pilot@example.com") + assert.NotContains(t, logOutput, "admin-1") + assert.NotContains(t, logOutput, "reason_code") +} + +func mustNewHandler(t *testing.T, cfg Config, deps Dependencies) http.Handler { + t.Helper() + + handler, err := newHandlerWithConfig(cfg, deps) + require.NoError(t, err) + return handler +} + +type getSessionFunc func(ctx context.Context, input getsession.Input) (getsession.Result, error) + +func (f getSessionFunc) Execute(ctx context.Context, input getsession.Input) (getsession.Result, error) { + return f(ctx, input) +} + +type listUserSessionsFunc func(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error) + +func (f listUserSessionsFunc) Execute(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error) { + return f(ctx, input) +} + +type revokeDeviceSessionFunc func(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) + +func (f revokeDeviceSessionFunc) Execute(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) { + return f(ctx, input) +} + +type revokeAllUserSessionsFunc func(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) + +func (f revokeAllUserSessionsFunc) Execute(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) { + return f(ctx, input) +} + +type blockUserFunc func(ctx context.Context, input blockuser.Input) (blockuser.Result, error) + +func (f blockUserFunc) Execute(ctx context.Context, input blockuser.Input) (blockuser.Result, error) { + return f(ctx, input) +} + +func validSessionDTO() shared.Session { + return shared.Session{ + DeviceSessionID: "device-session-123", + UserID: "user-123", + ClientPublicKey: "public-key-material", + Status: "active", + CreatedAt: "2026-04-05T12:00:00Z", + } +} + +func validRevokedSessionDTO() shared.Session { + dto := validSessionDTO() + dto.Status = "revoked" + revokedAt := "2026-04-05T12:01:00Z" + reasonCode := "admin_revoke" + actorType := "admin" + actorID := "admin-1" + dto.RevokedAt = &revokedAt + dto.RevokeReasonCode = &reasonCode + dto.RevokeActorType = &actorType + dto.RevokeActorID = &actorID + return dto +} + +func newObservedLogger() (*zap.Logger, *bytes.Buffer) { + buffer := &bytes.Buffer{} + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "" + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(encoderConfig), + zapcore.AddSync(buffer), + zap.DebugLevel, + ) + + return zap.New(core), buffer +} + +func unexpectedGetSession(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{}, errors.New("unexpected call") +} + +func unexpectedListUserSessions(context.Context, listusersessions.Input) (listusersessions.Result, error) { + return listusersessions.Result{}, errors.New("unexpected call") +} + +func unexpectedRevokeDeviceSession(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) { + return revokedevicesession.Result{}, errors.New("unexpected call") +} + +func unexpectedRevokeAllUserSessions(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) { + return revokeallusersessions.Result{}, errors.New("unexpected call") +} + +func unexpectedBlockUser(context.Context, blockuser.Input) (blockuser.Result, error) { + return blockuser.Result{}, errors.New("unexpected call") +} diff --git a/authsession/internal/api/internalhttp/json.go b/authsession/internal/api/internalhttp/json.go new file mode 100644 index 0000000..171ac9e --- /dev/null +++ b/authsession/internal/api/internalhttp/json.go @@ -0,0 +1,93 @@ +package internalhttp + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "galaxy/authsession/internal/service/shared" + + "github.com/gin-gonic/gin" +) + +const internalErrorCodeContextKey = "internal_error_code" + +type malformedJSONRequestError struct { + message string +} + +func (e *malformedJSONRequestError) Error() string { + if e == nil { + return "" + } + + return e.message +} + +func decodeJSONRequest(request *http.Request, target any) error { + if request == nil || request.Body == nil { + return &malformedJSONRequestError{message: "request body must not be empty"} + } + + return decodeJSONReader(request.Body, target) +} + +func decodeJSONReader(reader io.Reader, target any) error { + decoder := json.NewDecoder(reader) + decoder.DisallowUnknownFields() + + if err := decoder.Decode(target); err != nil { + return describeJSONDecodeError(err) + } + + if err := decoder.Decode(&struct{}{}); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return &malformedJSONRequestError{message: "request body must contain a single JSON object"} + } + + return &malformedJSONRequestError{message: "request body must contain a single JSON object"} +} + +func describeJSONDecodeError(err error) error { + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + + switch { + case errors.Is(err, io.EOF): + return &malformedJSONRequestError{message: "request body must not be empty"} + case errors.As(err, &syntaxErr): + return &malformedJSONRequestError{message: "request body contains malformed JSON"} + case errors.Is(err, io.ErrUnexpectedEOF): + return &malformedJSONRequestError{message: "request body contains malformed JSON"} + case errors.As(err, &typeErr): + if strings.TrimSpace(typeErr.Field) != "" { + return &malformedJSONRequestError{ + message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field), + } + } + + return &malformedJSONRequestError{message: "request body contains an invalid JSON value"} + case strings.HasPrefix(err.Error(), "json: unknown field "): + return &malformedJSONRequestError{ + message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")), + } + default: + return &malformedJSONRequestError{message: "request body contains invalid JSON"} + } +} + +func abortWithProjection(c *gin.Context, projection shared.InternalErrorProjection) { + c.Set(internalErrorCodeContextKey, projection.Code) + c.AbortWithStatusJSON(projection.StatusCode, errorResponse{ + Error: errorBody{ + Code: projection.Code, + Message: projection.Message, + }, + }) +} diff --git a/authsession/internal/api/internalhttp/observability.go b/authsession/internal/api/internalhttp/observability.go new file mode 100644 index 0000000..7f8d7a5 --- /dev/null +++ b/authsession/internal/api/internalhttp/observability.go @@ -0,0 +1,86 @@ +package internalhttp + +import ( + "time" + + authlogging "galaxy/authsession/internal/logging" + "galaxy/authsession/internal/telemetry" + + "github.com/gin-gonic/gin" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" +) + +type edgeOutcome string + +const ( + edgeOutcomeSuccess edgeOutcome = "success" + edgeOutcomeRejected edgeOutcome = "rejected" + edgeOutcomeFailed edgeOutcome = "failed" +) + +func withInternalObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc { + if logger == nil { + logger = zap.NewNop() + } + + return func(c *gin.Context) { + start := time.Now() + c.Next() + + statusCode := c.Writer.Status() + route := c.FullPath() + if route == "" { + route = "unmatched" + } + + errorCode, _ := c.Get(internalErrorCodeContextKey) + errorCodeValue, _ := errorCode.(string) + outcome := outcomeFromStatusCode(statusCode) + duration := time.Since(start) + + fields := []zap.Field{ + zap.String("component", "internal_http"), + zap.String("transport", "http"), + zap.String("route", route), + zap.String("method", c.Request.Method), + zap.Int("status_code", statusCode), + zap.Float64("duration_ms", float64(duration.Microseconds())/1000), + zap.String("edge_outcome", string(outcome)), + } + if errorCodeValue != "" { + fields = append(fields, zap.String("error_code", errorCodeValue)) + } + fields = append(fields, authlogging.TraceFieldsFromContext(c.Request.Context())...) + + metricAttrs := []attribute.KeyValue{ + attribute.String("route", route), + attribute.String("method", c.Request.Method), + attribute.String("edge_outcome", string(outcome)), + } + if errorCodeValue != "" { + metricAttrs = append(metricAttrs, attribute.String("error_code", errorCodeValue)) + } + metrics.RecordInternalHTTPRequest(c.Request.Context(), metricAttrs, duration) + + switch outcome { + case edgeOutcomeSuccess: + logger.Info("internal request completed", fields...) + case edgeOutcomeFailed: + logger.Error("internal request failed", fields...) + default: + logger.Warn("internal request rejected", fields...) + } + } +} + +func outcomeFromStatusCode(statusCode int) edgeOutcome { + switch { + case statusCode >= 500: + return edgeOutcomeFailed + case statusCode >= 400: + return edgeOutcomeRejected + default: + return edgeOutcomeSuccess + } +} diff --git a/authsession/internal/api/internalhttp/observability_test.go b/authsession/internal/api/internalhttp/observability_test.go new file mode 100644 index 0000000..69ebf7b --- /dev/null +++ b/authsession/internal/api/internalhttp/observability_test.go @@ -0,0 +1,121 @@ +package internalhttp + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/shared" + authtelemetry "galaxy/authsession/internal/telemetry" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +func TestInternalHandlerEmitsTraceFieldsAndMetrics(t *testing.T) { + t.Parallel() + + logger, buffer := newObservedLogger() + telemetryRuntime, reader, recorder := newObservedInternalTelemetryRuntime(t) + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + Logger: logger, + Telemetry: telemetryRuntime, + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{Session: validSessionDTO()}, nil + }), + ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) { + return listusersessions.Result{Sessions: []shared.Session{}}, nil + }), + RevokeDeviceSession: revokeDeviceSessionFunc(func(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) { + return revokedevicesession.Result{}, nil + }), + RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) { + return revokeallusersessions.Result{}, nil + }), + BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) { + return blockuser.Result{}, nil + }), + }) + + recorderHTTP := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil) + + handler.ServeHTTP(recorderHTTP, request) + + require.Equal(t, http.StatusOK, recorderHTTP.Code) + require.NotEmpty(t, recorder.Ended()) + assert.Contains(t, buffer.String(), "otel_trace_id") + assert.Contains(t, buffer.String(), "otel_span_id") + + assertMetricCount(t, reader, "authsession.internal_http.requests", map[string]string{ + "route": "/api/v1/internal/sessions/:device_session_id", + "method": http.MethodGet, + "edge_outcome": "success", + }, 1) +} + +func newObservedInternalTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader, *tracetest.SpanRecorder) { + t.Helper() + + reader := sdkmetric.NewManualReader() + meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + recorder := tracetest.NewSpanRecorder() + tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + + runtime, err := authtelemetry.NewWithProviders(meterProvider, tracerProvider) + require.NoError(t, err) + + return runtime, reader, recorder +} + +func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) { + t.Helper() + + var resourceMetrics metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &resourceMetrics)) + + for _, scopeMetrics := range resourceMetrics.ScopeMetrics { + for _, metric := range scopeMetrics.Metrics { + if metric.Name != metricName { + continue + } + + sum, ok := metric.Data.(metricdata.Sum[int64]) + require.True(t, ok) + + for _, point := range sum.DataPoints { + if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) { + assert.Equal(t, wantValue, point.Value) + return + } + } + } + } + + require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs) +} + +func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool { + if len(values) != len(want) { + return false + } + + for _, value := range values { + if want[string(value.Key)] != value.Value.AsString() { + return false + } + } + + return true +} diff --git a/authsession/internal/api/internalhttp/server.go b/authsession/internal/api/internalhttp/server.go new file mode 100644 index 0000000..324c533 --- /dev/null +++ b/authsession/internal/api/internalhttp/server.go @@ -0,0 +1,271 @@ +package internalhttp + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +const ( + defaultAddr = ":8081" + defaultReadHeaderTimeout = 2 * time.Second + defaultReadTimeout = 10 * time.Second + defaultIdleTimeout = time.Minute + defaultRequestTimeout = 3 * time.Second +) + +// GetSessionUseCase describes the trusted internal get-session service +// consumed by the HTTP transport layer. +type GetSessionUseCase interface { + // Execute loads one device session for trusted internal callers. + Execute(ctx context.Context, input getsession.Input) (getsession.Result, error) +} + +// ListUserSessionsUseCase describes the trusted internal list-user-sessions +// service consumed by the HTTP transport layer. +type ListUserSessionsUseCase interface { + // Execute lists all sessions of one user for trusted internal callers. + Execute(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error) +} + +// RevokeDeviceSessionUseCase describes the trusted internal single-session +// revoke service consumed by the HTTP transport layer. +type RevokeDeviceSessionUseCase interface { + // Execute revokes one device session and returns the frozen + // acknowledgement. + Execute(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) +} + +// RevokeAllUserSessionsUseCase describes the trusted internal bulk-revoke +// service consumed by the HTTP transport layer. +type RevokeAllUserSessionsUseCase interface { + // Execute revokes all active sessions of one user and returns the frozen + // acknowledgement. + Execute(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) +} + +// BlockUserUseCase describes the trusted internal block-user service consumed +// by the HTTP transport layer. +type BlockUserUseCase interface { + // Execute applies a block state to one subject and returns the frozen + // acknowledgement. + Execute(ctx context.Context, input blockuser.Input) (blockuser.Result, error) +} + +// Config describes the trusted internal HTTP listener owned by authsession. +type Config struct { + // Addr is the TCP listen address used by the trusted internal HTTP server. + Addr string + + // ReadHeaderTimeout bounds how long the listener may spend reading request + // headers before the server rejects the connection. + ReadHeaderTimeout time.Duration + + // ReadTimeout bounds how long the listener may spend reading one trusted + // internal request. + ReadTimeout time.Duration + + // IdleTimeout bounds how long the listener keeps an idle keep-alive + // connection open. + IdleTimeout time.Duration + + // RequestTimeout bounds one application-layer internal use-case call. + RequestTimeout time.Duration +} + +// Validate reports whether cfg contains a usable internal HTTP listener +// configuration. +func (cfg Config) Validate() error { + switch { + case cfg.Addr == "": + return errors.New("internal HTTP addr must not be empty") + case cfg.ReadHeaderTimeout <= 0: + return errors.New("internal HTTP read header timeout must be positive") + case cfg.ReadTimeout <= 0: + return errors.New("internal HTTP read timeout must be positive") + case cfg.IdleTimeout <= 0: + return errors.New("internal HTTP idle timeout must be positive") + case cfg.RequestTimeout <= 0: + return errors.New("internal HTTP request timeout must be positive") + default: + return nil + } +} + +// DefaultConfig returns the default trusted internal HTTP listener settings. +func DefaultConfig() Config { + return Config{ + Addr: defaultAddr, + ReadHeaderTimeout: defaultReadHeaderTimeout, + ReadTimeout: defaultReadTimeout, + IdleTimeout: defaultIdleTimeout, + RequestTimeout: defaultRequestTimeout, + } +} + +// Dependencies describes the collaborators used by the trusted internal HTTP +// transport layer. +type Dependencies struct { + // GetSession executes the trusted internal get-session use case. + GetSession GetSessionUseCase + + // ListUserSessions executes the trusted internal list-user-sessions use + // case. + ListUserSessions ListUserSessionsUseCase + + // RevokeDeviceSession executes the trusted internal single-session revoke + // use case. + RevokeDeviceSession RevokeDeviceSessionUseCase + + // RevokeAllUserSessions executes the trusted internal bulk-revoke use case. + RevokeAllUserSessions RevokeAllUserSessionsUseCase + + // BlockUser executes the trusted internal block-user use case. + BlockUser BlockUserUseCase + + // Logger writes structured transport logs. When nil, a no-op logger is + // used. + Logger *zap.Logger + + // Telemetry records OpenTelemetry spans and low-cardinality HTTP metrics. + // When nil, the transport still serves requests with no-op providers. + Telemetry *telemetry.Runtime +} + +// Server owns the trusted internal HTTP listener exposed by authsession. +type Server struct { + cfg Config + + handler http.Handler + logger *zap.Logger + + stateMu sync.RWMutex + server *http.Server + listener net.Listener +} + +// NewServer constructs one trusted internal HTTP server for cfg and deps. +func NewServer(cfg Config, deps Dependencies) (*Server, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("new internal HTTP server: %w", err) + } + + handler, err := newHandlerWithConfig(cfg, deps) + if err != nil { + return nil, fmt.Errorf("new internal HTTP server: %w", err) + } + + logger := deps.Logger + if logger == nil { + logger = zap.NewNop() + } + logger = logger.Named("internal_http") + + return &Server{ + cfg: cfg, + handler: handler, + logger: logger, + }, nil +} + +// Run binds the configured listener and serves the trusted internal HTTP +// surface until Shutdown closes the server. +func (s *Server) Run(ctx context.Context) error { + if ctx == nil { + return errors.New("run internal HTTP server: nil context") + } + if err := ctx.Err(); err != nil { + return err + } + + listener, err := net.Listen("tcp", s.cfg.Addr) + if err != nil { + return fmt.Errorf("run internal HTTP server: listen on %q: %w", s.cfg.Addr, err) + } + + server := &http.Server{ + Handler: s.handler, + ReadHeaderTimeout: s.cfg.ReadHeaderTimeout, + ReadTimeout: s.cfg.ReadTimeout, + IdleTimeout: s.cfg.IdleTimeout, + } + + s.stateMu.Lock() + s.server = server + s.listener = listener + s.stateMu.Unlock() + + s.logger.Info("internal HTTP server started", zap.String("addr", listener.Addr().String())) + + defer func() { + s.stateMu.Lock() + s.server = nil + s.listener = nil + s.stateMu.Unlock() + }() + + err = server.Serve(listener) + switch { + case err == nil: + return nil + case errors.Is(err, http.ErrServerClosed): + s.logger.Info("internal HTTP server stopped") + return nil + default: + return fmt.Errorf("run internal HTTP server: serve on %q: %w", s.cfg.Addr, err) + } +} + +// Shutdown gracefully stops the trusted internal HTTP server within ctx. +func (s *Server) Shutdown(ctx context.Context) error { + if ctx == nil { + return errors.New("shutdown internal HTTP server: nil context") + } + + s.stateMu.RLock() + server := s.server + s.stateMu.RUnlock() + + if server == nil { + return nil + } + + if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("shutdown internal HTTP server: %w", err) + } + + return nil +} + +func normalizeDependencies(deps Dependencies) (Dependencies, error) { + switch { + case deps.GetSession == nil: + return Dependencies{}, errors.New("get session use case must not be nil") + case deps.ListUserSessions == nil: + return Dependencies{}, errors.New("list user sessions use case must not be nil") + case deps.RevokeDeviceSession == nil: + return Dependencies{}, errors.New("revoke device session use case must not be nil") + case deps.RevokeAllUserSessions == nil: + return Dependencies{}, errors.New("revoke all user sessions use case must not be nil") + case deps.BlockUser == nil: + return Dependencies{}, errors.New("block user use case must not be nil") + case deps.Logger == nil: + deps.Logger = zap.NewNop() + } + + deps.Logger = deps.Logger.Named("internal_http") + return deps, nil +} diff --git a/authsession/internal/api/internalhttp/server_test.go b/authsession/internal/api/internalhttp/server_test.go new file mode 100644 index 0000000..31a295c --- /dev/null +++ b/authsession/internal/api/internalhttp/server_test.go @@ -0,0 +1,106 @@ +package internalhttp + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/shared" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewServerRejectsInvalidConfiguration(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Addr = "" + + _, err := NewServer(cfg, validDependencies()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "addr") +} + +func TestServerRunAndShutdown(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Addr = "127.0.0.1:0" + + server, err := NewServer(cfg, validDependencies()) + require.NoError(t, err) + + runErr := make(chan error, 1) + go func() { + runErr <- server.Run(context.Background()) + }() + + require.Eventually(t, func() bool { + server.stateMu.RLock() + defer server.stateMu.RUnlock() + return server.listener != nil + }, time.Second, 10*time.Millisecond) + + server.stateMu.RLock() + addr := server.listener.Addr().String() + server.stateMu.RUnlock() + + response, err := http.Post( + "http://"+addr+"/api/v1/internal/sessions/device-session-123/revoke", + "application/json", + bytes.NewBufferString(`{"reason_code":"admin_revoke","actor":{"type":"system"}}`), + ) + require.NoError(t, err) + defer response.Body.Close() + + assert.Equal(t, http.StatusOK, response.StatusCode) + + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, server.Shutdown(shutdownCtx)) + require.NoError(t, <-runErr) +} + +func validDependencies() Dependencies { + return Dependencies{ + GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) { + return getsession.Result{Session: validSessionDTO()}, nil + }), + ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) { + return listusersessions.Result{Sessions: []shared.Session{validSessionDTO()}}, nil + }), + RevokeDeviceSession: revokeDeviceSessionFunc(func(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) { + return revokedevicesession.Result{ + Outcome: "revoked", + DeviceSessionID: "device-session-123", + AffectedSessionCount: 1, + }, nil + }), + RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) { + return revokeallusersessions.Result{ + Outcome: "revoked", + UserID: "user-123", + AffectedSessionCount: 1, + AffectedDeviceSessionIDs: []string{"device-session-123"}, + }, nil + }), + BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) { + return blockuser.Result{ + Outcome: "blocked", + SubjectKind: blockuser.SubjectKindEmail, + SubjectValue: "pilot@example.com", + AffectedSessionCount: 0, + AffectedDeviceSessionIDs: []string{}, + }, nil + }), + } +} diff --git a/authsession/internal/api/publichttp/doc.go b/authsession/internal/api/publichttp/doc.go new file mode 100644 index 0000000..d95fe33 --- /dev/null +++ b/authsession/internal/api/publichttp/doc.go @@ -0,0 +1,3 @@ +// Package publichttp exposes the public HTTP transport expected by the +// gateway-facing authentication flow. +package publichttp diff --git a/authsession/internal/api/publichttp/e2e_test.go b/authsession/internal/api/publichttp/e2e_test.go new file mode 100644 index 0000000..6b2af4d --- /dev/null +++ b/authsession/internal/api/publichttp/e2e_test.go @@ -0,0 +1,391 @@ +package publichttp + +import ( + "bytes" + "context" + "crypto/ed25519" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "galaxy/authsession/internal/adapters/mail" + "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPublicHTTPEndToEndSendThenConfirm(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{}) + server := httptest.NewServer(app.handler) + defer server.Close() + + sendResponse := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, sendResponse.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, sendResponse.Body) + + attempts := app.mailSender.RecordedAttempts() + require.Len(t, attempts, 1) + + confirmBody := map[string]string{ + "challenge_id": "challenge-1", + "code": attempts[0].Input.Code, + "client_public_key": validClientPublicKey, + } + confirmResponse := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", confirmBody) + + assert.Equal(t, http.StatusOK, confirmResponse.StatusCode) + assert.JSONEq(t, `{"device_session_id":"device-session-1"}`, confirmResponse.Body) +} + +func TestPublicHTTPEndToEndBlockedSendReturnsChallengeID(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{ + SeedBlockedEmail: true, + }) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, response.Body) + assert.Empty(t, app.mailSender.RecordedAttempts()) +} + +func TestPublicHTTPEndToEndThrottledSendStillReturnsChallengeID(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{ + AbuseProtector: &testkit.InMemorySendEmailCodeAbuseProtector{}, + }) + server := httptest.NewServer(app.handler) + defer server.Close() + + first := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, first.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, first.Body) + + second := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, second.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-2"}`, second.Body) + assert.Len(t, app.mailSender.RecordedAttempts(), 1) +} + +func TestPublicHTTPEndToEndInvalidClientPublicKey(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{ + SeedChallenge: seedChallengeOptions{ + ID: "challenge-123", + Code: "123456", + Status: challenge.StatusSent, + }, + }) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSON( + t, + server.URL+"/api/v1/public/auth/confirm-email-code", + `{"challenge_id":"challenge-123","code":"123456","client_public_key":"invalid"}`, + ) + + assert.Equal(t, http.StatusBadRequest, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`, response.Body) +} + +func TestPublicHTTPEndToEndChallengeNotFound(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{}) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "missing", + "code": "123456", + "client_public_key": validClientPublicKey, + }) + + assert.Equal(t, http.StatusNotFound, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"challenge_not_found","message":"challenge not found"}}`, response.Body) +} + +func TestPublicHTTPEndToEndChallengeExpired(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{ + SeedChallenge: seedChallengeOptions{ + ID: "challenge-123", + Code: "123456", + Status: challenge.StatusSent, + ExpiresAt: time.Date(2026, 4, 5, 11, 59, 0, 0, time.UTC), + }, + }) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "challenge-123", + "code": "123456", + "client_public_key": validClientPublicKey, + }) + + assert.Equal(t, http.StatusGone, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"challenge_expired","message":"challenge expired"}}`, response.Body) +} + +func TestPublicHTTPEndToEndInvalidCode(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{ + SeedChallenge: seedChallengeOptions{ + ID: "challenge-123", + Code: "123456", + Status: challenge.StatusSent, + }, + }) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "challenge-123", + "code": "654321", + "client_public_key": validClientPublicKey, + }) + + assert.Equal(t, http.StatusBadRequest, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"invalid_code","message":"confirmation code is invalid"}}`, response.Body) +} + +func TestPublicHTTPEndToEndThrottledChallengeConfirmReturnsInvalidCode(t *testing.T) { + t.Parallel() + + app := newEndToEndApp(t, endToEndOptions{ + SeedChallenge: seedChallengeOptions{ + ID: "challenge-123", + Code: "123456", + Status: challenge.StatusDeliveryThrottled, + }, + }) + server := httptest.NewServer(app.handler) + defer server.Close() + + response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "challenge-123", + "code": "123456", + "client_public_key": validClientPublicKey, + }) + + assert.Equal(t, http.StatusBadRequest, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"invalid_code","message":"confirmation code is invalid"}}`, response.Body) +} + +func TestPublicHTTPEndToEndSessionLimitExceeded(t *testing.T) { + t.Parallel() + + limit := 1 + app := newEndToEndApp(t, endToEndOptions{ + Config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, + SeedExistingUser: true, + SeedActiveSession: &devicesession.Session{ + ID: common.DeviceSessionID("device-session-existing"), + UserID: common.UserID("user-1"), + ClientPublicKey: mustClientPublicKey(t, secondValidClientPublicKey), + Status: devicesession.StatusActive, + CreatedAt: time.Date(2026, 4, 5, 11, 58, 0, 0, time.UTC), + }, + }) + server := httptest.NewServer(app.handler) + defer server.Close() + + sendResponse := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, sendResponse.StatusCode) + + attempts := app.mailSender.RecordedAttempts() + require.Len(t, attempts, 1) + + confirmResponse := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "challenge-1", + "code": attempts[0].Input.Code, + "client_public_key": validClientPublicKey, + }) + + assert.Equal(t, http.StatusConflict, confirmResponse.StatusCode) + assert.JSONEq(t, `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`, confirmResponse.Body) +} + +type endToEndOptions struct { + Config ports.SessionLimitConfig + AbuseProtector ports.SendEmailCodeAbuseProtector + SeedBlockedEmail bool + SeedExistingUser bool + SeedChallenge seedChallengeOptions + SeedActiveSession *devicesession.Session +} + +type seedChallengeOptions struct { + ID string + Code string + Status challenge.Status + ExpiresAt time.Time +} + +type endToEndApp struct { + handler http.Handler + mailSender *mail.StubSender +} + +func newEndToEndApp(t *testing.T, options endToEndOptions) endToEndApp { + t.Helper() + + now := time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC) + challengeStore := &testkit.InMemoryChallengeStore{} + sessionStore := &testkit.InMemorySessionStore{} + userDirectory := &userservice.StubDirectory{} + mailSender := &mail.StubSender{} + idGenerator := &testkit.SequenceIDGenerator{} + codeGenerator := testkit.FixedCodeGenerator{Code: "123456"} + codeHasher := testkit.DeterministicCodeHasher{} + clock := testkit.FixedClock{Time: now} + publisher := &testkit.RecordingProjectionPublisher{} + + if options.SeedBlockedEmail { + require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_blocked"))) + } + if options.SeedExistingUser { + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + } + if options.SeedActiveSession != nil { + require.NoError(t, sessionStore.Create(context.Background(), *options.SeedActiveSession)) + } + if options.SeedChallenge.ID != "" { + expiresAt := options.SeedChallenge.ExpiresAt + if expiresAt.IsZero() { + expiresAt = now.Add(challenge.InitialTTL) + } + + record := challenge.Challenge{ + ID: common.ChallengeID(options.SeedChallenge.ID), + Email: common.Email("pilot@example.com"), + CodeHash: mustHashCode(t, options.SeedChallenge.Code), + Status: options.SeedChallenge.Status, + DeliveryState: deliveryStateForSeedChallenge(options.SeedChallenge.Status), + CreatedAt: now.Add(-time.Minute), + ExpiresAt: expiresAt, + } + require.NoError(t, challengeStore.Create(context.Background(), record)) + } + + sendService, err := sendemailcode.NewWithRuntime( + challengeStore, + userDirectory, + idGenerator, + codeGenerator, + codeHasher, + mailSender, + options.AbuseProtector, + clock, + nil, + ) + require.NoError(t, err) + + confirmService, err := confirmemailcode.New( + challengeStore, + sessionStore, + userDirectory, + testkit.StaticConfigProvider{Config: options.Config}, + publisher, + idGenerator, + codeHasher, + clock, + ) + require.NoError(t, err) + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + SendEmailCode: sendService, + ConfirmEmailCode: confirmService, + }) + + return endToEndApp{ + handler: handler, + mailSender: mailSender, + } +} + +func deliveryStateForSeedChallenge(status challenge.Status) challenge.DeliveryState { + switch status { + case challenge.StatusDeliverySuppressed: + return challenge.DeliverySuppressed + case challenge.StatusDeliveryThrottled: + return challenge.DeliveryThrottled + default: + return challenge.DeliverySent + } +} + +type httpResponse struct { + StatusCode int + Body string +} + +func postJSON(t *testing.T, url string, body string) httpResponse { + t.Helper() + + response, err := http.Post(url, "application/json", bytes.NewBufferString(body)) + require.NoError(t, err) + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + require.NoError(t, err) + + return httpResponse{StatusCode: response.StatusCode, Body: string(payload)} +} + +func postJSONValue(t *testing.T, url string, value any) httpResponse { + t.Helper() + + body, err := json.Marshal(value) + require.NoError(t, err) + return postJSON(t, url, string(body)) +} + +func mustHashCode(t *testing.T, code string) []byte { + t.Helper() + + sum := sha256.Sum256([]byte(code)) + return sum[:] +} + +func mustClientPublicKey(t *testing.T, encoded string) common.ClientPublicKey { + t.Helper() + + decoded, err := base64.StdEncoding.DecodeString(encoded) + require.NoError(t, err) + + key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded)) + require.NoError(t, err) + return key +} + +const ( + validClientPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=" + secondValidClientPublicKey = "ICEiIyQlJicoKSorLC0uLzAxMjM0NTY3ODk6Ozw9Pj8=" +) diff --git a/authsession/internal/api/publichttp/handler.go b/authsession/internal/api/publichttp/handler.go new file mode 100644 index 0000000..b5af8ab --- /dev/null +++ b/authsession/internal/api/publichttp/handler.go @@ -0,0 +1,242 @@ +package publichttp + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/mail" + "strings" + "sync" + "time" + + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/telemetry" + + "github.com/gin-gonic/gin" + "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" +) + +const jsonContentType = "application/json; charset=utf-8" + +const publicHTTPServiceName = "galaxy-authsession-public" + +type sendEmailCodeRequest struct { + Email string `json:"email"` +} + +type sendEmailCodeResponse struct { + ChallengeID string `json:"challenge_id"` +} + +type confirmEmailCodeRequest struct { + ChallengeID string `json:"challenge_id"` + Code string `json:"code"` + ClientPublicKey string `json:"client_public_key"` +} + +type confirmEmailCodeResponse struct { + DeviceSessionID string `json:"device_session_id"` +} + +type errorResponse struct { + Error errorBody `json:"error"` +} + +type errorBody struct { + Code string `json:"code"` + Message string `json:"message"` +} + +var configureGinModeOnce sync.Once + +func newHandlerWithConfig(cfg Config, deps Dependencies) (http.Handler, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + + normalizedDeps, err := normalizeDependencies(deps) + if err != nil { + return nil, err + } + + configureGinModeOnce.Do(func() { + gin.SetMode(gin.ReleaseMode) + }) + + engine := gin.New() + engine.Use(newOTelMiddleware(normalizedDeps.Telemetry)) + engine.Use(withPublicObservability(normalizedDeps.Logger, normalizedDeps.Telemetry)) + engine.POST( + "/api/v1/public/auth/send-email-code", + handleSendEmailCode(normalizedDeps.SendEmailCode, cfg.RequestTimeout), + ) + engine.POST( + "/api/v1/public/auth/confirm-email-code", + handleConfirmEmailCode(normalizedDeps.ConfirmEmailCode, cfg.RequestTimeout), + ) + + return engine, nil +} + +func newOTelMiddleware(runtime *telemetry.Runtime) gin.HandlerFunc { + options := []otelgin.Option{} + if runtime != nil { + options = append( + options, + otelgin.WithTracerProvider(runtime.TracerProvider()), + otelgin.WithMeterProvider(runtime.MeterProvider()), + ) + } + + return otelgin.Middleware(publicHTTPServiceName, options...) +} + +func handleSendEmailCode(useCase SendEmailCodeUseCase, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + var request sendEmailCodeRequest + if err := decodeJSONRequest(c.Request, &request); err != nil { + abortWithProjection(c, projectSendEmailCodeError(shared.InvalidRequest(err.Error()))) + return + } + if err := validateSendEmailCodeRequest(&request); err != nil { + abortWithProjection(c, projectSendEmailCodeError(shared.InvalidRequest(err.Error()))) + return + } + + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := useCase.Execute(callCtx, sendemailcode.Input{Email: request.Email}) + if err != nil { + abortWithProjection(c, projectSendEmailCodeError(err)) + return + } + if err := validateSendEmailCodeResult(&result); err != nil { + abortWithProjection(c, unavailableProjection(fmt.Errorf("send email code response: %w", err))) + return + } + + c.JSON(http.StatusOK, sendEmailCodeResponse{ChallengeID: result.ChallengeID}) + } +} + +func handleConfirmEmailCode(useCase ConfirmEmailCodeUseCase, timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + var request confirmEmailCodeRequest + if err := decodeJSONRequest(c.Request, &request); err != nil { + abortWithProjection(c, projectConfirmEmailCodeError(shared.InvalidRequest(err.Error()))) + return + } + if err := validateConfirmEmailCodeRequest(&request); err != nil { + abortWithProjection(c, projectConfirmEmailCodeError(shared.InvalidRequest(err.Error()))) + return + } + + callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + result, err := useCase.Execute(callCtx, confirmemailcode.Input{ + ChallengeID: request.ChallengeID, + Code: request.Code, + ClientPublicKey: request.ClientPublicKey, + }) + if err != nil { + abortWithProjection(c, projectConfirmEmailCodeError(err)) + return + } + if err := validateConfirmEmailCodeResult(&result); err != nil { + abortWithProjection(c, unavailableProjection(fmt.Errorf("confirm email code response: %w", err))) + return + } + + c.JSON(http.StatusOK, confirmEmailCodeResponse{DeviceSessionID: result.DeviceSessionID}) + } +} + +func validateSendEmailCodeRequest(request *sendEmailCodeRequest) error { + request.Email = strings.TrimSpace(request.Email) + if request.Email == "" { + return errors.New("email must not be empty") + } + + parsedAddress, err := mail.ParseAddress(request.Email) + if err != nil || parsedAddress.Name != "" || parsedAddress.Address != request.Email { + return errors.New("email must be a single valid email address") + } + + return nil +} + +func validateSendEmailCodeResult(result *sendemailcode.Result) error { + result.ChallengeID = strings.TrimSpace(result.ChallengeID) + if result.ChallengeID == "" { + return errors.New("challenge_id must not be empty") + } + + return nil +} + +func validateConfirmEmailCodeRequest(request *confirmEmailCodeRequest) error { + request.ChallengeID = strings.TrimSpace(request.ChallengeID) + if request.ChallengeID == "" { + return errors.New("challenge_id must not be empty") + } + + request.Code = strings.TrimSpace(request.Code) + if request.Code == "" { + return errors.New("code must not be empty") + } + + request.ClientPublicKey = strings.TrimSpace(request.ClientPublicKey) + if request.ClientPublicKey == "" { + return errors.New("client_public_key must not be empty") + } + + return nil +} + +func validateConfirmEmailCodeResult(result *confirmemailcode.Result) error { + result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID) + if result.DeviceSessionID == "" { + return errors.New("device_session_id must not be empty") + } + + return nil +} + +func projectSendEmailCodeError(err error) shared.PublicErrorProjection { + if isTimeoutOrCanceled(err) { + return unavailableProjection(err) + } + + projection := shared.ProjectPublicError(err) + if !shared.IsSendEmailCodePublicErrorCode(projection.Code) { + return unavailableProjection(err) + } + + return projection +} + +func projectConfirmEmailCodeError(err error) shared.PublicErrorProjection { + if isTimeoutOrCanceled(err) { + return unavailableProjection(err) + } + + projection := shared.ProjectPublicError(err) + if !shared.IsConfirmEmailCodePublicErrorCode(projection.Code) { + return unavailableProjection(err) + } + + return projection +} + +func unavailableProjection(err error) shared.PublicErrorProjection { + return shared.ProjectPublicError(shared.ServiceUnavailable(err)) +} + +func isTimeoutOrCanceled(err error) bool { + return errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) +} diff --git a/authsession/internal/api/publichttp/handler_test.go b/authsession/internal/api/publichttp/handler_test.go new file mode 100644 index 0000000..756dee2 --- /dev/null +++ b/authsession/internal/api/publichttp/handler_test.go @@ -0,0 +1,463 @@ +package publichttp + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/service/shared" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestSendEmailCodeHandlerSuccess(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{ChallengeID: "challenge-123"}, nil + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, errors.New("unexpected call") + }), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/send-email-code", + bytes.NewBufferString(`{"email":" pilot@example.com "}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String()) +} + +func TestConfirmEmailCodeHandlerSuccess(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(_ context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) { + assert.Equal(t, confirmemailcode.Input{ + ChallengeID: "challenge-123", + Code: "123456", + ClientPublicKey: "public-key-material", + }, input) + return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil + }), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/confirm-email-code", + bytes.NewBufferString(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String()) +} + +func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target string + body string + wantStatus int + wantBody string + }{ + { + name: "empty body", + target: "/api/v1/public/auth/send-email-code", + body: ``, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body must not be empty"}}`, + }, + { + name: "malformed json", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`, + }, + { + name: "multiple objects", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"pilot@example.com"}{"email":"next@example.com"}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body must contain a single JSON object"}}`, + }, + { + name: "unknown field", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"pilot@example.com","extra":true}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body contains unknown field \"extra\""}}`, + }, + { + name: "invalid json type", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":123}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"request body contains an invalid value for \"email\""}}`, + }, + { + name: "invalid email", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"not-an-email"}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`, + }, + { + name: "empty code", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`, + }, + } + + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, errors.New("unexpected call") + }), + }) + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body)) + if tt.body != "" { + request.Header.Set("Content-Type", "application/json") + } + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestPublicAuthHandlersMapServiceErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target string + body string + deps Dependencies + wantStatus int + wantBody string + }{ + { + name: "send route hides blocked by policy", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"pilot@example.com"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, shared.BlockedByPolicy() + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, errors.New("unexpected call") + }), + }, + wantStatus: http.StatusServiceUnavailable, + wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, + }, + { + name: "confirm invalid client public key", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, shared.InvalidClientPublicKey() + }), + }, + wantStatus: http.StatusBadRequest, + wantBody: `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`, + }, + { + name: "confirm challenge not found", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, shared.ChallengeNotFound() + }), + }, + wantStatus: http.StatusNotFound, + wantBody: `{"error":{"code":"challenge_not_found","message":"challenge not found"}}`, + }, + { + name: "confirm challenge expired", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, shared.ChallengeExpired() + }), + }, + wantStatus: http.StatusGone, + wantBody: `{"error":{"code":"challenge_expired","message":"challenge expired"}}`, + }, + { + name: "confirm blocked by policy", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, shared.BlockedByPolicy() + }), + }, + wantStatus: http.StatusForbidden, + wantBody: `{"error":{"code":"blocked_by_policy","message":"authentication is blocked by policy"}}`, + }, + { + name: "confirm session limit exceeded", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, shared.SessionLimitExceeded() + }), + }, + wantStatus: http.StatusConflict, + wantBody: `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`, + }, + { + name: "confirm hides internal error", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, shared.InternalError(errors.New("broken invariant")) + }), + }, + wantStatus: http.StatusServiceUnavailable, + wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), tt.deps) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body)) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type")) + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.RequestTimeout = 5 * time.Millisecond + + handler := mustNewHandler(t, cfg, Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, context.DeadlineExceeded + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, errors.New("unexpected call") + }), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/send-email-code", + bytes.NewBufferString(`{"email":"pilot@example.com"}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) + assert.JSONEq(t, `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, recorder.Body.String()) +} + +func TestPublicAuthHandlersRejectInvalidSuccessPayloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target string + body string + deps Dependencies + wantBody string + }{ + { + name: "send email blank challenge id", + target: "/api/v1/public/auth/send-email-code", + body: `{"email":"pilot@example.com"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{ChallengeID: " "}, nil + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, errors.New("unexpected call") + }), + }, + wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, + }, + { + name: "confirm blank device session id", + target: "/api/v1/public/auth/confirm-email-code", + body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`, + deps: Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{DeviceSessionID: " "}, nil + }), + }, + wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := mustNewHandler(t, DefaultConfig(), tt.deps) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body)) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + }) + } +} + +func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) { + t.Parallel() + + logger, buffer := newObservedLogger() + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + Logger: logger, + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, errors.New("unexpected call") + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil + }), + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/confirm-email-code", + bytes.NewBufferString(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorder, request) + + require.Equal(t, http.StatusOK, recorder.Code) + logOutput := buffer.String() + assert.NotContains(t, logOutput, "challenge-123") + assert.NotContains(t, logOutput, "123456") + assert.NotContains(t, logOutput, "public-key-material") + assert.NotContains(t, logOutput, "pilot@example.com") + assert.NotContains(t, logOutput, "device-session-123") +} + +func mustNewHandler(t *testing.T, cfg Config, deps Dependencies) http.Handler { + t.Helper() + + handler, err := newHandlerWithConfig(cfg, deps) + require.NoError(t, err) + return handler +} + +type sendEmailCodeFunc func(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error) + +func (f sendEmailCodeFunc) Execute(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error) { + return f(ctx, input) +} + +type confirmEmailCodeFunc func(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) + +func (f confirmEmailCodeFunc) Execute(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) { + return f(ctx, input) +} + +func newObservedLogger() (*zap.Logger, *bytes.Buffer) { + buffer := &bytes.Buffer{} + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "" + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(encoderConfig), + zapcore.AddSync(buffer), + zap.DebugLevel, + ) + + return zap.New(core), buffer +} diff --git a/authsession/internal/api/publichttp/json.go b/authsession/internal/api/publichttp/json.go new file mode 100644 index 0000000..f72a1d3 --- /dev/null +++ b/authsession/internal/api/publichttp/json.go @@ -0,0 +1,93 @@ +package publichttp + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "galaxy/authsession/internal/service/shared" + + "github.com/gin-gonic/gin" +) + +const publicErrorCodeContextKey = "public_error_code" + +type malformedJSONRequestError struct { + message string +} + +func (e *malformedJSONRequestError) Error() string { + if e == nil { + return "" + } + + return e.message +} + +func decodeJSONRequest(request *http.Request, target any) error { + if request == nil || request.Body == nil { + return &malformedJSONRequestError{message: "request body must not be empty"} + } + + return decodeJSONReader(request.Body, target) +} + +func decodeJSONReader(reader io.Reader, target any) error { + decoder := json.NewDecoder(reader) + decoder.DisallowUnknownFields() + + if err := decoder.Decode(target); err != nil { + return describeJSONDecodeError(err) + } + + if err := decoder.Decode(&struct{}{}); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return &malformedJSONRequestError{message: "request body must contain a single JSON object"} + } + + return &malformedJSONRequestError{message: "request body must contain a single JSON object"} +} + +func describeJSONDecodeError(err error) error { + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + + switch { + case errors.Is(err, io.EOF): + return &malformedJSONRequestError{message: "request body must not be empty"} + case errors.As(err, &syntaxErr): + return &malformedJSONRequestError{message: "request body contains malformed JSON"} + case errors.Is(err, io.ErrUnexpectedEOF): + return &malformedJSONRequestError{message: "request body contains malformed JSON"} + case errors.As(err, &typeErr): + if strings.TrimSpace(typeErr.Field) != "" { + return &malformedJSONRequestError{ + message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field), + } + } + + return &malformedJSONRequestError{message: "request body contains an invalid JSON value"} + case strings.HasPrefix(err.Error(), "json: unknown field "): + return &malformedJSONRequestError{ + message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")), + } + default: + return &malformedJSONRequestError{message: "request body contains invalid JSON"} + } +} + +func abortWithProjection(c *gin.Context, projection shared.PublicErrorProjection) { + c.Set(publicErrorCodeContextKey, projection.Code) + c.AbortWithStatusJSON(projection.StatusCode, errorResponse{ + Error: errorBody{ + Code: projection.Code, + Message: projection.Message, + }, + }) +} diff --git a/authsession/internal/api/publichttp/observability.go b/authsession/internal/api/publichttp/observability.go new file mode 100644 index 0000000..8fd59a6 --- /dev/null +++ b/authsession/internal/api/publichttp/observability.go @@ -0,0 +1,86 @@ +package publichttp + +import ( + "time" + + authlogging "galaxy/authsession/internal/logging" + "galaxy/authsession/internal/telemetry" + + "github.com/gin-gonic/gin" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" +) + +type edgeOutcome string + +const ( + edgeOutcomeSuccess edgeOutcome = "success" + edgeOutcomeRejected edgeOutcome = "rejected" + edgeOutcomeFailed edgeOutcome = "failed" +) + +func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc { + if logger == nil { + logger = zap.NewNop() + } + + return func(c *gin.Context) { + start := time.Now() + c.Next() + + statusCode := c.Writer.Status() + route := c.FullPath() + if route == "" { + route = "unmatched" + } + + errorCode, _ := c.Get(publicErrorCodeContextKey) + errorCodeValue, _ := errorCode.(string) + outcome := outcomeFromStatusCode(statusCode) + duration := time.Since(start) + + fields := []zap.Field{ + zap.String("component", "public_http"), + zap.String("transport", "http"), + zap.String("route", route), + zap.String("method", c.Request.Method), + zap.Int("status_code", statusCode), + zap.Float64("duration_ms", float64(duration.Microseconds())/1000), + zap.String("edge_outcome", string(outcome)), + } + if errorCodeValue != "" { + fields = append(fields, zap.String("error_code", errorCodeValue)) + } + fields = append(fields, authlogging.TraceFieldsFromContext(c.Request.Context())...) + + metricAttrs := []attribute.KeyValue{ + attribute.String("route", route), + attribute.String("method", c.Request.Method), + attribute.String("edge_outcome", string(outcome)), + } + if errorCodeValue != "" { + metricAttrs = append(metricAttrs, attribute.String("error_code", errorCodeValue)) + } + metrics.RecordPublicHTTPRequest(c.Request.Context(), metricAttrs, duration) + + switch outcome { + case edgeOutcomeSuccess: + logger.Info("public request completed", fields...) + case edgeOutcomeFailed: + logger.Error("public request failed", fields...) + default: + logger.Warn("public request rejected", fields...) + } + } +} + +func outcomeFromStatusCode(statusCode int) edgeOutcome { + switch { + case statusCode >= 500: + return edgeOutcomeFailed + case statusCode >= 400: + return edgeOutcomeRejected + default: + return edgeOutcomeSuccess + } +} diff --git a/authsession/internal/api/publichttp/observability_test.go b/authsession/internal/api/publichttp/observability_test.go new file mode 100644 index 0000000..6ebacea --- /dev/null +++ b/authsession/internal/api/publichttp/observability_test.go @@ -0,0 +1,114 @@ +package publichttp + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "testing" + + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + authtelemetry "galaxy/authsession/internal/telemetry" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +func TestPublicHandlerEmitsTraceFieldsAndMetrics(t *testing.T) { + t.Parallel() + + logger, buffer := newObservedLogger() + telemetryRuntime, reader, recorder := newObservedPublicTelemetryRuntime(t) + handler := mustNewHandler(t, DefaultConfig(), Dependencies{ + Logger: logger, + Telemetry: telemetryRuntime, + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{ChallengeID: "challenge-123"}, nil + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, nil + }), + }) + + recorderHTTP := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "/api/v1/public/auth/send-email-code", + bytes.NewBufferString(`{"email":"pilot@example.com"}`), + ) + request.Header.Set("Content-Type", "application/json") + + handler.ServeHTTP(recorderHTTP, request) + + require.Equal(t, http.StatusOK, recorderHTTP.Code) + require.NotEmpty(t, recorder.Ended()) + assert.Contains(t, buffer.String(), "otel_trace_id") + assert.Contains(t, buffer.String(), "otel_span_id") + + assertMetricCount(t, reader, "authsession.public_http.requests", map[string]string{ + "route": "/api/v1/public/auth/send-email-code", + "method": http.MethodPost, + "edge_outcome": "success", + }, 1) +} + +func newObservedPublicTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader, *tracetest.SpanRecorder) { + t.Helper() + + reader := sdkmetric.NewManualReader() + meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + recorder := tracetest.NewSpanRecorder() + tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + + runtime, err := authtelemetry.NewWithProviders(meterProvider, tracerProvider) + require.NoError(t, err) + + return runtime, reader, recorder +} + +func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) { + t.Helper() + + var resourceMetrics metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &resourceMetrics)) + + for _, scopeMetrics := range resourceMetrics.ScopeMetrics { + for _, metric := range scopeMetrics.Metrics { + if metric.Name != metricName { + continue + } + + sum, ok := metric.Data.(metricdata.Sum[int64]) + require.True(t, ok) + + for _, point := range sum.DataPoints { + if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) { + assert.Equal(t, wantValue, point.Value) + return + } + } + } + } + + require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs) +} + +func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool { + if len(values) != len(want) { + return false + } + + for _, value := range values { + if want[string(value.Key)] != value.Value.AsString() { + return false + } + } + + return true +} diff --git a/authsession/internal/api/publichttp/server.go b/authsession/internal/api/publichttp/server.go new file mode 100644 index 0000000..8197c23 --- /dev/null +++ b/authsession/internal/api/publichttp/server.go @@ -0,0 +1,228 @@ +package publichttp + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +const ( + defaultAddr = ":8080" + defaultReadHeaderTimeout = 2 * time.Second + defaultReadTimeout = 10 * time.Second + defaultIdleTimeout = time.Minute + defaultRequestTimeout = 3 * time.Second +) + +// SendEmailCodeUseCase describes the public send-email-code application +// service consumed by the HTTP transport layer. +type SendEmailCodeUseCase interface { + // Execute validates input and creates a new login challenge. + Execute(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error) +} + +// ConfirmEmailCodeUseCase describes the public confirm-email-code application +// service consumed by the HTTP transport layer. +type ConfirmEmailCodeUseCase interface { + // Execute validates input and completes an existing login challenge. + Execute(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) +} + +// Config describes the public HTTP listener owned by authsession. +type Config struct { + // Addr is the TCP listen address used by the public HTTP server. + Addr string + + // ReadHeaderTimeout bounds how long the listener may spend reading request + // headers before the server rejects the connection. + ReadHeaderTimeout time.Duration + + // ReadTimeout bounds how long the listener may spend reading one public + // request. + ReadTimeout time.Duration + + // IdleTimeout bounds how long the listener keeps an idle keep-alive + // connection open. + IdleTimeout time.Duration + + // RequestTimeout bounds one application-layer public-auth use-case call. + RequestTimeout time.Duration +} + +// Validate reports whether cfg contains a usable public HTTP listener +// configuration. +func (cfg Config) Validate() error { + switch { + case cfg.Addr == "": + return errors.New("public HTTP addr must not be empty") + case cfg.ReadHeaderTimeout <= 0: + return errors.New("public HTTP read header timeout must be positive") + case cfg.ReadTimeout <= 0: + return errors.New("public HTTP read timeout must be positive") + case cfg.IdleTimeout <= 0: + return errors.New("public HTTP idle timeout must be positive") + case cfg.RequestTimeout <= 0: + return errors.New("public HTTP request timeout must be positive") + default: + return nil + } +} + +// DefaultConfig returns the default public HTTP listener settings aligned with +// the gateway public-auth transport timeouts. +func DefaultConfig() Config { + return Config{ + Addr: defaultAddr, + ReadHeaderTimeout: defaultReadHeaderTimeout, + ReadTimeout: defaultReadTimeout, + IdleTimeout: defaultIdleTimeout, + RequestTimeout: defaultRequestTimeout, + } +} + +// Dependencies describes the collaborators used by the public HTTP transport +// layer. +type Dependencies struct { + // SendEmailCode executes the public send-email-code use case. + SendEmailCode SendEmailCodeUseCase + + // ConfirmEmailCode executes the public confirm-email-code use case. + ConfirmEmailCode ConfirmEmailCodeUseCase + + // Logger writes structured transport logs. When nil, a no-op logger is + // used. + Logger *zap.Logger + + // Telemetry records OpenTelemetry spans and low-cardinality HTTP metrics. + // When nil, the transport still serves requests with no-op providers. + Telemetry *telemetry.Runtime +} + +// Server owns the public auth HTTP listener exposed by authsession. +type Server struct { + cfg Config + + handler http.Handler + logger *zap.Logger + + stateMu sync.RWMutex + server *http.Server + listener net.Listener +} + +// NewServer constructs one public auth HTTP server for cfg and deps. +func NewServer(cfg Config, deps Dependencies) (*Server, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("new public HTTP server: %w", err) + } + + handler, err := newHandlerWithConfig(cfg, deps) + if err != nil { + return nil, fmt.Errorf("new public HTTP server: %w", err) + } + + logger := deps.Logger + if logger == nil { + logger = zap.NewNop() + } + logger = logger.Named("public_http") + + return &Server{ + cfg: cfg, + handler: handler, + logger: logger, + }, nil +} + +// Run binds the configured listener and serves the public auth HTTP surface +// until Shutdown closes the server. +func (s *Server) Run(ctx context.Context) error { + if ctx == nil { + return errors.New("run public HTTP server: nil context") + } + if err := ctx.Err(); err != nil { + return err + } + + listener, err := net.Listen("tcp", s.cfg.Addr) + if err != nil { + return fmt.Errorf("run public HTTP server: listen on %q: %w", s.cfg.Addr, err) + } + + server := &http.Server{ + Handler: s.handler, + ReadHeaderTimeout: s.cfg.ReadHeaderTimeout, + ReadTimeout: s.cfg.ReadTimeout, + IdleTimeout: s.cfg.IdleTimeout, + } + + s.stateMu.Lock() + s.server = server + s.listener = listener + s.stateMu.Unlock() + + s.logger.Info("public HTTP server started", zap.String("addr", listener.Addr().String())) + + defer func() { + s.stateMu.Lock() + s.server = nil + s.listener = nil + s.stateMu.Unlock() + }() + + err = server.Serve(listener) + switch { + case err == nil: + return nil + case errors.Is(err, http.ErrServerClosed): + s.logger.Info("public HTTP server stopped") + return nil + default: + return fmt.Errorf("run public HTTP server: serve on %q: %w", s.cfg.Addr, err) + } +} + +// Shutdown gracefully stops the public HTTP server within ctx. +func (s *Server) Shutdown(ctx context.Context) error { + if ctx == nil { + return errors.New("shutdown public HTTP server: nil context") + } + + s.stateMu.RLock() + server := s.server + s.stateMu.RUnlock() + + if server == nil { + return nil + } + + if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("shutdown public HTTP server: %w", err) + } + + return nil +} + +func normalizeDependencies(deps Dependencies) (Dependencies, error) { + switch { + case deps.SendEmailCode == nil: + return Dependencies{}, errors.New("send email code use case must not be nil") + case deps.ConfirmEmailCode == nil: + return Dependencies{}, errors.New("confirm email code use case must not be nil") + case deps.Logger == nil: + deps.Logger = zap.NewNop() + } + + deps.Logger = deps.Logger.Named("public_http") + return deps, nil +} diff --git a/authsession/internal/api/publichttp/server_test.go b/authsession/internal/api/publichttp/server_test.go new file mode 100644 index 0000000..6070d65 --- /dev/null +++ b/authsession/internal/api/publichttp/server_test.go @@ -0,0 +1,81 @@ +package publichttp + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewServerRejectsInvalidConfiguration(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Addr = "" + + _, err := NewServer(cfg, Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{}, nil + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{}, nil + }), + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "addr") +} + +func TestServerRunAndShutdown(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Addr = "127.0.0.1:0" + + server, err := NewServer(cfg, Dependencies{ + SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) { + return sendemailcode.Result{ChallengeID: "challenge-123"}, nil + }), + ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) { + return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil + }), + }) + require.NoError(t, err) + + runErr := make(chan error, 1) + go func() { + runErr <- server.Run(context.Background()) + }() + + require.Eventually(t, func() bool { + server.stateMu.RLock() + defer server.stateMu.RUnlock() + return server.listener != nil + }, time.Second, 10*time.Millisecond) + + server.stateMu.RLock() + addr := server.listener.Addr().String() + server.stateMu.RUnlock() + + response, err := http.Post( + "http://"+addr+"/api/v1/public/auth/send-email-code", + "application/json", + bytes.NewBufferString(`{"email":"pilot@example.com"}`), + ) + require.NoError(t, err) + defer response.Body.Close() + + assert.Equal(t, http.StatusOK, response.StatusCode) + + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, server.Shutdown(shutdownCtx)) + require.NoError(t, <-runErr) +} diff --git a/authsession/internal/app/app.go b/authsession/internal/app/app.go new file mode 100644 index 0000000..82c722f --- /dev/null +++ b/authsession/internal/app/app.go @@ -0,0 +1,168 @@ +// Package app wires the authsession process lifecycle and coordinates +// component startup and graceful shutdown. +package app + +import ( + "context" + "errors" + "fmt" + "sync" + + "galaxy/authsession/internal/config" +) + +// Component is a long-lived authsession subsystem that participates in +// coordinated startup and graceful shutdown. +type Component interface { + // Run starts the component and blocks until it stops. + Run(context.Context) error + + // Shutdown stops the component within the provided timeout-bounded context. + Shutdown(context.Context) error +} + +// App owns the process-level lifecycle of authsession and its registered +// components. +type App struct { + cfg config.Config + components []Component +} + +// New constructs an App with a defensive copy of the supplied components. +func New(cfg config.Config, components ...Component) *App { + clonedComponents := append([]Component(nil), components...) + + return &App{ + cfg: cfg, + components: clonedComponents, + } +} + +// Run starts all configured components, waits for cancellation or the first +// component failure, and then executes best-effort graceful shutdown. +func (a *App) Run(ctx context.Context) error { + if ctx == nil { + return errors.New("run authsession app: nil context") + } + if err := a.validate(); err != nil { + return err + } + if len(a.components) == 0 { + <-ctx.Done() + return nil + } + + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + + results := make(chan componentResult, len(a.components)) + var runWG sync.WaitGroup + + for idx, component := range a.components { + runWG.Add(1) + + go func(index int, component Component) { + defer runWG.Done() + results <- componentResult{ + index: index, + err: component.Run(runCtx), + } + }(idx, component) + } + + var runErr error + + select { + case <-ctx.Done(): + case result := <-results: + runErr = classifyComponentResult(ctx, result) + } + + cancel() + + shutdownErr := a.shutdownComponents() + waitErr := a.waitForComponents(&runWG) + + return errors.Join(runErr, shutdownErr, waitErr) +} + +type componentResult struct { + index int + err error +} + +func (a *App) validate() error { + if a.cfg.ShutdownTimeout <= 0 { + return fmt.Errorf("run authsession app: shutdown timeout must be positive, got %s", a.cfg.ShutdownTimeout) + } + + for idx, component := range a.components { + if component == nil { + return fmt.Errorf("run authsession app: component %d is nil", idx) + } + } + + return nil +} + +func classifyComponentResult(parentCtx context.Context, result componentResult) error { + switch { + case result.err == nil: + if parentCtx.Err() != nil { + return nil + } + return fmt.Errorf("run authsession app: component %d exited without error before shutdown", result.index) + case errors.Is(result.err, context.Canceled) && parentCtx.Err() != nil: + return nil + default: + return fmt.Errorf("run authsession app: component %d: %w", result.index, result.err) + } +} + +func (a *App) shutdownComponents() error { + var shutdownWG sync.WaitGroup + errs := make(chan error, len(a.components)) + + for idx, component := range a.components { + shutdownWG.Add(1) + + go func(index int, component Component) { + defer shutdownWG.Done() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout) + defer cancel() + + if err := component.Shutdown(shutdownCtx); err != nil { + errs <- fmt.Errorf("shutdown authsession component %d: %w", index, err) + } + }(idx, component) + } + + shutdownWG.Wait() + close(errs) + + var joined error + for err := range errs { + joined = errors.Join(joined, err) + } + + return joined +} + +func (a *App) waitForComponents(runWG *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + runWG.Wait() + close(done) + }() + + waitCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout) + defer cancel() + + select { + case <-done: + return nil + case <-waitCtx.Done(): + return fmt.Errorf("wait for authsession components: %w", waitCtx.Err()) + } +} diff --git a/authsession/internal/app/runtime.go b/authsession/internal/app/runtime.go new file mode 100644 index 0000000..dd6f10f --- /dev/null +++ b/authsession/internal/app/runtime.go @@ -0,0 +1,284 @@ +package app + +import ( + "context" + "errors" + "fmt" + + "galaxy/authsession/internal/adapters/local" + "galaxy/authsession/internal/adapters/mail" + "galaxy/authsession/internal/adapters/redis/challengestore" + "galaxy/authsession/internal/adapters/redis/configprovider" + "galaxy/authsession/internal/adapters/redis/projectionpublisher" + "galaxy/authsession/internal/adapters/redis/sendemailcodeabuse" + "galaxy/authsession/internal/adapters/redis/sessionstore" + "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/api/internalhttp" + "galaxy/authsession/internal/api/publichttp" + "galaxy/authsession/internal/config" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +type pinger interface { + Ping(context.Context) error +} + +type closer interface { + Close() error +} + +// Runtime owns the runnable authsession application plus the adapter cleanup +// functions that must run after the process stops. +type Runtime struct { + // App coordinates the long-lived HTTP listeners. + App *App + + cleanupFns []func() error +} + +// NewRuntime constructs the runnable authsession process from cfg using the +// Stage 18 Redis adapters, local runtime helpers, and the selectable mail and +// user-service runtime adapters from Stages 20 and 21. +func NewRuntime(ctx context.Context, cfg config.Config, logger *zap.Logger, telemetryRuntime *telemetry.Runtime) (*Runtime, error) { + if ctx == nil { + return nil, errors.New("new authsession runtime: nil context") + } + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("new authsession runtime: %w", err) + } + if logger == nil { + logger = zap.NewNop() + } + + runtime := &Runtime{} + cleanupOnError := func(err error) (*Runtime, error) { + return nil, errors.Join(err, runtime.Close()) + } + + challengeStore, err := challengestore.New(challengestore.Config{ + Addr: cfg.Redis.Addr, + Username: cfg.Redis.Username, + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + TLSEnabled: cfg.Redis.TLSEnabled, + KeyPrefix: cfg.Redis.ChallengeKeyPrefix, + OperationTimeout: cfg.Redis.OperationTimeout, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: challenge store: %w", err)) + } + runtime.cleanupFns = append(runtime.cleanupFns, challengeStore.Close) + + sessionStore, err := sessionstore.New(sessionstore.Config{ + Addr: cfg.Redis.Addr, + Username: cfg.Redis.Username, + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + TLSEnabled: cfg.Redis.TLSEnabled, + SessionKeyPrefix: cfg.Redis.SessionKeyPrefix, + UserSessionsKeyPrefix: cfg.Redis.UserSessionsKeyPrefix, + UserActiveSessionsKeyPrefix: cfg.Redis.UserActiveSessionsKeyPrefix, + OperationTimeout: cfg.Redis.OperationTimeout, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: session store: %w", err)) + } + runtime.cleanupFns = append(runtime.cleanupFns, sessionStore.Close) + + configStore, err := configprovider.New(configprovider.Config{ + Addr: cfg.Redis.Addr, + Username: cfg.Redis.Username, + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + TLSEnabled: cfg.Redis.TLSEnabled, + SessionLimitKey: cfg.Redis.SessionLimitKey, + OperationTimeout: cfg.Redis.OperationTimeout, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: config provider: %w", err)) + } + runtime.cleanupFns = append(runtime.cleanupFns, configStore.Close) + + publisher, err := projectionpublisher.New(projectionpublisher.Config{ + Addr: cfg.Redis.Addr, + Username: cfg.Redis.Username, + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + TLSEnabled: cfg.Redis.TLSEnabled, + SessionCacheKeyPrefix: cfg.Redis.GatewaySessionCacheKeyPrefix, + SessionEventsStream: cfg.Redis.GatewaySessionEventsStream, + StreamMaxLen: cfg.Redis.GatewaySessionEventsStreamMaxLen, + OperationTimeout: cfg.Redis.OperationTimeout, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: projection publisher: %w", err)) + } + runtime.cleanupFns = append(runtime.cleanupFns, publisher.Close) + + abuseProtector, err := sendemailcodeabuse.New(sendemailcodeabuse.Config{ + Addr: cfg.Redis.Addr, + Username: cfg.Redis.Username, + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + TLSEnabled: cfg.Redis.TLSEnabled, + KeyPrefix: cfg.Redis.SendEmailCodeThrottleKeyPrefix, + OperationTimeout: cfg.Redis.OperationTimeout, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: send email code abuse protector: %w", err)) + } + runtime.cleanupFns = append(runtime.cleanupFns, abuseProtector.Close) + + for name, dependency := range map[string]pinger{ + "challenge store": challengeStore, + "session store": sessionStore, + "config provider": configStore, + "projection publisher": publisher, + "send email code abuse protector": abuseProtector, + } { + if err := dependency.Ping(ctx); err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: ping %s: %w", name, err)) + } + } + + clock := local.Clock{} + idGenerator := local.IDGenerator{} + codeGenerator := local.CodeGenerator{} + codeHasher := local.CodeHasher{} + var mailSender ports.MailSender + switch cfg.MailService.Mode { + case "stub": + mailSender = &mail.StubSender{} + case "rest": + restClient, err := mail.NewRESTClient(mail.Config{ + BaseURL: cfg.MailService.BaseURL, + RequestTimeout: cfg.MailService.RequestTimeout, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: mail service REST client: %w", err)) + } + runtime.cleanupFns = append(runtime.cleanupFns, restClient.Close) + mailSender = restClient + default: + return cleanupOnError(fmt.Errorf("new authsession runtime: unsupported mail service mode %q", cfg.MailService.Mode)) + } + var userDirectory ports.UserDirectory + switch cfg.UserService.Mode { + case "stub": + userDirectory = &userservice.StubDirectory{} + case "rest": + restClient, err := userservice.NewRESTClient(userservice.Config{ + BaseURL: cfg.UserService.BaseURL, + RequestTimeout: cfg.UserService.RequestTimeout, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: user service REST client: %w", err)) + } + runtime.cleanupFns = append(runtime.cleanupFns, restClient.Close) + userDirectory = restClient + default: + return cleanupOnError(fmt.Errorf("new authsession runtime: unsupported user service mode %q", cfg.UserService.Mode)) + } + + sendEmailCodeService, err := sendemailcode.NewWithObservability( + challengeStore, + userDirectory, + idGenerator, + codeGenerator, + codeHasher, + mailSender, + abuseProtector, + clock, + logger, + telemetryRuntime, + ) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: send email code service: %w", err)) + } + confirmEmailCodeService, err := confirmemailcode.NewWithObservability( + challengeStore, + sessionStore, + userDirectory, + configStore, + publisher, + idGenerator, + codeHasher, + clock, + logger, + telemetryRuntime, + ) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: confirm email code service: %w", err)) + } + getSessionService, err := getsession.New(sessionStore) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: get session service: %w", err)) + } + listUserSessionsService, err := listusersessions.New(sessionStore) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: list user sessions service: %w", err)) + } + revokeDeviceSessionService, err := revokedevicesession.NewWithObservability(sessionStore, publisher, clock, logger, telemetryRuntime) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: revoke device session service: %w", err)) + } + revokeAllUserSessionsService, err := revokeallusersessions.NewWithObservability(sessionStore, userDirectory, publisher, clock, logger, telemetryRuntime) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: revoke all user sessions service: %w", err)) + } + blockUserService, err := blockuser.NewWithObservability(userDirectory, sessionStore, publisher, clock, logger, telemetryRuntime) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: block user service: %w", err)) + } + + publicServer, err := publichttp.NewServer(cfg.PublicHTTP, publichttp.Dependencies{ + SendEmailCode: sendEmailCodeService, + ConfirmEmailCode: confirmEmailCodeService, + Logger: logger, + Telemetry: telemetryRuntime, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: public HTTP server: %w", err)) + } + + internalServer, err := internalhttp.NewServer(cfg.InternalHTTP, internalhttp.Dependencies{ + GetSession: getSessionService, + ListUserSessions: listUserSessionsService, + RevokeDeviceSession: revokeDeviceSessionService, + RevokeAllUserSessions: revokeAllUserSessionsService, + BlockUser: blockUserService, + Logger: logger, + Telemetry: telemetryRuntime, + }) + if err != nil { + return cleanupOnError(fmt.Errorf("new authsession runtime: internal HTTP server: %w", err)) + } + + runtime.App = New(cfg, publicServer, internalServer) + return runtime, nil +} + +// Close releases the runtime-managed adapter resources. Close is idempotent in +// practice because every underlying adapter Close method is idempotent. +func (r *Runtime) Close() error { + if r == nil { + return nil + } + + var joined error + for index := len(r.cleanupFns) - 1; index >= 0; index-- { + joined = errors.Join(joined, r.cleanupFns[index]()) + } + + return joined +} diff --git a/authsession/internal/app/runtime_test.go b/authsession/internal/app/runtime_test.go new file mode 100644 index 0000000..62c33d5 --- /dev/null +++ b/authsession/internal/app/runtime_test.go @@ -0,0 +1,212 @@ +package app + +import ( + "bytes" + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "galaxy/authsession/internal/config" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestNewRuntimeStartsAndStopsHTTPServers(t *testing.T) { + t.Parallel() + + redisServer := miniredis.RunT(t) + + cfg := config.DefaultConfig() + cfg.Redis.Addr = redisServer.Addr() + cfg.PublicHTTP.Addr = mustFreeAddr(t) + cfg.InternalHTTP.Addr = mustFreeAddr(t) + + runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, runtime.Close()) + }() + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- runtime.App.Run(runCtx) + }() + + require.Eventually(t, func() bool { + response, err := http.Post( + "http://"+cfg.PublicHTTP.Addr+"/api/v1/public/auth/send-email-code", + "application/json", + bytes.NewBufferString(`{"email":"pilot@example.com"}`), + ) + if err != nil { + return false + } + defer response.Body.Close() + _, _ = io.ReadAll(response.Body) + + return response.StatusCode == http.StatusOK + }, 5*time.Second, 25*time.Millisecond) + + require.Eventually(t, func() bool { + response, err := http.Get("http://" + cfg.InternalHTTP.Addr + "/api/v1/internal/sessions/missing") + if err != nil { + return false + } + defer response.Body.Close() + _, _ = io.ReadAll(response.Body) + + return response.StatusCode == http.StatusNotFound + }, 5*time.Second, 25*time.Millisecond) + + cancel() + require.NoError(t, <-runErrCh) +} + +func TestNewRuntimeUsesRESTUserDirectoryWhenConfigured(t *testing.T) { + t.Parallel() + + redisServer := miniredis.RunT(t) + userService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && r.URL.Path == "/api/v1/internal/users/user-1/exists" { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"exists":true}`) + return + } + + http.NotFound(w, r) + })) + defer userService.Close() + + cfg := config.DefaultConfig() + cfg.Redis.Addr = redisServer.Addr() + cfg.PublicHTTP.Addr = mustFreeAddr(t) + cfg.InternalHTTP.Addr = mustFreeAddr(t) + cfg.UserService.Mode = "rest" + cfg.UserService.BaseURL = userService.URL + cfg.UserService.RequestTimeout = 250 * time.Millisecond + + runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, runtime.Close()) + }() + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- runtime.App.Run(runCtx) + }() + + require.Eventually(t, func() bool { + response, err := http.Post( + "http://"+cfg.InternalHTTP.Addr+"/api/v1/internal/users/user-1/sessions/revoke-all", + "application/json", + bytes.NewBufferString(`{"reason_code":"logout_all","actor":{"type":"system"}}`), + ) + if err != nil { + return false + } + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + if err != nil { + return false + } + + return response.StatusCode == http.StatusOK && + bytes.Contains(payload, []byte(`"outcome":"no_active_sessions"`)) && + bytes.Contains(payload, []byte(`"user_id":"user-1"`)) + }, 5*time.Second, 25*time.Millisecond) + + cancel() + require.NoError(t, <-runErrCh) +} + +func TestNewRuntimeUsesRESTMailSenderWhenConfigured(t *testing.T) { + t.Parallel() + + redisServer := miniredis.RunT(t) + var calls atomic.Int64 + mailService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/v1/internal/login-code-deliveries" { + calls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"outcome":"suppressed"}`) + return + } + + http.NotFound(w, r) + })) + defer mailService.Close() + + cfg := config.DefaultConfig() + cfg.Redis.Addr = redisServer.Addr() + cfg.PublicHTTP.Addr = mustFreeAddr(t) + cfg.InternalHTTP.Addr = mustFreeAddr(t) + cfg.MailService.Mode = "rest" + cfg.MailService.BaseURL = mailService.URL + cfg.MailService.RequestTimeout = 250 * time.Millisecond + + runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, runtime.Close()) + }() + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- runtime.App.Run(runCtx) + }() + + require.Eventually(t, func() bool { + response, err := http.Post( + "http://"+cfg.PublicHTTP.Addr+"/api/v1/public/auth/send-email-code", + "application/json", + bytes.NewBufferString(`{"email":"pilot@example.com"}`), + ) + if err != nil { + return false + } + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + if err != nil { + return false + } + + return response.StatusCode == http.StatusOK && + bytes.Contains(payload, []byte(`"challenge_id":"`)) && + calls.Load() == 1 + }, 5*time.Second, 25*time.Millisecond) + + cancel() + require.NoError(t, <-runErrCh) +} + +func mustFreeAddr(t *testing.T) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + assert.NoError(t, listener.Close()) + }() + + return listener.Addr().String() +} diff --git a/authsession/internal/config/config.go b/authsession/internal/config/config.go new file mode 100644 index 0000000..8b65385 --- /dev/null +++ b/authsession/internal/config/config.go @@ -0,0 +1,610 @@ +// Package config loads the authsession process configuration from environment +// variables. +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + "galaxy/authsession/internal/api/internalhttp" + "galaxy/authsession/internal/api/publichttp" + + "go.uber.org/zap/zapcore" +) + +const ( + shutdownTimeoutEnvVar = "AUTHSESSION_SHUTDOWN_TIMEOUT" + logLevelEnvVar = "AUTHSESSION_LOG_LEVEL" + + publicHTTPAddrEnvVar = "AUTHSESSION_PUBLIC_HTTP_ADDR" + publicHTTPReadHeaderTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_READ_HEADER_TIMEOUT" + publicHTTPReadTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_READ_TIMEOUT" + publicHTTPIdleTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_IDLE_TIMEOUT" + publicHTTPRequestTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_REQUEST_TIMEOUT" + + internalHTTPAddrEnvVar = "AUTHSESSION_INTERNAL_HTTP_ADDR" + internalHTTPReadHeaderTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_READ_HEADER_TIMEOUT" + internalHTTPReadTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_READ_TIMEOUT" + internalHTTPIdleTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_IDLE_TIMEOUT" + internalHTTPRequestTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_REQUEST_TIMEOUT" + + redisAddrEnvVar = "AUTHSESSION_REDIS_ADDR" + redisUsernameEnvVar = "AUTHSESSION_REDIS_USERNAME" + redisPasswordEnvVar = "AUTHSESSION_REDIS_PASSWORD" + redisDBEnvVar = "AUTHSESSION_REDIS_DB" + redisTLSEnabledEnvVar = "AUTHSESSION_REDIS_TLS_ENABLED" + redisOperationTimeoutEnvVar = "AUTHSESSION_REDIS_OPERATION_TIMEOUT" + + redisChallengeKeyPrefixEnvVar = "AUTHSESSION_REDIS_CHALLENGE_KEY_PREFIX" + redisSessionKeyPrefixEnvVar = "AUTHSESSION_REDIS_SESSION_KEY_PREFIX" + redisUserSessionsKeyPrefixEnvVar = "AUTHSESSION_REDIS_USER_SESSIONS_KEY_PREFIX" + redisUserActiveSessionsKeyPrefixEnvVar = "AUTHSESSION_REDIS_USER_ACTIVE_SESSIONS_KEY_PREFIX" + redisSessionLimitKeyEnvVar = "AUTHSESSION_REDIS_SESSION_LIMIT_KEY" + redisGatewaySessionCacheKeyPrefixEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_CACHE_KEY_PREFIX" + redisGatewaySessionEventsStreamEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM" + redisGatewaySessionEventsStreamMaxLenEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM_MAX_LEN" + redisSendEmailCodeThrottleKeyPrefixEnvVar = "AUTHSESSION_REDIS_SEND_EMAIL_CODE_THROTTLE_KEY_PREFIX" + + userServiceModeEnvVar = "AUTHSESSION_USER_SERVICE_MODE" + userServiceBaseURLEnvVar = "AUTHSESSION_USER_SERVICE_BASE_URL" + userServiceRequestTimeoutEnvVar = "AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT" + + mailServiceModeEnvVar = "AUTHSESSION_MAIL_SERVICE_MODE" + mailServiceBaseURLEnvVar = "AUTHSESSION_MAIL_SERVICE_BASE_URL" + mailServiceRequestTimeoutEnvVar = "AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT" + + otelServiceNameEnvVar = "OTEL_SERVICE_NAME" + otelTracesExporterEnvVar = "OTEL_TRACES_EXPORTER" + otelMetricsExporterEnvVar = "OTEL_METRICS_EXPORTER" + otelExporterOTLPProtocolEnvVar = "OTEL_EXPORTER_OTLP_PROTOCOL" + otelExporterOTLPTracesProtocolEnvVar = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL" + otelExporterOTLPMetricsProtocolEnvVar = "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL" + otelStdoutTracesEnabledEnvVar = "AUTHSESSION_OTEL_STDOUT_TRACES_ENABLED" + otelStdoutMetricsEnabledEnvVar = "AUTHSESSION_OTEL_STDOUT_METRICS_ENABLED" + + defaultShutdownTimeout = 5 * time.Second + defaultLogLevel = "info" + defaultRedisDB = 0 + defaultRedisOperationTimeout = 250 * time.Millisecond + defaultChallengeKeyPrefix = "authsession:challenge:" + defaultSessionKeyPrefix = "authsession:session:" + defaultUserSessionsKeyPrefix = "authsession:user-sessions:" + defaultUserActiveSessionsKeyPrefix = "authsession:user-active-sessions:" + defaultSessionLimitKey = "authsession:config:active-session-limit" + defaultGatewaySessionCacheKeyPrefix = "gateway:session:" + defaultGatewaySessionEventsStream = "gateway:session_events" + defaultGatewaySessionEventsStreamMaxLen = 1024 + defaultSendEmailCodeThrottleKeyPrefix = "authsession:send-email-code-throttle:" + defaultUserServiceMode = userServiceModeStub + defaultUserServiceRequestTimeout = time.Second + defaultMailServiceMode = mailServiceModeStub + defaultMailServiceRequestTimeout = time.Second + defaultOTelServiceName = "galaxy-authsession" + otelExporterNone = "none" + otelExporterOTLP = "otlp" + otelProtocolHTTPProtobuf = "http/protobuf" + otelProtocolGRPC = "grpc" + userServiceModeStub = "stub" + userServiceModeREST = "rest" + mailServiceModeStub = "stub" + mailServiceModeREST = "rest" +) + +// Config stores the full process-level authsession configuration. +type Config struct { + // ShutdownTimeout bounds graceful shutdown of every long-lived component. + ShutdownTimeout time.Duration + + // Logging configures the process-wide structured logger. + Logging LoggingConfig + + // PublicHTTP configures the public HTTP listener. + PublicHTTP publichttp.Config + + // InternalHTTP configures the trusted internal HTTP listener. + InternalHTTP internalhttp.Config + + // Redis configures the Redis-backed adapters. + Redis RedisConfig + + // UserService configures the selectable runtime user-directory adapter. + UserService UserServiceConfig + + // MailService configures the selectable runtime mail-delivery adapter. + MailService MailServiceConfig + + // Telemetry configures the process-wide OpenTelemetry runtime. + Telemetry TelemetryConfig +} + +// LoggingConfig configures the process-wide structured logger. +type LoggingConfig struct { + // Level stores the zap-compatible log level string. + Level string +} + +// RedisConfig configures the Redis-backed authsession adapters. +type RedisConfig struct { + // Addr is the shared Redis address used by the authsession adapters. + Addr string + + // Username is the optional Redis ACL username. + Username string + + // Password is the optional Redis ACL password. + Password string + + // DB is the Redis logical database index. + DB int + + // TLSEnabled configures whether Redis connections use TLS. + TLSEnabled bool + + // OperationTimeout bounds each adapter Redis round trip. + OperationTimeout time.Duration + + // ChallengeKeyPrefix namespaces the challenge source-of-truth records. + ChallengeKeyPrefix string + + // SessionKeyPrefix namespaces the primary session records. + SessionKeyPrefix string + + // UserSessionsKeyPrefix namespaces the all-session user index. + UserSessionsKeyPrefix string + + // UserActiveSessionsKeyPrefix namespaces the active-session user index. + UserActiveSessionsKeyPrefix string + + // SessionLimitKey stores the exact session-limit Redis key. + SessionLimitKey string + + // GatewaySessionCacheKeyPrefix namespaces the projected gateway session + // cache keys. + GatewaySessionCacheKeyPrefix string + + // GatewaySessionEventsStream stores the projected gateway session-events + // Redis Stream key. + GatewaySessionEventsStream string + + // GatewaySessionEventsStreamMaxLen bounds the projected gateway session + // event stream with approximate trimming. + GatewaySessionEventsStreamMaxLen int64 + + // SendEmailCodeThrottleKeyPrefix namespaces the resend-throttle TTL keys. + SendEmailCodeThrottleKeyPrefix string +} + +// UserServiceConfig configures the runtime user-directory integration mode. +type UserServiceConfig struct { + // Mode selects the runtime adapter implementation. Supported values are + // `stub` and `rest`. + Mode string + + // BaseURL is the absolute base URL of the REST-backed user-service when + // Mode is `rest`. + BaseURL string + + // RequestTimeout bounds each outbound user-service request when Mode is + // `rest`. + RequestTimeout time.Duration +} + +// MailServiceConfig configures the runtime mail-delivery integration mode. +type MailServiceConfig struct { + // Mode selects the runtime adapter implementation. Supported values are + // `stub` and `rest`. + Mode string + + // BaseURL is the absolute base URL of the REST-backed mail service when + // Mode is `rest`. + BaseURL string + + // RequestTimeout bounds each outbound mail-service request when Mode is + // `rest`. + RequestTimeout time.Duration +} + +// TelemetryConfig configures the authsession OpenTelemetry runtime. +type TelemetryConfig struct { + // ServiceName overrides the default OpenTelemetry service name. + ServiceName string + + // TracesExporter selects the external traces exporter. Supported values are + // `none` and `otlp`. + TracesExporter string + + // MetricsExporter selects the external metrics exporter. Supported values + // are `none` and `otlp`. + MetricsExporter string + + // TracesProtocol selects the OTLP traces protocol when TracesExporter is + // `otlp`. + TracesProtocol string + + // MetricsProtocol selects the OTLP metrics protocol when MetricsExporter is + // `otlp`. + MetricsProtocol string + + // StdoutTracesEnabled enables the additional stdout trace exporter used for + // local development and debugging. + StdoutTracesEnabled bool + + // StdoutMetricsEnabled enables the additional stdout metric exporter used + // for local development and debugging. + StdoutMetricsEnabled bool +} + +// DefaultConfig returns the default authsession process configuration with all +// optional values filled. +func DefaultConfig() Config { + return Config{ + ShutdownTimeout: defaultShutdownTimeout, + Logging: LoggingConfig{ + Level: defaultLogLevel, + }, + PublicHTTP: publichttp.DefaultConfig(), + InternalHTTP: internalhttp.DefaultConfig(), + Redis: RedisConfig{ + DB: defaultRedisDB, + OperationTimeout: defaultRedisOperationTimeout, + ChallengeKeyPrefix: defaultChallengeKeyPrefix, + SessionKeyPrefix: defaultSessionKeyPrefix, + UserSessionsKeyPrefix: defaultUserSessionsKeyPrefix, + UserActiveSessionsKeyPrefix: defaultUserActiveSessionsKeyPrefix, + SessionLimitKey: defaultSessionLimitKey, + GatewaySessionCacheKeyPrefix: defaultGatewaySessionCacheKeyPrefix, + GatewaySessionEventsStream: defaultGatewaySessionEventsStream, + GatewaySessionEventsStreamMaxLen: defaultGatewaySessionEventsStreamMaxLen, + SendEmailCodeThrottleKeyPrefix: defaultSendEmailCodeThrottleKeyPrefix, + }, + UserService: UserServiceConfig{ + Mode: defaultUserServiceMode, + RequestTimeout: defaultUserServiceRequestTimeout, + }, + MailService: MailServiceConfig{ + Mode: defaultMailServiceMode, + RequestTimeout: defaultMailServiceRequestTimeout, + }, + Telemetry: TelemetryConfig{ + ServiceName: defaultOTelServiceName, + TracesExporter: otelExporterNone, + MetricsExporter: otelExporterNone, + }, + } +} + +// LoadFromEnv loads the authsession process configuration from environment +// variables, applying documented defaults where appropriate. +func LoadFromEnv() (Config, error) { + cfg := DefaultConfig() + + var err error + + cfg.ShutdownTimeout, err = loadDurationEnvWithDefault(shutdownTimeoutEnvVar, cfg.ShutdownTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + + cfg.Logging.Level = loadStringEnvWithDefault(logLevelEnvVar, cfg.Logging.Level) + if err := validateLogLevel(cfg.Logging.Level); err != nil { + return Config{}, fmt.Errorf("load authsession config: %s: %w", logLevelEnvVar, err) + } + + cfg.PublicHTTP.Addr = loadStringEnvWithDefault(publicHTTPAddrEnvVar, cfg.PublicHTTP.Addr) + cfg.PublicHTTP.ReadHeaderTimeout, err = loadDurationEnvWithDefault(publicHTTPReadHeaderTimeoutEnvVar, cfg.PublicHTTP.ReadHeaderTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.PublicHTTP.ReadTimeout, err = loadDurationEnvWithDefault(publicHTTPReadTimeoutEnvVar, cfg.PublicHTTP.ReadTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.PublicHTTP.IdleTimeout, err = loadDurationEnvWithDefault(publicHTTPIdleTimeoutEnvVar, cfg.PublicHTTP.IdleTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.PublicHTTP.RequestTimeout, err = loadDurationEnvWithDefault(publicHTTPRequestTimeoutEnvVar, cfg.PublicHTTP.RequestTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + + cfg.InternalHTTP.Addr = loadStringEnvWithDefault(internalHTTPAddrEnvVar, cfg.InternalHTTP.Addr) + cfg.InternalHTTP.ReadHeaderTimeout, err = loadDurationEnvWithDefault(internalHTTPReadHeaderTimeoutEnvVar, cfg.InternalHTTP.ReadHeaderTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.InternalHTTP.ReadTimeout, err = loadDurationEnvWithDefault(internalHTTPReadTimeoutEnvVar, cfg.InternalHTTP.ReadTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.InternalHTTP.IdleTimeout, err = loadDurationEnvWithDefault(internalHTTPIdleTimeoutEnvVar, cfg.InternalHTTP.IdleTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.InternalHTTP.RequestTimeout, err = loadDurationEnvWithDefault(internalHTTPRequestTimeoutEnvVar, cfg.InternalHTTP.RequestTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + + cfg.Redis.Addr = loadStringEnvWithDefault(redisAddrEnvVar, cfg.Redis.Addr) + cfg.Redis.Username = os.Getenv(redisUsernameEnvVar) + cfg.Redis.Password = os.Getenv(redisPasswordEnvVar) + cfg.Redis.DB, err = loadIntEnvWithDefault(redisDBEnvVar, cfg.Redis.DB) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.Redis.TLSEnabled, err = loadBoolEnvWithDefault(redisTLSEnabledEnvVar, cfg.Redis.TLSEnabled) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.Redis.OperationTimeout, err = loadDurationEnvWithDefault(redisOperationTimeoutEnvVar, cfg.Redis.OperationTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.Redis.ChallengeKeyPrefix = loadStringEnvWithDefault(redisChallengeKeyPrefixEnvVar, cfg.Redis.ChallengeKeyPrefix) + cfg.Redis.SessionKeyPrefix = loadStringEnvWithDefault(redisSessionKeyPrefixEnvVar, cfg.Redis.SessionKeyPrefix) + cfg.Redis.UserSessionsKeyPrefix = loadStringEnvWithDefault(redisUserSessionsKeyPrefixEnvVar, cfg.Redis.UserSessionsKeyPrefix) + cfg.Redis.UserActiveSessionsKeyPrefix = loadStringEnvWithDefault(redisUserActiveSessionsKeyPrefixEnvVar, cfg.Redis.UserActiveSessionsKeyPrefix) + cfg.Redis.SessionLimitKey = loadStringEnvWithDefault(redisSessionLimitKeyEnvVar, cfg.Redis.SessionLimitKey) + cfg.Redis.GatewaySessionCacheKeyPrefix = loadStringEnvWithDefault(redisGatewaySessionCacheKeyPrefixEnvVar, cfg.Redis.GatewaySessionCacheKeyPrefix) + cfg.Redis.GatewaySessionEventsStream = loadStringEnvWithDefault(redisGatewaySessionEventsStreamEnvVar, cfg.Redis.GatewaySessionEventsStream) + streamMaxLen, err := loadInt64EnvWithDefault(redisGatewaySessionEventsStreamMaxLenEnvVar, cfg.Redis.GatewaySessionEventsStreamMaxLen) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.Redis.GatewaySessionEventsStreamMaxLen = streamMaxLen + cfg.Redis.SendEmailCodeThrottleKeyPrefix = loadStringEnvWithDefault(redisSendEmailCodeThrottleKeyPrefixEnvVar, cfg.Redis.SendEmailCodeThrottleKeyPrefix) + + cfg.UserService.Mode = strings.TrimSpace(loadStringEnvWithDefault(userServiceModeEnvVar, cfg.UserService.Mode)) + cfg.UserService.BaseURL = loadStringEnvWithDefault(userServiceBaseURLEnvVar, cfg.UserService.BaseURL) + cfg.UserService.RequestTimeout, err = loadDurationEnvWithDefault(userServiceRequestTimeoutEnvVar, cfg.UserService.RequestTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + + cfg.MailService.Mode = strings.TrimSpace(loadStringEnvWithDefault(mailServiceModeEnvVar, cfg.MailService.Mode)) + cfg.MailService.BaseURL = loadStringEnvWithDefault(mailServiceBaseURLEnvVar, cfg.MailService.BaseURL) + cfg.MailService.RequestTimeout, err = loadDurationEnvWithDefault(mailServiceRequestTimeoutEnvVar, cfg.MailService.RequestTimeout) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + + cfg.Telemetry.ServiceName = loadStringEnvWithDefault(otelServiceNameEnvVar, cfg.Telemetry.ServiceName) + cfg.Telemetry.TracesExporter = normalizeExporterValue(loadStringEnvWithDefault(otelTracesExporterEnvVar, cfg.Telemetry.TracesExporter)) + cfg.Telemetry.MetricsExporter = normalizeExporterValue(loadStringEnvWithDefault(otelMetricsExporterEnvVar, cfg.Telemetry.MetricsExporter)) + cfg.Telemetry.TracesProtocol = loadOTLPProtocol( + os.Getenv(otelExporterOTLPTracesProtocolEnvVar), + os.Getenv(otelExporterOTLPProtocolEnvVar), + cfg.Telemetry.TracesExporter, + ) + cfg.Telemetry.MetricsProtocol = loadOTLPProtocol( + os.Getenv(otelExporterOTLPMetricsProtocolEnvVar), + os.Getenv(otelExporterOTLPProtocolEnvVar), + cfg.Telemetry.MetricsExporter, + ) + cfg.Telemetry.StdoutTracesEnabled, err = loadBoolEnvWithDefault(otelStdoutTracesEnabledEnvVar, cfg.Telemetry.StdoutTracesEnabled) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + cfg.Telemetry.StdoutMetricsEnabled, err = loadBoolEnvWithDefault(otelStdoutMetricsEnabledEnvVar, cfg.Telemetry.StdoutMetricsEnabled) + if err != nil { + return Config{}, fmt.Errorf("load authsession config: %w", err) + } + + if err := cfg.Validate(); err != nil { + return Config{}, err + } + + return cfg, nil +} + +// Validate reports whether cfg contains a consistent authsession process +// configuration. +func (cfg Config) Validate() error { + switch { + case cfg.ShutdownTimeout <= 0: + return fmt.Errorf("load authsession config: %s must be positive", shutdownTimeoutEnvVar) + case strings.TrimSpace(cfg.Redis.Addr) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisAddrEnvVar) + case cfg.Redis.DB < 0: + return fmt.Errorf("load authsession config: %s must not be negative", redisDBEnvVar) + case cfg.Redis.OperationTimeout <= 0: + return fmt.Errorf("load authsession config: %s must be positive", redisOperationTimeoutEnvVar) + case strings.TrimSpace(cfg.Redis.ChallengeKeyPrefix) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisChallengeKeyPrefixEnvVar) + case strings.TrimSpace(cfg.Redis.SessionKeyPrefix) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisSessionKeyPrefixEnvVar) + case strings.TrimSpace(cfg.Redis.UserSessionsKeyPrefix) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisUserSessionsKeyPrefixEnvVar) + case strings.TrimSpace(cfg.Redis.UserActiveSessionsKeyPrefix) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisUserActiveSessionsKeyPrefixEnvVar) + case strings.TrimSpace(cfg.Redis.SessionLimitKey) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisSessionLimitKeyEnvVar) + case strings.TrimSpace(cfg.Redis.GatewaySessionCacheKeyPrefix) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisGatewaySessionCacheKeyPrefixEnvVar) + case strings.TrimSpace(cfg.Redis.GatewaySessionEventsStream) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisGatewaySessionEventsStreamEnvVar) + case cfg.Redis.GatewaySessionEventsStreamMaxLen <= 0: + return fmt.Errorf("load authsession config: %s must be positive", redisGatewaySessionEventsStreamMaxLenEnvVar) + case strings.TrimSpace(cfg.Redis.SendEmailCodeThrottleKeyPrefix) == "": + return fmt.Errorf("load authsession config: %s must not be empty", redisSendEmailCodeThrottleKeyPrefixEnvVar) + } + + if err := cfg.PublicHTTP.Validate(); err != nil { + return fmt.Errorf("load authsession config: public HTTP: %w", err) + } + if err := cfg.InternalHTTP.Validate(); err != nil { + return fmt.Errorf("load authsession config: internal HTTP: %w", err) + } + if err := cfg.UserService.Validate(); err != nil { + return fmt.Errorf("load authsession config: %w", err) + } + if err := cfg.MailService.Validate(); err != nil { + return fmt.Errorf("load authsession config: %w", err) + } + if err := cfg.Telemetry.Validate(); err != nil { + return fmt.Errorf("load authsession config: %w", err) + } + + return nil +} + +// Validate reports whether cfg contains a supported user-service runtime +// configuration. +func (cfg UserServiceConfig) Validate() error { + switch cfg.Mode { + case userServiceModeStub: + return nil + case userServiceModeREST: + if strings.TrimSpace(cfg.BaseURL) == "" { + return fmt.Errorf("%s must not be empty in rest mode", userServiceBaseURLEnvVar) + } + if cfg.RequestTimeout <= 0 { + return fmt.Errorf("%s must be positive in rest mode", userServiceRequestTimeoutEnvVar) + } + return nil + default: + return fmt.Errorf("%s %q is unsupported", userServiceModeEnvVar, cfg.Mode) + } +} + +// Validate reports whether cfg contains a supported mail-service runtime +// configuration. +func (cfg MailServiceConfig) Validate() error { + switch cfg.Mode { + case mailServiceModeStub: + return nil + case mailServiceModeREST: + if strings.TrimSpace(cfg.BaseURL) == "" { + return fmt.Errorf("%s must not be empty in rest mode", mailServiceBaseURLEnvVar) + } + if cfg.RequestTimeout <= 0 { + return fmt.Errorf("%s must be positive in rest mode", mailServiceRequestTimeoutEnvVar) + } + return nil + default: + return fmt.Errorf("%s %q is unsupported", mailServiceModeEnvVar, cfg.Mode) + } +} + +// Validate reports whether cfg contains a supported OpenTelemetry exporter +// configuration. +func (cfg TelemetryConfig) Validate() error { + switch cfg.TracesExporter { + case otelExporterNone, otelExporterOTLP: + default: + return fmt.Errorf("%s %q is unsupported", otelTracesExporterEnvVar, cfg.TracesExporter) + } + + switch cfg.MetricsExporter { + case otelExporterNone, otelExporterOTLP: + default: + return fmt.Errorf("%s %q is unsupported", otelMetricsExporterEnvVar, cfg.MetricsExporter) + } + + if cfg.TracesProtocol != "" && cfg.TracesProtocol != otelProtocolHTTPProtobuf && cfg.TracesProtocol != otelProtocolGRPC { + return fmt.Errorf("%s %q is unsupported", otelExporterOTLPTracesProtocolEnvVar, cfg.TracesProtocol) + } + if cfg.MetricsProtocol != "" && cfg.MetricsProtocol != otelProtocolHTTPProtobuf && cfg.MetricsProtocol != otelProtocolGRPC { + return fmt.Errorf("%s %q is unsupported", otelExporterOTLPMetricsProtocolEnvVar, cfg.MetricsProtocol) + } + + return nil +} + +func loadStringEnvWithDefault(name string, value string) string { + if raw, ok := os.LookupEnv(name); ok { + return strings.TrimSpace(raw) + } + + return value +} + +func loadDurationEnvWithDefault(name string, value time.Duration) (time.Duration, error) { + raw, ok := os.LookupEnv(name) + if !ok { + return value, nil + } + + parsed, err := time.ParseDuration(strings.TrimSpace(raw)) + if err != nil { + return 0, fmt.Errorf("%s: %w", name, err) + } + + return parsed, nil +} + +func loadIntEnvWithDefault(name string, value int) (int, error) { + raw, ok := os.LookupEnv(name) + if !ok { + return value, nil + } + + parsed, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil { + return 0, fmt.Errorf("%s: %w", name, err) + } + + return parsed, nil +} + +func loadInt64EnvWithDefault(name string, value int64) (int64, error) { + raw, ok := os.LookupEnv(name) + if !ok { + return value, nil + } + + parsed, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64) + if err != nil { + return 0, fmt.Errorf("%s: %w", name, err) + } + + return parsed, nil +} + +func loadBoolEnvWithDefault(name string, value bool) (bool, error) { + raw, ok := os.LookupEnv(name) + if !ok { + return value, nil + } + + parsed, err := strconv.ParseBool(strings.TrimSpace(raw)) + if err != nil { + return false, fmt.Errorf("%s: %w", name, err) + } + + return parsed, nil +} + +func validateLogLevel(value string) error { + var level zapcore.Level + if err := level.UnmarshalText([]byte(strings.TrimSpace(value))); err != nil { + return err + } + + return nil +} + +func normalizeExporterValue(value string) string { + switch strings.TrimSpace(value) { + case "", otelExporterNone: + return otelExporterNone + default: + return strings.TrimSpace(value) + } +} + +func loadOTLPProtocol(primary string, fallback string, exporter string) string { + protocol := strings.TrimSpace(primary) + if protocol == "" { + protocol = strings.TrimSpace(fallback) + } + if protocol == "" && exporter == otelExporterOTLP { + return otelProtocolHTTPProtobuf + } + + return protocol +} diff --git a/authsession/internal/config/config_test.go b/authsession/internal/config/config_test.go new file mode 100644 index 0000000..aafe427 --- /dev/null +++ b/authsession/internal/config/config_test.go @@ -0,0 +1,161 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadFromEnvUsesDefaults(t *testing.T) { + t.Setenv(redisAddrEnvVar, "127.0.0.1:6379") + + cfg, err := LoadFromEnv() + require.NoError(t, err) + + defaults := DefaultConfig() + assert.Equal(t, defaults.ShutdownTimeout, cfg.ShutdownTimeout) + assert.Equal(t, defaults.Logging.Level, cfg.Logging.Level) + assert.Equal(t, defaults.PublicHTTP, cfg.PublicHTTP) + assert.Equal(t, defaults.InternalHTTP, cfg.InternalHTTP) + assert.Equal(t, "127.0.0.1:6379", cfg.Redis.Addr) + assert.Equal(t, defaults.Redis.DB, cfg.Redis.DB) + assert.Equal(t, defaults.Redis.OperationTimeout, cfg.Redis.OperationTimeout) + assert.Equal(t, defaults.UserService, cfg.UserService) + assert.Equal(t, defaults.MailService, cfg.MailService) + assert.Equal(t, defaults.Telemetry.ServiceName, cfg.Telemetry.ServiceName) + assert.Equal(t, defaults.Telemetry.TracesExporter, cfg.Telemetry.TracesExporter) + assert.Equal(t, defaults.Telemetry.MetricsExporter, cfg.Telemetry.MetricsExporter) + assert.False(t, cfg.Telemetry.StdoutTracesEnabled) + assert.False(t, cfg.Telemetry.StdoutMetricsEnabled) +} + +func TestLoadFromEnvAppliesOverrides(t *testing.T) { + t.Setenv(shutdownTimeoutEnvVar, "9s") + t.Setenv(logLevelEnvVar, "debug") + t.Setenv(publicHTTPAddrEnvVar, "127.0.0.1:18080") + t.Setenv(internalHTTPAddrEnvVar, "127.0.0.1:18081") + t.Setenv(redisAddrEnvVar, "127.0.0.1:6380") + t.Setenv(redisUsernameEnvVar, "alice") + t.Setenv(redisPasswordEnvVar, "secret") + t.Setenv(redisDBEnvVar, "3") + t.Setenv(redisTLSEnabledEnvVar, "true") + t.Setenv(redisOperationTimeoutEnvVar, "750ms") + t.Setenv(userServiceModeEnvVar, "rest") + t.Setenv(userServiceBaseURLEnvVar, "http://127.0.0.1:19090") + t.Setenv(userServiceRequestTimeoutEnvVar, "900ms") + t.Setenv(mailServiceModeEnvVar, "rest") + t.Setenv(mailServiceBaseURLEnvVar, "http://127.0.0.1:19091") + t.Setenv(mailServiceRequestTimeoutEnvVar, "950ms") + t.Setenv(otelServiceNameEnvVar, "custom-authsession") + t.Setenv(otelTracesExporterEnvVar, "otlp") + t.Setenv(otelMetricsExporterEnvVar, "otlp") + t.Setenv(otelExporterOTLPProtocolEnvVar, "grpc") + t.Setenv(otelStdoutTracesEnabledEnvVar, "true") + t.Setenv(otelStdoutMetricsEnabledEnvVar, "true") + + cfg, err := LoadFromEnv() + require.NoError(t, err) + + assert.Equal(t, 9*time.Second, cfg.ShutdownTimeout) + assert.Equal(t, "debug", cfg.Logging.Level) + assert.Equal(t, "127.0.0.1:18080", cfg.PublicHTTP.Addr) + assert.Equal(t, "127.0.0.1:18081", cfg.InternalHTTP.Addr) + assert.Equal(t, "127.0.0.1:6380", cfg.Redis.Addr) + assert.Equal(t, "alice", cfg.Redis.Username) + assert.Equal(t, "secret", cfg.Redis.Password) + assert.Equal(t, 3, cfg.Redis.DB) + assert.True(t, cfg.Redis.TLSEnabled) + assert.Equal(t, 750*time.Millisecond, cfg.Redis.OperationTimeout) + assert.Equal(t, UserServiceConfig{ + Mode: "rest", + BaseURL: "http://127.0.0.1:19090", + RequestTimeout: 900 * time.Millisecond, + }, cfg.UserService) + assert.Equal(t, MailServiceConfig{ + Mode: "rest", + BaseURL: "http://127.0.0.1:19091", + RequestTimeout: 950 * time.Millisecond, + }, cfg.MailService) + assert.Equal(t, "custom-authsession", cfg.Telemetry.ServiceName) + assert.Equal(t, "otlp", cfg.Telemetry.TracesExporter) + assert.Equal(t, "otlp", cfg.Telemetry.MetricsExporter) + assert.Equal(t, "grpc", cfg.Telemetry.TracesProtocol) + assert.Equal(t, "grpc", cfg.Telemetry.MetricsProtocol) + assert.True(t, cfg.Telemetry.StdoutTracesEnabled) + assert.True(t, cfg.Telemetry.StdoutMetricsEnabled) +} + +func TestLoadFromEnvRejectsInvalidValues(t *testing.T) { + tests := []struct { + name string + envName string + envVal string + }{ + {name: "invalid duration", envName: shutdownTimeoutEnvVar, envVal: "later"}, + {name: "invalid bool", envName: otelStdoutTracesEnabledEnvVar, envVal: "sometimes"}, + {name: "invalid log level", envName: logLevelEnvVar, envVal: "verbose"}, + {name: "invalid traces protocol", envName: otelExporterOTLPTracesProtocolEnvVar, envVal: "udp"}, + {name: "invalid user service mode", envName: userServiceModeEnvVar, envVal: "grpc"}, + {name: "invalid user service timeout", envName: userServiceRequestTimeoutEnvVar, envVal: "never"}, + {name: "invalid mail service mode", envName: mailServiceModeEnvVar, envVal: "grpc"}, + {name: "invalid mail service timeout", envName: mailServiceRequestTimeoutEnvVar, envVal: "never"}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Setenv(redisAddrEnvVar, "127.0.0.1:6379") + t.Setenv(tt.envName, tt.envVal) + if tt.envName == otelExporterOTLPTracesProtocolEnvVar { + t.Setenv(otelTracesExporterEnvVar, "otlp") + } + + _, err := LoadFromEnv() + require.Error(t, err) + assert.Contains(t, err.Error(), tt.envName) + }) + } +} + +func TestLoadFromEnvRejectsInvalidRESTUserServiceConfiguration(t *testing.T) { + t.Setenv(redisAddrEnvVar, "127.0.0.1:6379") + t.Setenv(userServiceModeEnvVar, "rest") + + t.Run("missing base url", func(t *testing.T) { + _, err := LoadFromEnv() + require.Error(t, err) + assert.Contains(t, err.Error(), userServiceBaseURLEnvVar) + }) + + t.Run("non positive timeout", func(t *testing.T) { + t.Setenv(userServiceBaseURLEnvVar, "http://127.0.0.1:19090") + t.Setenv(userServiceRequestTimeoutEnvVar, "0s") + + _, err := LoadFromEnv() + require.Error(t, err) + assert.Contains(t, err.Error(), userServiceRequestTimeoutEnvVar) + }) +} + +func TestLoadFromEnvRejectsInvalidRESTMailServiceConfiguration(t *testing.T) { + t.Setenv(redisAddrEnvVar, "127.0.0.1:6379") + t.Setenv(mailServiceModeEnvVar, "rest") + + t.Run("missing base url", func(t *testing.T) { + _, err := LoadFromEnv() + require.Error(t, err) + assert.Contains(t, err.Error(), mailServiceBaseURLEnvVar) + }) + + t.Run("non positive timeout", func(t *testing.T) { + t.Setenv(mailServiceBaseURLEnvVar, "http://127.0.0.1:19091") + t.Setenv(mailServiceRequestTimeoutEnvVar, "0s") + + _, err := LoadFromEnv() + require.Error(t, err) + assert.Contains(t, err.Error(), mailServiceRequestTimeoutEnvVar) + }) +} diff --git a/authsession/internal/domain/challenge/model.go b/authsession/internal/domain/challenge/model.go new file mode 100644 index 0000000..4555bd1 --- /dev/null +++ b/authsession/internal/domain/challenge/model.go @@ -0,0 +1,342 @@ +// Package challenge defines the source-of-truth domain model for one e-mail +// confirmation challenge. +package challenge + +import ( + "errors" + "fmt" + "time" + + "galaxy/authsession/internal/domain/common" +) + +// Status identifies the coarse lifecycle state of one challenge. +type Status string + +const ( + // StatusPendingSend reports that the challenge has been created but its + // delivery outcome has not been recorded yet. + StatusPendingSend Status = "pending_send" + + // StatusSent reports that the confirmation code was delivered successfully. + StatusSent Status = "sent" + + // StatusDeliverySuppressed reports that outward send succeeded but actual + // delivery was intentionally suppressed by policy. + StatusDeliverySuppressed Status = "delivery_suppressed" + + // StatusDeliveryThrottled reports that a fresh challenge was created but + // delivery was skipped because the auth-side resend cooldown is still + // active. + StatusDeliveryThrottled Status = "delivery_throttled" + + // StatusConfirmedPendingExpire reports that the challenge was confirmed + // successfully and is temporarily retained for idempotent retry handling. + StatusConfirmedPendingExpire Status = "confirmed_pending_expire" + + // StatusExpired reports that the challenge can no longer be confirmed. + StatusExpired Status = "expired" + + // StatusFailed reports that the challenge reached a terminal failure state. + StatusFailed Status = "failed" + + // StatusCancelled reports that the challenge was cancelled explicitly. + StatusCancelled Status = "cancelled" +) + +// IsKnown reports whether Status is one of the challenge states supported by +// the current domain model. +func (s Status) IsKnown() bool { + switch s { + case StatusPendingSend, + StatusSent, + StatusDeliverySuppressed, + StatusDeliveryThrottled, + StatusConfirmedPendingExpire, + StatusExpired, + StatusFailed, + StatusCancelled: + return true + default: + return false + } +} + +// IsTerminal reports whether Status can no longer accept any lifecycle +// transition in the v1 challenge state machine. +func (s Status) IsTerminal() bool { + switch s { + case StatusExpired, StatusFailed, StatusCancelled: + return true + default: + return false + } +} + +// AcceptsFreshConfirm reports whether Status may still consume a first +// successful confirmation attempt. +func (s Status) AcceptsFreshConfirm() bool { + switch s { + case StatusSent, StatusDeliverySuppressed: + return true + default: + return false + } +} + +// IsConfirmedRetryState reports whether Status should use the idempotent retry +// path for a previously successful confirmation. +func (s Status) IsConfirmedRetryState() bool { + return s == StatusConfirmedPendingExpire +} + +// CanTransitionTo reports whether the current challenge Status may move to +// next under the coarse lifecycle rules fixed by Stage 2. +func (s Status) CanTransitionTo(next Status) bool { + switch s { + case StatusPendingSend: + switch next { + case StatusSent, StatusDeliverySuppressed, StatusDeliveryThrottled, StatusFailed, StatusCancelled, StatusExpired: + return true + } + case StatusSent, StatusDeliverySuppressed: + switch next { + case StatusConfirmedPendingExpire, StatusFailed, StatusCancelled, StatusExpired: + return true + } + case StatusConfirmedPendingExpire: + return next == StatusExpired + } + + return false +} + +// DeliveryState identifies the recorded delivery result of one challenge. +type DeliveryState string + +const ( + // DeliveryPending reports that no delivery outcome has been recorded yet. + DeliveryPending DeliveryState = "pending" + + // DeliverySent reports that the challenge code was sent successfully. + DeliverySent DeliveryState = "sent" + + // DeliverySuppressed reports that the outward flow stays success-shaped + // while actual delivery is intentionally skipped. + DeliverySuppressed DeliveryState = "suppressed" + + // DeliveryThrottled reports that the outward flow stays success-shaped + // while actual delivery is skipped because the resend cooldown is active. + DeliveryThrottled DeliveryState = "throttled" + + // DeliveryFailed reports that delivery was attempted and failed explicitly. + DeliveryFailed DeliveryState = "failed" +) + +// IsKnown reports whether DeliveryState is one of the delivery states +// supported by the current domain model. +func (s DeliveryState) IsKnown() bool { + switch s { + case DeliveryPending, DeliverySent, DeliverySuppressed, DeliveryThrottled, DeliveryFailed: + return true + default: + return false + } +} + +// CanTransitionTo reports whether the current DeliveryState may move to next +// under the coarse delivery rules fixed by Stage 2. +func (s DeliveryState) CanTransitionTo(next DeliveryState) bool { + if s != DeliveryPending { + return false + } + + switch next { + case DeliverySent, DeliverySuppressed, DeliveryThrottled, DeliveryFailed: + return true + default: + return false + } +} + +// AttemptCounters groups the mutable send and confirm counters tracked by one +// challenge aggregate. +type AttemptCounters struct { + // Send counts delivery attempts initiated for the challenge. + Send int + + // Confirm counts confirmation attempts evaluated against the challenge. + Confirm int +} + +// Validate reports whether AttemptCounters contains only non-negative values. +func (c AttemptCounters) Validate() error { + if c.Send < 0 { + return errors.New("challenge send attempt count must not be negative") + } + if c.Confirm < 0 { + return errors.New("challenge confirm attempt count must not be negative") + } + + return nil +} + +// AbuseMetadata stores minimal abuse-related timestamps without fixing later +// anti-abuse policy details too early. +type AbuseMetadata struct { + // LastAttemptAt optionally records the last send or confirm attempt time + // associated with the challenge. + LastAttemptAt *time.Time +} + +// Validate reports whether AbuseMetadata contains structurally valid values. +func (m AbuseMetadata) Validate() error { + if m.LastAttemptAt != nil && m.LastAttemptAt.IsZero() { + return errors.New("challenge abuse metadata last attempt time must not be zero") + } + + return nil +} + +// Confirmation stores the idempotency metadata recorded after a successful +// challenge confirmation. +type Confirmation struct { + // SessionID is the created device session returned by the successful + // confirmation. + SessionID common.DeviceSessionID + + // ClientPublicKey is the validated client key bound to SessionID. + ClientPublicKey common.ClientPublicKey + + // ConfirmedAt records when the successful confirmation happened. + ConfirmedAt time.Time +} + +// Validate reports whether Confirmation contains all metadata required for a +// confirmed challenge. +func (c Confirmation) Validate() error { + if err := c.SessionID.Validate(); err != nil { + return fmt.Errorf("challenge confirmation session id: %w", err) + } + if err := c.ClientPublicKey.Validate(); err != nil { + return fmt.Errorf("challenge confirmation client public key: %w", err) + } + if c.ConfirmedAt.IsZero() { + return errors.New("challenge confirmation time must not be zero") + } + + return nil +} + +// Challenge is the minimal source-of-truth aggregate shape fixed by Stage 2. +type Challenge struct { + // ID identifies the challenge. + ID common.ChallengeID + + // Email stores the normalized target e-mail address. + Email common.Email + + // CodeHash stores only the hashed confirmation code. + CodeHash []byte + + // Status reports the coarse challenge lifecycle state. + Status Status + + // DeliveryState reports the recorded delivery outcome. + DeliveryState DeliveryState + + // CreatedAt reports when the challenge was created. + CreatedAt time.Time + + // ExpiresAt reports when the challenge becomes unusable. + ExpiresAt time.Time + + // Attempts groups the send and confirm counters. + Attempts AttemptCounters + + // Abuse stores minimal abuse-related timestamps. + Abuse AbuseMetadata + + // Confirmation is present only after a successful confirm transition. + Confirmation *Confirmation +} + +// IsExpiredAt reports whether the challenge is unusable at now either because +// it is already marked expired or because its expiration timestamp has passed. +func (c Challenge) IsExpiredAt(now time.Time) bool { + return c.Status == StatusExpired || !c.ExpiresAt.After(now) +} + +// Validate reports whether Challenge satisfies the Stage-2 structural and +// lifecycle invariants. +func (c Challenge) Validate() error { + if err := c.ID.Validate(); err != nil { + return fmt.Errorf("challenge id: %w", err) + } + if err := c.Email.Validate(); err != nil { + return fmt.Errorf("challenge email: %w", err) + } + if len(c.CodeHash) == 0 { + return errors.New("challenge code hash must not be empty") + } + if !c.Status.IsKnown() { + return fmt.Errorf("challenge status %q is unsupported", c.Status) + } + if !c.DeliveryState.IsKnown() { + return fmt.Errorf("challenge delivery state %q is unsupported", c.DeliveryState) + } + if c.CreatedAt.IsZero() { + return errors.New("challenge creation time must not be zero") + } + if c.ExpiresAt.IsZero() { + return errors.New("challenge expiration time must not be zero") + } + if c.ExpiresAt.Before(c.CreatedAt) { + return errors.New("challenge expiration time must not be before creation time") + } + if err := c.Attempts.Validate(); err != nil { + return err + } + if err := c.Abuse.Validate(); err != nil { + return err + } + + switch c.Status { + case StatusPendingSend: + if c.DeliveryState != DeliveryPending { + return errors.New("pending_send challenge must keep pending delivery state") + } + case StatusSent: + if c.DeliveryState != DeliverySent { + return errors.New("sent challenge must keep sent delivery state") + } + case StatusDeliverySuppressed: + if c.DeliveryState != DeliverySuppressed { + return errors.New("delivery_suppressed challenge must keep suppressed delivery state") + } + case StatusDeliveryThrottled: + if c.DeliveryState != DeliveryThrottled { + return errors.New("delivery_throttled challenge must keep throttled delivery state") + } + case StatusConfirmedPendingExpire: + if c.DeliveryState != DeliverySent && c.DeliveryState != DeliverySuppressed { + return errors.New("confirmed_pending_expire challenge must come from sent or suppressed delivery state") + } + } + + if c.Status == StatusConfirmedPendingExpire { + if c.Confirmation == nil { + return errors.New("confirmed_pending_expire challenge must contain confirmation metadata") + } + if err := c.Confirmation.Validate(); err != nil { + return fmt.Errorf("challenge confirmation: %w", err) + } + return nil + } + + if c.Confirmation != nil { + return errors.New("only confirmed_pending_expire challenge may contain confirmation metadata") + } + + return nil +} diff --git a/authsession/internal/domain/challenge/model_test.go b/authsession/internal/domain/challenge/model_test.go new file mode 100644 index 0000000..8769f2c --- /dev/null +++ b/authsession/internal/domain/challenge/model_test.go @@ -0,0 +1,439 @@ +package challenge + +import ( + "crypto/ed25519" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" +) + +func TestPolicyConstants(t *testing.T) { + t.Parallel() + + if InitialTTL != 5*time.Minute { + require.Failf(t, "test failed", "InitialTTL = %s, want %s", InitialTTL, 5*time.Minute) + } + if ResendThrottleCooldown != time.Minute { + require.Failf(t, "test failed", "ResendThrottleCooldown = %s, want %s", ResendThrottleCooldown, time.Minute) + } + if ConfirmedRetention != 5*time.Minute { + require.Failf(t, "test failed", "ConfirmedRetention = %s, want %s", ConfirmedRetention, 5*time.Minute) + } + if MaxInvalidConfirmAttempts != 5 { + require.Failf(t, "test failed", "MaxInvalidConfirmAttempts = %d, want %d", MaxInvalidConfirmAttempts, 5) + } +} + +func TestStatusIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Status + want bool + }{ + {name: "pending send", value: StatusPendingSend, want: true}, + {name: "sent", value: StatusSent, want: true}, + {name: "suppressed", value: StatusDeliverySuppressed, want: true}, + {name: "throttled", value: StatusDeliveryThrottled, want: true}, + {name: "confirmed", value: StatusConfirmedPendingExpire, want: true}, + {name: "expired", value: StatusExpired, want: true}, + {name: "failed", value: StatusFailed, want: true}, + {name: "cancelled", value: StatusCancelled, want: true}, + {name: "unknown", value: Status("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStatusIsTerminal(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Status + want bool + }{ + {name: "pending send", value: StatusPendingSend, want: false}, + {name: "sent", value: StatusSent, want: false}, + {name: "delivery suppressed", value: StatusDeliverySuppressed, want: false}, + {name: "delivery throttled", value: StatusDeliveryThrottled, want: false}, + {name: "confirmed pending expire", value: StatusConfirmedPendingExpire, want: false}, + {name: "expired", value: StatusExpired, want: true}, + {name: "failed", value: StatusFailed, want: true}, + {name: "cancelled", value: StatusCancelled, want: true}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsTerminal(); got != tt.want { + require.Failf(t, "test failed", "IsTerminal() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStatusAcceptsFreshConfirm(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Status + want bool + }{ + {name: "pending send", value: StatusPendingSend, want: false}, + {name: "sent", value: StatusSent, want: true}, + {name: "delivery suppressed", value: StatusDeliverySuppressed, want: true}, + {name: "delivery throttled", value: StatusDeliveryThrottled, want: false}, + {name: "confirmed", value: StatusConfirmedPendingExpire, want: false}, + {name: "expired", value: StatusExpired, want: false}, + {name: "failed", value: StatusFailed, want: false}, + {name: "cancelled", value: StatusCancelled, want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.AcceptsFreshConfirm(); got != tt.want { + require.Failf(t, "test failed", "AcceptsFreshConfirm() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStatusIsConfirmedRetryState(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Status + want bool + }{ + {name: "sent", value: StatusSent, want: false}, + {name: "delivery suppressed", value: StatusDeliverySuppressed, want: false}, + {name: "delivery throttled", value: StatusDeliveryThrottled, want: false}, + {name: "confirmed", value: StatusConfirmedPendingExpire, want: true}, + {name: "expired", value: StatusExpired, want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsConfirmedRetryState(); got != tt.want { + require.Failf(t, "test failed", "IsConfirmedRetryState() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStatusCanTransitionTo(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + from Status + to Status + want bool + }{ + {name: "pending to sent", from: StatusPendingSend, to: StatusSent, want: true}, + {name: "pending to suppressed", from: StatusPendingSend, to: StatusDeliverySuppressed, want: true}, + {name: "pending to throttled", from: StatusPendingSend, to: StatusDeliveryThrottled, want: true}, + {name: "pending to failed", from: StatusPendingSend, to: StatusFailed, want: true}, + {name: "pending to cancelled", from: StatusPendingSend, to: StatusCancelled, want: true}, + {name: "pending to expired", from: StatusPendingSend, to: StatusExpired, want: true}, + {name: "pending to confirmed", from: StatusPendingSend, to: StatusConfirmedPendingExpire, want: false}, + {name: "sent to confirmed", from: StatusSent, to: StatusConfirmedPendingExpire, want: true}, + {name: "sent to failed", from: StatusSent, to: StatusFailed, want: true}, + {name: "suppressed to confirmed", from: StatusDeliverySuppressed, to: StatusConfirmedPendingExpire, want: true}, + {name: "throttled to confirmed", from: StatusDeliveryThrottled, to: StatusConfirmedPendingExpire, want: false}, + {name: "confirmed to expired", from: StatusConfirmedPendingExpire, to: StatusExpired, want: true}, + {name: "confirmed to failed", from: StatusConfirmedPendingExpire, to: StatusFailed, want: false}, + {name: "expired terminal", from: StatusExpired, to: StatusCancelled, want: false}, + {name: "failed terminal", from: StatusFailed, to: StatusExpired, want: false}, + {name: "cancelled terminal", from: StatusCancelled, to: StatusExpired, want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.from.CanTransitionTo(tt.to); got != tt.want { + require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDeliveryStateIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value DeliveryState + want bool + }{ + {name: "pending", value: DeliveryPending, want: true}, + {name: "sent", value: DeliverySent, want: true}, + {name: "suppressed", value: DeliverySuppressed, want: true}, + {name: "throttled", value: DeliveryThrottled, want: true}, + {name: "failed", value: DeliveryFailed, want: true}, + {name: "unknown", value: DeliveryState("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDeliveryStateCanTransitionTo(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + from DeliveryState + to DeliveryState + want bool + }{ + {name: "pending to sent", from: DeliveryPending, to: DeliverySent, want: true}, + {name: "pending to suppressed", from: DeliveryPending, to: DeliverySuppressed, want: true}, + {name: "pending to throttled", from: DeliveryPending, to: DeliveryThrottled, want: true}, + {name: "pending to failed", from: DeliveryPending, to: DeliveryFailed, want: true}, + {name: "sent terminal", from: DeliverySent, to: DeliveryFailed, want: false}, + {name: "suppressed terminal", from: DeliverySuppressed, to: DeliverySent, want: false}, + {name: "failed terminal", from: DeliveryFailed, to: DeliverySent, want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.from.CanTransitionTo(tt.to); got != tt.want { + require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChallengeIsExpiredAt(t *testing.T) { + t.Parallel() + + now := time.Unix(1_775_121_700, 0).UTC() + tests := []struct { + name string + mutate func(*Challenge) + want bool + }{ + {name: "active before expiration", want: false}, + { + name: "expired status", + mutate: func(c *Challenge) { + c.Status = StatusExpired + }, + want: true, + }, + { + name: "expiration timestamp passed", + mutate: func(c *Challenge) { + c.ExpiresAt = now + }, + want: true, + }, + { + name: "confirmed retained before expiration", + mutate: func(c *Challenge) { + c.Status = StatusConfirmedPendingExpire + c.DeliveryState = DeliverySent + c.Confirmation = validConfirmation(t) + c.ExpiresAt = now.Add(time.Second) + }, + want: false, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + challenge := validChallenge(t) + challenge.CreatedAt = now.Add(-time.Minute) + challenge.ExpiresAt = now.Add(time.Minute) + if tt.mutate != nil { + tt.mutate(&challenge) + } + if err := challenge.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + + if got := challenge.IsExpiredAt(now); got != tt.want { + require.Failf(t, "test failed", "IsExpiredAt() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChallengeValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*Challenge) + wantErr bool + }{ + {name: "valid pending"}, + { + name: "valid confirmed", + mutate: func(c *Challenge) { + c.Status = StatusConfirmedPendingExpire + c.DeliveryState = DeliverySent + c.Confirmation = validConfirmation(t) + }, + }, + { + name: "confirmed requires metadata", + mutate: func(c *Challenge) { + c.Status = StatusConfirmedPendingExpire + c.DeliveryState = DeliverySent + }, + wantErr: true, + }, + { + name: "unconfirmed rejects metadata", + mutate: func(c *Challenge) { + c.Confirmation = validConfirmation(t) + }, + wantErr: true, + }, + { + name: "pending requires pending delivery", + mutate: func(c *Challenge) { + c.DeliveryState = DeliverySent + }, + wantErr: true, + }, + { + name: "sent requires sent delivery", + mutate: func(c *Challenge) { + c.Status = StatusSent + c.DeliveryState = DeliverySuppressed + }, + wantErr: true, + }, + { + name: "throttled requires throttled delivery", + mutate: func(c *Challenge) { + c.Status = StatusDeliveryThrottled + c.DeliveryState = DeliverySent + }, + wantErr: true, + }, + { + name: "expiration before creation", + mutate: func(c *Challenge) { + c.ExpiresAt = c.CreatedAt.Add(-time.Second) + }, + wantErr: true, + }, + { + name: "negative confirm attempts", + mutate: func(c *Challenge) { + c.Attempts.Confirm = -1 + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + challenge := validChallenge(t) + if tt.mutate != nil { + tt.mutate(&challenge) + } + + err := challenge.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} + +func validChallenge(t *testing.T) Challenge { + t.Helper() + + return Challenge{ + ID: common.ChallengeID("challenge-123"), + Email: common.Email("pilot@example.com"), + CodeHash: []byte("hash-123"), + Status: StatusPendingSend, + DeliveryState: DeliveryPending, + CreatedAt: time.Unix(1_775_121_600, 0).UTC(), + ExpiresAt: time.Unix(1_775_121_900, 0).UTC(), + Attempts: AttemptCounters{ + Send: 0, + Confirm: 0, + }, + } +} + +func validConfirmation(t *testing.T) *Confirmation { + t.Helper() + + raw := make(ed25519.PublicKey, ed25519.PublicKeySize) + for index := range raw { + raw[index] = byte(index + 1) + } + + key, err := common.NewClientPublicKey(raw) + if err != nil { + require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err) + } + + return &Confirmation{ + SessionID: common.DeviceSessionID("device-session-123"), + ClientPublicKey: key, + ConfirmedAt: time.Unix(1_775_121_700, 0).UTC(), + } +} diff --git a/authsession/internal/domain/challenge/policy.go b/authsession/internal/domain/challenge/policy.go new file mode 100644 index 0000000..62ec1e4 --- /dev/null +++ b/authsession/internal/domain/challenge/policy.go @@ -0,0 +1,26 @@ +package challenge + +import "time" + +const ( + // InitialTTL is the v1 lifetime of a newly created challenge before it + // becomes expired. + InitialTTL = 5 * time.Minute + + // ResendThrottleCooldown is the fixed Stage-17 cooldown applied to repeated + // public send-email-code requests for the same normalized e-mail address. + ResendThrottleCooldown = time.Minute + + // ConfirmedRetention is the v1 idempotency window kept after a successful + // challenge confirmation. + ConfirmedRetention = 5 * time.Minute + + // MaxInvalidConfirmAttempts is the v1 threshold after which repeated invalid + // confirmation codes move a challenge into the failed state. + MaxInvalidConfirmAttempts = 5 +) + +// V1 resend policy keeps every public send-email-code request independent: +// each call creates a fresh challenge, existing challenges are not reused or +// deduplicated, and Stage 17 adds a fixed auth-side resend cooldown that may +// record the fresh challenge as delivery_throttled. diff --git a/authsession/internal/domain/common/types.go b/authsession/internal/domain/common/types.go new file mode 100644 index 0000000..b698e31 --- /dev/null +++ b/authsession/internal/domain/common/types.go @@ -0,0 +1,201 @@ +// Package common defines small shared domain primitives used by auth/session +// aggregates and integration models. +package common + +import ( + "bytes" + "crypto/ed25519" + "encoding/base64" + "errors" + "fmt" + "net/mail" + "strings" +) + +// ChallengeID identifies one auth confirmation challenge owned by the service. +type ChallengeID string + +// String returns ChallengeID as a plain string identifier. +func (id ChallengeID) String() string { + return string(id) +} + +// IsZero reports whether ChallengeID does not contain a usable identifier. +func (id ChallengeID) IsZero() bool { + return strings.TrimSpace(string(id)) == "" +} + +// Validate reports whether ChallengeID is non-empty and already normalized for +// domain use. +func (id ChallengeID) Validate() error { + return validateToken("challenge id", string(id)) +} + +// DeviceSessionID identifies one persisted device session. +type DeviceSessionID string + +// String returns DeviceSessionID as a plain string identifier. +func (id DeviceSessionID) String() string { + return string(id) +} + +// IsZero reports whether DeviceSessionID does not contain a usable identifier. +func (id DeviceSessionID) IsZero() bool { + return strings.TrimSpace(string(id)) == "" +} + +// Validate reports whether DeviceSessionID is non-empty and already +// normalized for domain use. +func (id DeviceSessionID) Validate() error { + return validateToken("device session id", string(id)) +} + +// UserID identifies one user resolved through the user-service boundary. +type UserID string + +// String returns UserID as a plain string identifier. +func (id UserID) String() string { + return string(id) +} + +// IsZero reports whether UserID does not contain a usable identifier. +func (id UserID) IsZero() bool { + return strings.TrimSpace(string(id)) == "" +} + +// Validate reports whether UserID is non-empty and already normalized for +// domain use. +func (id UserID) Validate() error { + return validateToken("user id", string(id)) +} + +// Email stores one already-normalized e-mail address used by the auth domain. +type Email string + +// String returns Email as the stored canonical e-mail string. +func (e Email) String() string { + return string(e) +} + +// IsZero reports whether Email does not contain a usable e-mail value. +func (e Email) IsZero() bool { + return strings.TrimSpace(string(e)) == "" +} + +// Validate reports whether Email is non-empty, does not contain surrounding +// whitespace, and matches the same single-address syntax expected by the +// public gateway contract. +func (e Email) Validate() error { + raw := string(e) + if err := validateToken("email", raw); err != nil { + return err + } + + parsedAddress, err := mail.ParseAddress(raw) + if err != nil || parsedAddress.Name != "" || parsedAddress.Address != raw { + return fmt.Errorf("email %q must be a single valid email address", raw) + } + + return nil +} + +// RevokeReasonCode stores one machine-readable revoke reason code. +type RevokeReasonCode string + +// String returns RevokeReasonCode as its stored code value. +func (code RevokeReasonCode) String() string { + return string(code) +} + +// IsZero reports whether RevokeReasonCode is empty. +func (code RevokeReasonCode) IsZero() bool { + return strings.TrimSpace(string(code)) == "" +} + +// Validate reports whether RevokeReasonCode is non-empty and normalized for +// domain use. +func (code RevokeReasonCode) Validate() error { + return validateToken("revoke reason code", string(code)) +} + +// RevokeActorType stores one machine-readable actor type for revoke audit. +type RevokeActorType string + +// String returns RevokeActorType as its stored type value. +func (actorType RevokeActorType) String() string { + return string(actorType) +} + +// IsZero reports whether RevokeActorType is empty. +func (actorType RevokeActorType) IsZero() bool { + return strings.TrimSpace(string(actorType)) == "" +} + +// Validate reports whether RevokeActorType is non-empty and normalized for +// domain use. +func (actorType RevokeActorType) Validate() error { + return validateToken("revoke actor type", string(actorType)) +} + +// ClientPublicKey stores one validated Ed25519 public key in parsed binary +// form inside the domain model. +type ClientPublicKey struct { + value ed25519.PublicKey +} + +// NewClientPublicKey validates value and returns a defensive copy suitable for +// storing inside domain aggregates. +func NewClientPublicKey(value ed25519.PublicKey) (ClientPublicKey, error) { + key := ClientPublicKey{ + value: bytes.Clone(value), + } + if err := key.Validate(); err != nil { + return ClientPublicKey{}, err + } + + return key, nil +} + +// String returns ClientPublicKey as the standard base64-encoded raw 32-byte +// Ed25519 public key string. +func (key ClientPublicKey) String() string { + if key.IsZero() { + return "" + } + + return base64.StdEncoding.EncodeToString(key.value) +} + +// IsZero reports whether ClientPublicKey does not contain key material. +func (key ClientPublicKey) IsZero() bool { + return len(key.value) == 0 +} + +// Validate reports whether ClientPublicKey contains exactly one Ed25519 public +// key. +func (key ClientPublicKey) Validate() error { + switch len(key.value) { + case 0: + return errors.New("client public key must not be empty") + case ed25519.PublicKeySize: + return nil + default: + return fmt.Errorf("client public key must contain exactly %d bytes", ed25519.PublicKeySize) + } +} + +// PublicKey returns a defensive copy of the parsed Ed25519 public key. +func (key ClientPublicKey) PublicKey() ed25519.PublicKey { + return bytes.Clone(key.value) +} + +func validateToken(name string, value string) error { + switch { + case strings.TrimSpace(value) == "": + return fmt.Errorf("%s must not be empty", name) + case strings.TrimSpace(value) != value: + return fmt.Errorf("%s must not contain surrounding whitespace", name) + default: + return nil + } +} diff --git a/authsession/internal/domain/common/types_test.go b/authsession/internal/domain/common/types_test.go new file mode 100644 index 0000000..cbcf3ef --- /dev/null +++ b/authsession/internal/domain/common/types_test.go @@ -0,0 +1,133 @@ +package common + +import ( + "crypto/ed25519" + "github.com/stretchr/testify/require" + "testing" +) + +func TestChallengeIDValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value ChallengeID + wantErr bool + }{ + {name: "valid", value: ChallengeID("challenge-123")}, + {name: "empty", value: ChallengeID(""), wantErr: true}, + {name: "whitespace", value: ChallengeID(" challenge-123 "), wantErr: true}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.value.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} + +func TestEmailValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Email + wantErr bool + }{ + {name: "valid", value: Email("pilot@example.com")}, + {name: "invalid", value: Email("pilot"), wantErr: true}, + {name: "surrounding whitespace", value: Email(" pilot@example.com "), wantErr: true}, + {name: "display name", value: Email("Pilot "), wantErr: true}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.value.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} + +func TestNewClientPublicKey(t *testing.T) { + t.Parallel() + + raw := make(ed25519.PublicKey, ed25519.PublicKeySize) + for i := range raw { + raw[i] = byte(i) + } + + key, err := NewClientPublicKey(raw) + if err != nil { + require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err) + } + + if key.IsZero() { + require.FailNow(t, "IsZero() = true, want false") + } + + cloned := key.PublicKey() + if len(cloned) != ed25519.PublicKeySize { + require.Failf(t, "test failed", "PublicKey() length = %d, want %d", len(cloned), ed25519.PublicKeySize) + } + + raw[0] = 99 + if key.PublicKey()[0] == 99 { + require.FailNow(t, "PublicKey() was mutated through constructor input") + } +} + +func TestClientPublicKeyValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value ClientPublicKey + wantErr bool + }{ + {name: "empty", value: ClientPublicKey{}, wantErr: true}, + { + name: "short", + value: ClientPublicKey{value: make(ed25519.PublicKey, ed25519.PublicKeySize-1)}, + wantErr: true, + }, + { + name: "valid", + value: ClientPublicKey{value: make(ed25519.PublicKey, ed25519.PublicKeySize)}, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.value.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} diff --git a/authsession/internal/domain/devicesession/model.go b/authsession/internal/domain/devicesession/model.go new file mode 100644 index 0000000..ca8c84d --- /dev/null +++ b/authsession/internal/domain/devicesession/model.go @@ -0,0 +1,162 @@ +// Package devicesession defines the source-of-truth domain model for one +// authenticated device session. +package devicesession + +import ( + "errors" + "fmt" + "strings" + "time" + + "galaxy/authsession/internal/domain/common" +) + +// Status identifies the coarse lifecycle state of one device session. +type Status string + +const ( + // StatusActive reports that the session may be used for authenticated + // request verification. + StatusActive Status = "active" + + // StatusRevoked reports that the session has been revoked and must no + // longer authenticate requests. + StatusRevoked Status = "revoked" +) + +// RevokeReasonDeviceLogout reports that one device logged itself out. +const RevokeReasonDeviceLogout common.RevokeReasonCode = "device_logout" + +// RevokeReasonLogoutAll reports that the session was revoked by a +// user-scoped logout-all action. +const RevokeReasonLogoutAll common.RevokeReasonCode = "logout_all" + +// RevokeReasonAdminRevoke reports that the session was revoked +// administratively. +const RevokeReasonAdminRevoke common.RevokeReasonCode = "admin_revoke" + +// RevokeReasonUserBlocked reports that the session was revoked because future +// auth flow for the user or e-mail was blocked. +const RevokeReasonUserBlocked common.RevokeReasonCode = "user_blocked" + +// IsKnown reports whether Status is one of the device-session states +// supported by the current domain model. +func (s Status) IsKnown() bool { + switch s { + case StatusActive, StatusRevoked: + return true + default: + return false + } +} + +// CanTransitionTo reports whether the current device-session Status may move +// to next under the Stage-2 lifecycle rules. +func (s Status) CanTransitionTo(next Status) bool { + return s == StatusActive && next == StatusRevoked +} + +// IsKnownRevokeReasonCode reports whether code is one of the built-in revoke +// reasons fixed by the Stage-2 domain model. +func IsKnownRevokeReasonCode(code common.RevokeReasonCode) bool { + switch code { + case RevokeReasonDeviceLogout, + RevokeReasonLogoutAll, + RevokeReasonAdminRevoke, + RevokeReasonUserBlocked: + return true + default: + return false + } +} + +// Revocation stores the audit metadata recorded when a session is revoked. +type Revocation struct { + // At reports when the revoke took effect. + At time.Time + + // ReasonCode stores one machine-readable revoke reason code. + ReasonCode common.RevokeReasonCode + + // ActorType stores one machine-readable initiator type. + ActorType common.RevokeActorType + + // ActorID optionally stores a stable initiator identifier. + ActorID string +} + +// Validate reports whether Revocation contains all metadata required for a +// revoked session. +func (r Revocation) Validate() error { + if r.At.IsZero() { + return errors.New("session revocation time must not be zero") + } + if err := r.ReasonCode.Validate(); err != nil { + return fmt.Errorf("session revocation reason code: %w", err) + } + if err := r.ActorType.Validate(); err != nil { + return fmt.Errorf("session revocation actor type: %w", err) + } + if strings.TrimSpace(r.ActorID) != r.ActorID { + return errors.New("session revocation actor id must not contain surrounding whitespace") + } + + return nil +} + +// Session is the minimal source-of-truth aggregate shape fixed by Stage 2. +type Session struct { + // ID identifies the device session. + ID common.DeviceSessionID + + // UserID identifies the durable user linkage for the session. + UserID common.UserID + + // ClientPublicKey stores the validated device public key in parsed form. + ClientPublicKey common.ClientPublicKey + + // Status reports the coarse lifecycle state of the session. + Status Status + + // CreatedAt reports when the session was created. + CreatedAt time.Time + + // Revocation is present only when Status is StatusRevoked. + Revocation *Revocation +} + +// Validate reports whether Session satisfies the Stage-2 structural and +// lifecycle invariants. +func (s Session) Validate() error { + if err := s.ID.Validate(); err != nil { + return fmt.Errorf("session id: %w", err) + } + if err := s.UserID.Validate(); err != nil { + return fmt.Errorf("session user id: %w", err) + } + if err := s.ClientPublicKey.Validate(); err != nil { + return fmt.Errorf("session client public key: %w", err) + } + if !s.Status.IsKnown() { + return fmt.Errorf("session status %q is unsupported", s.Status) + } + if s.CreatedAt.IsZero() { + return errors.New("session creation time must not be zero") + } + + switch s.Status { + case StatusActive: + if s.Revocation != nil { + return errors.New("active session must not contain revocation metadata") + } + case StatusRevoked: + if s.Revocation == nil { + return errors.New("revoked session must contain revocation metadata") + } + if err := s.Revocation.Validate(); err != nil { + return fmt.Errorf("session revocation: %w", err) + } + } + + return nil +} diff --git a/authsession/internal/domain/devicesession/model_test.go b/authsession/internal/domain/devicesession/model_test.go new file mode 100644 index 0000000..e60d7d6 --- /dev/null +++ b/authsession/internal/domain/devicesession/model_test.go @@ -0,0 +1,186 @@ +package devicesession + +import ( + "crypto/ed25519" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" +) + +func TestStatusIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Status + want bool + }{ + {name: "active", value: StatusActive, want: true}, + {name: "revoked", value: StatusRevoked, want: true}, + {name: "unknown", value: Status("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStatusCanTransitionTo(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + from Status + to Status + want bool + }{ + {name: "active to revoked", from: StatusActive, to: StatusRevoked, want: true}, + {name: "active to active", from: StatusActive, to: StatusActive, want: false}, + {name: "revoked terminal", from: StatusRevoked, to: StatusActive, want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.from.CanTransitionTo(tt.to); got != tt.want { + require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsKnownRevokeReasonCode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value common.RevokeReasonCode + want bool + }{ + {name: "device logout", value: RevokeReasonDeviceLogout, want: true}, + {name: "logout all", value: RevokeReasonLogoutAll, want: true}, + {name: "admin revoke", value: RevokeReasonAdminRevoke, want: true}, + {name: "user blocked", value: RevokeReasonUserBlocked, want: true}, + {name: "custom code", value: common.RevokeReasonCode("custom_policy"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := IsKnownRevokeReasonCode(tt.value); got != tt.want { + require.Failf(t, "test failed", "IsKnownRevokeReasonCode() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSessionValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*Session) + wantErr bool + }{ + {name: "active valid"}, + { + name: "revoked valid", + mutate: func(s *Session) { + s.Status = StatusRevoked + s.Revocation = validRevocation() + }, + }, + { + name: "active rejects revocation", + mutate: func(s *Session) { + s.Revocation = validRevocation() + }, + wantErr: true, + }, + { + name: "revoked requires revocation", + mutate: func(s *Session) { + s.Status = StatusRevoked + }, + wantErr: true, + }, + { + name: "revoked requires complete metadata", + mutate: func(s *Session) { + s.Status = StatusRevoked + revocation := validRevocation() + revocation.ReasonCode = "" + s.Revocation = revocation + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + session := validSession(t) + if tt.mutate != nil { + tt.mutate(&session) + } + + err := session.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} + +func validSession(t *testing.T) Session { + t.Helper() + + raw := make(ed25519.PublicKey, ed25519.PublicKeySize) + for index := range raw { + raw[index] = byte(index + 7) + } + + key, err := common.NewClientPublicKey(raw) + if err != nil { + require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err) + } + + return Session{ + ID: common.DeviceSessionID("device-session-123"), + UserID: common.UserID("user-123"), + ClientPublicKey: key, + Status: StatusActive, + CreatedAt: time.Unix(1_775_121_600, 0).UTC(), + } +} + +func validRevocation() *Revocation { + return &Revocation{ + At: time.Unix(1_775_121_800, 0).UTC(), + ReasonCode: RevokeReasonAdminRevoke, + ActorType: common.RevokeActorType("admin"), + ActorID: "admin-123", + } +} diff --git a/authsession/internal/domain/gatewayprojection/model.go b/authsession/internal/domain/gatewayprojection/model.go new file mode 100644 index 0000000..180926a --- /dev/null +++ b/authsession/internal/domain/gatewayprojection/model.go @@ -0,0 +1,141 @@ +// Package gatewayprojection defines the gateway-facing integration snapshot +// model that stays separate from source-of-truth session entities. +package gatewayprojection + +import ( + "crypto/ed25519" + "encoding/base64" + "errors" + "fmt" + "strings" + "time" + + "galaxy/authsession/internal/domain/common" +) + +// Status identifies the coarse lifecycle state projected to the gateway. +type Status string + +const ( + // StatusActive reports that the projected session may authenticate + // requests on the gateway hot path. + StatusActive Status = "active" + + // StatusRevoked reports that the projected session must be rejected on the + // gateway hot path. + StatusRevoked Status = "revoked" +) + +// IsKnown reports whether Status is one of the projection states supported by +// the current integration model. +func (s Status) IsKnown() bool { + switch s { + case StatusActive, StatusRevoked: + return true + default: + return false + } +} + +// Snapshot stores the gateway-facing session projection without exposing any +// Redis-specific field naming or storage encoding. +type Snapshot struct { + // DeviceSessionID identifies the projected device session. + DeviceSessionID common.DeviceSessionID + + // UserID identifies the projected user. + UserID common.UserID + + // ClientPublicKey stores the standard base64-encoded raw 32-byte Ed25519 + // public key string expected by the gateway. + ClientPublicKey string + + // Status reports whether the projected session is active or revoked. + Status Status + + // RevokedAt optionally reports when the revoke took effect. + RevokedAt *time.Time + + // RevokeReasonCode optionally stores the machine-readable revoke reason. + RevokeReasonCode common.RevokeReasonCode + + // RevokeActorType optionally stores the machine-readable revoke actor type. + RevokeActorType common.RevokeActorType + + // RevokeActorID optionally stores a stable revoke actor identifier. + RevokeActorID string +} + +// Validate reports whether Snapshot satisfies the Stage-2 structural +// invariants. +func (s Snapshot) Validate() error { + if err := s.DeviceSessionID.Validate(); err != nil { + return fmt.Errorf("gateway projection device session id: %w", err) + } + if err := s.UserID.Validate(); err != nil { + return fmt.Errorf("gateway projection user id: %w", err) + } + if err := validateClientPublicKey(s.ClientPublicKey); err != nil { + return fmt.Errorf("gateway projection client public key: %w", err) + } + if !s.Status.IsKnown() { + return fmt.Errorf("gateway projection status %q is unsupported", s.Status) + } + + if s.Status == StatusActive { + if s.RevokedAt != nil { + return errors.New("active gateway projection must not contain revoked time") + } + if !s.RevokeReasonCode.IsZero() { + return errors.New("active gateway projection must not contain revoke reason code") + } + if !s.RevokeActorType.IsZero() { + return errors.New("active gateway projection must not contain revoke actor type") + } + if s.RevokeActorID != "" { + return errors.New("active gateway projection must not contain revoke actor id") + } + return nil + } + + if s.RevokedAt != nil && s.RevokedAt.IsZero() { + return errors.New("gateway projection revoked time must not be zero") + } + if !s.RevokeReasonCode.IsZero() { + if err := s.RevokeReasonCode.Validate(); err != nil { + return fmt.Errorf("gateway projection revoke reason code: %w", err) + } + } + if !s.RevokeActorType.IsZero() { + if err := s.RevokeActorType.Validate(); err != nil { + return fmt.Errorf("gateway projection revoke actor type: %w", err) + } + } + if s.RevokeActorType.IsZero() && s.RevokeActorID != "" { + return errors.New("gateway projection revoke actor id requires revoke actor type") + } + if strings.TrimSpace(s.RevokeActorID) != s.RevokeActorID { + return errors.New("gateway projection revoke actor id must not contain surrounding whitespace") + } + + return nil +} + +func validateClientPublicKey(value string) error { + switch { + case strings.TrimSpace(value) == "": + return errors.New("client public key must not be empty") + case strings.TrimSpace(value) != value: + return errors.New("client public key must not contain surrounding whitespace") + } + + decoded, err := base64.StdEncoding.DecodeString(value) + if err != nil { + return fmt.Errorf("client public key must be valid base64: %w", err) + } + if len(decoded) != ed25519.PublicKeySize { + return fmt.Errorf("client public key must contain exactly %d bytes", ed25519.PublicKeySize) + } + + return nil +} diff --git a/authsession/internal/domain/gatewayprojection/model_test.go b/authsession/internal/domain/gatewayprojection/model_test.go new file mode 100644 index 0000000..ad6479c --- /dev/null +++ b/authsession/internal/domain/gatewayprojection/model_test.go @@ -0,0 +1,146 @@ +package gatewayprojection + +import ( + "crypto/ed25519" + "encoding/base64" + "github.com/stretchr/testify/require" + "reflect" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" +) + +func TestStatusIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Status + want bool + }{ + {name: "active", value: StatusActive, want: true}, + {name: "revoked", value: StatusRevoked, want: true}, + {name: "unknown", value: Status("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSnapshotValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*Snapshot) + wantErr bool + }{ + {name: "active valid"}, + { + name: "revoked valid", + mutate: func(snapshot *Snapshot) { + snapshot.Status = StatusRevoked + revokedAt := time.Unix(1_775_121_900, 0).UTC() + snapshot.RevokedAt = &revokedAt + snapshot.RevokeReasonCode = common.RevokeReasonCode("admin_revoke") + snapshot.RevokeActorType = common.RevokeActorType("admin") + snapshot.RevokeActorID = "admin-123" + }, + }, + { + name: "active rejects revoke metadata", + mutate: func(snapshot *Snapshot) { + snapshot.RevokeReasonCode = common.RevokeReasonCode("admin_revoke") + }, + wantErr: true, + }, + { + name: "invalid key encoding", + mutate: func(snapshot *Snapshot) { + snapshot.ClientPublicKey = "not-base64" + }, + wantErr: true, + }, + { + name: "actor id requires actor type", + mutate: func(snapshot *Snapshot) { + snapshot.Status = StatusRevoked + snapshot.RevokeActorID = "admin-123" + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + snapshot := validSnapshot() + if tt.mutate != nil { + tt.mutate(&snapshot) + } + + err := snapshot.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} + +func TestSnapshotStaysSeparateFromSessionDomainShape(t *testing.T) { + t.Parallel() + + snapshotType := reflect.TypeOf(Snapshot{}) + sessionType := reflect.TypeOf(devicesession.Session{}) + + clientPublicKeyField, ok := snapshotType.FieldByName("ClientPublicKey") + if !ok { + require.FailNow(t, "Snapshot is missing ClientPublicKey field") + } + if clientPublicKeyField.Type.Kind() != reflect.String { + require.Failf(t, "test failed", "Snapshot.ClientPublicKey kind = %s, want string", clientPublicKeyField.Type.Kind()) + } + + sessionClientPublicKeyField, ok := sessionType.FieldByName("ClientPublicKey") + if !ok { + require.FailNow(t, "devicesession.Session is missing ClientPublicKey field") + } + if clientPublicKeyField.Type == sessionClientPublicKeyField.Type { + require.FailNow(t, "Snapshot.ClientPublicKey must stay separate from devicesession.Session.ClientPublicKey type") + } + + if _, ok := snapshotType.FieldByName("RevokedAtMS"); ok { + require.FailNow(t, "Snapshot must not expose Redis-specific RevokedAtMS field") + } +} + +func validSnapshot() Snapshot { + raw := make(ed25519.PublicKey, ed25519.PublicKeySize) + for index := range raw { + raw[index] = byte(index + 17) + } + + return Snapshot{ + DeviceSessionID: common.DeviceSessionID("device-session-123"), + UserID: common.UserID("user-123"), + ClientPublicKey: base64.StdEncoding.EncodeToString(raw), + Status: StatusActive, + } +} diff --git a/authsession/internal/domain/sessionlimit/model.go b/authsession/internal/domain/sessionlimit/model.go new file mode 100644 index 0000000..f1148a7 --- /dev/null +++ b/authsession/internal/domain/sessionlimit/model.go @@ -0,0 +1,89 @@ +// Package sessionlimit defines the domain decision shape used for active +// device-session limit evaluation. +package sessionlimit + +import ( + "errors" + "fmt" +) + +// Kind identifies the coarse outcome of evaluating the active-session limit. +type Kind string + +const ( + // KindDisabled reports that no configured limit is currently active. + KindDisabled Kind = "disabled" + + // KindAllowed reports that creating the next session is allowed. + KindAllowed Kind = "allowed" + + // KindExceeded reports that creating the next session would exceed the + // configured limit. + KindExceeded Kind = "exceeded" +) + +// IsKnown reports whether Kind is one of the session-limit outcomes supported +// by the current domain model. +func (k Kind) IsKnown() bool { + switch k { + case KindDisabled, KindAllowed, KindExceeded: + return true + default: + return false + } +} + +// Decision stores the result of evaluating one possible next session creation. +type Decision struct { + // Kind reports the coarse decision outcome. + Kind Kind + + // ConfiguredLimit stores the active configured limit when one exists. + ConfiguredLimit *int + + // ActiveSessionCount stores the current active-session count before create. + ActiveSessionCount int + + // NextSessionCount stores the count that would exist after creating the next + // session. + NextSessionCount int +} + +// Validate reports whether Decision satisfies the Stage-2 structural +// invariants. +func (d Decision) Validate() error { + if !d.Kind.IsKnown() { + return fmt.Errorf("session-limit decision kind %q is unsupported", d.Kind) + } + if d.ActiveSessionCount < 0 { + return errors.New("session-limit active session count must not be negative") + } + if d.NextSessionCount < 0 { + return errors.New("session-limit next session count must not be negative") + } + if d.NextSessionCount != d.ActiveSessionCount+1 { + return errors.New("session-limit next session count must equal active session count plus one") + } + + switch d.Kind { + case KindDisabled: + if d.ConfiguredLimit != nil { + return errors.New("disabled session-limit decision must not contain configured limit") + } + case KindAllowed, KindExceeded: + if d.ConfiguredLimit == nil { + return errors.New("limited session-limit decision must contain configured limit") + } + if *d.ConfiguredLimit <= 0 { + return errors.New("session-limit configured limit must be positive") + } + if d.Kind == KindAllowed && d.NextSessionCount > *d.ConfiguredLimit { + return errors.New("allowed session-limit decision must not exceed configured limit") + } + if d.Kind == KindExceeded && d.NextSessionCount <= *d.ConfiguredLimit { + return errors.New("exceeded session-limit decision must be above configured limit") + } + } + + return nil +} diff --git a/authsession/internal/domain/sessionlimit/model_test.go b/authsession/internal/domain/sessionlimit/model_test.go new file mode 100644 index 0000000..df6d407 --- /dev/null +++ b/authsession/internal/domain/sessionlimit/model_test.go @@ -0,0 +1,128 @@ +package sessionlimit + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKindIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Kind + want bool + }{ + {name: "disabled", value: KindDisabled, want: true}, + {name: "allowed", value: KindAllowed, want: true}, + {name: "exceeded", value: KindExceeded, want: true}, + {name: "unknown", value: Kind("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDecisionValidate(t *testing.T) { + t.Parallel() + + limitTwo := 2 + limitThree := 3 + + tests := []struct { + name string + value Decision + wantErr bool + }{ + { + name: "disabled valid", + value: Decision{ + Kind: KindDisabled, + ActiveSessionCount: 0, + NextSessionCount: 1, + }, + }, + { + name: "allowed valid", + value: Decision{ + Kind: KindAllowed, + ConfiguredLimit: &limitThree, + ActiveSessionCount: 1, + NextSessionCount: 2, + }, + }, + { + name: "exceeded valid", + value: Decision{ + Kind: KindExceeded, + ConfiguredLimit: &limitTwo, + ActiveSessionCount: 2, + NextSessionCount: 3, + }, + }, + { + name: "disabled rejects limit", + value: Decision{ + Kind: KindDisabled, + ConfiguredLimit: &limitTwo, + ActiveSessionCount: 0, + NextSessionCount: 1, + }, + wantErr: true, + }, + { + name: "allowed requires limit", + value: Decision{ + Kind: KindAllowed, + ActiveSessionCount: 0, + NextSessionCount: 1, + }, + wantErr: true, + }, + { + name: "allowed rejects overflow", + value: Decision{ + Kind: KindAllowed, + ConfiguredLimit: &limitTwo, + ActiveSessionCount: 2, + NextSessionCount: 3, + }, + wantErr: true, + }, + { + name: "next count must be active plus one", + value: Decision{ + Kind: KindDisabled, + ActiveSessionCount: 2, + NextSessionCount: 2, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.value.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} diff --git a/authsession/internal/domain/userresolution/model.go b/authsession/internal/domain/userresolution/model.go new file mode 100644 index 0000000..1d158d3 --- /dev/null +++ b/authsession/internal/domain/userresolution/model.go @@ -0,0 +1,110 @@ +// Package userresolution defines the domain result returned by the user +// resolution boundary before session creation. +package userresolution + +import ( + "errors" + "fmt" + "strings" + + "galaxy/authsession/internal/domain/common" +) + +// Kind identifies the coarse user-resolution result for one normalized e-mail. +type Kind string + +const ( + // KindExisting reports that the e-mail belongs to an existing user. + KindExisting Kind = "existing" + + // KindCreatable reports that the e-mail is free and user creation is + // allowed. + KindCreatable Kind = "creatable" + + // KindBlocked reports that the e-mail or subject is blocked from login or + // registration. + KindBlocked Kind = "blocked" +) + +// IsKnown reports whether Kind is one of the user-resolution kinds supported +// by the current domain model. +func (k Kind) IsKnown() bool { + switch k { + case KindExisting, KindCreatable, KindBlocked: + return true + default: + return false + } +} + +// BlockReasonCode stores one machine-readable user-block reason. +type BlockReasonCode string + +// String returns BlockReasonCode as its stored code value. +func (code BlockReasonCode) String() string { + return string(code) +} + +// IsZero reports whether BlockReasonCode is empty. +func (code BlockReasonCode) IsZero() bool { + return strings.TrimSpace(string(code)) == "" +} + +// Validate reports whether BlockReasonCode is non-empty and normalized for +// domain use. +func (code BlockReasonCode) Validate() error { + switch { + case code.IsZero(): + return errors.New("block reason code must not be empty") + case strings.TrimSpace(string(code)) != string(code): + return errors.New("block reason code must not contain surrounding whitespace") + default: + return nil + } +} + +// Result stores the coarse user-resolution outcome consumed by later auth +// workflow stages. +type Result struct { + // Kind reports the coarse resolution outcome. + Kind Kind + + // UserID is set only when Kind is KindExisting. + UserID common.UserID + + // BlockReasonCode is set only when Kind is KindBlocked. + BlockReasonCode BlockReasonCode +} + +// Validate reports whether Result satisfies the Stage-2 structural invariants. +func (r Result) Validate() error { + if !r.Kind.IsKnown() { + return fmt.Errorf("user resolution kind %q is unsupported", r.Kind) + } + + switch r.Kind { + case KindExisting: + if err := r.UserID.Validate(); err != nil { + return fmt.Errorf("user resolution user id: %w", err) + } + if !r.BlockReasonCode.IsZero() { + return errors.New("existing user resolution must not contain block reason code") + } + case KindCreatable: + if !r.UserID.IsZero() { + return errors.New("creatable user resolution must not contain user id") + } + if !r.BlockReasonCode.IsZero() { + return errors.New("creatable user resolution must not contain block reason code") + } + case KindBlocked: + if !r.UserID.IsZero() { + return errors.New("blocked user resolution must not contain user id") + } + if err := r.BlockReasonCode.Validate(); err != nil { + return fmt.Errorf("user resolution block reason code: %w", err) + } + } + + return nil +} diff --git a/authsession/internal/domain/userresolution/model_test.go b/authsession/internal/domain/userresolution/model_test.go new file mode 100644 index 0000000..9d637a1 --- /dev/null +++ b/authsession/internal/domain/userresolution/model_test.go @@ -0,0 +1,113 @@ +package userresolution + +import ( + "github.com/stretchr/testify/require" + "testing" + + "galaxy/authsession/internal/domain/common" +) + +func TestKindIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Kind + want bool + }{ + {name: "existing", value: KindExisting, want: true}, + {name: "creatable", value: KindCreatable, want: true}, + {name: "blocked", value: KindBlocked, want: true}, + {name: "unknown", value: Kind("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestResultValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value Result + wantErr bool + }{ + { + name: "existing valid", + value: Result{ + Kind: KindExisting, + UserID: common.UserID("user-123"), + }, + }, + { + name: "creatable valid", + value: Result{ + Kind: KindCreatable, + }, + }, + { + name: "blocked valid", + value: Result{ + Kind: KindBlocked, + BlockReasonCode: BlockReasonCode("policy_blocked"), + }, + }, + { + name: "existing requires user id", + value: Result{ + Kind: KindExisting, + }, + wantErr: true, + }, + { + name: "creatable rejects user id", + value: Result{ + Kind: KindCreatable, + UserID: common.UserID("user-123"), + }, + wantErr: true, + }, + { + name: "blocked requires reason", + value: Result{ + Kind: KindBlocked, + }, + wantErr: true, + }, + { + name: "blocked rejects user id", + value: Result{ + Kind: KindBlocked, + UserID: common.UserID("user-123"), + BlockReasonCode: BlockReasonCode("policy_blocked"), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.value.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} diff --git a/authsession/internal/logging/logger.go b/authsession/internal/logging/logger.go new file mode 100644 index 0000000..8ff4b16 --- /dev/null +++ b/authsession/internal/logging/logger.go @@ -0,0 +1,82 @@ +// Package logging configures the authsession structured logger and provides +// context-aware helpers for attaching OpenTelemetry trace identifiers. +package logging + +import ( + "context" + "strings" + + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// New constructs the process-wide JSON logger from level. +func New(level string) (*zap.Logger, error) { + atomicLevel := zap.NewAtomicLevel() + if err := atomicLevel.UnmarshalText([]byte(strings.TrimSpace(level))); err != nil { + return nil, err + } + + zapCfg := zap.NewProductionConfig() + zapCfg.Level = atomicLevel + zapCfg.Sampling = nil + zapCfg.Encoding = "json" + zapCfg.EncoderConfig.TimeKey = "timestamp" + zapCfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + zapCfg.OutputPaths = []string{"stdout"} + zapCfg.ErrorOutputPaths = []string{"stderr"} + + return zapCfg.Build() +} + +// TraceFieldsFromContext returns zap fields for the active OpenTelemetry span +// when ctx carries a valid span context. +func TraceFieldsFromContext(ctx context.Context) []zap.Field { + if ctx == nil { + return nil + } + + spanContext := trace.SpanContextFromContext(ctx) + if !spanContext.IsValid() { + return nil + } + + return []zap.Field{ + zap.String("otel_trace_id", spanContext.TraceID().String()), + zap.String("otel_span_id", spanContext.SpanID().String()), + } +} + +// Sync flushes logger and ignores the benign stdout or stderr sync errors +// commonly returned by containerized or redirected process outputs. +func Sync(logger *zap.Logger) error { + if logger == nil { + return nil + } + + err := logger.Sync() + if err == nil || isIgnorableSyncError(err) { + return nil + } + + return err +} + +func isIgnorableSyncError(err error) bool { + if err == nil { + return false + } + + message := strings.ToLower(err.Error()) + switch { + case strings.Contains(message, "invalid argument"): + return true + case strings.Contains(message, "bad file descriptor"): + return true + case strings.Contains(message, "inappropriate ioctl for device"): + return true + default: + return false + } +} diff --git a/authsession/internal/logging/logger_test.go b/authsession/internal/logging/logger_test.go new file mode 100644 index 0000000..5a00f56 --- /dev/null +++ b/authsession/internal/logging/logger_test.go @@ -0,0 +1,37 @@ +package logging + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +func TestNewRejectsInvalidLogLevel(t *testing.T) { + t.Parallel() + + _, err := New("verbose") + + require.Error(t, err) +} + +func TestTraceFieldsFromContextReturnsTraceAndSpanIDs(t *testing.T) { + t.Parallel() + + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + + ctx, span := provider.Tracer("test").Start(context.Background(), "operation") + defer span.End() + + fields := TraceFieldsFromContext(ctx) + + require.Len(t, fields, 2) + assert.Equal(t, "otel_trace_id", fields[0].Key) + assert.Equal(t, "otel_span_id", fields[1].Key) + assert.NotEmpty(t, fields[0].String) + assert.NotEmpty(t, fields[1].String) +} diff --git a/authsession/internal/ports/challenge_store.go b/authsession/internal/ports/challenge_store.go new file mode 100644 index 0000000..7aacf68 --- /dev/null +++ b/authsession/internal/ports/challenge_store.go @@ -0,0 +1,43 @@ +package ports + +import ( + "context" + "fmt" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" +) + +// ChallengeStore provides source-of-truth persistence for auth confirmation +// challenges without exposing storage-specific primitives. +type ChallengeStore interface { + // Get returns the stored challenge for challengeID. Implementations must + // wrap ErrNotFound when challengeID does not exist. + Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) + + // Create persists record as a new challenge. Implementations must wrap + // ErrConflict when record.ID already exists. + Create(ctx context.Context, record challenge.Challenge) error + + // CompareAndSwap replaces previous with next when the currently stored + // challenge matches previous exactly. Implementations must wrap ErrConflict + // when the stored challenge differs from previous and wrap ErrNotFound when + // previous.ID does not exist. + CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error +} + +// ValidateComparableChallenges reports whether previous and next are suitable +// for one ChallengeStore compare-and-swap call. +func ValidateComparableChallenges(previous challenge.Challenge, next challenge.Challenge) error { + if err := previous.Validate(); err != nil { + return fmt.Errorf("previous challenge: %w", err) + } + if err := next.Validate(); err != nil { + return fmt.Errorf("next challenge: %w", err) + } + if previous.ID != next.ID { + return fmt.Errorf("challenge compare-and-swap ids must match: %q != %q", previous.ID, next.ID) + } + + return nil +} diff --git a/authsession/internal/ports/clock.go b/authsession/internal/ports/clock.go new file mode 100644 index 0000000..cfe56b1 --- /dev/null +++ b/authsession/internal/ports/clock.go @@ -0,0 +1,9 @@ +package ports + +import "time" + +// Clock returns current UTC time for the auth/session application layer. +type Clock interface { + // Now returns the current service time. + Now() time.Time +} diff --git a/authsession/internal/ports/code_generator.go b/authsession/internal/ports/code_generator.go new file mode 100644 index 0000000..007154b --- /dev/null +++ b/authsession/internal/ports/code_generator.go @@ -0,0 +1,8 @@ +package ports + +// CodeGenerator generates cleartext confirmation codes for new auth +// challenges. +type CodeGenerator interface { + // Generate returns one fresh cleartext confirmation code. + Generate() (string, error) +} diff --git a/authsession/internal/ports/code_hasher.go b/authsession/internal/ports/code_hasher.go new file mode 100644 index 0000000..922dc83 --- /dev/null +++ b/authsession/internal/ports/code_hasher.go @@ -0,0 +1,11 @@ +package ports + +// CodeHasher hashes cleartext confirmation codes and compares later user input +// against stored hashes. +type CodeHasher interface { + // Hash returns the stored representation for code. + Hash(code string) ([]byte, error) + + // Compare reports whether hash matches code. + Compare(hash []byte, code string) (bool, error) +} diff --git a/authsession/internal/ports/config_provider.go b/authsession/internal/ports/config_provider.go new file mode 100644 index 0000000..4111cef --- /dev/null +++ b/authsession/internal/ports/config_provider.go @@ -0,0 +1,42 @@ +package ports + +import ( + "context" + "errors" + "fmt" +) + +// ConfigProvider returns dynamic auth/session configuration required by later +// service workflows. +type ConfigProvider interface { + // LoadSessionLimit returns the current active-session-limit configuration. + // A nil ActiveSessionLimit means that the limit is disabled. + LoadSessionLimit(ctx context.Context) (SessionLimitConfig, error) +} + +// SessionLimitConfig stores the active-session-limit configuration in a form +// that preserves “limit absent” as a first-class state. +type SessionLimitConfig struct { + // ActiveSessionLimit stores the configured limit when one is present. Nil + // means that no active-session limit is configured. + ActiveSessionLimit *int +} + +// Validate reports whether SessionLimitConfig contains a valid limit value +// when one is configured. +func (c SessionLimitConfig) Validate() error { + if c.ActiveSessionLimit != nil && *c.ActiveSessionLimit <= 0 { + return errors.New("session limit config active session limit must be positive when configured") + } + + return nil +} + +// String returns a debug-friendly representation of SessionLimitConfig. +func (c SessionLimitConfig) String() string { + if c.ActiveSessionLimit == nil { + return "session_limit=disabled" + } + + return fmt.Sprintf("session_limit=%d", *c.ActiveSessionLimit) +} diff --git a/authsession/internal/ports/errors.go b/authsession/internal/ports/errors.go new file mode 100644 index 0000000..6d0716c --- /dev/null +++ b/authsession/internal/ports/errors.go @@ -0,0 +1,16 @@ +// Package ports defines the storage-agnostic and transport-agnostic service +// boundaries used by the auth/session application layer. +package ports + +import "errors" + +var ( + // ErrNotFound reports that a requested source-of-truth record or remote + // subject does not exist in the dependency behind the port. + ErrNotFound = errors.New("ports: record not found") + + // ErrConflict reports that a create or compare-and-swap style mutation + // cannot be applied because the current dependency state no longer matches + // the caller expectation. + ErrConflict = errors.New("ports: conflict") +) diff --git a/authsession/internal/ports/id_generator.go b/authsession/internal/ports/id_generator.go new file mode 100644 index 0000000..b38ca26 --- /dev/null +++ b/authsession/internal/ports/id_generator.go @@ -0,0 +1,13 @@ +package ports + +import "galaxy/authsession/internal/domain/common" + +// IDGenerator generates stable domain identifiers for new challenges and +// device sessions. +type IDGenerator interface { + // NewChallengeID returns a fresh challenge identifier. + NewChallengeID() (common.ChallengeID, error) + + // NewDeviceSessionID returns a fresh device-session identifier. + NewDeviceSessionID() (common.DeviceSessionID, error) +} diff --git a/authsession/internal/ports/mail_sender.go b/authsession/internal/ports/mail_sender.go new file mode 100644 index 0000000..4e503cb --- /dev/null +++ b/authsession/internal/ports/mail_sender.go @@ -0,0 +1,86 @@ +package ports + +import ( + "context" + "errors" + "fmt" + "strings" + + "galaxy/authsession/internal/domain/common" +) + +// MailSender delivers the public login code or intentionally suppresses +// outward delivery while keeping the auth flow success-shaped. +type MailSender interface { + // SendLoginCode attempts delivery for one generated login code. Explicit + // delivery failure is reported through error, while sent vs suppressed is + // returned in the result. + SendLoginCode(ctx context.Context, input SendLoginCodeInput) (SendLoginCodeResult, error) +} + +// SendLoginCodeInput describes one mail-delivery request generated by the auth +// flow. +type SendLoginCodeInput struct { + // Email identifies the normalized target e-mail address. + Email common.Email + + // Code stores the cleartext login code that should be delivered to Email. + Code string +} + +// Validate reports whether SendLoginCodeInput contains a complete delivery +// request. +func (i SendLoginCodeInput) Validate() error { + if err := i.Email.Validate(); err != nil { + return fmt.Errorf("send login code input email: %w", err) + } + switch { + case strings.TrimSpace(i.Code) == "": + return errors.New("send login code input code must not be empty") + case strings.TrimSpace(i.Code) != i.Code: + return errors.New("send login code input code must not contain surrounding whitespace") + default: + return nil + } +} + +// SendLoginCodeOutcome identifies the coarse mail-delivery outcome reported +// back to the auth flow. +type SendLoginCodeOutcome string + +const ( + // SendLoginCodeOutcomeSent reports that delivery was attempted and accepted. + SendLoginCodeOutcomeSent SendLoginCodeOutcome = "sent" + + // SendLoginCodeOutcomeSuppressed reports that outward behavior remains + // success-shaped while actual delivery is intentionally skipped. + SendLoginCodeOutcomeSuppressed SendLoginCodeOutcome = "suppressed" +) + +// IsKnown reports whether SendLoginCodeOutcome is supported by the current +// mail-sender contract. +func (o SendLoginCodeOutcome) IsKnown() bool { + switch o { + case SendLoginCodeOutcomeSent, SendLoginCodeOutcomeSuppressed: + return true + default: + return false + } +} + +// SendLoginCodeResult describes the stable outcome returned by MailSender for +// one delivery request. +type SendLoginCodeResult struct { + // Outcome reports whether delivery was sent or intentionally suppressed. + Outcome SendLoginCodeOutcome +} + +// Validate reports whether SendLoginCodeResult satisfies the mail-sender +// contract invariants. +func (r SendLoginCodeResult) Validate() error { + if !r.Outcome.IsKnown() { + return fmt.Errorf("send login code result outcome %q is unsupported", r.Outcome) + } + + return nil +} diff --git a/authsession/internal/ports/ports_test.go b/authsession/internal/ports/ports_test.go new file mode 100644 index 0000000..c39f01b --- /dev/null +++ b/authsession/internal/ports/ports_test.go @@ -0,0 +1,371 @@ +package ports + +import ( + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/userresolution" +) + +func TestRevokeSessionOutcomeIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value RevokeSessionOutcome + want bool + }{ + {name: "revoked", value: RevokeSessionOutcomeRevoked, want: true}, + {name: "already revoked", value: RevokeSessionOutcomeAlreadyRevoked, want: true}, + {name: "unknown", value: RevokeSessionOutcome("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRevokeUserSessionsOutcomeIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value RevokeUserSessionsOutcome + want bool + }{ + {name: "revoked", value: RevokeUserSessionsOutcomeRevoked, want: true}, + {name: "no active sessions", value: RevokeUserSessionsOutcomeNoActiveSessions, want: true}, + {name: "unknown", value: RevokeUserSessionsOutcome("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEnsureUserOutcomeIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value EnsureUserOutcome + want bool + }{ + {name: "existing", value: EnsureUserOutcomeExisting, want: true}, + {name: "created", value: EnsureUserOutcomeCreated, want: true}, + {name: "blocked", value: EnsureUserOutcomeBlocked, want: true}, + {name: "unknown", value: EnsureUserOutcome("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBlockUserOutcomeIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value BlockUserOutcome + want bool + }{ + {name: "blocked", value: BlockUserOutcomeBlocked, want: true}, + {name: "already blocked", value: BlockUserOutcomeAlreadyBlocked, want: true}, + {name: "unknown", value: BlockUserOutcome("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSendLoginCodeOutcomeIsKnown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value SendLoginCodeOutcome + want bool + }{ + {name: "sent", value: SendLoginCodeOutcomeSent, want: true}, + {name: "suppressed", value: SendLoginCodeOutcomeSuppressed, want: true}, + {name: "unknown", value: SendLoginCodeOutcome("unknown"), want: false}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.value.IsKnown(); got != tt.want { + require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSessionLimitConfigValidate(t *testing.T) { + t.Parallel() + + positive := 3 + zero := 0 + + tests := []struct { + name string + value SessionLimitConfig + wantErr bool + }{ + {name: "absent", value: SessionLimitConfig{}}, + {name: "positive", value: SessionLimitConfig{ActiveSessionLimit: &positive}}, + {name: "zero", value: SessionLimitConfig{ActiveSessionLimit: &zero}, wantErr: true}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.value.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} + +func TestRevokeSessionInputValidate(t *testing.T) { + t.Parallel() + + input := RevokeSessionInput{ + DeviceSessionID: common.DeviceSessionID("device-session-1"), + Revocation: devicesession.Revocation{ + At: time.Unix(10, 0).UTC(), + ReasonCode: devicesession.RevokeReasonLogoutAll, + ActorType: common.RevokeActorType("system"), + }, + } + + if err := input.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } +} + +func TestRevokeSessionResultValidate(t *testing.T) { + t.Parallel() + + result := RevokeSessionResult{ + Outcome: RevokeSessionOutcomeRevoked, + Session: revokedSessionFixture(), + } + + if err := result.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } +} + +func TestRevokeUserSessionsResultValidate(t *testing.T) { + t.Parallel() + + result := RevokeUserSessionsResult{ + Outcome: RevokeUserSessionsOutcomeRevoked, + UserID: common.UserID("user-1"), + Sessions: []devicesession.Session{ + revokedSessionFixture(), + }, + } + + if err := result.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } +} + +func TestEnsureUserResultValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value EnsureUserResult + wantErr bool + }{ + { + name: "existing", + value: EnsureUserResult{ + Outcome: EnsureUserOutcomeExisting, + UserID: common.UserID("user-1"), + }, + }, + { + name: "created", + value: EnsureUserResult{ + Outcome: EnsureUserOutcomeCreated, + UserID: common.UserID("user-2"), + }, + }, + { + name: "blocked", + value: EnsureUserResult{ + Outcome: EnsureUserOutcomeBlocked, + BlockReasonCode: userresolution.BlockReasonCode("policy_block"), + }, + }, + { + name: "blocked with user id", + value: EnsureUserResult{ + Outcome: EnsureUserOutcomeBlocked, + UserID: common.UserID("user-1"), + BlockReasonCode: userresolution.BlockReasonCode("policy_block"), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.value.Validate() + if tt.wantErr && err == nil { + require.FailNow(t, "Validate() returned nil error") + } + if !tt.wantErr && err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + }) + } +} + +func TestBlockUserInputsAndResultValidate(t *testing.T) { + t.Parallel() + + byID := BlockUserByIDInput{ + UserID: common.UserID("user-1"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + } + if err := byID.Validate(); err != nil { + require.Failf(t, "test failed", "BlockUserByIDInput.Validate() returned error: %v", err) + } + + byEmail := BlockUserByEmailInput{ + Email: common.Email("pilot@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + } + if err := byEmail.Validate(); err != nil { + require.Failf(t, "test failed", "BlockUserByEmailInput.Validate() returned error: %v", err) + } + + result := BlockUserResult{ + Outcome: BlockUserOutcomeBlocked, + UserID: common.UserID("user-1"), + } + if err := result.Validate(); err != nil { + require.Failf(t, "test failed", "BlockUserResult.Validate() returned error: %v", err) + } +} + +func TestSendLoginCodeInputAndResultValidate(t *testing.T) { + t.Parallel() + + input := SendLoginCodeInput{ + Email: common.Email("pilot@example.com"), + Code: "654321", + } + if err := input.Validate(); err != nil { + require.Failf(t, "test failed", "SendLoginCodeInput.Validate() returned error: %v", err) + } + + result := SendLoginCodeResult{Outcome: SendLoginCodeOutcomeSent} + if err := result.Validate(); err != nil { + require.Failf(t, "test failed", "SendLoginCodeResult.Validate() returned error: %v", err) + } +} + +func TestValidateComparableChallenges(t *testing.T) { + t.Parallel() + + previous := challengeFixture() + next := challengeFixture() + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + + if err := ValidateComparableChallenges(previous, next); err != nil { + require.Failf(t, "test failed", "ValidateComparableChallenges() returned error: %v", err) + } +} + +func challengeFixture() challenge.Challenge { + timestamp := time.Unix(10, 0).UTC() + return challenge.Challenge{ + ID: common.ChallengeID("challenge-1"), + Email: common.Email("pilot@example.com"), + CodeHash: []byte("hash"), + Status: challenge.StatusPendingSend, + DeliveryState: challenge.DeliveryPending, + CreatedAt: timestamp, + ExpiresAt: timestamp.Add(5 * time.Minute), + } +} + +func revokedSessionFixture() devicesession.Session { + timestamp := time.Unix(10, 0).UTC() + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID("device-session-1"), + UserID: common.UserID("user-1"), + ClientPublicKey: key, + Status: devicesession.StatusRevoked, + CreatedAt: timestamp.Add(-time.Minute), + Revocation: &devicesession.Revocation{ + At: timestamp, + ReasonCode: devicesession.RevokeReasonLogoutAll, + ActorType: common.RevokeActorType("system"), + }, + } +} diff --git a/authsession/internal/ports/projection_publisher.go b/authsession/internal/ports/projection_publisher.go new file mode 100644 index 0000000..28e17fc --- /dev/null +++ b/authsession/internal/ports/projection_publisher.go @@ -0,0 +1,15 @@ +package ports + +import ( + "context" + + "galaxy/authsession/internal/domain/gatewayprojection" +) + +// GatewaySessionProjectionPublisher publishes gateway-facing session snapshots +// after source-of-truth session changes. +type GatewaySessionProjectionPublisher interface { + // PublishSession writes or propagates snapshot in the gateway-facing + // projection model. + PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error +} diff --git a/authsession/internal/ports/send_email_code_abuse.go b/authsession/internal/ports/send_email_code_abuse.go new file mode 100644 index 0000000..dee566f --- /dev/null +++ b/authsession/internal/ports/send_email_code_abuse.go @@ -0,0 +1,100 @@ +package ports + +import ( + "context" + "fmt" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" +) + +// SendEmailCodeAbuseProtector decides whether one public send-email-code +// attempt may proceed immediately or must be throttled by the auth-side resend +// cooldown. +type SendEmailCodeAbuseProtector interface { + // CheckAndReserve validates input, checks the current resend cooldown + // decision for input.Email, and reserves a new cooldown window immediately + // when the outcome is allowed. + CheckAndReserve(ctx context.Context, input SendEmailCodeAbuseInput) (SendEmailCodeAbuseResult, error) +} + +// SendEmailCodeAbuseInput describes one resend-throttle decision request for +// a normalized public send-email-code attempt. +type SendEmailCodeAbuseInput struct { + // Email identifies the normalized e-mail address addressed by the public + // request. + Email common.Email + + // Now records when the send attempt is being evaluated. + Now time.Time +} + +// Validate reports whether SendEmailCodeAbuseInput contains a complete resend +// cooldown decision request. +func (i SendEmailCodeAbuseInput) Validate() error { + if err := i.Email.Validate(); err != nil { + return fmt.Errorf("send email code abuse input email: %w", err) + } + if i.Now.IsZero() { + return fmt.Errorf("send email code abuse input now must not be zero") + } + + return nil +} + +// SendEmailCodeAbuseOutcome identifies the coarse resend-throttle decision for +// one public send-email-code attempt. +type SendEmailCodeAbuseOutcome string + +const ( + // SendEmailCodeAbuseOutcomeAllowed reports that the attempt may proceed and + // that the cooldown window has been reserved immediately. + SendEmailCodeAbuseOutcomeAllowed SendEmailCodeAbuseOutcome = "allowed" + + // SendEmailCodeAbuseOutcomeThrottled reports that the cooldown window is + // still active and that the caller must not extend it. + SendEmailCodeAbuseOutcomeThrottled SendEmailCodeAbuseOutcome = "throttled" +) + +// IsKnown reports whether SendEmailCodeAbuseOutcome belongs to the stable +// Stage-17 resend-throttle contract. +func (o SendEmailCodeAbuseOutcome) IsKnown() bool { + switch o { + case SendEmailCodeAbuseOutcomeAllowed, SendEmailCodeAbuseOutcomeThrottled: + return true + default: + return false + } +} + +// SendEmailCodeAbuseResult describes one resend-throttle decision returned by +// SendEmailCodeAbuseProtector. +type SendEmailCodeAbuseResult struct { + // Outcome reports whether the current send attempt may proceed or must be + // throttled. + Outcome SendEmailCodeAbuseOutcome +} + +// Validate reports whether SendEmailCodeAbuseResult satisfies the resend +// cooldown contract. +func (r SendEmailCodeAbuseResult) Validate() error { + if !r.Outcome.IsKnown() { + return fmt.Errorf("send email code abuse result outcome %q is unsupported", r.Outcome) + } + + return nil +} + +// SendEmailCodeThrottleStatusToChallengeStatus maps one resend-throttle +// outcome to the challenge lifecycle state used by sendemailcode. +func SendEmailCodeThrottleStatusToChallengeStatus(outcome SendEmailCodeAbuseOutcome) (challenge.Status, challenge.DeliveryState, error) { + switch outcome { + case SendEmailCodeAbuseOutcomeAllowed: + return challenge.StatusPendingSend, challenge.DeliveryPending, nil + case SendEmailCodeAbuseOutcomeThrottled: + return challenge.StatusDeliveryThrottled, challenge.DeliveryThrottled, nil + default: + return "", "", fmt.Errorf("map send email code abuse outcome %q: unsupported outcome", outcome) + } +} diff --git a/authsession/internal/ports/send_email_code_abuse_test.go b/authsession/internal/ports/send_email_code_abuse_test.go new file mode 100644 index 0000000..4d6e7fd --- /dev/null +++ b/authsession/internal/ports/send_email_code_abuse_test.go @@ -0,0 +1,47 @@ +package ports + +import ( + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSendEmailCodeAbuseOutcomeIsKnown(t *testing.T) { + t.Parallel() + + assert.True(t, SendEmailCodeAbuseOutcomeAllowed.IsKnown()) + assert.True(t, SendEmailCodeAbuseOutcomeThrottled.IsKnown()) + assert.False(t, SendEmailCodeAbuseOutcome("unknown").IsKnown()) +} + +func TestSendEmailCodeAbuseInputAndResultValidate(t *testing.T) { + t.Parallel() + + input := SendEmailCodeAbuseInput{ + Email: common.Email("pilot@example.com"), + Now: time.Unix(10, 0).UTC(), + } + require.NoError(t, input.Validate()) + + result := SendEmailCodeAbuseResult{Outcome: SendEmailCodeAbuseOutcomeThrottled} + require.NoError(t, result.Validate()) +} + +func TestSendEmailCodeThrottleStatusToChallengeStatus(t *testing.T) { + t.Parallel() + + status, deliveryState, err := SendEmailCodeThrottleStatusToChallengeStatus(SendEmailCodeAbuseOutcomeAllowed) + require.NoError(t, err) + assert.Equal(t, challenge.StatusPendingSend, status) + assert.Equal(t, challenge.DeliveryPending, deliveryState) + + status, deliveryState, err = SendEmailCodeThrottleStatusToChallengeStatus(SendEmailCodeAbuseOutcomeThrottled) + require.NoError(t, err) + assert.Equal(t, challenge.StatusDeliveryThrottled, status) + assert.Equal(t, challenge.DeliveryThrottled, deliveryState) +} diff --git a/authsession/internal/ports/session_store.go b/authsession/internal/ports/session_store.go new file mode 100644 index 0000000..3c03638 --- /dev/null +++ b/authsession/internal/ports/session_store.go @@ -0,0 +1,214 @@ +package ports + +import ( + "context" + "errors" + "fmt" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" +) + +// SessionStore provides source-of-truth persistence for device sessions +// without exposing storage-specific encoding or transaction primitives. +type SessionStore interface { + // Get returns the stored session for deviceSessionID. Implementations must + // wrap ErrNotFound when deviceSessionID does not exist. + Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) + + // ListByUserID returns every stored session for userID in newest-first + // order. Implementations must return an empty slice, not ErrNotFound, when + // userID has no stored sessions. + ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) + + // CountActiveByUserID returns the number of active sessions currently stored + // for userID. + CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) + + // Create persists record as a new device session. Implementations must wrap + // ErrConflict when record.ID already exists. + Create(ctx context.Context, record devicesession.Session) error + + // Revoke stores a revoked view of one target session. Implementations must + // wrap ErrNotFound when input.DeviceSessionID does not exist. + Revoke(ctx context.Context, input RevokeSessionInput) (RevokeSessionResult, error) + + // RevokeAllByUserID stores revoked views for all currently active sessions + // owned by input.UserID. + RevokeAllByUserID(ctx context.Context, input RevokeUserSessionsInput) (RevokeUserSessionsResult, error) +} + +// RevokeSessionInput describes one single-session revoke mutation requested +// from SessionStore. +type RevokeSessionInput struct { + // DeviceSessionID identifies the session that should be revoked. + DeviceSessionID common.DeviceSessionID + + // Revocation stores the audit metadata that must be attached to the revoked + // session. + Revocation devicesession.Revocation +} + +// Validate reports whether RevokeSessionInput contains a complete revoke +// request. +func (i RevokeSessionInput) Validate() error { + if err := i.DeviceSessionID.Validate(); err != nil { + return fmt.Errorf("revoke session input device session id: %w", err) + } + if err := i.Revocation.Validate(); err != nil { + return fmt.Errorf("revoke session input revocation: %w", err) + } + + return nil +} + +// RevokeSessionOutcome identifies the coarse outcome of revoking one device +// session. +type RevokeSessionOutcome string + +const ( + // RevokeSessionOutcomeRevoked reports that an active session was moved to + // the revoked state by the current mutation. + RevokeSessionOutcomeRevoked RevokeSessionOutcome = "revoked" + + // RevokeSessionOutcomeAlreadyRevoked reports that the requested session had + // already been revoked before the current mutation. + RevokeSessionOutcomeAlreadyRevoked RevokeSessionOutcome = "already_revoked" +) + +// IsKnown reports whether RevokeSessionOutcome is supported by the current +// session-store contract. +func (o RevokeSessionOutcome) IsKnown() bool { + switch o { + case RevokeSessionOutcomeRevoked, RevokeSessionOutcomeAlreadyRevoked: + return true + default: + return false + } +} + +// RevokeSessionResult describes the stable outcome returned by SessionStore +// after a single-session revoke attempt. +type RevokeSessionResult struct { + // Outcome reports whether the session was revoked just now or had already + // been revoked. + Outcome RevokeSessionOutcome + + // Session stores the current source-of-truth session state after the revoke + // attempt. + Session devicesession.Session +} + +// Validate reports whether RevokeSessionResult satisfies the session-store +// contract invariants. +func (r RevokeSessionResult) Validate() error { + if !r.Outcome.IsKnown() { + return fmt.Errorf("revoke session result outcome %q is unsupported", r.Outcome) + } + if err := r.Session.Validate(); err != nil { + return fmt.Errorf("revoke session result session: %w", err) + } + if r.Session.Status != devicesession.StatusRevoked { + return errors.New("revoke session result session must be revoked") + } + + return nil +} + +// RevokeUserSessionsInput describes one bulk user-session revoke mutation +// requested from SessionStore. +type RevokeUserSessionsInput struct { + // UserID identifies the owner whose active sessions should be revoked. + UserID common.UserID + + // Revocation stores the audit metadata that must be attached to every + // revoked session. + Revocation devicesession.Revocation +} + +// Validate reports whether RevokeUserSessionsInput contains a complete bulk +// revoke request. +func (i RevokeUserSessionsInput) Validate() error { + if err := i.UserID.Validate(); err != nil { + return fmt.Errorf("revoke user sessions input user id: %w", err) + } + if err := i.Revocation.Validate(); err != nil { + return fmt.Errorf("revoke user sessions input revocation: %w", err) + } + + return nil +} + +// RevokeUserSessionsOutcome identifies the coarse outcome of revoking all +// active sessions of one user. +type RevokeUserSessionsOutcome string + +const ( + // RevokeUserSessionsOutcomeRevoked reports that one or more active sessions + // were revoked by the current mutation. + RevokeUserSessionsOutcomeRevoked RevokeUserSessionsOutcome = "revoked" + + // RevokeUserSessionsOutcomeNoActiveSessions reports that the target user did + // not currently own any active sessions. + RevokeUserSessionsOutcomeNoActiveSessions RevokeUserSessionsOutcome = "no_active_sessions" +) + +// IsKnown reports whether RevokeUserSessionsOutcome is supported by the +// current session-store contract. +func (o RevokeUserSessionsOutcome) IsKnown() bool { + switch o { + case RevokeUserSessionsOutcomeRevoked, RevokeUserSessionsOutcomeNoActiveSessions: + return true + default: + return false + } +} + +// RevokeUserSessionsResult describes the stable outcome returned by +// SessionStore after one bulk revoke attempt. +type RevokeUserSessionsResult struct { + // Outcome reports whether at least one active session was revoked. + Outcome RevokeUserSessionsOutcome + + // UserID identifies the owner whose sessions were evaluated. + UserID common.UserID + + // Sessions stores the current source-of-truth session states for every + // session affected by the bulk revoke operation. + Sessions []devicesession.Session +} + +// Validate reports whether RevokeUserSessionsResult satisfies the bulk +// session-store contract invariants. +func (r RevokeUserSessionsResult) Validate() error { + if !r.Outcome.IsKnown() { + return fmt.Errorf("revoke user sessions result outcome %q is unsupported", r.Outcome) + } + if err := r.UserID.Validate(); err != nil { + return fmt.Errorf("revoke user sessions result user id: %w", err) + } + for index, session := range r.Sessions { + if err := session.Validate(); err != nil { + return fmt.Errorf("revoke user sessions result session %d: %w", index, err) + } + if session.Status != devicesession.StatusRevoked { + return fmt.Errorf("revoke user sessions result session %d must be revoked", index) + } + if session.UserID != r.UserID { + return fmt.Errorf("revoke user sessions result session %d belongs to %q, want %q", index, session.UserID, r.UserID) + } + } + + switch r.Outcome { + case RevokeUserSessionsOutcomeRevoked: + if len(r.Sessions) == 0 { + return errors.New("revoke user sessions result must include sessions when outcome is revoked") + } + case RevokeUserSessionsOutcomeNoActiveSessions: + if len(r.Sessions) != 0 { + return errors.New("revoke user sessions result must not include sessions when outcome is no_active_sessions") + } + } + + return nil +} diff --git a/authsession/internal/ports/user_directory.go b/authsession/internal/ports/user_directory.go new file mode 100644 index 0000000..835a544 --- /dev/null +++ b/authsession/internal/ports/user_directory.go @@ -0,0 +1,203 @@ +package ports + +import ( + "context" + "errors" + "fmt" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" +) + +// UserDirectory provides the auth/session boundary to user ownership, +// registration, and block-policy decisions. +type UserDirectory interface { + // ResolveByEmail returns the current resolution state for email without + // creating any new user record. + ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) + + // ExistsByUserID reports whether userID currently identifies a stored user + // record. + ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) + + // EnsureUserByEmail returns an existing user for email, creates a new user + // when registration is allowed, or reports a blocked outcome when the + // address may not continue through confirm flow. + EnsureUserByEmail(ctx context.Context, email common.Email) (EnsureUserResult, error) + + // BlockByUserID applies a block state to the user identified by + // input.UserID. Implementations must wrap ErrNotFound when input.UserID does + // not exist. + BlockByUserID(ctx context.Context, input BlockUserByIDInput) (BlockUserResult, error) + + // BlockByEmail applies a block state to input.Email, even when no user + // record currently exists for that e-mail address. + BlockByEmail(ctx context.Context, input BlockUserByEmailInput) (BlockUserResult, error) +} + +// EnsureUserOutcome identifies the coarse outcome of ensuring a user record +// for one normalized e-mail address. +type EnsureUserOutcome string + +const ( + // EnsureUserOutcomeExisting reports that the e-mail already belonged to a + // stored user. + EnsureUserOutcomeExisting EnsureUserOutcome = "existing" + + // EnsureUserOutcomeCreated reports that a new user was created for the + // e-mail address. + EnsureUserOutcomeCreated EnsureUserOutcome = "created" + + // EnsureUserOutcomeBlocked reports that the e-mail cannot be used for login + // or registration. + EnsureUserOutcomeBlocked EnsureUserOutcome = "blocked" +) + +// IsKnown reports whether EnsureUserOutcome is supported by the current +// user-directory contract. +func (o EnsureUserOutcome) IsKnown() bool { + switch o { + case EnsureUserOutcomeExisting, EnsureUserOutcomeCreated, EnsureUserOutcomeBlocked: + return true + default: + return false + } +} + +// EnsureUserResult describes the stable outcome returned by UserDirectory +// after one ensure-user attempt. +type EnsureUserResult struct { + // Outcome reports whether the user already existed, was created, or is + // blocked by policy. + Outcome EnsureUserOutcome + + // UserID is present when Outcome is EnsureUserOutcomeExisting or + // EnsureUserOutcomeCreated. + UserID common.UserID + + // BlockReasonCode is present only when Outcome is EnsureUserOutcomeBlocked. + BlockReasonCode userresolution.BlockReasonCode +} + +// Validate reports whether EnsureUserResult satisfies the user-directory +// contract invariants. +func (r EnsureUserResult) Validate() error { + if !r.Outcome.IsKnown() { + return fmt.Errorf("ensure user result outcome %q is unsupported", r.Outcome) + } + + switch r.Outcome { + case EnsureUserOutcomeExisting, EnsureUserOutcomeCreated: + if err := r.UserID.Validate(); err != nil { + return fmt.Errorf("ensure user result user id: %w", err) + } + if !r.BlockReasonCode.IsZero() { + return errors.New("ensure user result must not contain block reason code for existing or created outcomes") + } + case EnsureUserOutcomeBlocked: + if !r.UserID.IsZero() { + return errors.New("ensure user result must not contain user id for blocked outcome") + } + if err := r.BlockReasonCode.Validate(); err != nil { + return fmt.Errorf("ensure user result block reason code: %w", err) + } + } + + return nil +} + +// BlockUserByIDInput describes one block mutation targeted by stable user id. +type BlockUserByIDInput struct { + // UserID identifies the user that should be blocked. + UserID common.UserID + + // ReasonCode stores the machine-readable block reason to apply. + ReasonCode userresolution.BlockReasonCode +} + +// Validate reports whether BlockUserByIDInput contains a complete block +// request. +func (i BlockUserByIDInput) Validate() error { + if err := i.UserID.Validate(); err != nil { + return fmt.Errorf("block user by id input user id: %w", err) + } + if err := i.ReasonCode.Validate(); err != nil { + return fmt.Errorf("block user by id input reason code: %w", err) + } + + return nil +} + +// BlockUserByEmailInput describes one block mutation targeted by normalized +// e-mail address. +type BlockUserByEmailInput struct { + // Email identifies the e-mail address that should be blocked. + Email common.Email + + // ReasonCode stores the machine-readable block reason to apply. + ReasonCode userresolution.BlockReasonCode +} + +// Validate reports whether BlockUserByEmailInput contains a complete block +// request. +func (i BlockUserByEmailInput) Validate() error { + if err := i.Email.Validate(); err != nil { + return fmt.Errorf("block user by email input email: %w", err) + } + if err := i.ReasonCode.Validate(); err != nil { + return fmt.Errorf("block user by email input reason code: %w", err) + } + + return nil +} + +// BlockUserOutcome identifies the coarse outcome of blocking one user or +// e-mail subject. +type BlockUserOutcome string + +const ( + // BlockUserOutcomeBlocked reports that the current mutation applied a new + // block state. + BlockUserOutcomeBlocked BlockUserOutcome = "blocked" + + // BlockUserOutcomeAlreadyBlocked reports that the target subject had already + // been blocked before the current mutation. + BlockUserOutcomeAlreadyBlocked BlockUserOutcome = "already_blocked" +) + +// IsKnown reports whether BlockUserOutcome is supported by the current +// user-directory contract. +func (o BlockUserOutcome) IsKnown() bool { + switch o { + case BlockUserOutcomeBlocked, BlockUserOutcomeAlreadyBlocked: + return true + default: + return false + } +} + +// BlockUserResult describes the stable outcome returned by UserDirectory after +// one block attempt. +type BlockUserResult struct { + // Outcome reports whether the current mutation applied a new block state. + Outcome BlockUserOutcome + + // UserID optionally stores the stable user identifier resolved for the + // blocked subject when one exists. + UserID common.UserID +} + +// Validate reports whether BlockUserResult satisfies the user-directory +// contract invariants. +func (r BlockUserResult) Validate() error { + if !r.Outcome.IsKnown() { + return fmt.Errorf("block user result outcome %q is unsupported", r.Outcome) + } + if !r.UserID.IsZero() { + if err := r.UserID.Validate(); err != nil { + return fmt.Errorf("block user result user id: %w", err) + } + } + + return nil +} diff --git a/authsession/internal/service/blockuser/consistency_test.go b/authsession/internal/service/blockuser/consistency_test.go new file mode 100644 index 0000000..b0d7a27 --- /dev/null +++ b/authsession/internal/service/blockuser/consistency_test.go @@ -0,0 +1,88 @@ +package blockuser + +import ( + "context" + "errors" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteRetriesProjectionPublishesForBlockFlow(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{ + Errors: []error{errors.New("publish failed"), nil}, + } + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + + service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, "blocked", result.Outcome) + assert.EqualValues(t, 1, result.AffectedSessionCount) + require.Len(t, publisher.PublishedSnapshots(), 2) +} + +func TestExecuteRepairsProjectionOnRepeatedAlreadyBlockedRequest(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")} + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + + service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err)) + require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts) + + sessionRecord, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, sessionRecord.Revocation) + assert.Equal(t, devicesession.StatusRevoked, sessionRecord.Status) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, sessionRecord.Revocation.ReasonCode) + + resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + require.NoError(t, resolveErr) + assert.Equal(t, userresolution.KindBlocked, resolution.Kind) + + publisher.Err = nil + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, "already_blocked", result.Outcome) + assert.EqualValues(t, 0, result.AffectedSessionCount) + require.NotNil(t, result.AffectedDeviceSessionIDs) + assert.Empty(t, result.AffectedDeviceSessionIDs) + require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1) +} diff --git a/authsession/internal/service/blockuser/cross_flow_test.go b/authsession/internal/service/blockuser/cross_flow_test.go new file mode 100644 index 0000000..c338a3d --- /dev/null +++ b/authsession/internal/service/blockuser/cross_flow_test.go @@ -0,0 +1,91 @@ +package blockuser + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const blockFlowPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=" + +func TestBlockUserAffectsLaterSendAndConfirmFlows(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + sessionStore := &testkit.InMemorySessionStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + publisher := &testkit.RecordingProjectionPublisher{} + idGenerator := &testkit.SequenceIDGenerator{ + ChallengeIDs: []common.ChallengeID{"challenge-1"}, + DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"}, + } + hasher := testkit.DeterministicCodeHasher{} + mailSender := &testkit.RecordingMailSender{} + now := time.Unix(20, 0).UTC() + clock := testkit.FixedClock{Time: now} + + blockService, err := New(userDirectory, sessionStore, publisher, clock) + require.NoError(t, err) + + _, err = blockService.Execute(context.Background(), Input{ + Email: "pilot@example.com", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + + sendService, err := sendemailcode.New( + challengeStore, + userDirectory, + idGenerator, + testkit.FixedCodeGenerator{Code: "654321"}, + hasher, + mailSender, + clock, + ) + require.NoError(t, err) + + sendResult, err := sendService.Execute(context.Background(), sendemailcode.Input{Email: "pilot@example.com"}) + require.NoError(t, err) + assert.Equal(t, "challenge-1", sendResult.ChallengeID) + assert.Empty(t, mailSender.RecordedInputs()) + + challengeRecord, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, err) + assert.Equal(t, challenge.StatusDeliverySuppressed, challengeRecord.Status) + assert.Equal(t, challenge.DeliverySuppressed, challengeRecord.DeliveryState) + + confirmService, err := confirmemailcode.New( + challengeStore, + sessionStore, + userDirectory, + testkit.StaticConfigProvider{}, + publisher, + idGenerator, + hasher, + clock, + ) + require.NoError(t, err) + + _, err = confirmService.Execute(context.Background(), confirmemailcode.Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: blockFlowPublicKey, + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err)) + + updatedChallenge, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, getErr) + assert.Equal(t, challenge.StatusFailed, updatedChallenge.Status) +} diff --git a/authsession/internal/service/blockuser/observability_test.go b/authsession/internal/service/blockuser/observability_test.go new file mode 100644 index 0000000..006a321 --- /dev/null +++ b/authsession/internal/service/blockuser/observability_test.go @@ -0,0 +1,64 @@ +package blockuser + +import ( + "bytes" + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestExecuteLogsSafeOutcomeFields(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + sessionStore := &testkit.InMemorySessionStore{} + require.NoError(t, sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + + logger, buffer := newObservedServiceLogger() + service, err := NewWithObservability( + userDirectory, + sessionStore, + &testkit.RecordingProjectionPublisher{}, + testkit.FixedClock{Time: time.Unix(20, 0).UTC()}, + logger, + nil, + ) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + + logOutput := buffer.String() + assert.Contains(t, logOutput, "block_user") + assert.Contains(t, logOutput, "\"user_id\":\"user-1\"") + assert.Contains(t, logOutput, "\"reason_code\":\"policy_block\"") + assert.NotContains(t, logOutput, "pilot@example.com") +} + +func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) { + buffer := &bytes.Buffer{} + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "" + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(encoderConfig), + zapcore.AddSync(buffer), + zap.DebugLevel, + ) + + return zap.New(core), buffer +} diff --git a/authsession/internal/service/blockuser/service.go b/authsession/internal/service/blockuser/service.go new file mode 100644 index 0000000..b8f0c20 --- /dev/null +++ b/authsession/internal/service/blockuser/service.go @@ -0,0 +1,294 @@ +// Package blockuser implements the trusted internal block-user use case. +package blockuser + +import ( + "context" + "errors" + "fmt" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +const ( + // SubjectKindUserID identifies a block request addressed by stable user id. + SubjectKindUserID = "user_id" + + // SubjectKindEmail identifies a block request addressed by normalized e-mail + // address. + SubjectKindEmail = "email" +) + +// Input describes one trusted internal block-user request. +type Input struct { + // UserID identifies the subject to block when the request is user-id based. + UserID string + + // Email identifies the subject to block when the request is e-mail based. + Email string + + // ReasonCode stores the machine-readable block reason code applied to the + // user directory. + ReasonCode string + + // ActorType stores the machine-readable actor type for any derived session + // revocation. + ActorType string + + // ActorID stores the optional stable actor identifier for any derived + // session revocation. + ActorID string +} + +// Result describes the frozen internal block-user acknowledgement. +type Result struct { + // Outcome reports whether the block state was newly applied or already + // existed. + Outcome string + + // SubjectKind reports whether the request targeted `user_id` or `email`. + SubjectKind string + + // SubjectValue stores the normalized subject value addressed by the + // operation. + SubjectValue string + + // AffectedSessionCount reports how many sessions changed state during the + // current call. + AffectedSessionCount int64 + + // AffectedDeviceSessionIDs lists every session identifier affected during + // the current call. + AffectedDeviceSessionIDs []string +} + +// Service executes the trusted internal block-user use case. +type Service struct { + userDirectory ports.UserDirectory + sessionStore ports.SessionStore + publisher ports.GatewaySessionProjectionPublisher + clock ports.Clock + logger *zap.Logger + telemetry *telemetry.Runtime +} + +// New returns a block-user service wired to the required ports. +func New(userDirectory ports.UserDirectory, sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) { + return NewWithObservability(userDirectory, sessionStore, publisher, clock, nil, nil) +} + +// NewWithObservability returns a block-user service wired to the required +// ports plus optional structured logging and telemetry dependencies. +func NewWithObservability( + userDirectory ports.UserDirectory, + sessionStore ports.SessionStore, + publisher ports.GatewaySessionProjectionPublisher, + clock ports.Clock, + logger *zap.Logger, + telemetryRuntime *telemetry.Runtime, +) (*Service, error) { + switch { + case userDirectory == nil: + return nil, fmt.Errorf("blockuser: user directory must not be nil") + case sessionStore == nil: + return nil, fmt.Errorf("blockuser: session store must not be nil") + case publisher == nil: + return nil, fmt.Errorf("blockuser: projection publisher must not be nil") + case clock == nil: + return nil, fmt.Errorf("blockuser: clock must not be nil") + default: + return &Service{ + userDirectory: userDirectory, + sessionStore: sessionStore, + publisher: publisher, + clock: clock, + logger: namedLogger(logger, "block_user"), + telemetry: telemetryRuntime, + }, nil + } +} + +// Execute applies the requested block state and revokes any active sessions of +// the resolved user when one exists. +func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) { + logFields := []zap.Field{ + zap.String("component", "service"), + zap.String("use_case", "block_user"), + } + defer func() { + if result.Outcome != "" { + logFields = append(logFields, zap.String("outcome", result.Outcome)) + } + if result.SubjectKind != "" { + logFields = append(logFields, zap.String("subject_kind", result.SubjectKind)) + } + if result.AffectedSessionCount > 0 { + logFields = append(logFields, zap.Int64("affected_session_count", result.AffectedSessionCount)) + } + shared.LogServiceOutcome(s.logger, ctx, "block user completed", err, logFields...) + }() + + subjectKind, subjectValue, storeResult, err := s.blockSubject(ctx, input) + if err != nil { + return Result{}, err + } + logFields = append(logFields, zap.String("reason_code", shared.NormalizeString(input.ReasonCode))) + if !storeResult.UserID.IsZero() { + logFields = append(logFields, zap.String("user_id", storeResult.UserID.String())) + } + + affectedDeviceSessionIDs := []string{} + affectedSessionCount := int64(0) + if !storeResult.UserID.IsZero() { + revocation, err := shared.BuildRevocation( + devicesession.RevokeReasonUserBlocked.String(), + input.ActorType, + input.ActorID, + s.clock.Now(), + ) + if err != nil { + return Result{}, err + } + + revokeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{ + UserID: storeResult.UserID, + Revocation: revocation, + }) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if err := revokeResult.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + + for _, record := range revokeResult.Sessions { + if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user"); err != nil { + return Result{}, err + } + affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String()) + } + if revokeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions { + if err := s.republishCurrentRevokedSessions(ctx, storeResult.UserID); err != nil { + return Result{}, err + } + } + affectedSessionCount = int64(len(revokeResult.Sessions)) + if affectedSessionCount > 0 { + s.telemetry.RecordSessionRevocations(ctx, "block_user", devicesession.RevokeReasonUserBlocked.String(), affectedSessionCount) + } + } + + result = Result{ + Outcome: string(storeResult.Outcome), + SubjectKind: subjectKind, + SubjectValue: subjectValue, + AffectedSessionCount: affectedSessionCount, + AffectedDeviceSessionIDs: affectedDeviceSessionIDs, + } + + return result, nil +} + +func (s *Service) blockSubject(ctx context.Context, input Input) (string, string, ports.BlockUserResult, error) { + userID := shared.NormalizeString(input.UserID) + email := shared.NormalizeString(input.Email) + + switch { + case userID == "" && email == "": + return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided") + case userID != "" && email != "": + return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided") + case userID != "": + parsedUserID, err := shared.ParseUserID(userID) + if err != nil { + return "", "", ports.BlockUserResult{}, err + } + reasonCode, err := parseBlockReasonCode(input.ReasonCode) + if err != nil { + return "", "", ports.BlockUserResult{}, err + } + + result, err := s.userDirectory.BlockByUserID(ctx, ports.BlockUserByIDInput{ + UserID: parsedUserID, + ReasonCode: reasonCode, + }) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return "", "", ports.BlockUserResult{}, shared.SubjectNotFound() + default: + return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err) + } + } + if err := result.Validate(); err != nil { + return "", "", ports.BlockUserResult{}, shared.InternalError(err) + } + s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_user_id", string(result.Outcome)) + + return SubjectKindUserID, parsedUserID.String(), result, nil + default: + parsedEmail, err := shared.ParseEmail(email) + if err != nil { + return "", "", ports.BlockUserResult{}, err + } + reasonCode, err := parseBlockReasonCode(input.ReasonCode) + if err != nil { + return "", "", ports.BlockUserResult{}, err + } + + result, err := s.userDirectory.BlockByEmail(ctx, ports.BlockUserByEmailInput{ + Email: parsedEmail, + ReasonCode: reasonCode, + }) + if err != nil { + return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err) + } + if err := result.Validate(); err != nil { + return "", "", ports.BlockUserResult{}, shared.InternalError(err) + } + s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_email", string(result.Outcome)) + + return SubjectKindEmail, parsedEmail.String(), result, nil + } +} + +func parseBlockReasonCode(value string) (userresolution.BlockReasonCode, error) { + reasonCode := userresolution.BlockReasonCode(shared.NormalizeString(value)) + if err := reasonCode.Validate(); err != nil { + return "", shared.InvalidRequest(err.Error()) + } + + return reasonCode, nil +} + +func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error { + records, err := s.sessionStore.ListByUserID(ctx, userID) + if err != nil { + return shared.ServiceUnavailable(err) + } + + for _, record := range records { + if record.Status != devicesession.StatusRevoked { + continue + } + if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user_repair"); err != nil { + return err + } + } + + return nil +} + +func namedLogger(logger *zap.Logger, name string) *zap.Logger { + if logger == nil { + logger = zap.NewNop() + } + + return logger.Named(name) +} diff --git a/authsession/internal/service/blockuser/service_test.go b/authsession/internal/service/blockuser/service_test.go new file mode 100644 index 0000000..e58fb0b --- /dev/null +++ b/authsession/internal/service/blockuser/service_test.go @@ -0,0 +1,237 @@ +package blockuser + +import ( + "context" + "errors" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteBlocksByUserIDAndRevokesSessions(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{} + if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, "blocked", result.Outcome) + assert.EqualValues(t, 1, result.AffectedSessionCount) + assert.Equal(t, SubjectKindUserID, result.SubjectKind) + assert.Equal(t, "user-1", result.SubjectValue) + assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs) + + stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.StatusRevoked, stored.Status) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode) + assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType) + + resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + require.NoError(t, resolveErr) + assert.Equal(t, userresolution.KindBlocked, resolution.Kind) + assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode) + + published := publisher.PublishedSnapshots() + require.Len(t, published, 1) + assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode) + assert.Equal(t, common.RevokeActorType("admin"), published[0].RevokeActorType) +} + +func TestExecuteBlocksByEmailWithoutExistingUser(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + publisher := &testkit.RecordingProjectionPublisher{} + service, err := New(userDirectory, &testkit.InMemorySessionStore{}, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + Email: "pilot@example.com", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, "blocked", result.Outcome) + assert.EqualValues(t, 0, result.AffectedSessionCount) + assert.Equal(t, SubjectKindEmail, result.SubjectKind) + assert.Equal(t, "pilot@example.com", result.SubjectValue) + require.NotNil(t, result.AffectedDeviceSessionIDs) + assert.Empty(t, result.AffectedDeviceSessionIDs) + + resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + require.NoError(t, resolveErr) + assert.Equal(t, userresolution.KindBlocked, resolution.Kind) + assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode) + assert.Empty(t, publisher.PublishedSnapshots()) +} + +func TestExecuteBlocksByEmailWithExistingUserAndRevokesSessions(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{} + if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + Email: "pilot@example.com", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, "blocked", result.Outcome) + assert.EqualValues(t, 1, result.AffectedSessionCount) + assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs) + + stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode) + assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType) + + resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + require.NoError(t, resolveErr) + assert.Equal(t, userresolution.KindBlocked, resolution.Kind) + assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode) + + published := publisher.PublishedSnapshots() + require.Len(t, published, 1) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode) +} + +func TestExecuteReturnsSubjectNotFoundForUnknownUserID(t *testing.T) { + t.Parallel() + + service, err := New(&testkit.InMemoryUserDirectory{}, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + _, err = service.Execute(context.Background(), Input{ + UserID: "missing", + ReasonCode: "policy_block", + ActorType: "admin", + }) + assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err)) +} + +func TestExecuteAlreadyBlockedStillRevokesLingeringSessions(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{} + if err := userDirectory.SeedBlockedUser(common.Email("pilot@example.com"), common.UserID("user-1"), userresolution.BlockReasonCode("policy_block")); err != nil { + require.Failf(t, "test failed", "SeedBlockedUser() returned error: %v", err) + } + if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + Email: "pilot@example.com", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, "already_blocked", result.Outcome) + assert.EqualValues(t, 1, result.AffectedSessionCount) + assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs) + + stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode) + assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType) + + published := publisher.PublishedSnapshots() + require.Len(t, published, 1) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode) +} + +func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) { + t.Parallel() + + userDirectory := &testkit.InMemoryUserDirectory{} + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")} + if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "policy_block", + ActorType: "admin", + }) + assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err)) + + stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode) + + resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com")) + require.NoError(t, resolveErr) + assert.Equal(t, userresolution.KindBlocked, resolution.Kind) + assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode) +} + +func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: key, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} diff --git a/authsession/internal/service/blockuser/stub_user_directory_test.go b/authsession/internal/service/blockuser/stub_user_directory_test.go new file mode 100644 index 0000000..e3b68b2 --- /dev/null +++ b/authsession/internal/service/blockuser/stub_user_directory_test.go @@ -0,0 +1,60 @@ +package blockuser + +import ( + "context" + "testing" + "time" + + stubuserservice "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) { + t.Parallel() + + t.Run("blocks by email through runtime stub", func(t *testing.T) { + t.Parallel() + + userDirectory := &stubuserservice.StubDirectory{} + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + store := &testkit.InMemorySessionStore{} + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + + service, err := New(userDirectory, store, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + Email: "pilot@example.com", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, SubjectKindEmail, result.SubjectKind) + assert.Equal(t, "blocked", result.Outcome) + assert.EqualValues(t, 1, result.AffectedSessionCount) + }) + + t.Run("blocks by user id through runtime stub", func(t *testing.T) { + t.Parallel() + + userDirectory := &stubuserservice.StubDirectory{} + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + service, err := New(userDirectory, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "policy_block", + ActorType: "admin", + }) + require.NoError(t, err) + assert.Equal(t, SubjectKindUserID, result.SubjectKind) + assert.Equal(t, "blocked", result.Outcome) + }) +} diff --git a/authsession/internal/service/confirmemailcode/anti_abuse_test.go b/authsession/internal/service/confirmemailcode/anti_abuse_test.go new file mode 100644 index 0000000..f235cd6 --- /dev/null +++ b/authsession/internal/service/confirmemailcode/anti_abuse_test.go @@ -0,0 +1,39 @@ +package confirmemailcode + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/service/shared" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteReturnsInvalidCodeForThrottledChallengeWithoutConsumingAttempts(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + record.Status = challenge.StatusDeliveryThrottled + record.DeliveryState = challenge.DeliveryThrottled + require.NoError(t, record.Validate()) + require.NoError(t, deps.challengeStore.Create(context.Background(), record)) + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeInvalidCode, shared.CodeOf(err)) + + updated, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, getErr) + assert.Equal(t, 0, updated.Attempts.Confirm) + assert.Equal(t, challenge.StatusDeliveryThrottled, updated.Status) +} diff --git a/authsession/internal/service/confirmemailcode/consistency_test.go b/authsession/internal/service/confirmemailcode/consistency_test.go new file mode 100644 index 0000000..bb8dd45 --- /dev/null +++ b/authsession/internal/service/confirmemailcode/consistency_test.go @@ -0,0 +1,106 @@ +package confirmemailcode + +import ( + "context" + "errors" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/service/shared" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteConfirmsChallengeAfterTransientProjectionPublishFailures(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + deps.publisher.Errors = []error{errors.New("publish failed"), nil} + require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + require.NoError(t, deps.challengeStore.Create( + context.Background(), + sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)), + )) + + service := mustNewConfirmService(t, deps) + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.NoError(t, err) + assert.Equal(t, "device-session-1", result.DeviceSessionID) + require.Len(t, deps.publisher.PublishedSnapshots(), 2) +} + +func TestExecuteConfirmedRetryRepublishesAfterTransientProjectionPublishFailures(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + deps.publisher.Errors = []error{errors.New("publish failed"), nil} + key := mustClientPublicKey(t, publicKeyString()) + require.NoError(t, deps.challengeStore.Create( + context.Background(), + confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute)), + )) + require.NoError(t, deps.sessionStore.Create( + context.Background(), + activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute)), + )) + + service := mustNewConfirmService(t, deps) + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.NoError(t, err) + assert.Equal(t, "device-session-1", result.DeviceSessionID) + require.Len(t, deps.publisher.PublishedSnapshots(), 2) +} + +func TestExecuteRepairsProjectionOnIdenticalRetryAfterExhaustedPublishRetries(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + deps.publisher.Err = errors.New("publish failed") + require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + require.NoError(t, deps.challengeStore.Create( + context.Background(), + sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)), + )) + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err)) + require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts) + + sessionRecord, getErr := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + assert.Equal(t, devicesession.StatusActive, sessionRecord.Status) + + challengeRecord, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, getErr) + assert.Equal(t, challenge.StatusConfirmedPendingExpire, challengeRecord.Status) + require.NotNil(t, challengeRecord.Confirmation) + + deps.publisher.Err = nil + + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.NoError(t, err) + assert.Equal(t, "device-session-1", result.DeviceSessionID) + require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1) +} diff --git a/authsession/internal/service/confirmemailcode/service.go b/authsession/internal/service/confirmemailcode/service.go new file mode 100644 index 0000000..18b7243 --- /dev/null +++ b/authsession/internal/service/confirmemailcode/service.go @@ -0,0 +1,588 @@ +// Package confirmemailcode implements the public confirm-email-code use case. +package confirmemailcode + +import ( + "context" + "errors" + "fmt" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/sessionlimit" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +const ( + revokeReasonConfirmRace common.RevokeReasonCode = "confirm_race_repair" + revokeActorTypeService common.RevokeActorType = "service" + revokeActorIDService = "confirmemailcode" +) + +// Input describes one public confirm-email-code request. +type Input struct { + // ChallengeID identifies the challenge that should be confirmed. + ChallengeID string + + // Code is the cleartext confirmation code submitted by the caller. + Code string + + // ClientPublicKey is the base64-encoded raw 32-byte Ed25519 public key that + // should be registered for the created device session. + ClientPublicKey string +} + +// Result describes one public confirm-email-code response. +type Result struct { + // DeviceSessionID is the stable identifier of the created or idempotently + // recovered device session. + DeviceSessionID string +} + +// Service executes the public confirm-email-code use case. +type Service struct { + challengeStore ports.ChallengeStore + sessionStore ports.SessionStore + userDirectory ports.UserDirectory + configProvider ports.ConfigProvider + publisher ports.GatewaySessionProjectionPublisher + idGenerator ports.IDGenerator + codeHasher ports.CodeHasher + clock ports.Clock + logger *zap.Logger + telemetry *telemetry.Runtime +} + +// New returns a confirm-email-code service wired to the required ports. +func New( + challengeStore ports.ChallengeStore, + sessionStore ports.SessionStore, + userDirectory ports.UserDirectory, + configProvider ports.ConfigProvider, + publisher ports.GatewaySessionProjectionPublisher, + idGenerator ports.IDGenerator, + codeHasher ports.CodeHasher, + clock ports.Clock, +) (*Service, error) { + return NewWithTelemetry( + challengeStore, + sessionStore, + userDirectory, + configProvider, + publisher, + idGenerator, + codeHasher, + clock, + nil, + ) +} + +// NewWithTelemetry returns a confirm-email-code service wired to the required +// ports plus the optional Stage-17 telemetry runtime. +func NewWithTelemetry( + challengeStore ports.ChallengeStore, + sessionStore ports.SessionStore, + userDirectory ports.UserDirectory, + configProvider ports.ConfigProvider, + publisher ports.GatewaySessionProjectionPublisher, + idGenerator ports.IDGenerator, + codeHasher ports.CodeHasher, + clock ports.Clock, + telemetryRuntime *telemetry.Runtime, +) (*Service, error) { + return NewWithObservability( + challengeStore, + sessionStore, + userDirectory, + configProvider, + publisher, + idGenerator, + codeHasher, + clock, + nil, + telemetryRuntime, + ) +} + +// NewWithObservability returns a confirm-email-code service wired to the +// required ports plus optional structured logging and telemetry dependencies. +func NewWithObservability( + challengeStore ports.ChallengeStore, + sessionStore ports.SessionStore, + userDirectory ports.UserDirectory, + configProvider ports.ConfigProvider, + publisher ports.GatewaySessionProjectionPublisher, + idGenerator ports.IDGenerator, + codeHasher ports.CodeHasher, + clock ports.Clock, + logger *zap.Logger, + telemetryRuntime *telemetry.Runtime, +) (*Service, error) { + switch { + case challengeStore == nil: + return nil, fmt.Errorf("confirmemailcode: challenge store must not be nil") + case sessionStore == nil: + return nil, fmt.Errorf("confirmemailcode: session store must not be nil") + case userDirectory == nil: + return nil, fmt.Errorf("confirmemailcode: user directory must not be nil") + case configProvider == nil: + return nil, fmt.Errorf("confirmemailcode: config provider must not be nil") + case publisher == nil: + return nil, fmt.Errorf("confirmemailcode: projection publisher must not be nil") + case idGenerator == nil: + return nil, fmt.Errorf("confirmemailcode: id generator must not be nil") + case codeHasher == nil: + return nil, fmt.Errorf("confirmemailcode: code hasher must not be nil") + case clock == nil: + return nil, fmt.Errorf("confirmemailcode: clock must not be nil") + default: + return &Service{ + challengeStore: challengeStore, + sessionStore: sessionStore, + userDirectory: userDirectory, + configProvider: configProvider, + publisher: publisher, + idGenerator: idGenerator, + codeHasher: codeHasher, + clock: clock, + logger: namedLogger(logger, "confirm_email_code"), + telemetry: telemetryRuntime, + }, nil + } +} + +// Execute validates one challenge confirmation attempt, creates a device +// session when policy allows it, and handles short-window idempotent retries. +func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) { + logFields := []zap.Field{ + zap.String("component", "service"), + zap.String("use_case", "confirm_email_code"), + } + defer func() { + outcome := string(telemetry.ConfirmEmailCodeOutcomeSuccess) + if err != nil { + outcome = shared.CodeOf(err) + if outcome == "" { + outcome = shared.ErrorCodeServiceUnavailable + } + } + s.telemetry.RecordConfirmEmailCode(ctx, outcome) + logFields = append(logFields, zap.String("outcome", outcome)) + if result.DeviceSessionID != "" { + logFields = append(logFields, zap.String("device_session_id", result.DeviceSessionID)) + } + shared.LogServiceOutcome(s.logger, ctx, "confirm email code completed", err, logFields...) + }() + + challengeID, err := shared.ParseChallengeID(input.ChallengeID) + if err != nil { + return Result{}, err + } + logFields = append(logFields, zap.String("challenge_id", challengeID.String())) + code, err := shared.ParseRequiredCode(input.Code) + if err != nil { + return Result{}, err + } + clientPublicKey, err := shared.ParseClientPublicKey(input.ClientPublicKey) + if err != nil { + return Result{}, err + } + + for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ { + current, err := s.challengeStore.Get(ctx, challengeID) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return Result{}, shared.ChallengeNotFound() + default: + return Result{}, shared.ServiceUnavailable(err) + } + } + + now := s.clock.Now().UTC() + if expired, err := s.ensureChallengeNotExpired(ctx, current, now); err != nil { + if errors.Is(err, ports.ErrConflict) { + continue + } + return Result{}, err + } else if expired { + return Result{}, shared.ChallengeExpired() + } + + switch { + case current.Status.IsConfirmedRetryState(): + return s.handleConfirmedRetry(ctx, current, code, clientPublicKey) + case !current.Status.AcceptsFreshConfirm(): + return Result{}, shared.InvalidCode() + } + + match, err := s.codeHasher.Compare(current.CodeHash, code) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if !match { + if err := s.recordInvalidConfirmAttempt(ctx, current, now); err != nil { + if errors.Is(err, ports.ErrConflict) { + continue + } + return Result{}, err + } + + return Result{}, shared.InvalidCode() + } + + ensureUserResult, err := s.userDirectory.EnsureUserByEmail(ctx, current.Email) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if err := ensureUserResult.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + s.telemetry.RecordUserDirectoryOutcome(ctx, "ensure_user_by_email", string(ensureUserResult.Outcome)) + if !ensureUserResult.UserID.IsZero() { + logFields = append(logFields, zap.String("user_id", ensureUserResult.UserID.String())) + } + if ensureUserResult.Outcome == ports.EnsureUserOutcomeBlocked { + if err := s.markChallengeFailed(ctx, current, now); err != nil { + if errors.Is(err, ports.ErrConflict) { + continue + } + return Result{}, err + } + + return Result{}, shared.BlockedByPolicy() + } + + limitConfig, err := s.configProvider.LoadSessionLimit(ctx) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + decision, err := s.evaluateSessionLimit(ctx, ensureUserResult.UserID, limitConfig) + if err != nil { + return Result{}, err + } + if decision.Kind == sessionlimit.KindExceeded { + s.telemetry.RecordSessionLimitRejection(ctx) + return Result{}, shared.SessionLimitExceeded() + } + + sessionRecord, err := s.createSession(ctx, ensureUserResult.UserID, clientPublicKey, now) + if err != nil { + return Result{}, err + } + + next := current + next.Status = challenge.StatusConfirmedPendingExpire + next.ExpiresAt = now.Add(challenge.ConfirmedRetention) + next.Abuse.LastAttemptAt = &now + next.Confirmation = &challenge.Confirmation{ + SessionID: sessionRecord.ID, + ClientPublicKey: clientPublicKey, + ConfirmedAt: now, + } + if err := next.Validate(); err != nil { + s.bestEffortRevokeSupersededSession(ctx, sessionRecord) + return Result{}, shared.InternalError(err) + } + + if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil { + if errors.Is(err, ports.ErrConflict) { + return s.handleCreateSessionCASConflict(ctx, challengeID, code, clientPublicKey, sessionRecord) + } + + s.bestEffortRevokeSupersededSession(ctx, sessionRecord) + return Result{}, shared.ServiceUnavailable(err) + } + + // Publish the currently stored session view so a concurrent revoke/block + // cannot overwrite source of truth with a stale active projection. + currentSession, err := s.sessionStore.Get(ctx, sessionRecord.ID) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: newly created session %q was not found", sessionRecord.ID)) + default: + return Result{}, shared.ServiceUnavailable(err) + } + } + if err := s.publishSession(ctx, currentSession, "confirm_email_code"); err != nil { + return Result{}, err + } + + return Result{DeviceSessionID: currentSession.ID.String()}, nil + } + + return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: compare-and-swap retry limit exceeded")) +} + +func (s *Service) ensureChallengeNotExpired(ctx context.Context, current challenge.Challenge, now time.Time) (bool, error) { + if current.IsExpiredAt(now) { + if current.Status != challenge.StatusExpired && current.Status.CanTransitionTo(challenge.StatusExpired) { + next := current + next.Status = challenge.StatusExpired + next.Abuse.LastAttemptAt = &now + next.Confirmation = nil + if err := next.Validate(); err != nil { + return true, shared.InternalError(err) + } + if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil { + if !errors.Is(err, ports.ErrConflict) { + return true, shared.ServiceUnavailable(err) + } + return false, err + } + } + + return true, nil + } + + return false, nil +} + +func (s *Service) handleConfirmedRetry(ctx context.Context, current challenge.Challenge, code string, clientPublicKey common.ClientPublicKey) (Result, error) { + match, err := s.codeHasher.Compare(current.CodeHash, code) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if !match { + return Result{}, shared.InvalidCode() + } + if current.Confirmation == nil { + return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed challenge is missing confirmation metadata")) + } + if current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() { + return Result{}, shared.InvalidCode() + } + + record, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed session %q was not found", current.Confirmation.SessionID)) + default: + return Result{}, shared.ServiceUnavailable(err) + } + } + if err := s.publishSession(ctx, record, "confirm_email_code_retry"); err != nil { + return Result{}, err + } + + return Result{DeviceSessionID: record.ID.String()}, nil +} + +func (s *Service) recordInvalidConfirmAttempt(ctx context.Context, current challenge.Challenge, now time.Time) error { + next := current + next.Attempts.Confirm++ + next.Abuse.LastAttemptAt = &now + if next.Attempts.Confirm >= challenge.MaxInvalidConfirmAttempts { + next.Status = challenge.StatusFailed + } + if err := next.Validate(); err != nil { + return shared.InternalError(err) + } + + if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil { + switch { + case errors.Is(err, ports.ErrConflict): + return err + default: + return shared.ServiceUnavailable(err) + } + } + + return nil +} + +func (s *Service) markChallengeFailed(ctx context.Context, current challenge.Challenge, now time.Time) error { + next := current + next.Status = challenge.StatusFailed + next.Abuse.LastAttemptAt = &now + if err := next.Validate(); err != nil { + return shared.InternalError(err) + } + + if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil { + switch { + case errors.Is(err, ports.ErrConflict): + return err + default: + return shared.ServiceUnavailable(err) + } + } + + return nil +} + +func (s *Service) evaluateSessionLimit(ctx context.Context, userID common.UserID, config ports.SessionLimitConfig) (sessionlimit.Decision, error) { + activeSessionCount, err := s.sessionStore.CountActiveByUserID(ctx, userID) + if err != nil { + return sessionlimit.Decision{}, shared.ServiceUnavailable(err) + } + + decision, err := shared.EvaluateSessionLimit(config, activeSessionCount) + if err != nil { + return sessionlimit.Decision{}, err + } + + return decision, nil +} + +func (s *Service) createSession(ctx context.Context, userID common.UserID, clientPublicKey common.ClientPublicKey, now time.Time) (devicesession.Session, error) { + for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ { + deviceSessionID, err := s.idGenerator.NewDeviceSessionID() + if err != nil { + return devicesession.Session{}, shared.ServiceUnavailable(err) + } + + record := devicesession.Session{ + ID: deviceSessionID, + UserID: userID, + ClientPublicKey: clientPublicKey, + Status: devicesession.StatusActive, + CreatedAt: now, + } + if err := record.Validate(); err != nil { + return devicesession.Session{}, shared.InternalError(err) + } + + if err := s.sessionStore.Create(ctx, record); err != nil { + if errors.Is(err, ports.ErrConflict) { + continue + } + return devicesession.Session{}, shared.ServiceUnavailable(err) + } + s.telemetry.RecordSessionCreated(ctx) + + return record, nil + } + + return devicesession.Session{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: session id conflict retry limit exceeded")) +} + +func (s *Service) handleCreateSessionCASConflict( + ctx context.Context, + challengeID common.ChallengeID, + code string, + clientPublicKey common.ClientPublicKey, + createdSession devicesession.Session, +) (Result, error) { + defer s.bestEffortRevokeSupersededSession(ctx, createdSession) + + current, err := s.challengeStore.Get(ctx, challengeID) + if err != nil { + if errors.Is(err, ports.ErrNotFound) { + return Result{}, shared.ServiceUnavailable(err) + } + return Result{}, shared.ServiceUnavailable(err) + } + + if current.Status != challenge.StatusConfirmedPendingExpire || current.Confirmation == nil { + return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q changed to unexpected status %q after create", challengeID, current.Status)) + } + + match, err := s.codeHasher.Compare(current.CodeHash, code) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if !match || current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() { + return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q was confirmed by a different payload", challengeID)) + } + + winningSession, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: winning session %q was not found", current.Confirmation.SessionID)) + default: + return Result{}, shared.ServiceUnavailable(err) + } + } + if err := s.publishSession(ctx, winningSession, "confirm_email_code_race_winner"); err != nil { + return Result{}, err + } + + return Result{DeviceSessionID: winningSession.ID.String()}, nil +} + +func (s *Service) bestEffortRevokeSupersededSession(ctx context.Context, record devicesession.Session) { + revocation := devicesession.Revocation{ + At: s.clock.Now().UTC(), + ReasonCode: revokeReasonConfirmRace, + ActorType: revokeActorTypeService, + ActorID: revokeActorIDService, + } + if err := revocation.Validate(); err != nil { + return + } + + revokeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{ + DeviceSessionID: record.ID, + Revocation: revocation, + }) + if err != nil { + s.logger.Warn( + "best-effort superseded session revoke failed", + zap.String("component", "service"), + zap.String("use_case", "confirm_email_code"), + zap.String("operation", "confirm_email_code_race_cleanup"), + zap.String("device_session_id", record.ID.String()), + zap.String("reason_code", revocation.ReasonCode.String()), + zap.Error(err), + ) + return + } + if err := revokeResult.Validate(); err != nil { + s.logger.Warn( + "best-effort superseded session revoke produced invalid result", + zap.String("component", "service"), + zap.String("use_case", "confirm_email_code"), + zap.String("operation", "confirm_email_code_race_cleanup"), + zap.String("device_session_id", record.ID.String()), + zap.Error(err), + ) + return + } + if revokeResult.Outcome == ports.RevokeSessionOutcomeRevoked { + s.telemetry.RecordSessionRevocations(ctx, "confirm_email_code_race_cleanup", revocation.ReasonCode.String(), 1) + } + + snapshot, err := shared.ToGatewayProjectionSnapshot(revokeResult.Session) + if err != nil { + s.logger.Warn( + "best-effort superseded session snapshot mapping failed", + zap.String("component", "service"), + zap.String("use_case", "confirm_email_code"), + zap.String("operation", "confirm_email_code_race_cleanup"), + zap.String("device_session_id", revokeResult.Session.ID.String()), + zap.Error(err), + ) + return + } + if err := shared.PublishProjectionSnapshotWithTelemetry(ctx, s.publisher, snapshot, s.telemetry, "confirm_email_code_race_cleanup"); err != nil { + s.logger.Warn( + "best-effort superseded session publish failed", + zap.String("component", "service"), + zap.String("use_case", "confirm_email_code"), + zap.String("operation", "confirm_email_code_race_cleanup"), + zap.String("device_session_id", revokeResult.Session.ID.String()), + zap.Error(err), + ) + } +} + +func (s *Service) publishSession(ctx context.Context, record devicesession.Session, operation string) error { + return shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, operation) +} + +func namedLogger(logger *zap.Logger, name string) *zap.Logger { + if logger == nil { + logger = zap.NewNop() + } + + return logger.Named(name) +} diff --git a/authsession/internal/service/confirmemailcode/service_test.go b/authsession/internal/service/confirmemailcode/service_test.go new file mode 100644 index 0000000..93e3234 --- /dev/null +++ b/authsession/internal/service/confirmemailcode/service_test.go @@ -0,0 +1,682 @@ +package confirmemailcode + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" +) + +func TestExecuteConfirmsChallengeForExistingUser(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if result.DeviceSessionID != "device-session-1" { + require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1") + } + + record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != devicesession.StatusActive { + require.Failf(t, "test failed", "session status = %q, want %q", record.Status, devicesession.StatusActive) + } + + challengeRecord, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if challengeRecord.Status != challenge.StatusConfirmedPendingExpire || challengeRecord.Confirmation == nil { + require.Failf(t, "test failed", "challenge status = %q, confirmation = %+v", challengeRecord.Status, challengeRecord.Confirmation) + } + if len(deps.publisher.PublishedSnapshots()) != 1 { + require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots())) + } +} + +func TestExecuteConfirmsChallengeByCreatingUser(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.userDirectory.QueueCreatedUserIDs(common.UserID("user-created")); err != nil { + require.Failf(t, "test failed", "QueueCreatedUserIDs() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "new@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if result.DeviceSessionID != "device-session-1" { + require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1") + } + + record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.UserID != common.UserID("user-created") { + require.Failf(t, "test failed", "session user id = %q, want %q", record.UserID, common.UserID("user-created")) + } +} + +func TestExecuteConfirmsSuppressedChallenge(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + record.Status = challenge.StatusDeliverySuppressed + record.DeliveryState = challenge.DeliverySuppressed + if err := record.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if result.DeviceSessionID != "device-session-1" { + require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1") + } +} + +func TestExecuteReturnsChallengeNotFound(t *testing.T) { + t.Parallel() + + service := mustNewConfirmService(t, newConfirmDeps(t)) + + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "missing", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeChallengeNotFound { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeNotFound) + } +} + +func TestExecuteReturnsChallengeExpiredAndMarksExpired(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-2*time.Minute), deps.now.Add(-time.Second))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired) + } + + record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != challenge.StatusExpired { + require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusExpired) + } +} + +func TestExecuteReturnsChallengeExpiredForConfirmedChallengeAfterRetentionWindow(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + key, err := shared.ParseClientPublicKey(publicKeyString()) + if err != nil { + require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err) + } + record := confirmedChallengeFixture( + t, + deps.hasher, + "challenge-1", + "pilot@example.com", + "654321", + "device-session-1", + key, + deps.now.Add(-2*challenge.ConfirmedRetention), + deps.now.Add(-time.Second), + ) + if err := deps.challengeStore.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err = service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired) + } + + updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if updated.Status != challenge.StatusExpired { + require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusExpired) + } + if updated.Confirmation != nil { + require.Failf(t, "test failed", "Confirmation = %+v, want nil after expiration", updated.Confirmation) + } +} + +func TestExecuteReturnsInvalidClientPublicKey(t *testing.T) { + t.Parallel() + + service := mustNewConfirmService(t, newConfirmDeps(t)) + + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: "invalid", + }) + if shared.CodeOf(err) != shared.ErrorCodeInvalidClientPublicKey { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidClientPublicKey) + } +} + +func TestExecuteInvalidCodeIncrementsAttempts(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "000000", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeInvalidCode { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode) + } + + record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Attempts.Confirm != 1 { + require.Failf(t, "test failed", "Attempts.Confirm = %d, want 1", record.Attempts.Confirm) + } +} + +func TestExecuteFifthInvalidAttemptMarksChallengeFailed(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + record.Attempts.Confirm = 4 + if err := deps.challengeStore.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "000000", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeInvalidCode { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode) + } + + updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if updated.Status != challenge.StatusFailed { + require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusFailed) + } +} + +func TestExecuteDoesNotCreateSessionAfterTooManyAttempts(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + record.Attempts.Confirm = challenge.MaxInvalidConfirmAttempts + record.Status = challenge.StatusFailed + if err := record.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeInvalidCode { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode) + } + + if got, err := deps.sessionStore.CountActiveByUserID(context.Background(), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "CountActiveByUserID() returned error: %v", err) + } else if got != 0 { + require.Failf(t, "test failed", "CountActiveByUserID() = %d, want 0", got) + } +} + +func TestExecuteReturnsSameSessionIDForIdempotentRetryAndRepublishes(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + key, err := shared.ParseClientPublicKey(publicKeyString()) + if err != nil { + require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err) + } + record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + if err := deps.challengeStore.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if result.DeviceSessionID != "device-session-1" { + require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1") + } + if len(deps.publisher.PublishedSnapshots()) != 1 { + require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots())) + } +} + +func TestExecuteReturnsInvalidCodeForDifferentKeyDuringIdempotentRetry(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + key, err := shared.ParseClientPublicKey(publicKeyString()) + if err != nil { + require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err) + } + record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + if err := deps.challengeStore.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err = service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: alternatePublicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeInvalidCode { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode) + } + + updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if updated.Attempts.Confirm != 0 { + require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm) + } + if updated.Confirmation == nil { + require.FailNow(t, "Confirmation = nil, want metadata to stay intact") + } + if updated.Confirmation.SessionID != common.DeviceSessionID("device-session-1") { + require.Failf(t, "test failed", "Confirmation.SessionID = %q, want %q", updated.Confirmation.SessionID, common.DeviceSessionID("device-session-1")) + } +} + +func TestExecuteReturnsInvalidCodeForNonConfirmableStates(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + status challenge.Status + deliveryState challenge.DeliveryState + }{ + {name: "pending send", status: challenge.StatusPendingSend, deliveryState: challenge.DeliveryPending}, + {name: "failed", status: challenge.StatusFailed, deliveryState: challenge.DeliveryFailed}, + {name: "cancelled", status: challenge.StatusCancelled, deliveryState: challenge.DeliverySent}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + record.Status = tt.status + record.DeliveryState = tt.deliveryState + if err := record.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeInvalidCode { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode) + } + + updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if updated.Attempts.Confirm != 0 { + require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm) + } + }) + } +} + +func TestExecuteMarksChallengeFailedAndReturnsBlockedByPolicy(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil { + require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeBlockedByPolicy { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeBlockedByPolicy) + } + + record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != challenge.StatusFailed { + require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusFailed) + } +} + +func TestExecuteReturnsSessionLimitExceededWithoutConsumingChallenge(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-existing", "user-1", mustClientPublicKey(t, publicKeyString()), deps.now.Add(-2*time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + limit := 1 + deps.configProvider.Config.ActiveSessionLimit = &limit + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeSessionLimitExceeded { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionLimitExceeded) + } + + record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != challenge.StatusSent { + require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusSent) + } + if record.Attempts.Confirm != 0 { + require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", record.Attempts.Confirm) + } +} + +func TestExecuteReturnsServiceUnavailableThenSucceedsIdempotentlyAfterPublishFailure(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + deps.publisher.Err = errors.New("publish failed") + if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service := mustNewConfirmService(t, deps) + _, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if shared.CodeOf(err) != shared.ErrorCodeServiceUnavailable { + require.Failf(t, "test failed", "first Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeServiceUnavailable) + } + + deps.publisher.Err = nil + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + if err != nil { + require.Failf(t, "test failed", "second Execute() returned error: %v", err) + } + if result.DeviceSessionID != "device-session-1" { + require.Failf(t, "test failed", "second Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1") + } +} + +type confirmDeps struct { + challengeStore *testkit.InMemoryChallengeStore + sessionStore *testkit.InMemorySessionStore + userDirectory *testkit.InMemoryUserDirectory + configProvider testkit.StaticConfigProvider + publisher *testkit.RecordingProjectionPublisher + idGenerator *testkit.SequenceIDGenerator + hasher testkit.DeterministicCodeHasher + now time.Time +} + +func newConfirmDeps(t *testing.T) confirmDeps { + t.Helper() + + return confirmDeps{ + challengeStore: &testkit.InMemoryChallengeStore{}, + sessionStore: &testkit.InMemorySessionStore{}, + userDirectory: &testkit.InMemoryUserDirectory{}, + configProvider: testkit.StaticConfigProvider{}, + publisher: &testkit.RecordingProjectionPublisher{}, + idGenerator: &testkit.SequenceIDGenerator{ + DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"}, + }, + hasher: testkit.DeterministicCodeHasher{}, + now: time.Unix(20, 0).UTC(), + } +} + +func mustNewConfirmService(t *testing.T, deps confirmDeps) *Service { + t.Helper() + + service, err := New( + deps.challengeStore, + deps.sessionStore, + deps.userDirectory, + deps.configProvider, + deps.publisher, + deps.idGenerator, + deps.hasher, + testkit.FixedClock{Time: deps.now}, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + return service +} + +func sentChallengeFixture( + t *testing.T, + hasher testkit.DeterministicCodeHasher, + challengeID string, + email string, + code string, + createdAt time.Time, + expiresAt time.Time, +) challenge.Challenge { + t.Helper() + + codeHash, err := hasher.Hash(code) + if err != nil { + require.Failf(t, "test failed", "Hash() returned error: %v", err) + } + + record := challenge.Challenge{ + ID: common.ChallengeID(challengeID), + Email: common.Email(email), + CodeHash: codeHash, + Status: challenge.StatusSent, + DeliveryState: challenge.DeliverySent, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + } + if err := record.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + + return record +} + +func confirmedChallengeFixture( + t *testing.T, + hasher testkit.DeterministicCodeHasher, + challengeID string, + email string, + code string, + deviceSessionID string, + clientPublicKey common.ClientPublicKey, + createdAt time.Time, + expiresAt time.Time, +) challenge.Challenge { + t.Helper() + + record := sentChallengeFixture(t, hasher, challengeID, email, code, createdAt, expiresAt) + record.Status = challenge.StatusConfirmedPendingExpire + record.Confirmation = &challenge.Confirmation{ + SessionID: common.DeviceSessionID(deviceSessionID), + ClientPublicKey: clientPublicKey, + ConfirmedAt: createdAt.Add(time.Minute), + } + if err := record.Validate(); err != nil { + require.Failf(t, "test failed", "Validate() returned error: %v", err) + } + + return record +} + +func activeSessionFixture(deviceSessionID string, userID string, clientPublicKey common.ClientPublicKey, createdAt time.Time) devicesession.Session { + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: clientPublicKey, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} + +func mustClientPublicKey(t *testing.T, value string) common.ClientPublicKey { + t.Helper() + + key, err := shared.ParseClientPublicKey(value) + if err != nil { + require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err) + } + + return key +} + +func publicKeyString() string { + return "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=" +} + +func alternatePublicKeyString() string { + return "AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE=" +} diff --git a/authsession/internal/service/confirmemailcode/stub_user_directory_test.go b/authsession/internal/service/confirmemailcode/stub_user_directory_test.go new file mode 100644 index 0000000..6289735 --- /dev/null +++ b/authsession/internal/service/confirmemailcode/stub_user_directory_test.go @@ -0,0 +1,109 @@ +package confirmemailcode + +import ( + "context" + "testing" + "time" + + stubuserservice "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/service/shared" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) { + t.Parallel() + + t.Run("creates user through EnsureUserByEmail", func(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + userDirectory := &stubuserservice.StubDirectory{} + require.NoError(t, userDirectory.QueueCreatedUserIDs(common.UserID("user-created"))) + deps.userDirectory = nil + require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture( + t, + deps.hasher, + "challenge-1", + "pilot@example.com", + "654321", + deps.now.Add(-time.Minute), + deps.now.Add(time.Minute), + ))) + + service, err := New( + deps.challengeStore, + deps.sessionStore, + userDirectory, + deps.configProvider, + deps.publisher, + deps.idGenerator, + deps.hasher, + fixedClock(deps.now), + ) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.NoError(t, err) + assert.Equal(t, "device-session-1", result.DeviceSessionID) + + sessionRecord, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, err) + assert.Equal(t, common.UserID("user-created"), sessionRecord.UserID) + }) + + t.Run("blocked email returns blocked by policy", func(t *testing.T) { + t.Parallel() + + deps := newConfirmDeps(t) + userDirectory := &stubuserservice.StubDirectory{} + require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block"))) + require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture( + t, + deps.hasher, + "challenge-1", + "pilot@example.com", + "654321", + deps.now.Add(-time.Minute), + deps.now.Add(time.Minute), + ))) + + service, err := New( + deps.challengeStore, + deps.sessionStore, + userDirectory, + deps.configProvider, + deps.publisher, + deps.idGenerator, + deps.hasher, + fixedClock(deps.now), + ) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err)) + + record, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, getErr) + assert.Equal(t, challenge.StatusFailed, record.Status) + }) +} + +type fixedClock time.Time + +func (c fixedClock) Now() time.Time { + return time.Time(c) +} diff --git a/authsession/internal/service/confirmemailcode/telemetry_test.go b/authsession/internal/service/confirmemailcode/telemetry_test.go new file mode 100644 index 0000000..873901d --- /dev/null +++ b/authsession/internal/service/confirmemailcode/telemetry_test.go @@ -0,0 +1,104 @@ +package confirmemailcode + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + authtelemetry "galaxy/authsession/internal/telemetry" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" +) + +func TestExecuteRecordsInvalidCodeMetricForThrottledChallenge(t *testing.T) { + t.Parallel() + + runtime, reader := newObservedConfirmTelemetryRuntime(t) + deps := newConfirmDeps(t) + record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)) + record.Status = challenge.StatusDeliveryThrottled + record.DeliveryState = challenge.DeliveryThrottled + require.NoError(t, record.Validate()) + require.NoError(t, deps.challengeStore.Create(context.Background(), record)) + + service, err := NewWithTelemetry( + deps.challengeStore, + deps.sessionStore, + deps.userDirectory, + deps.configProvider, + deps.publisher, + deps.idGenerator, + deps.hasher, + testkit.FixedClock{Time: deps.now}, + runtime, + ) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + ChallengeID: "challenge-1", + Code: "654321", + ClientPublicKey: publicKeyString(), + }) + require.Error(t, err) + + assertConfirmMetricCount(t, reader, map[string]string{"outcome": "invalid_code"}, 1) +} + +func newObservedConfirmTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader) { + t.Helper() + + reader := sdkmetric.NewManualReader() + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + + runtime, err := authtelemetry.New(provider) + require.NoError(t, err) + + return runtime, reader +} + +func assertConfirmMetricCount(t *testing.T, reader *sdkmetric.ManualReader, wantAttrs map[string]string, wantValue int64) { + t.Helper() + + var resourceMetrics metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &resourceMetrics)) + + for _, scopeMetrics := range resourceMetrics.ScopeMetrics { + for _, metric := range scopeMetrics.Metrics { + if metric.Name != "authsession.confirm_email_code.attempts" { + continue + } + + sum, ok := metric.Data.(metricdata.Sum[int64]) + require.True(t, ok) + + for _, point := range sum.DataPoints { + if hasConfirmMetricAttributes(point.Attributes.ToSlice(), wantAttrs) { + assert.Equal(t, wantValue, point.Value) + return + } + } + } + } + + require.Failf(t, "test failed", "confirm metric with attrs %v not found", wantAttrs) +} + +func hasConfirmMetricAttributes(values []attribute.KeyValue, want map[string]string) bool { + if len(values) != len(want) { + return false + } + + for _, value := range values { + if want[string(value.Key)] != value.Value.AsString() { + return false + } + } + + return true +} diff --git a/authsession/internal/service/getsession/service.go b/authsession/internal/service/getsession/service.go new file mode 100644 index 0000000..e1986c2 --- /dev/null +++ b/authsession/internal/service/getsession/service.go @@ -0,0 +1,65 @@ +// Package getsession implements the trusted internal read use case for one +// device session. +package getsession + +import ( + "context" + "errors" + "fmt" + + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" +) + +// Input describes one trusted internal get-session request. +type Input struct { + // DeviceSessionID identifies the session that should be read. + DeviceSessionID string +} + +// Result describes one trusted internal get-session response. +type Result struct { + // Session stores the frozen internal read-model DTO. + Session shared.Session +} + +// Service executes the trusted internal get-session use case against the +// configured ports. +type Service struct { + sessionStore ports.SessionStore +} + +// New returns a get-session service wired to sessionStore. +func New(sessionStore ports.SessionStore) (*Service, error) { + if sessionStore == nil { + return nil, fmt.Errorf("getsession: session store must not be nil") + } + + return &Service{sessionStore: sessionStore}, nil +} + +// Execute loads one source-of-truth session and projects it into the frozen +// internal read DTO shape. +func (s *Service) Execute(ctx context.Context, input Input) (Result, error) { + deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID) + if err != nil { + return Result{}, err + } + + record, err := s.sessionStore.Get(ctx, deviceSessionID) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return Result{}, shared.SessionNotFound() + default: + return Result{}, shared.ServiceUnavailable(err) + } + } + + session, err := shared.ToSession(record) + if err != nil { + return Result{}, shared.InternalError(err) + } + + return Result{Session: session}, nil +} diff --git a/authsession/internal/service/getsession/service_test.go b/authsession/internal/service/getsession/service_test.go new file mode 100644 index 0000000..9fdacd3 --- /dev/null +++ b/authsession/internal/service/getsession/service_test.go @@ -0,0 +1,68 @@ +package getsession + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" +) + +func TestExecuteReturnsMappedSession(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(store) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + result, err := service.Execute(context.Background(), Input{DeviceSessionID: " device-session-1 "}) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if result.Session.DeviceSessionID != "device-session-1" { + require.Failf(t, "test failed", "Execute().Session.DeviceSessionID = %q, want %q", result.Session.DeviceSessionID, "device-session-1") + } + if result.Session.CreatedAt != time.Unix(10, 0).UTC().Format(time.RFC3339) { + require.Failf(t, "test failed", "Execute().Session.CreatedAt = %q", result.Session.CreatedAt) + } +} + +func TestExecuteReturnsSessionNotFound(t *testing.T) { + t.Parallel() + + service, err := New(&testkit.InMemorySessionStore{}) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + _, err = service.Execute(context.Background(), Input{DeviceSessionID: "missing"}) + if shared.CodeOf(err) != shared.ErrorCodeSessionNotFound { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionNotFound) + } +} + +func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: key, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} diff --git a/authsession/internal/service/listusersessions/service.go b/authsession/internal/service/listusersessions/service.go new file mode 100644 index 0000000..9dc17fb --- /dev/null +++ b/authsession/internal/service/listusersessions/service.go @@ -0,0 +1,58 @@ +// Package listusersessions implements the trusted internal read use case for +// listing all sessions of one user. +package listusersessions + +import ( + "context" + "fmt" + + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" +) + +// Input describes one trusted internal list-user-sessions request. +type Input struct { + // UserID identifies the owner whose sessions should be listed. + UserID string +} + +// Result describes one trusted internal list-user-sessions response. +type Result struct { + // Sessions stores the frozen internal read-model DTO slice. + Sessions []shared.Session +} + +// Service executes the trusted internal list-user-sessions use case. +type Service struct { + sessionStore ports.SessionStore +} + +// New returns a list-user-sessions service wired to sessionStore. +func New(sessionStore ports.SessionStore) (*Service, error) { + if sessionStore == nil { + return nil, fmt.Errorf("listusersessions: session store must not be nil") + } + + return &Service{sessionStore: sessionStore}, nil +} + +// Execute loads all source-of-truth sessions for one user and projects them +// into the frozen internal read DTO shape. +func (s *Service) Execute(ctx context.Context, input Input) (Result, error) { + userID, err := shared.ParseUserID(input.UserID) + if err != nil { + return Result{}, err + } + + records, err := s.sessionStore.ListByUserID(ctx, userID) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + + sessions, err := shared.ToSessions(records) + if err != nil { + return Result{}, shared.InternalError(err) + } + + return Result{Sessions: sessions}, nil +} diff --git a/authsession/internal/service/listusersessions/service_test.go b/authsession/internal/service/listusersessions/service_test.go new file mode 100644 index 0000000..3fdfe40 --- /dev/null +++ b/authsession/internal/service/listusersessions/service_test.go @@ -0,0 +1,73 @@ +package listusersessions + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/testkit" +) + +func TestExecutePreservesNewestFirstOrder(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + older := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + newer := activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()) + for _, record := range []devicesession.Session{older, newer} { + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + } + + service, err := New(store) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + result, err := service.Execute(context.Background(), Input{UserID: "user-1"}) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if len(result.Sessions) != 2 { + require.Failf(t, "test failed", "Execute().Sessions length = %d, want 2", len(result.Sessions)) + } + if result.Sessions[0].DeviceSessionID != "device-session-2" || result.Sessions[1].DeviceSessionID != "device-session-1" { + require.Failf(t, "test failed", "Execute().Sessions order = [%q %q]", result.Sessions[0].DeviceSessionID, result.Sessions[1].DeviceSessionID) + } +} + +func TestExecuteReturnsEmptyForUnknownUser(t *testing.T) { + t.Parallel() + + service, err := New(&testkit.InMemorySessionStore{}) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + result, err := service.Execute(context.Background(), Input{UserID: "missing"}) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if len(result.Sessions) != 0 { + require.Failf(t, "test failed", "Execute().Sessions length = %d, want 0", len(result.Sessions)) + } +} + +func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: key, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} diff --git a/authsession/internal/service/revokeallusersessions/consistency_test.go b/authsession/internal/service/revokeallusersessions/consistency_test.go new file mode 100644 index 0000000..26850f3 --- /dev/null +++ b/authsession/internal/service/revokeallusersessions/consistency_test.go @@ -0,0 +1,106 @@ +package revokeallusersessions + +import ( + "context" + "errors" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteRetriesProjectionPublishesForBulkRevoke(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + publisher := &testkit.RecordingProjectionPublisher{ + Errors: []error{ + errors.New("publish failed"), + nil, + errors.New("publish failed"), + nil, + }, + } + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()))) + + service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "revoked", result.Outcome) + assert.EqualValues(t, 2, result.AffectedSessionCount) + assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs) + require.Len(t, publisher.PublishedSnapshots(), 4) +} + +func TestExecuteRepublishesCurrentRevokedSessionsOnNoActiveSessionsRetry(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + publisher := &testkit.RecordingProjectionPublisher{ + Errors: []error{ + nil, + errors.New("publish failed"), + errors.New("publish failed"), + errors.New("publish failed"), + }, + } + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()))) + + service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err)) + require.Len(t, publisher.PublishedSnapshots(), 4) + + for _, deviceSessionID := range []common.DeviceSessionID{"device-session-1", "device-session-2"} { + record, getErr := store.Get(context.Background(), deviceSessionID) + require.NoError(t, getErr) + require.NotNil(t, record.Revocation) + assert.Equal(t, devicesession.StatusRevoked, record.Status) + } + + publisher.Errors = nil + publisher.Err = nil + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "no_active_sessions", result.Outcome) + assert.EqualValues(t, 0, result.AffectedSessionCount) + require.NotNil(t, result.AffectedDeviceSessionIDs) + assert.Empty(t, result.AffectedDeviceSessionIDs) + + published := publisher.PublishedSnapshots() + require.Len(t, published, 6) + assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{ + published[4].DeviceSessionID, + published[5].DeviceSessionID, + }) +} diff --git a/authsession/internal/service/revokeallusersessions/service.go b/authsession/internal/service/revokeallusersessions/service.go new file mode 100644 index 0000000..8190025 --- /dev/null +++ b/authsession/internal/service/revokeallusersessions/service.go @@ -0,0 +1,200 @@ +// Package revokeallusersessions implements the trusted internal bulk revoke +// use case for all sessions of one user. +package revokeallusersessions + +import ( + "context" + "fmt" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +// Input describes one trusted internal revoke-all-user-sessions request. +type Input struct { + // UserID identifies the owner whose sessions should be revoked. + UserID string + + // ReasonCode stores the machine-readable revoke reason code. + ReasonCode string + + // ActorType stores the machine-readable revoke actor type. + ActorType string + + // ActorID stores the optional stable revoke actor identifier. + ActorID string +} + +// Result describes the frozen internal bulk revoke acknowledgement. +type Result struct { + // Outcome reports whether active sessions were revoked during the current + // call. + Outcome string + + // UserID identifies the user addressed by the operation. + UserID string + + // AffectedSessionCount reports how many sessions changed state during the + // current call. + AffectedSessionCount int64 + + // AffectedDeviceSessionIDs lists every session identifier affected during + // the current call. + AffectedDeviceSessionIDs []string +} + +// Service executes the trusted internal revoke-all-user-sessions use case. +type Service struct { + sessionStore ports.SessionStore + userDirectory ports.UserDirectory + publisher ports.GatewaySessionProjectionPublisher + clock ports.Clock + logger *zap.Logger + telemetry *telemetry.Runtime +} + +// New returns a revoke-all-user-sessions service wired to the required ports. +func New(sessionStore ports.SessionStore, userDirectory ports.UserDirectory, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) { + return NewWithObservability(sessionStore, userDirectory, publisher, clock, nil, nil) +} + +// NewWithObservability returns a revoke-all-user-sessions service wired to the +// required ports plus optional structured logging and telemetry dependencies. +func NewWithObservability( + sessionStore ports.SessionStore, + userDirectory ports.UserDirectory, + publisher ports.GatewaySessionProjectionPublisher, + clock ports.Clock, + logger *zap.Logger, + telemetryRuntime *telemetry.Runtime, +) (*Service, error) { + switch { + case sessionStore == nil: + return nil, fmt.Errorf("revokeallusersessions: session store must not be nil") + case userDirectory == nil: + return nil, fmt.Errorf("revokeallusersessions: user directory must not be nil") + case publisher == nil: + return nil, fmt.Errorf("revokeallusersessions: projection publisher must not be nil") + case clock == nil: + return nil, fmt.Errorf("revokeallusersessions: clock must not be nil") + default: + return &Service{ + sessionStore: sessionStore, + userDirectory: userDirectory, + publisher: publisher, + clock: clock, + logger: namedLogger(logger, "revoke_all_user_sessions"), + telemetry: telemetryRuntime, + }, nil + } +} + +// Execute revokes all active sessions of one user and republishes revoked +// gateway projections for every affected session. +func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) { + logFields := []zap.Field{ + zap.String("component", "service"), + zap.String("use_case", "revoke_all_user_sessions"), + } + defer func() { + shared.LogServiceOutcome(s.logger, ctx, "revoke all user sessions completed", err, logFields...) + }() + + userID, err := shared.ParseUserID(input.UserID) + if err != nil { + return Result{}, err + } + logFields = append(logFields, zap.String("user_id", userID.String())) + + revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now()) + if err != nil { + return Result{}, err + } + logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String())) + + exists, err := s.userDirectory.ExistsByUserID(ctx, userID) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + s.telemetry.RecordUserDirectoryOutcome(ctx, "exists_by_user_id", boolOutcome(exists)) + if !exists { + return Result{}, shared.SubjectNotFound() + } + + storeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{ + UserID: userID, + Revocation: revocation, + }) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if err := storeResult.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome))) + + affectedDeviceSessionIDs := make([]string, 0, len(storeResult.Sessions)) + for _, record := range storeResult.Sessions { + if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions"); err != nil { + return Result{}, err + } + affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String()) + } + if storeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions { + if err := s.republishCurrentRevokedSessions(ctx, userID); err != nil { + return Result{}, err + } + } + + affectedSessionCount := int64(len(storeResult.Sessions)) + if affectedSessionCount > 0 { + s.telemetry.RecordSessionRevocations(ctx, "revoke_all_user_sessions", revocation.ReasonCode.String(), affectedSessionCount) + } + logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount)) + + return Result{ + Outcome: string(storeResult.Outcome), + UserID: storeResult.UserID.String(), + AffectedSessionCount: affectedSessionCount, + AffectedDeviceSessionIDs: affectedDeviceSessionIDs, + }, nil +} + +func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error { + records, err := s.sessionStore.ListByUserID(ctx, userID) + if err != nil { + return shared.ServiceUnavailable(err) + } + + for _, record := range records { + if record.Status != devicesession.StatusRevoked { + continue + } + if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions_repair"); err != nil { + return err + } + } + + return nil +} + +func boolOutcome(value bool) string { + if value { + return "exists" + } + + return "missing" +} + +func namedLogger(logger *zap.Logger, name string) *zap.Logger { + if logger == nil { + logger = zap.NewNop() + } + + return logger.Named(name) +} diff --git a/authsession/internal/service/revokeallusersessions/service_test.go b/authsession/internal/service/revokeallusersessions/service_test.go new file mode 100644 index 0000000..7ac6ea7 --- /dev/null +++ b/authsession/internal/service/revokeallusersessions/service_test.go @@ -0,0 +1,162 @@ +package revokeallusersessions + +import ( + "context" + "errors" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteRevokesExistingUserSessionsAndPublishes(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + publisher := &testkit.RecordingProjectionPublisher{} + if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + for _, record := range []devicesession.Session{ + activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()), + activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()), + } { + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + } + + service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "revoked", result.Outcome) + assert.EqualValues(t, 2, result.AffectedSessionCount) + assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs) + + for _, deviceSessionID := range result.AffectedDeviceSessionIDs { + stored, getErr := store.Get(context.Background(), common.DeviceSessionID(deviceSessionID)) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.StatusRevoked, stored.Status) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode) + assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType) + assert.Empty(t, stored.Revocation.ActorID) + assert.Equal(t, time.Unix(30, 0).UTC(), stored.Revocation.At) + } + + published := publisher.PublishedSnapshots() + require.Len(t, published, 2) + assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{ + published[0].DeviceSessionID, + published[1].DeviceSessionID, + }) + for _, snapshot := range published { + assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, snapshot.RevokeReasonCode) + assert.Equal(t, common.RevokeActorType("system"), snapshot.RevokeActorType) + require.NotNil(t, snapshot.RevokedAt) + assert.Equal(t, time.Unix(30, 0).UTC(), *snapshot.RevokedAt) + } +} + +func TestExecuteReturnsNoActiveSessionsForExistingUserWithoutActiveSessions(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + publisher := &testkit.RecordingProjectionPublisher{} + if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + + service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "no_active_sessions", result.Outcome) + assert.EqualValues(t, 0, result.AffectedSessionCount) + require.NotNil(t, result.AffectedDeviceSessionIDs) + assert.Empty(t, result.AffectedDeviceSessionIDs) + assert.Empty(t, publisher.PublishedSnapshots()) +} + +func TestExecuteReturnsSubjectNotFoundForUnknownUser(t *testing.T) { + t.Parallel() + + service, err := New(&testkit.InMemorySessionStore{}, &testkit.InMemoryUserDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + _, err = service.Execute(context.Background(), Input{ + UserID: "missing", + ReasonCode: "logout_all", + ActorType: "system", + }) + assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err)) +} + +func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")} + if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err)) + + stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.StatusRevoked, stored.Status) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode) + assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType) +} + +func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: key, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} diff --git a/authsession/internal/service/revokeallusersessions/stub_user_directory_test.go b/authsession/internal/service/revokeallusersessions/stub_user_directory_test.go new file mode 100644 index 0000000..386ccf2 --- /dev/null +++ b/authsession/internal/service/revokeallusersessions/stub_user_directory_test.go @@ -0,0 +1,53 @@ +package revokeallusersessions + +import ( + "context" + "testing" + "time" + + stubuserservice "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) { + t.Parallel() + + t.Run("existing user uses ExistsByUserID and returns no active sessions", func(t *testing.T) { + t.Parallel() + + userDirectory := &stubuserservice.StubDirectory{} + require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + service, err := New(&testkit.InMemorySessionStore{}, userDirectory, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + UserID: "user-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "no_active_sessions", result.Outcome) + assert.Zero(t, result.AffectedSessionCount) + }) + + t.Run("unknown user returns subject not found", func(t *testing.T) { + t.Parallel() + + service, err := New(&testkit.InMemorySessionStore{}, &stubuserservice.StubDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()}) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + UserID: "missing", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err)) + }) +} diff --git a/authsession/internal/service/revokedevicesession/consistency_test.go b/authsession/internal/service/revokedevicesession/consistency_test.go new file mode 100644 index 0000000..bc11bbd --- /dev/null +++ b/authsession/internal/service/revokedevicesession/consistency_test.go @@ -0,0 +1,75 @@ +package revokedevicesession + +import ( + "context" + "errors" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteRetriesProjectionPublishUntilSuccess(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{ + Errors: []error{errors.New("publish failed"), nil}, + } + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + + service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + DeviceSessionID: "device-session-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "revoked", result.Outcome) + require.Len(t, publisher.PublishedSnapshots(), 2) +} + +func TestExecuteRepairsProjectionOnRepeatedAlreadyRevokedRequest(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")} + require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()))) + + service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + DeviceSessionID: "device-session-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.Error(t, err) + assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err)) + require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts) + + stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.StatusRevoked, stored.Status) + + publisher.Err = nil + + result, err := service.Execute(context.Background(), Input{ + DeviceSessionID: "device-session-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "already_revoked", result.Outcome) + assert.EqualValues(t, 0, result.AffectedSessionCount) + require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1) +} diff --git a/authsession/internal/service/revokedevicesession/service.go b/authsession/internal/service/revokedevicesession/service.go new file mode 100644 index 0000000..afc556b --- /dev/null +++ b/authsession/internal/service/revokedevicesession/service.go @@ -0,0 +1,151 @@ +// Package revokedevicesession implements the trusted internal single-session +// revoke use case. +package revokedevicesession + +import ( + "context" + "errors" + "fmt" + + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +// Input describes one trusted internal revoke-device-session request. +type Input struct { + // DeviceSessionID identifies the session that should be revoked. + DeviceSessionID string + + // ReasonCode stores the machine-readable revoke reason code. + ReasonCode string + + // ActorType stores the machine-readable revoke actor type. + ActorType string + + // ActorID stores the optional stable revoke actor identifier. + ActorID string +} + +// Result describes the frozen internal revoke-device-session acknowledgement. +type Result struct { + // Outcome reports whether the current call revoked the session or found it + // already revoked. + Outcome string + + // DeviceSessionID identifies the session addressed by the operation. + DeviceSessionID string + + // AffectedSessionCount reports how many sessions changed state during the + // current call. + AffectedSessionCount int64 +} + +// Service executes the trusted internal revoke-device-session use case. +type Service struct { + sessionStore ports.SessionStore + publisher ports.GatewaySessionProjectionPublisher + clock ports.Clock + logger *zap.Logger + telemetry *telemetry.Runtime +} + +// New returns a revoke-device-session service wired to the required ports. +func New(sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) { + return NewWithObservability(sessionStore, publisher, clock, nil, nil) +} + +// NewWithObservability returns a revoke-device-session service wired to the +// required ports plus optional structured logging and telemetry dependencies. +func NewWithObservability( + sessionStore ports.SessionStore, + publisher ports.GatewaySessionProjectionPublisher, + clock ports.Clock, + logger *zap.Logger, + telemetryRuntime *telemetry.Runtime, +) (*Service, error) { + switch { + case sessionStore == nil: + return nil, fmt.Errorf("revokedevicesession: session store must not be nil") + case publisher == nil: + return nil, fmt.Errorf("revokedevicesession: projection publisher must not be nil") + case clock == nil: + return nil, fmt.Errorf("revokedevicesession: clock must not be nil") + default: + return &Service{ + sessionStore: sessionStore, + publisher: publisher, + clock: clock, + logger: namedLogger(logger, "revoke_device_session"), + telemetry: telemetryRuntime, + }, nil + } +} + +// Execute revokes one device session and republishes the current gateway +// projection for the resulting source-of-truth session state. +func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) { + logFields := []zap.Field{ + zap.String("component", "service"), + zap.String("use_case", "revoke_device_session"), + } + defer func() { + shared.LogServiceOutcome(s.logger, ctx, "revoke device session completed", err, logFields...) + }() + + deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID) + if err != nil { + return Result{}, err + } + logFields = append(logFields, zap.String("device_session_id", deviceSessionID.String())) + + revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now()) + if err != nil { + return Result{}, err + } + logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String())) + + storeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{ + DeviceSessionID: deviceSessionID, + Revocation: revocation, + }) + if err != nil { + switch { + case errors.Is(err, ports.ErrNotFound): + return Result{}, shared.SessionNotFound() + default: + return Result{}, shared.ServiceUnavailable(err) + } + } + if err := storeResult.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome))) + + if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, storeResult.Session, s.telemetry, "revoke_device_session"); err != nil { + return Result{}, err + } + + affectedSessionCount := int64(0) + if storeResult.Outcome == ports.RevokeSessionOutcomeRevoked { + affectedSessionCount = 1 + s.telemetry.RecordSessionRevocations(ctx, "revoke_device_session", revocation.ReasonCode.String(), affectedSessionCount) + } + logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount)) + + return Result{ + Outcome: string(storeResult.Outcome), + DeviceSessionID: storeResult.Session.ID.String(), + AffectedSessionCount: affectedSessionCount, + }, nil +} + +func namedLogger(logger *zap.Logger, name string) *zap.Logger { + if logger == nil { + logger = zap.NewNop() + } + + return logger.Named(name) +} diff --git a/authsession/internal/service/revokedevicesession/service_test.go b/authsession/internal/service/revokedevicesession/service_test.go new file mode 100644 index 0000000..9ccffe3 --- /dev/null +++ b/authsession/internal/service/revokedevicesession/service_test.go @@ -0,0 +1,166 @@ +package revokedevicesession + +import ( + "context" + "errors" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteRevokesActiveSessionAndPublishes(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{} + record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + DeviceSessionID: "device-session-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "revoked", result.Outcome) + assert.EqualValues(t, 1, result.AffectedSessionCount) + assert.Equal(t, "device-session-1", result.DeviceSessionID) + + stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, err) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.StatusRevoked, stored.Status) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode) + assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType) + assert.Empty(t, stored.Revocation.ActorID) + assert.Equal(t, time.Unix(20, 0).UTC(), stored.Revocation.At) + + published := publisher.PublishedSnapshots() + require.Len(t, published, 1) + assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status) + assert.Equal(t, common.DeviceSessionID("device-session-1"), published[0].DeviceSessionID) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode) + assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType) + require.NotNil(t, published[0].RevokedAt) + assert.Equal(t, time.Unix(20, 0).UTC(), published[0].RevokedAt.UTC()) +} + +func TestExecuteAlreadyRevokedReturnsZeroAffectedAndRepublishes(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{} + record := revokedSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{ + DeviceSessionID: "device-session-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + require.NoError(t, err) + assert.Equal(t, "already_revoked", result.Outcome) + assert.EqualValues(t, 0, result.AffectedSessionCount) + assert.Equal(t, "device-session-1", result.DeviceSessionID) + + stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, err) + require.NotNil(t, stored.Revocation) + assert.Equal(t, *record.Revocation, *stored.Revocation) + + published := publisher.PublishedSnapshots() + require.Len(t, published, 1) + assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode) + assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType) + require.NotNil(t, published[0].RevokedAt) + assert.Equal(t, record.Revocation.At, *published[0].RevokedAt) +} + +func TestExecuteReturnsSessionNotFound(t *testing.T) { + t.Parallel() + + service, err := New(&testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + _, err = service.Execute(context.Background(), Input{ + DeviceSessionID: "missing", + ReasonCode: "logout_all", + ActorType: "system", + }) + assert.Equal(t, shared.ErrorCodeSessionNotFound, shared.CodeOf(err)) +} + +func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) { + t.Parallel() + + store := &testkit.InMemorySessionStore{} + publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")} + record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()}) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{ + DeviceSessionID: "device-session-1", + ReasonCode: "logout_all", + ActorType: "system", + }) + assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err)) + + stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, getErr) + require.NotNil(t, stored.Revocation) + assert.Equal(t, devicesession.StatusRevoked, stored.Status) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode) + assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType) +} + +func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: key, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} + +func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + record := activeSessionFixture(deviceSessionID, userID, createdAt) + record.Status = devicesession.StatusRevoked + record.Revocation = &devicesession.Revocation{ + At: createdAt.Add(time.Minute), + ReasonCode: devicesession.RevokeReasonLogoutAll, + ActorType: common.RevokeActorType("system"), + } + return record +} diff --git a/authsession/internal/service/sendemailcode/anti_abuse_test.go b/authsession/internal/service/sendemailcode/anti_abuse_test.go new file mode 100644 index 0000000..2b2bf46 --- /dev/null +++ b/authsession/internal/service/sendemailcode/anti_abuse_test.go @@ -0,0 +1,167 @@ +package sendemailcode + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteCreatesThrottledChallengeWithoutUserDirectoryOrMail(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{} + now := time.Unix(10, 0).UTC() + require.NoError(t, reserveSendCooldown(abuseProtector, common.Email("pilot@example.com"), now)) + + userDirectory := &countingUserDirectory{} + mailSender := &testkit.RecordingMailSender{} + service, err := NewWithRuntime( + challengeStore, + userDirectory, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + abuseProtector, + testkit.FixedClock{Time: now}, + nil, + ) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + assert.Equal(t, "challenge-1", result.ChallengeID) + assert.Zero(t, userDirectory.resolveCalls) + assert.Empty(t, mailSender.RecordedInputs()) + + record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, getErr) + assert.Equal(t, challenge.StatusDeliveryThrottled, record.Status) + assert.Equal(t, challenge.DeliveryThrottled, record.DeliveryState) + assert.Equal(t, 1, record.Attempts.Send) +} + +func TestExecuteBlockedEmailOutsideThrottleStillSuppressesDelivery(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block"))) + mailSender := &testkit.RecordingMailSender{} + + service, err := NewWithRuntime( + challengeStore, + userDirectory, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + &testkit.InMemorySendEmailCodeAbuseProtector{}, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + nil, + ) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + assert.Equal(t, "challenge-1", result.ChallengeID) + assert.Empty(t, mailSender.RecordedInputs()) + + record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, getErr) + assert.Equal(t, challenge.StatusDeliverySuppressed, record.Status) + assert.Equal(t, challenge.DeliverySuppressed, record.DeliveryState) +} + +func TestExecuteAllowsAgainAfterCooldown(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + mailSender := &testkit.RecordingMailSender{} + abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{} + clock := &mutableClock{time: time.Unix(10, 0).UTC()} + idGenerator := &testkit.SequenceIDGenerator{ + ChallengeIDs: []common.ChallengeID{"challenge-1", "challenge-2"}, + } + + service, err := NewWithRuntime( + challengeStore, + userDirectory, + idGenerator, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + abuseProtector, + clock, + nil, + ) + require.NoError(t, err) + + first, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + assert.Equal(t, "challenge-1", first.ChallengeID) + + clock.time = clock.time.Add(challenge.ResendThrottleCooldown) + + second, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + assert.Equal(t, "challenge-2", second.ChallengeID) + require.Len(t, mailSender.RecordedInputs(), 2) + + secondRecord, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-2")) + require.NoError(t, getErr) + assert.Equal(t, challenge.StatusSent, secondRecord.Status) + assert.Equal(t, challenge.DeliverySent, secondRecord.DeliveryState) +} + +func reserveSendCooldown(protector ports.SendEmailCodeAbuseProtector, email common.Email, now time.Time) error { + _, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now, + }) + return err +} + +type mutableClock struct { + time time.Time +} + +func (c *mutableClock) Now() time.Time { + return c.time +} + +type countingUserDirectory struct { + resolveCalls int +} + +func (d *countingUserDirectory) ResolveByEmail(_ context.Context, _ common.Email) (userresolution.Result, error) { + d.resolveCalls++ + return userresolution.Result{Kind: userresolution.KindCreatable}, nil +} + +func (d *countingUserDirectory) ExistsByUserID(context.Context, common.UserID) (bool, error) { + return false, nil +} + +func (d *countingUserDirectory) EnsureUserByEmail(context.Context, common.Email) (ports.EnsureUserResult, error) { + return ports.EnsureUserResult{}, nil +} + +func (d *countingUserDirectory) BlockByUserID(context.Context, ports.BlockUserByIDInput) (ports.BlockUserResult, error) { + return ports.BlockUserResult{}, nil +} + +func (d *countingUserDirectory) BlockByEmail(context.Context, ports.BlockUserByEmailInput) (ports.BlockUserResult, error) { + return ports.BlockUserResult{}, nil +} diff --git a/authsession/internal/service/sendemailcode/observability_test.go b/authsession/internal/service/sendemailcode/observability_test.go new file mode 100644 index 0000000..bf89729 --- /dev/null +++ b/authsession/internal/service/sendemailcode/observability_test.go @@ -0,0 +1,59 @@ +package sendemailcode + +import ( + "bytes" + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestExecuteLogsSafeOutcomeFields(t *testing.T) { + t.Parallel() + + logger, buffer := newObservedServiceLogger() + service, err := NewWithObservability( + &testkit.InMemoryChallengeStore{}, + &testkit.InMemoryUserDirectory{}, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + &testkit.RecordingMailSender{}, + nil, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + logger, + nil, + ) + require.NoError(t, err) + + _, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + + logOutput := buffer.String() + assert.Contains(t, logOutput, "send_email_code") + assert.Contains(t, logOutput, "challenge-1") + assert.Contains(t, logOutput, "\"outcome\":\"sent\"") + assert.NotContains(t, logOutput, "pilot@example.com") + assert.NotContains(t, logOutput, "654321") +} + +func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) { + buffer := &bytes.Buffer{} + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "" + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(encoderConfig), + zapcore.AddSync(buffer), + zap.DebugLevel, + ) + + return zap.New(core), buffer +} diff --git a/authsession/internal/service/sendemailcode/service.go b/authsession/internal/service/sendemailcode/service.go new file mode 100644 index 0000000..d583bb4 --- /dev/null +++ b/authsession/internal/service/sendemailcode/service.go @@ -0,0 +1,331 @@ +// Package sendemailcode implements the public send-email-code use case. +package sendemailcode + +import ( + "context" + "fmt" + "reflect" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/telemetry" + + "go.uber.org/zap" +) + +// Input describes one public send-email-code request. +type Input struct { + // Email is the user-supplied e-mail address that should receive the login + // code. + Email string +} + +// Result describes one public send-email-code response. +type Result struct { + // ChallengeID is the stable challenge identifier returned to the caller. + ChallengeID string +} + +// Service executes the public send-email-code use case. +type Service struct { + challengeStore ports.ChallengeStore + userDirectory ports.UserDirectory + idGenerator ports.IDGenerator + codeGenerator ports.CodeGenerator + codeHasher ports.CodeHasher + mailSender ports.MailSender + abuseProtector ports.SendEmailCodeAbuseProtector + clock ports.Clock + logger *zap.Logger + telemetry *telemetry.Runtime +} + +// New returns a send-email-code service wired to the required ports. +func New( + challengeStore ports.ChallengeStore, + userDirectory ports.UserDirectory, + idGenerator ports.IDGenerator, + codeGenerator ports.CodeGenerator, + codeHasher ports.CodeHasher, + mailSender ports.MailSender, + clock ports.Clock, +) (*Service, error) { + return NewWithRuntime( + challengeStore, + userDirectory, + idGenerator, + codeGenerator, + codeHasher, + mailSender, + nil, + clock, + nil, + ) +} + +// NewWithRuntime returns a send-email-code service wired to the required +// ports plus the optional Stage-17 runtime collaborators. +func NewWithRuntime( + challengeStore ports.ChallengeStore, + userDirectory ports.UserDirectory, + idGenerator ports.IDGenerator, + codeGenerator ports.CodeGenerator, + codeHasher ports.CodeHasher, + mailSender ports.MailSender, + abuseProtector ports.SendEmailCodeAbuseProtector, + clock ports.Clock, + telemetryRuntime *telemetry.Runtime, +) (*Service, error) { + return NewWithObservability( + challengeStore, + userDirectory, + idGenerator, + codeGenerator, + codeHasher, + mailSender, + abuseProtector, + clock, + nil, + telemetryRuntime, + ) +} + +// NewWithObservability returns a send-email-code service wired to the required +// ports plus optional structured logging and telemetry dependencies. +func NewWithObservability( + challengeStore ports.ChallengeStore, + userDirectory ports.UserDirectory, + idGenerator ports.IDGenerator, + codeGenerator ports.CodeGenerator, + codeHasher ports.CodeHasher, + mailSender ports.MailSender, + abuseProtector ports.SendEmailCodeAbuseProtector, + clock ports.Clock, + logger *zap.Logger, + telemetryRuntime *telemetry.Runtime, +) (*Service, error) { + switch { + case challengeStore == nil: + return nil, fmt.Errorf("sendemailcode: challenge store must not be nil") + case userDirectory == nil: + return nil, fmt.Errorf("sendemailcode: user directory must not be nil") + case idGenerator == nil: + return nil, fmt.Errorf("sendemailcode: id generator must not be nil") + case codeGenerator == nil: + return nil, fmt.Errorf("sendemailcode: code generator must not be nil") + case codeHasher == nil: + return nil, fmt.Errorf("sendemailcode: code hasher must not be nil") + case mailSender == nil: + return nil, fmt.Errorf("sendemailcode: mail sender must not be nil") + case clock == nil: + return nil, fmt.Errorf("sendemailcode: clock must not be nil") + default: + return &Service{ + challengeStore: challengeStore, + userDirectory: userDirectory, + idGenerator: idGenerator, + codeGenerator: codeGenerator, + codeHasher: codeHasher, + mailSender: mailSender, + abuseProtector: normalizeAbuseProtector(abuseProtector), + clock: clock, + logger: namedLogger(logger, "send_email_code"), + telemetry: telemetryRuntime, + }, nil + } +} + +// Execute creates a fresh challenge for every request, stores only the hashed +// confirmation code, and records whether delivery was sent or intentionally +// suppressed. +func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) { + logFields := []zap.Field{ + zap.String("component", "service"), + zap.String("use_case", "send_email_code"), + } + outcome := "" + defer func() { + if outcome != "" { + logFields = append(logFields, zap.String("outcome", outcome)) + } + if result.ChallengeID != "" { + logFields = append(logFields, zap.String("challenge_id", result.ChallengeID)) + } + shared.LogServiceOutcome(s.logger, ctx, "send email code completed", err, logFields...) + }() + + email, err := shared.ParseEmail(input.Email) + if err != nil { + return Result{}, err + } + + now := s.clock.Now().UTC() + abuseResult, err := s.abuseProtector.CheckAndReserve(ctx, ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now, + }) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if err := abuseResult.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + + challengeID, err := s.idGenerator.NewChallengeID() + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + code, err := s.codeGenerator.Generate() + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + codeHash, err := s.codeHasher.Hash(code) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + + pendingStatus, pendingDeliveryState, err := ports.SendEmailCodeThrottleStatusToChallengeStatus(abuseResult.Outcome) + if err != nil { + return Result{}, shared.InternalError(err) + } + pending := challenge.Challenge{ + ID: challengeID, + Email: email, + CodeHash: codeHash, + Status: pendingStatus, + DeliveryState: pendingDeliveryState, + CreatedAt: now, + ExpiresAt: now.Add(challenge.InitialTTL), + } + if err := pending.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + if err := s.challengeStore.Create(ctx, pending); err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + s.telemetry.RecordChallengeCreated(ctx) + + final := pending + final.Attempts.Send = 1 + final.Abuse.LastAttemptAt = &now + if abuseResult.Outcome == ports.SendEmailCodeAbuseOutcomeThrottled { + result, err = s.finishChallenge(ctx, pending, final) + if err == nil { + outcome = string(telemetry.SendEmailCodeOutcomeThrottled) + s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeThrottled, telemetry.SendEmailCodeReasonThrottled) + } + return result, err + } + + resolution, err := s.userDirectory.ResolveByEmail(ctx, email) + if err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + if err := resolution.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + s.telemetry.RecordUserDirectoryOutcome(ctx, "resolve_by_email", string(resolution.Kind)) + + switch resolution.Kind { + case userresolution.KindBlocked: + final.Status = challenge.StatusDeliverySuppressed + final.DeliveryState = challenge.DeliverySuppressed + result, err = s.finishChallenge(ctx, pending, final) + if err == nil { + outcome = string(telemetry.SendEmailCodeOutcomeSuppressed) + s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSuppressed, telemetry.SendEmailCodeReasonBlocked) + } + return result, err + default: + deliveryResult, err := s.mailSender.SendLoginCode(ctx, ports.SendLoginCodeInput{ + Email: email, + Code: code, + }) + if err != nil { + final.Status = challenge.StatusFailed + final.DeliveryState = challenge.DeliveryFailed + if _, persistErr := s.finishChallenge(ctx, pending, final); persistErr != nil { + return Result{}, persistErr + } + outcome = string(telemetry.SendEmailCodeOutcomeFailed) + s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeFailed, telemetry.SendEmailCodeReasonMailSender) + + return Result{}, shared.ServiceUnavailable(err) + } + if err := deliveryResult.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + + switch deliveryResult.Outcome { + case ports.SendLoginCodeOutcomeSent: + final.Status = challenge.StatusSent + final.DeliveryState = challenge.DeliverySent + result, err = s.finishChallenge(ctx, pending, final) + if err == nil { + outcome = string(telemetry.SendEmailCodeOutcomeSent) + s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSent, "") + } + return result, err + case ports.SendLoginCodeOutcomeSuppressed: + final.Status = challenge.StatusDeliverySuppressed + final.DeliveryState = challenge.DeliverySuppressed + result, err = s.finishChallenge(ctx, pending, final) + if err == nil { + outcome = string(telemetry.SendEmailCodeOutcomeSuppressed) + s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSuppressed, telemetry.SendEmailCodeReasonMailSender) + } + return result, err + default: + return Result{}, shared.InternalError(fmt.Errorf("sendemailcode: unsupported delivery outcome %q", deliveryResult.Outcome)) + } + } +} + +func (s *Service) finishChallenge(ctx context.Context, pending challenge.Challenge, final challenge.Challenge) (Result, error) { + if err := final.Validate(); err != nil { + return Result{}, shared.InternalError(err) + } + if err := s.challengeStore.CompareAndSwap(ctx, pending, final); err != nil { + return Result{}, shared.ServiceUnavailable(err) + } + + return Result{ChallengeID: final.ID.String()}, nil +} + +func normalizeAbuseProtector(protector ports.SendEmailCodeAbuseProtector) ports.SendEmailCodeAbuseProtector { + if protector == nil { + return allowAllSendEmailCodeAbuseProtector{} + } + + value := reflect.ValueOf(protector) + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + if value.IsNil() { + return allowAllSendEmailCodeAbuseProtector{} + } + } + + return protector +} + +type allowAllSendEmailCodeAbuseProtector struct{} + +func (allowAllSendEmailCodeAbuseProtector) CheckAndReserve(_ context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) { + if err := input.Validate(); err != nil { + return ports.SendEmailCodeAbuseResult{}, err + } + + return ports.SendEmailCodeAbuseResult{ + Outcome: ports.SendEmailCodeAbuseOutcomeAllowed, + }, nil +} + +func namedLogger(logger *zap.Logger, name string) *zap.Logger { + if logger == nil { + logger = zap.NewNop() + } + + return logger.Named(name) +} diff --git a/authsession/internal/service/sendemailcode/service_test.go b/authsession/internal/service/sendemailcode/service_test.go new file mode 100644 index 0000000..df781a7 --- /dev/null +++ b/authsession/internal/service/sendemailcode/service_test.go @@ -0,0 +1,310 @@ +package sendemailcode + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" +) + +func TestExecuteSendsChallengeForExistingAndCreatableUsers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*testkit.InMemoryUserDirectory) error + email string + }{ + { + name: "existing", + seed: func(directory *testkit.InMemoryUserDirectory) error { + return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")) + }, + email: " pilot@example.com ", + }, + { + name: "creatable", + seed: func(*testkit.InMemoryUserDirectory) error { return nil }, + email: "new@example.com", + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + if err := tt.seed(userDirectory); err != nil { + require.Failf(t, "test failed", "seed() returned error: %v", err) + } + mailSender := &testkit.RecordingMailSender{} + service, err := New( + challengeStore, + userDirectory, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + result, err := service.Execute(context.Background(), Input{Email: tt.email}) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if result.ChallengeID != "challenge-1" { + require.Failf(t, "test failed", "Execute().ChallengeID = %q, want %q", result.ChallengeID, "challenge-1") + } + if len(mailSender.RecordedInputs()) != 1 { + require.Failf(t, "test failed", "RecordedInputs() length = %d, want 1", len(mailSender.RecordedInputs())) + } + + record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != challenge.StatusSent || record.DeliveryState != challenge.DeliverySent { + require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState) + } + if record.Attempts.Send != 1 { + require.Failf(t, "test failed", "Attempts.Send = %d, want 1", record.Attempts.Send) + } + if string(record.CodeHash) == "654321" { + require.FailNow(t, "CodeHash stored cleartext code") + } + }) + } +} + +func TestExecuteSuppressesDeliveryForBlockedEmail(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + if err := userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil { + require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err) + } + mailSender := &testkit.RecordingMailSender{} + + service, err := New( + challengeStore, + userDirectory, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + if result.ChallengeID != "challenge-1" { + require.Failf(t, "test failed", "Execute().ChallengeID = %q, want %q", result.ChallengeID, "challenge-1") + } + if len(mailSender.RecordedInputs()) != 0 { + require.Failf(t, "test failed", "RecordedInputs() length = %d, want 0", len(mailSender.RecordedInputs())) + } + + record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != challenge.StatusDeliverySuppressed || record.DeliveryState != challenge.DeliverySuppressed { + require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState) + } +} + +func TestExecuteHandlesMailSenderSuppressedOutcome(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + mailSender := &testkit.RecordingMailSender{ + DefaultResult: ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSuppressed}, + } + + service, err := New( + challengeStore, + &testkit.InMemoryUserDirectory{}, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + _, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + if err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + + record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != challenge.StatusDeliverySuppressed || record.DeliveryState != challenge.DeliverySuppressed { + require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState) + } +} + +func TestExecuteMarksChallengeFailedWhenMailSenderFails(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + mailSender := &testkit.RecordingMailSender{Err: errors.New("mail failed")} + + service, err := New( + challengeStore, + &testkit.InMemoryUserDirectory{}, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + _, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + if shared.CodeOf(err) != shared.ErrorCodeServiceUnavailable { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeServiceUnavailable) + } + + record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if record.Status != challenge.StatusFailed || record.DeliveryState != challenge.DeliveryFailed { + require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState) + } +} + +func TestExecuteReturnsInvalidRequestForBadEmail(t *testing.T) { + t.Parallel() + + service, err := New( + &testkit.InMemoryChallengeStore{}, + &testkit.InMemoryUserDirectory{}, + &testkit.SequenceIDGenerator{}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + &testkit.RecordingMailSender{}, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + _, err = service.Execute(context.Background(), Input{Email: "pilot"}) + if shared.CodeOf(err) != shared.ErrorCodeInvalidRequest { + require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidRequest) + } +} + +func TestExecuteCreatesFreshChallengeForRepeatedSend(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + mailSender := &testkit.RecordingMailSender{} + clock := testkit.FixedClock{Time: time.Unix(10, 0).UTC()} + + service, err := New( + challengeStore, + &testkit.InMemoryUserDirectory{}, + &testkit.SequenceIDGenerator{ + ChallengeIDs: []common.ChallengeID{"challenge-1", "challenge-2"}, + }, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + clock, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + first, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + if err != nil { + require.Failf(t, "test failed", "first Execute() returned error: %v", err) + } + second, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + if err != nil { + require.Failf(t, "test failed", "second Execute() returned error: %v", err) + } + if first.ChallengeID == second.ChallengeID { + require.Failf(t, "test failed", "challenge ids are equal: %q", first.ChallengeID) + } + + firstRecord, err := challengeStore.Get(context.Background(), common.ChallengeID(first.ChallengeID)) + if err != nil { + require.Failf(t, "test failed", "Get(%q) returned error: %v", first.ChallengeID, err) + } + secondRecord, err := challengeStore.Get(context.Background(), common.ChallengeID(second.ChallengeID)) + if err != nil { + require.Failf(t, "test failed", "Get(%q) returned error: %v", second.ChallengeID, err) + } + if firstRecord.Status != challenge.StatusSent { + require.Failf(t, "test failed", "first challenge status = %q, want %q", firstRecord.Status, challenge.StatusSent) + } + if secondRecord.Status != challenge.StatusSent { + require.Failf(t, "test failed", "second challenge status = %q, want %q", secondRecord.Status, challenge.StatusSent) + } + if len(mailSender.RecordedInputs()) != 2 { + require.Failf(t, "test failed", "RecordedInputs() length = %d, want 2", len(mailSender.RecordedInputs())) + } +} + +func TestExecuteSetsChallengeExpirationFromInitialTTL(t *testing.T) { + t.Parallel() + + now := time.Unix(10, 0).UTC() + challengeStore := &testkit.InMemoryChallengeStore{} + + service, err := New( + challengeStore, + &testkit.InMemoryUserDirectory{}, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + &testkit.RecordingMailSender{}, + testkit.FixedClock{Time: now}, + ) + if err != nil { + require.Failf(t, "test failed", "New() returned error: %v", err) + } + + if _, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}); err != nil { + require.Failf(t, "test failed", "Execute() returned error: %v", err) + } + + record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + wantExpiresAt := now.Add(challenge.InitialTTL) + if !record.ExpiresAt.Equal(wantExpiresAt) { + require.Failf(t, "test failed", "ExpiresAt = %s, want %s", record.ExpiresAt, wantExpiresAt) + } +} diff --git a/authsession/internal/service/sendemailcode/stub_sender_test.go b/authsession/internal/service/sendemailcode/stub_sender_test.go new file mode 100644 index 0000000..5dc113a --- /dev/null +++ b/authsession/internal/service/sendemailcode/stub_sender_test.go @@ -0,0 +1,98 @@ +package sendemailcode + +import ( + "context" + "errors" + "testing" + "time" + + stubmail "galaxy/authsession/internal/adapters/mail" + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteWithStubSender(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sender *stubmail.StubSender + wantStatus challenge.Status + wantDeliveryState challenge.DeliveryState + wantErrorCode string + wantRecordedAttempt int + }{ + { + name: "sent", + sender: &stubmail.StubSender{}, + wantStatus: challenge.StatusSent, + wantDeliveryState: challenge.DeliverySent, + wantRecordedAttempt: 1, + }, + { + name: "suppressed", + sender: &stubmail.StubSender{ + DefaultMode: stubmail.StubModeSuppressed, + }, + wantStatus: challenge.StatusDeliverySuppressed, + wantDeliveryState: challenge.DeliverySuppressed, + wantRecordedAttempt: 1, + }, + { + name: "failed", + sender: &stubmail.StubSender{ + DefaultMode: stubmail.StubModeFailed, + DefaultError: errors.New("stub delivery failed"), + }, + wantStatus: challenge.StatusFailed, + wantDeliveryState: challenge.DeliveryFailed, + wantErrorCode: shared.ErrorCodeServiceUnavailable, + wantRecordedAttempt: 1, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + challengeStore := &testkit.InMemoryChallengeStore{} + service, err := New( + challengeStore, + &testkit.InMemoryUserDirectory{}, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + tt.sender, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + ) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + if tt.wantErrorCode == "" { + require.NoError(t, err) + assert.Equal(t, "challenge-1", result.ChallengeID) + } else { + require.Error(t, err) + assert.Equal(t, tt.wantErrorCode, shared.CodeOf(err)) + assert.Equal(t, Result{}, result) + } + + record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, getErr) + assert.Equal(t, tt.wantStatus, record.Status) + assert.Equal(t, tt.wantDeliveryState, record.DeliveryState) + + attempts := tt.sender.RecordedAttempts() + require.Len(t, attempts, tt.wantRecordedAttempt) + assert.Equal(t, common.Email("pilot@example.com"), attempts[0].Input.Email) + assert.Equal(t, "654321", attempts[0].Input.Code) + }) + } +} diff --git a/authsession/internal/service/sendemailcode/stub_user_directory_test.go b/authsession/internal/service/sendemailcode/stub_user_directory_test.go new file mode 100644 index 0000000..0a7e738 --- /dev/null +++ b/authsession/internal/service/sendemailcode/stub_user_directory_test.go @@ -0,0 +1,93 @@ +package sendemailcode + +import ( + "context" + "testing" + "time" + + stubmail "galaxy/authsession/internal/adapters/mail" + stubuserservice "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*stubuserservice.StubDirectory) error + email string + wantStatus challenge.Status + wantDeliveryState challenge.DeliveryState + wantMailCalls int + }{ + { + name: "existing user", + email: "pilot@example.com", + seed: func(directory *stubuserservice.StubDirectory) error { + return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")) + }, + wantStatus: challenge.StatusSent, + wantDeliveryState: challenge.DeliverySent, + wantMailCalls: 1, + }, + { + name: "creatable user", + email: "new@example.com", + seed: func(*stubuserservice.StubDirectory) error { return nil }, + wantStatus: challenge.StatusSent, + wantDeliveryState: challenge.DeliverySent, + wantMailCalls: 1, + }, + { + name: "blocked email", + email: "blocked@example.com", + seed: func(directory *stubuserservice.StubDirectory) error { + return directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")) + }, + wantStatus: challenge.StatusDeliverySuppressed, + wantDeliveryState: challenge.DeliverySuppressed, + wantMailCalls: 0, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + userDirectory := &stubuserservice.StubDirectory{} + require.NoError(t, tt.seed(userDirectory)) + + challengeStore := &testkit.InMemoryChallengeStore{} + mailSender := &stubmail.StubSender{} + service, err := New( + challengeStore, + userDirectory, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + testkit.FixedClock{Time: time.Unix(10, 0).UTC()}, + ) + require.NoError(t, err) + + result, err := service.Execute(context.Background(), Input{Email: tt.email}) + require.NoError(t, err) + assert.Equal(t, "challenge-1", result.ChallengeID) + + record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1")) + require.NoError(t, err) + assert.Equal(t, tt.wantStatus, record.Status) + assert.Equal(t, tt.wantDeliveryState, record.DeliveryState) + assert.Len(t, mailSender.RecordedAttempts(), tt.wantMailCalls) + }) + } +} diff --git a/authsession/internal/service/sendemailcode/telemetry_test.go b/authsession/internal/service/sendemailcode/telemetry_test.go new file mode 100644 index 0000000..23f8f15 --- /dev/null +++ b/authsession/internal/service/sendemailcode/telemetry_test.go @@ -0,0 +1,171 @@ +package sendemailcode + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + authtelemetry "galaxy/authsession/internal/telemetry" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" +) + +func TestExecuteRecordsSentMetric(t *testing.T) { + t.Parallel() + + runtime, reader := newObservedTelemetryRuntime(t) + service, _, mailSender := newObservedSendService(t, observedSendOptions{ + Telemetry: runtime, + }) + + _, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + require.Len(t, mailSender.RecordedInputs(), 1) + + assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{ + "outcome": "sent", + }, 1) +} + +func TestExecuteRecordsBlockedSuppressedMetric(t *testing.T) { + t.Parallel() + + runtime, reader := newObservedTelemetryRuntime(t) + service, _, _ := newObservedSendService(t, observedSendOptions{ + Telemetry: runtime, + SeedBlockedEmail: true, + }) + + _, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + + assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{ + "outcome": "suppressed", + "reason": "blocked", + }, 1) +} + +func TestExecuteRecordsThrottledMetric(t *testing.T) { + t.Parallel() + + runtime, reader := newObservedTelemetryRuntime(t) + abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{} + now := time.Unix(10, 0).UTC() + require.NoError(t, reserveSendCooldown(abuseProtector, common.Email("pilot@example.com"), now)) + + service, _, mailSender := newObservedSendService(t, observedSendOptions{ + Telemetry: runtime, + AbuseProtector: abuseProtector, + Clock: testkit.FixedClock{Time: now}, + }) + + _, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}) + require.NoError(t, err) + assert.Empty(t, mailSender.RecordedInputs()) + + assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{ + "outcome": "throttled", + "reason": "throttled", + }, 1) +} + +type observedSendOptions struct { + Telemetry *authtelemetry.Runtime + AbuseProtector *testkit.InMemorySendEmailCodeAbuseProtector + SeedBlockedEmail bool + Clock portsClock +} + +type portsClock interface { + Now() time.Time +} + +func newObservedSendService(t *testing.T, options observedSendOptions) (*Service, *testkit.InMemoryChallengeStore, *testkit.RecordingMailSender) { + t.Helper() + + challengeStore := &testkit.InMemoryChallengeStore{} + userDirectory := &testkit.InMemoryUserDirectory{} + if options.SeedBlockedEmail { + require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block"))) + } + mailSender := &testkit.RecordingMailSender{} + clock := options.Clock + if clock == nil { + clock = testkit.FixedClock{Time: time.Unix(10, 0).UTC()} + } + + service, err := NewWithRuntime( + challengeStore, + userDirectory, + &testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}}, + testkit.FixedCodeGenerator{Code: "654321"}, + testkit.DeterministicCodeHasher{}, + mailSender, + options.AbuseProtector, + clock, + options.Telemetry, + ) + require.NoError(t, err) + + return service, challengeStore, mailSender +} + +func newObservedTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader) { + t.Helper() + + reader := sdkmetric.NewManualReader() + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + + runtime, err := authtelemetry.New(provider) + require.NoError(t, err) + + return runtime, reader +} + +func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) { + t.Helper() + + var resourceMetrics metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &resourceMetrics)) + + for _, scopeMetrics := range resourceMetrics.ScopeMetrics { + for _, metric := range scopeMetrics.Metrics { + if metric.Name != metricName { + continue + } + + sum, ok := metric.Data.(metricdata.Sum[int64]) + require.True(t, ok) + + for _, point := range sum.DataPoints { + if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) { + assert.Equal(t, wantValue, point.Value) + return + } + } + } + } + + require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs) +} + +func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool { + if len(values) != len(want) { + return false + } + + for _, value := range values { + if want[string(value.Key)] != value.Value.AsString() { + return false + } + } + + return true +} diff --git a/authsession/internal/service/shared/doc.go b/authsession/internal/service/shared/doc.go new file mode 100644 index 0000000..79cf9e6 --- /dev/null +++ b/authsession/internal/service/shared/doc.go @@ -0,0 +1,4 @@ +// Package shared provides cross-use-case application helpers for auth/session +// services, including typed service errors, input normalization, DTO mapping, +// and application-level retry helpers. +package shared diff --git a/authsession/internal/service/shared/errors.go b/authsession/internal/service/shared/errors.go new file mode 100644 index 0000000..cdb3ff0 --- /dev/null +++ b/authsession/internal/service/shared/errors.go @@ -0,0 +1,407 @@ +package shared + +import ( + "errors" + "net/http" + "strings" +) + +const ( + // ErrorCodeInvalidRequest reports malformed or semantically invalid service + // input. + ErrorCodeInvalidRequest = "invalid_request" + + // ErrorCodeChallengeNotFound reports that the requested challenge does not + // exist. + ErrorCodeChallengeNotFound = "challenge_not_found" + + // ErrorCodeChallengeExpired reports that the requested challenge may no + // longer be confirmed. + ErrorCodeChallengeExpired = "challenge_expired" + + // ErrorCodeInvalidCode reports that the submitted confirmation code does not + // match the stored challenge. + ErrorCodeInvalidCode = "invalid_code" + + // ErrorCodeInvalidClientPublicKey reports that the submitted client public + // key does not satisfy the Ed25519/base64 contract. + ErrorCodeInvalidClientPublicKey = "invalid_client_public_key" + + // ErrorCodeBlockedByPolicy reports that the auth flow is denied by current + // user or registration policy. + ErrorCodeBlockedByPolicy = "blocked_by_policy" + + // ErrorCodeSessionLimitExceeded reports that creating another active session + // would violate the configured limit. + ErrorCodeSessionLimitExceeded = "session_limit_exceeded" + + // ErrorCodeSessionNotFound reports that the requested device session does + // not exist. + ErrorCodeSessionNotFound = "session_not_found" + + // ErrorCodeSubjectNotFound reports that the requested trusted internal + // subject does not exist. + ErrorCodeSubjectNotFound = "subject_not_found" + + // ErrorCodeServiceUnavailable reports that a required dependency or + // propagation step is temporarily unavailable. + ErrorCodeServiceUnavailable = "service_unavailable" + + // ErrorCodeInternalError reports that local state is inconsistent or an + // invariant was broken unexpectedly. + ErrorCodeInternalError = "internal_error" +) + +const genericInvalidRequestMessage = "request is invalid" + +var publicErrorStatusCodes = map[string]int{ + ErrorCodeInvalidRequest: http.StatusBadRequest, + ErrorCodeInvalidClientPublicKey: http.StatusBadRequest, + ErrorCodeInvalidCode: http.StatusBadRequest, + ErrorCodeChallengeNotFound: http.StatusNotFound, + ErrorCodeChallengeExpired: http.StatusGone, + ErrorCodeBlockedByPolicy: http.StatusForbidden, + ErrorCodeSessionLimitExceeded: http.StatusConflict, + ErrorCodeServiceUnavailable: http.StatusServiceUnavailable, +} + +var publicStableMessages = map[string]string{ + ErrorCodeChallengeNotFound: "challenge not found", + ErrorCodeChallengeExpired: "challenge expired", + ErrorCodeInvalidCode: "confirmation code is invalid", + ErrorCodeInvalidClientPublicKey: "client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key", + ErrorCodeBlockedByPolicy: "authentication is blocked by policy", + ErrorCodeSessionLimitExceeded: "active session limit would be exceeded", + ErrorCodeServiceUnavailable: "service is unavailable", +} + +var internalErrorStatusCodes = map[string]int{ + ErrorCodeInvalidRequest: http.StatusBadRequest, + ErrorCodeSessionNotFound: http.StatusNotFound, + ErrorCodeSubjectNotFound: http.StatusNotFound, + ErrorCodeServiceUnavailable: http.StatusServiceUnavailable, + ErrorCodeInternalError: http.StatusInternalServerError, +} + +var internalStableMessages = map[string]string{ + ErrorCodeSessionNotFound: "session not found", + ErrorCodeSubjectNotFound: "subject not found", + ErrorCodeServiceUnavailable: "service is unavailable", + ErrorCodeInternalError: "internal server error", +} + +// PublicErrorProjection describes one transport-ready public auth error after +// internal service errors have been normalized to the frozen client-safe +// surface. +type PublicErrorProjection struct { + // StatusCode is the HTTP status that should be returned to the public auth + // caller. + StatusCode int + + // Code is the stable client-safe error code written into the public JSON + // envelope. + Code string + + // Message is the client-safe error description exposed to the public auth + // caller. + Message string +} + +// InternalErrorProjection describes one transport-ready internal API error +// after service-layer failures have been normalized to the frozen trusted +// caller surface. +type InternalErrorProjection struct { + // StatusCode is the HTTP status that should be returned to the trusted + // caller. + StatusCode int + + // Code is the stable error code written into the internal JSON envelope. + Code string + + // Message is the trusted-caller-safe error description exposed by the + // internal HTTP API. + Message string +} + +// ServiceError projects one stable application-layer failure with a service +// error code and a caller-safe message. +type ServiceError struct { + // Code is the stable error code expected by later transport mapping. + Code string + + // Message is the caller-safe error description. + Message string + + // Err optionally stores the wrapped underlying cause. + Err error +} + +// Error returns the caller-safe error description. +func (e *ServiceError) Error() string { + if e == nil { + return "" + } + + switch { + case strings.TrimSpace(e.Message) != "": + return e.Message + case strings.TrimSpace(e.Code) != "": + return e.Code + case e.Err != nil: + return e.Err.Error() + default: + return ErrorCodeInternalError + } +} + +// Unwrap returns the wrapped cause, if any. +func (e *ServiceError) Unwrap() error { + if e == nil { + return nil + } + + return e.Err +} + +// NewServiceError returns a new typed application-layer error. +func NewServiceError(code string, message string, err error) *ServiceError { + return &ServiceError{ + Code: strings.TrimSpace(code), + Message: strings.TrimSpace(message), + Err: err, + } +} + +// IsPublicErrorCode reports whether code belongs to the frozen public auth +// error surface. +func IsPublicErrorCode(code string) bool { + _, ok := publicErrorStatusCodes[strings.TrimSpace(code)] + return ok +} + +// IsInternalOnlyErrorCode reports whether code is intentionally excluded from +// the public auth transport surface. +func IsInternalOnlyErrorCode(code string) bool { + switch strings.TrimSpace(code) { + case ErrorCodeSessionNotFound, ErrorCodeSubjectNotFound, ErrorCodeInternalError: + return true + default: + return false + } +} + +// IsSendEmailCodePublicErrorCode reports whether code may be exposed by the +// public send-email-code route after public projection. +func IsSendEmailCodePublicErrorCode(code string) bool { + switch strings.TrimSpace(code) { + case ErrorCodeInvalidRequest, ErrorCodeServiceUnavailable: + return true + default: + return false + } +} + +// IsConfirmEmailCodePublicErrorCode reports whether code may be exposed by the +// public confirm-email-code route after public projection. +func IsConfirmEmailCodePublicErrorCode(code string) bool { + switch strings.TrimSpace(code) { + case ErrorCodeInvalidRequest, + ErrorCodeChallengeNotFound, + ErrorCodeChallengeExpired, + ErrorCodeInvalidCode, + ErrorCodeInvalidClientPublicKey, + ErrorCodeBlockedByPolicy, + ErrorCodeSessionLimitExceeded, + ErrorCodeServiceUnavailable: + return true + default: + return false + } +} + +// PublicHTTPStatusCode reports the frozen public HTTP status for code. Unknown +// or internal-only codes are normalized to 503 service_unavailable. +func PublicHTTPStatusCode(code string) int { + if statusCode, ok := publicErrorStatusCodes[strings.TrimSpace(code)]; ok { + return statusCode + } + + return http.StatusServiceUnavailable +} + +// ProjectPublicError normalizes err to the frozen public-auth error surface. +// Unknown and internal-only service failures are intentionally projected as +// 503 service_unavailable so internal invariants do not leak to public callers. +func ProjectPublicError(err error) PublicErrorProjection { + serviceErr, ok := errors.AsType[*ServiceError](err) + code := CodeOf(err) + if !IsPublicErrorCode(code) { + return PublicErrorProjection{ + StatusCode: http.StatusServiceUnavailable, + Code: ErrorCodeServiceUnavailable, + Message: publicMessageForCode(ErrorCodeServiceUnavailable, ""), + } + } + + message := "" + if ok && serviceErr != nil { + message = serviceErr.Message + } + + return PublicErrorProjection{ + StatusCode: PublicHTTPStatusCode(code), + Code: code, + Message: publicMessageForCode(code, message), + } +} + +// InternalHTTPStatusCode reports the frozen internal HTTP status for code. +// Unknown codes are normalized to 500 internal_error. +func InternalHTTPStatusCode(code string) int { + if statusCode, ok := internalErrorStatusCodes[strings.TrimSpace(code)]; ok { + return statusCode + } + + return http.StatusInternalServerError +} + +// ProjectInternalError normalizes err to the frozen internal trusted HTTP +// error surface. Unknown failures are intentionally projected as +// 500 internal_error so transport callers do not depend on unclassified local +// failures. +func ProjectInternalError(err error) InternalErrorProjection { + serviceErr, ok := errors.AsType[*ServiceError](err) + code := CodeOf(err) + if _, known := internalErrorStatusCodes[code]; !known { + return InternalErrorProjection{ + StatusCode: http.StatusInternalServerError, + Code: ErrorCodeInternalError, + Message: internalMessageForCode(ErrorCodeInternalError, ""), + } + } + + message := "" + if ok && serviceErr != nil { + message = serviceErr.Message + } + + return InternalErrorProjection{ + StatusCode: InternalHTTPStatusCode(code), + Code: code, + Message: internalMessageForCode(code, message), + } +} + +// InvalidRequest reports one malformed or semantically invalid caller input. +func InvalidRequest(message string) *ServiceError { + return NewServiceError(ErrorCodeInvalidRequest, message, nil) +} + +// ChallengeNotFound reports that the requested challenge does not exist. +func ChallengeNotFound() *ServiceError { + return NewServiceError(ErrorCodeChallengeNotFound, "challenge not found", nil) +} + +// ChallengeExpired reports that the requested challenge is expired. +func ChallengeExpired() *ServiceError { + return NewServiceError(ErrorCodeChallengeExpired, "challenge expired", nil) +} + +// InvalidCode reports that the submitted confirmation code is invalid. +func InvalidCode() *ServiceError { + return NewServiceError(ErrorCodeInvalidCode, "confirmation code is invalid", nil) +} + +// InvalidClientPublicKey reports that the submitted client public key does not +// satisfy the frozen contract. +func InvalidClientPublicKey() *ServiceError { + return NewServiceError( + ErrorCodeInvalidClientPublicKey, + "client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key", + nil, + ) +} + +// BlockedByPolicy reports that the current auth flow is denied by policy. +func BlockedByPolicy() *ServiceError { + return NewServiceError(ErrorCodeBlockedByPolicy, "authentication is blocked by policy", nil) +} + +// SessionLimitExceeded reports that creating another active session would +// exceed the current configured limit. +func SessionLimitExceeded() *ServiceError { + return NewServiceError(ErrorCodeSessionLimitExceeded, "active session limit would be exceeded", nil) +} + +// SessionNotFound reports that the requested session does not exist. +func SessionNotFound() *ServiceError { + return NewServiceError(ErrorCodeSessionNotFound, "session not found", nil) +} + +// SubjectNotFound reports that the requested internal subject does not exist. +func SubjectNotFound() *ServiceError { + return NewServiceError(ErrorCodeSubjectNotFound, "subject not found", nil) +} + +// ServiceUnavailable reports that a required dependency or propagation step is +// temporarily unavailable. +func ServiceUnavailable(err error) *ServiceError { + return NewServiceError(ErrorCodeServiceUnavailable, "service is unavailable", err) +} + +// InternalError reports an invariant-breaking local failure. +func InternalError(err error) *ServiceError { + return NewServiceError(ErrorCodeInternalError, "internal error", err) +} + +// CodeOf returns the stable service error code of err when err wraps a +// ServiceError. Otherwise it returns ErrorCodeInternalError. +func CodeOf(err error) string { + serviceErr, ok := errors.AsType[*ServiceError](err) + if !ok || serviceErr == nil || strings.TrimSpace(serviceErr.Code) == "" { + return ErrorCodeInternalError + } + + return serviceErr.Code +} + +func publicMessageForCode(code string, message string) string { + trimmedMessage := strings.TrimSpace(message) + + switch strings.TrimSpace(code) { + case ErrorCodeInvalidRequest: + if trimmedMessage != "" { + return trimmedMessage + } + return genericInvalidRequestMessage + case ErrorCodeServiceUnavailable: + return publicStableMessages[ErrorCodeServiceUnavailable] + default: + if stableMessage, ok := publicStableMessages[strings.TrimSpace(code)]; ok { + return stableMessage + } + return publicStableMessages[ErrorCodeServiceUnavailable] + } +} + +func internalMessageForCode(code string, message string) string { + trimmedMessage := strings.TrimSpace(message) + + switch strings.TrimSpace(code) { + case ErrorCodeInvalidRequest: + if trimmedMessage != "" { + return trimmedMessage + } + return genericInvalidRequestMessage + case ErrorCodeSessionNotFound, + ErrorCodeSubjectNotFound, + ErrorCodeServiceUnavailable, + ErrorCodeInternalError: + if stableMessage, ok := internalStableMessages[strings.TrimSpace(code)]; ok { + return stableMessage + } + return internalStableMessages[ErrorCodeInternalError] + default: + return internalStableMessages[ErrorCodeInternalError] + } +} diff --git a/authsession/internal/service/shared/normalize.go b/authsession/internal/service/shared/normalize.go new file mode 100644 index 0000000..9a58d8b --- /dev/null +++ b/authsession/internal/service/shared/normalize.go @@ -0,0 +1,158 @@ +package shared + +import ( + "crypto/ed25519" + "encoding/base64" + "fmt" + "strings" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" +) + +// NormalizeString trims surrounding Unicode whitespace from value. +func NormalizeString(value string) string { + return strings.TrimSpace(value) +} + +// ParseEmail trims value and validates it against the frozen public e-mail +// contract. +func ParseEmail(value string) (common.Email, error) { + email := common.Email(NormalizeString(value)) + if err := email.Validate(); err != nil { + return "", InvalidRequest(err.Error()) + } + + return email, nil +} + +// ParseChallengeID trims value and validates it as one challenge identifier. +func ParseChallengeID(value string) (common.ChallengeID, error) { + challengeID := common.ChallengeID(NormalizeString(value)) + if err := challengeID.Validate(); err != nil { + return "", InvalidRequest(err.Error()) + } + + return challengeID, nil +} + +// ParseDeviceSessionID trims value and validates it as one device-session +// identifier. +func ParseDeviceSessionID(value string) (common.DeviceSessionID, error) { + deviceSessionID := common.DeviceSessionID(NormalizeString(value)) + if err := deviceSessionID.Validate(); err != nil { + return "", InvalidRequest(err.Error()) + } + + return deviceSessionID, nil +} + +// ParseUserID trims value and validates it as one user identifier. +func ParseUserID(value string) (common.UserID, error) { + userID := common.UserID(NormalizeString(value)) + if err := userID.Validate(); err != nil { + return "", InvalidRequest(err.Error()) + } + + return userID, nil +} + +// ParseRequiredCode trims value and validates it as a required non-empty +// confirmation code. +func ParseRequiredCode(value string) (string, error) { + code := NormalizeString(value) + if code == "" { + return "", InvalidRequest("code must not be empty") + } + + return code, nil +} + +// ParseClientPublicKey trims value and validates it as the standard +// base64-encoded raw 32-byte Ed25519 public key expected by the public auth +// contract. +func ParseClientPublicKey(value string) (common.ClientPublicKey, error) { + normalized := NormalizeString(value) + if normalized == "" { + return common.ClientPublicKey{}, InvalidClientPublicKey() + } + + decoded, err := base64.StdEncoding.Strict().DecodeString(normalized) + if err != nil || len(decoded) != ed25519.PublicKeySize { + return common.ClientPublicKey{}, InvalidClientPublicKey() + } + + key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded)) + if err != nil { + return common.ClientPublicKey{}, InvalidClientPublicKey() + } + + return key, nil +} + +// ParseRevokeReasonCode trims value and validates it as one machine-readable +// revoke reason code. +func ParseRevokeReasonCode(value string) (common.RevokeReasonCode, error) { + code := common.RevokeReasonCode(NormalizeString(value)) + if err := code.Validate(); err != nil { + return "", InvalidRequest(err.Error()) + } + + return code, nil +} + +// ParseRevokeActorType trims value and validates it as one machine-readable +// revoke actor type. +func ParseRevokeActorType(value string) (common.RevokeActorType, error) { + actorType := common.RevokeActorType(NormalizeString(value)) + if err := actorType.Validate(); err != nil { + return "", InvalidRequest(err.Error()) + } + + return actorType, nil +} + +// ParseOptionalActorID trims value and validates it as one optional stable +// actor identifier. +func ParseOptionalActorID(value string) (string, error) { + actorID := NormalizeString(value) + if actorID != value { + return "", InvalidRequest("actor_id must not contain surrounding whitespace") + } + + return actorID, nil +} + +// BuildRevocation validates one revoke request payload and returns the domain +// revocation metadata applied to a session mutation. +func BuildRevocation(reasonCode string, actorType string, actorID string, at time.Time) (devicesession.Revocation, error) { + if at.IsZero() { + return devicesession.Revocation{}, InternalError(fmt.Errorf("revocation time must not be zero")) + } + + parsedReasonCode, err := ParseRevokeReasonCode(reasonCode) + if err != nil { + return devicesession.Revocation{}, err + } + parsedActorType, err := ParseRevokeActorType(actorType) + if err != nil { + return devicesession.Revocation{}, err + } + parsedActorID, err := ParseOptionalActorID(actorID) + if err != nil { + return devicesession.Revocation{}, err + } + + revocation := devicesession.Revocation{ + At: at.UTC(), + ReasonCode: parsedReasonCode, + ActorType: parsedActorType, + ActorID: parsedActorID, + } + if err := revocation.Validate(); err != nil { + return devicesession.Revocation{}, InternalError(fmt.Errorf("build revocation: %w", err)) + } + + return revocation, nil +} diff --git a/authsession/internal/service/shared/observability.go b/authsession/internal/service/shared/observability.go new file mode 100644 index 0000000..1fadbf9 --- /dev/null +++ b/authsession/internal/service/shared/observability.go @@ -0,0 +1,46 @@ +package shared + +import ( + "context" + + authlogging "galaxy/authsession/internal/logging" + + "go.uber.org/zap" +) + +// LogServiceOutcome writes one structured service-level outcome log with a +// stable severity derived from err and with trace fields attached when ctx +// carries an active span. +func LogServiceOutcome(logger *zap.Logger, ctx context.Context, message string, err error, fields ...zap.Field) { + if logger == nil { + logger = zap.NewNop() + } + + fields = append(fields, authlogging.TraceFieldsFromContext(ctx)...) + + switch { + case err == nil: + logger.Info(message, fields...) + case isExpectedServiceErrorCode(CodeOf(err)): + logger.Warn(message, append(fields, zap.Error(err))...) + default: + logger.Error(message, append(fields, zap.Error(err))...) + } +} + +func isExpectedServiceErrorCode(code string) bool { + switch code { + case ErrorCodeInvalidRequest, + ErrorCodeChallengeNotFound, + ErrorCodeChallengeExpired, + ErrorCodeInvalidCode, + ErrorCodeInvalidClientPublicKey, + ErrorCodeBlockedByPolicy, + ErrorCodeSessionLimitExceeded, + ErrorCodeSessionNotFound, + ErrorCodeSubjectNotFound: + return true + default: + return false + } +} diff --git a/authsession/internal/service/shared/policy.go b/authsession/internal/service/shared/policy.go new file mode 100644 index 0000000..3950b68 --- /dev/null +++ b/authsession/internal/service/shared/policy.go @@ -0,0 +1,11 @@ +package shared + +const ( + // MaxCompareAndSwapRetries bounds application-level retry loops around + // compare-and-swap challenge updates. + MaxCompareAndSwapRetries = 3 + + // MaxProjectionPublishAttempts bounds synchronous request-path retries + // around gateway session projection publication. + MaxProjectionPublishAttempts = 3 +) diff --git a/authsession/internal/service/shared/projection_publish.go b/authsession/internal/service/shared/projection_publish.go new file mode 100644 index 0000000..98a9446 --- /dev/null +++ b/authsession/internal/service/shared/projection_publish.go @@ -0,0 +1,86 @@ +package shared + +import ( + "context" + "errors" + "fmt" + + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/telemetry" +) + +// PublishProjectionSnapshot publishes snapshot through publisher with a small +// bounded retry loop suitable for request-path consistency repair. +func PublishProjectionSnapshot(ctx context.Context, publisher ports.GatewaySessionProjectionPublisher, snapshot gatewayprojection.Snapshot) error { + return PublishProjectionSnapshotWithTelemetry(ctx, publisher, snapshot, nil, "") +} + +// PublishProjectionSnapshotWithTelemetry publishes snapshot through publisher +// with the bounded request-path retry policy and optional publish-failure +// telemetry. +func PublishProjectionSnapshotWithTelemetry( + ctx context.Context, + publisher ports.GatewaySessionProjectionPublisher, + snapshot gatewayprojection.Snapshot, + telemetryRuntime *telemetry.Runtime, + operation string, +) error { + if publisher == nil { + return InternalError(errors.New("projection publisher must not be nil")) + } + if ctx == nil { + return ServiceUnavailable(errors.New("projection publish context must not be nil")) + } + if err := snapshot.Validate(); err != nil { + return InternalError(fmt.Errorf("publish projection snapshot: %w", err)) + } + + var lastErr error + for attempt := 0; attempt < MaxProjectionPublishAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return ServiceUnavailable(err) + } + + if err := publisher.PublishSession(ctx, snapshot); err == nil { + return nil + } else { + lastErr = err + } + } + + telemetryRuntime.RecordProjectionPublishFailure(ctx, operation) + return ServiceUnavailable( + fmt.Errorf( + "publish projection snapshot %q after %d attempts: %w", + snapshot.DeviceSessionID, + MaxProjectionPublishAttempts, + lastErr, + ), + ) +} + +// PublishSessionProjection converts record into the gateway-facing snapshot and +// publishes it with the bounded request-path retry policy. +func PublishSessionProjection(ctx context.Context, publisher ports.GatewaySessionProjectionPublisher, record devicesession.Session) error { + return PublishSessionProjectionWithTelemetry(ctx, publisher, record, nil, "") +} + +// PublishSessionProjectionWithTelemetry converts record into the +// gateway-facing snapshot and publishes it with the bounded request-path retry +// policy and optional publish-failure telemetry. +func PublishSessionProjectionWithTelemetry( + ctx context.Context, + publisher ports.GatewaySessionProjectionPublisher, + record devicesession.Session, + telemetryRuntime *telemetry.Runtime, + operation string, +) error { + snapshot, err := ToGatewayProjectionSnapshot(record) + if err != nil { + return InternalError(err) + } + + return PublishProjectionSnapshotWithTelemetry(ctx, publisher, snapshot, telemetryRuntime, operation) +} diff --git a/authsession/internal/service/shared/projection_publish_test.go b/authsession/internal/service/shared/projection_publish_test.go new file mode 100644 index 0000000..ccb139c --- /dev/null +++ b/authsession/internal/service/shared/projection_publish_test.go @@ -0,0 +1,119 @@ +package shared + +import ( + "context" + "errors" + "testing" + + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPublishSessionProjectionRetriesUntilSuccess(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + errors []error + wantAttempts int + }{ + { + name: "success on second attempt", + errors: []error{errors.New("transient publish failure"), nil}, + wantAttempts: 2, + }, + { + name: "success on third attempt", + errors: []error{errors.New("transient publish failure"), errors.New("transient publish failure"), nil}, + wantAttempts: 3, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + publisher := &testkit.RecordingProjectionPublisher{Errors: tt.errors} + + err := PublishSessionProjection(context.Background(), publisher, revokedSessionFixture()) + require.NoError(t, err) + require.Len(t, publisher.PublishedSnapshots(), tt.wantAttempts) + }) + } +} + +func TestPublishSessionProjectionReturnsServiceUnavailableAfterExhaustedRetries(t *testing.T) { + t.Parallel() + + publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")} + + err := PublishSessionProjection(context.Background(), publisher, revokedSessionFixture()) + require.Error(t, err) + assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err)) + require.Len(t, publisher.PublishedSnapshots(), MaxProjectionPublishAttempts) +} + +func TestPublishProjectionSnapshotStopsRetriesWhenContextIsCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + publisher := &cancelingProjectionPublisher{ + cancel: cancel, + err: errors.New("publish failed"), + } + + err := PublishProjectionSnapshot(ctx, publisher, mustProjectionSnapshot(t)) + require.Error(t, err) + assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err)) + assert.Equal(t, 1, publisher.attempts) +} + +func TestPublishSessionProjectionReturnsInternalErrorForInvalidLocalRecord(t *testing.T) { + t.Parallel() + + publisher := &testkit.RecordingProjectionPublisher{} + + err := PublishSessionProjection(context.Background(), publisher, invalidSessionFixture()) + require.Error(t, err) + assert.Equal(t, ErrorCodeInternalError, CodeOf(err)) + assert.Empty(t, publisher.PublishedSnapshots()) +} + +type cancelingProjectionPublisher struct { + attempts int + cancel context.CancelFunc + err error +} + +func (p *cancelingProjectionPublisher) PublishSession(_ context.Context, snapshot gatewayprojection.Snapshot) error { + if err := snapshot.Validate(); err != nil { + return err + } + + p.attempts++ + if p.cancel != nil { + p.cancel() + p.cancel = nil + } + + return p.err +} + +func mustProjectionSnapshot(t *testing.T) gatewayprojection.Snapshot { + t.Helper() + + snapshot, err := ToGatewayProjectionSnapshot(revokedSessionFixture()) + require.NoError(t, err) + + return snapshot +} + +func invalidSessionFixture() devicesession.Session { + return devicesession.Session{} +} diff --git a/authsession/internal/service/shared/session.go b/authsession/internal/service/shared/session.go new file mode 100644 index 0000000..c3a2bb9 --- /dev/null +++ b/authsession/internal/service/shared/session.go @@ -0,0 +1,134 @@ +package shared + +import ( + "fmt" + "time" + + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" +) + +// Session mirrors the frozen internal read-model DTO used by later trusted +// transport handlers. +type Session struct { + // DeviceSessionID is the stable identifier of one device session. + DeviceSessionID string + + // UserID is the stable identifier of the session owner. + UserID string + + // ClientPublicKey is the base64-encoded raw 32-byte Ed25519 public key of + // the device session. + ClientPublicKey string + + // Status reports whether the session is active or revoked. + Status string + + // CreatedAt is the RFC3339 UTC timestamp at which the session was created. + CreatedAt string + + // RevokedAt is the RFC3339 UTC timestamp at which the session was revoked, + // when the session is revoked. + RevokedAt *string + + // RevokeReasonCode is the machine-readable revoke reason code when the + // session is revoked. + RevokeReasonCode *string + + // RevokeActorType is the machine-readable revoke actor type when the + // session is revoked. + RevokeActorType *string + + // RevokeActorID is the optional stable revoke actor identifier when the + // session is revoked. + RevokeActorID *string +} + +// ToSession converts source-of-truth session into the frozen internal read DTO +// shape. +func ToSession(record devicesession.Session) (Session, error) { + if err := record.Validate(); err != nil { + return Session{}, fmt.Errorf("map session: %w", err) + } + + result := Session{ + DeviceSessionID: record.ID.String(), + UserID: record.UserID.String(), + ClientPublicKey: record.ClientPublicKey.String(), + Status: string(record.Status), + CreatedAt: formatTime(record.CreatedAt), + } + + if record.Revocation != nil { + revokedAt := formatTime(record.Revocation.At) + reasonCode := record.Revocation.ReasonCode.String() + actorType := record.Revocation.ActorType.String() + result.RevokedAt = &revokedAt + result.RevokeReasonCode = &reasonCode + result.RevokeActorType = &actorType + if record.Revocation.ActorID != "" { + actorID := record.Revocation.ActorID + result.RevokeActorID = &actorID + } + } + + return result, nil +} + +// ToSessions converts every source-of-truth session into the frozen internal +// read DTO shape. +func ToSessions(records []devicesession.Session) ([]Session, error) { + result := make([]Session, 0, len(records)) + for index, record := range records { + mapped, err := ToSession(record) + if err != nil { + return nil, fmt.Errorf("map session %d: %w", index, err) + } + result = append(result, mapped) + } + + return result, nil +} + +// ToGatewayProjectionSnapshot converts source-of-truth session into the +// separate gateway-facing projection model. +func ToGatewayProjectionSnapshot(record devicesession.Session) (gatewayprojection.Snapshot, error) { + if err := record.Validate(); err != nil { + return gatewayprojection.Snapshot{}, fmt.Errorf("map gateway projection snapshot: %w", err) + } + + snapshot := gatewayprojection.Snapshot{ + DeviceSessionID: record.ID, + UserID: record.UserID, + ClientPublicKey: record.ClientPublicKey.String(), + Status: gatewayprojection.Status(record.Status), + } + if record.Revocation != nil { + snapshot.RevokedAt = cloneTimePointer(commonTimePointer(record.Revocation.At.UTC())) + snapshot.RevokeReasonCode = record.Revocation.ReasonCode + snapshot.RevokeActorType = record.Revocation.ActorType + snapshot.RevokeActorID = record.Revocation.ActorID + } + if err := snapshot.Validate(); err != nil { + return gatewayprojection.Snapshot{}, fmt.Errorf("map gateway projection snapshot: %w", err) + } + + return snapshot, nil +} + +func formatTime(value time.Time) string { + return value.UTC().Format(time.RFC3339) +} + +func commonTimePointer(value time.Time) *time.Time { + return &value +} + +func cloneTimePointer(value *time.Time) *time.Time { + if value == nil { + return nil + } + + cloned := *value + return &cloned +} diff --git a/authsession/internal/service/shared/session_limit.go b/authsession/internal/service/shared/session_limit.go new file mode 100644 index 0000000..1676dbf --- /dev/null +++ b/authsession/internal/service/shared/session_limit.go @@ -0,0 +1,40 @@ +package shared + +import ( + "fmt" + + "galaxy/authsession/internal/domain/sessionlimit" + "galaxy/authsession/internal/ports" +) + +// EvaluateSessionLimit evaluates the Stage-4 active-session creation decision +// from the loaded configuration and current active-session count. +func EvaluateSessionLimit(config ports.SessionLimitConfig, activeSessionCount int) (sessionlimit.Decision, error) { + if err := config.Validate(); err != nil { + return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: %w", err)) + } + if activeSessionCount < 0 { + return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: active session count %d is negative", activeSessionCount)) + } + + decision := sessionlimit.Decision{ + ActiveSessionCount: activeSessionCount, + NextSessionCount: activeSessionCount + 1, + } + + if config.ActiveSessionLimit == nil { + decision.Kind = sessionlimit.KindDisabled + } else { + decision.ConfiguredLimit = config.ActiveSessionLimit + if decision.NextSessionCount <= *config.ActiveSessionLimit { + decision.Kind = sessionlimit.KindAllowed + } else { + decision.Kind = sessionlimit.KindExceeded + } + } + if err := decision.Validate(); err != nil { + return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: %w", err)) + } + + return decision, nil +} diff --git a/authsession/internal/service/shared/shared_test.go b/authsession/internal/service/shared/shared_test.go new file mode 100644 index 0000000..32b5c9b --- /dev/null +++ b/authsession/internal/service/shared/shared_test.go @@ -0,0 +1,380 @@ +package shared + +import ( + "errors" + "net/http" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/domain/sessionlimit" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNormalizeString(t *testing.T) { + t.Parallel() + + assert.Equal(t, "pilot@example.com", NormalizeString(" pilot@example.com \n")) +} + +func TestParseClientPublicKey(t *testing.T) { + t.Parallel() + + key, err := ParseClientPublicKey(" AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8= ") + require.NoError(t, err) + assert.Equal(t, "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", key.String()) + + _, err = ParseClientPublicKey("invalid") + require.Error(t, err) + assert.Equal(t, ErrorCodeInvalidClientPublicKey, CodeOf(err)) +} + +func TestToSession(t *testing.T) { + t.Parallel() + + record := revokedSessionFixture() + + dto, err := ToSession(record) + require.NoError(t, err) + assert.Equal(t, record.ID.String(), dto.DeviceSessionID) + require.NotNil(t, dto.RevokedAt) + assert.Equal(t, record.Revocation.At.UTC().Format(time.RFC3339), *dto.RevokedAt) +} + +func TestToGatewayProjectionSnapshot(t *testing.T) { + t.Parallel() + + record := revokedSessionFixture() + + snapshot, err := ToGatewayProjectionSnapshot(record) + require.NoError(t, err) + assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status) +} + +func TestEvaluateSessionLimit(t *testing.T) { + t.Parallel() + + limit := 2 + + tests := []struct { + name string + config ports.SessionLimitConfig + active int + want sessionlimit.Kind + }{ + {name: "disabled", config: ports.SessionLimitConfig{}, active: 3, want: sessionlimit.KindDisabled}, + {name: "allowed", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 1, want: sessionlimit.KindAllowed}, + {name: "exceeded", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 2, want: sessionlimit.KindExceeded}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + decision, err := EvaluateSessionLimit(tt.config, tt.active) + require.NoError(t, err) + assert.Equal(t, tt.want, decision.Kind) + }) + } +} + +func TestServiceErrorCodePreservation(t *testing.T) { + t.Parallel() + + baseErr := errors.New("base") + err := ServiceUnavailable(baseErr) + + assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err)) + assert.ErrorIs(t, err, baseErr) +} + +func TestErrorCodeClassification(t *testing.T) { + t.Parallel() + + publicCodes := []string{ + ErrorCodeInvalidRequest, + ErrorCodeChallengeNotFound, + ErrorCodeChallengeExpired, + ErrorCodeInvalidCode, + ErrorCodeInvalidClientPublicKey, + ErrorCodeBlockedByPolicy, + ErrorCodeSessionLimitExceeded, + ErrorCodeServiceUnavailable, + } + for _, code := range publicCodes { + assert.Truef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code) + assert.Falsef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code) + } + + internalOnlyCodes := []string{ + ErrorCodeSessionNotFound, + ErrorCodeSubjectNotFound, + ErrorCodeInternalError, + } + for _, code := range internalOnlyCodes { + assert.Falsef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code) + assert.Truef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code) + } +} + +func TestPublicUseCaseErrorCodeSets(t *testing.T) { + t.Parallel() + + assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeInvalidRequest)) + assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeServiceUnavailable)) + assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeBlockedByPolicy)) + assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeChallengeNotFound)) + + confirmCodes := []string{ + ErrorCodeInvalidRequest, + ErrorCodeChallengeNotFound, + ErrorCodeChallengeExpired, + ErrorCodeInvalidCode, + ErrorCodeInvalidClientPublicKey, + ErrorCodeBlockedByPolicy, + ErrorCodeSessionLimitExceeded, + ErrorCodeServiceUnavailable, + } + for _, code := range confirmCodes { + assert.Truef(t, IsConfirmEmailCodePublicErrorCode(code), "IsConfirmEmailCodePublicErrorCode(%q)", code) + } + assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeInternalError)) + assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeSessionNotFound)) +} + +func TestPublicHTTPStatusCode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + code string + want int + }{ + {name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest}, + {name: "invalid client public key", code: ErrorCodeInvalidClientPublicKey, want: http.StatusBadRequest}, + {name: "invalid code", code: ErrorCodeInvalidCode, want: http.StatusBadRequest}, + {name: "challenge not found", code: ErrorCodeChallengeNotFound, want: http.StatusNotFound}, + {name: "challenge expired", code: ErrorCodeChallengeExpired, want: http.StatusGone}, + {name: "blocked by policy", code: ErrorCodeBlockedByPolicy, want: http.StatusForbidden}, + {name: "session limit exceeded", code: ErrorCodeSessionLimitExceeded, want: http.StatusConflict}, + {name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable}, + {name: "internal error normalized", code: ErrorCodeInternalError, want: http.StatusServiceUnavailable}, + {name: "unknown normalized", code: "unknown", want: http.StatusServiceUnavailable}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, PublicHTTPStatusCode(tt.code)) + }) + } +} + +func TestInternalHTTPStatusCode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + code string + want int + }{ + {name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest}, + {name: "session not found", code: ErrorCodeSessionNotFound, want: http.StatusNotFound}, + {name: "subject not found", code: ErrorCodeSubjectNotFound, want: http.StatusNotFound}, + {name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable}, + {name: "internal error", code: ErrorCodeInternalError, want: http.StatusInternalServerError}, + {name: "unknown normalized", code: "unknown", want: http.StatusInternalServerError}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, InternalHTTPStatusCode(tt.code)) + }) + } +} + +func TestProjectPublicError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want PublicErrorProjection + }{ + { + name: "invalid request keeps detailed message", + err: InvalidRequest("email must be a single valid email address"), + want: PublicErrorProjection{ + StatusCode: http.StatusBadRequest, + Code: ErrorCodeInvalidRequest, + Message: "email must be a single valid email address", + }, + }, + { + name: "invalid code keeps canonical message", + err: NewServiceError(ErrorCodeInvalidCode, "custom detail should not leak", nil), + want: PublicErrorProjection{ + StatusCode: http.StatusBadRequest, + Code: ErrorCodeInvalidCode, + Message: "confirmation code is invalid", + }, + }, + { + name: "service unavailable keeps generic message", + err: NewServiceError(ErrorCodeServiceUnavailable, "dependency timeout", errors.New("dependency timeout")), + want: PublicErrorProjection{ + StatusCode: http.StatusServiceUnavailable, + Code: ErrorCodeServiceUnavailable, + Message: "service is unavailable", + }, + }, + { + name: "internal error is hidden", + err: InternalError(errors.New("broken invariant")), + want: PublicErrorProjection{ + StatusCode: http.StatusServiceUnavailable, + Code: ErrorCodeServiceUnavailable, + Message: "service is unavailable", + }, + }, + { + name: "internal only session not found is hidden", + err: SessionNotFound(), + want: PublicErrorProjection{ + StatusCode: http.StatusServiceUnavailable, + Code: ErrorCodeServiceUnavailable, + Message: "service is unavailable", + }, + }, + { + name: "non service error is hidden", + err: errors.New("boom"), + want: PublicErrorProjection{ + StatusCode: http.StatusServiceUnavailable, + Code: ErrorCodeServiceUnavailable, + Message: "service is unavailable", + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, ProjectPublicError(tt.err)) + }) + } +} + +func TestProjectInternalError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want InternalErrorProjection + }{ + { + name: "invalid request keeps detailed message", + err: InvalidRequest("reason_code must not be empty"), + want: InternalErrorProjection{ + StatusCode: http.StatusBadRequest, + Code: ErrorCodeInvalidRequest, + Message: "reason_code must not be empty", + }, + }, + { + name: "session not found keeps canonical message", + err: NewServiceError(ErrorCodeSessionNotFound, "custom detail should not leak", nil), + want: InternalErrorProjection{ + StatusCode: http.StatusNotFound, + Code: ErrorCodeSessionNotFound, + Message: "session not found", + }, + }, + { + name: "subject not found keeps canonical message", + err: SubjectNotFound(), + want: InternalErrorProjection{ + StatusCode: http.StatusNotFound, + Code: ErrorCodeSubjectNotFound, + Message: "subject not found", + }, + }, + { + name: "service unavailable keeps generic message", + err: NewServiceError(ErrorCodeServiceUnavailable, "redis timeout", errors.New("redis timeout")), + want: InternalErrorProjection{ + StatusCode: http.StatusServiceUnavailable, + Code: ErrorCodeServiceUnavailable, + Message: "service is unavailable", + }, + }, + { + name: "internal error uses internal server error message", + err: InternalError(errors.New("broken invariant")), + want: InternalErrorProjection{ + StatusCode: http.StatusInternalServerError, + Code: ErrorCodeInternalError, + Message: "internal server error", + }, + }, + { + name: "unexpected error is hidden", + err: errors.New("boom"), + want: InternalErrorProjection{ + StatusCode: http.StatusInternalServerError, + Code: ErrorCodeInternalError, + Message: "internal server error", + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, ProjectInternalError(tt.err)) + }) + } +} + +func revokedSessionFixture() devicesession.Session { + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + revokedAt := time.Unix(20, 0).UTC() + return devicesession.Session{ + ID: common.DeviceSessionID("device-session-1"), + UserID: common.UserID("user-1"), + ClientPublicKey: key, + Status: devicesession.StatusRevoked, + CreatedAt: time.Unix(10, 0).UTC(), + Revocation: &devicesession.Revocation{ + At: revokedAt, + ReasonCode: devicesession.RevokeReasonLogoutAll, + ActorType: common.RevokeActorType("system"), + ActorID: "actor-1", + }, + } +} diff --git a/authsession/internal/telemetry/runtime.go b/authsession/internal/telemetry/runtime.go new file mode 100644 index 0000000..c00e222 --- /dev/null +++ b/authsession/internal/telemetry/runtime.go @@ -0,0 +1,620 @@ +// Package telemetry provides shared OpenTelemetry runtime helpers and +// low-cardinality authsession instruments. +package telemetry + +import ( + "context" + "errors" + "fmt" + "galaxy/authsession/internal/domain/devicesession" + "io" + "os" + "strings" + "sync" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/exporters/stdout/stdoutmetric" + "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/propagation" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + oteltrace "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +const meterName = "galaxy/authsession" + +const ( + processExporterNone = "none" + processExporterOTLP = "otlp" + processProtocolHTTPProtobuf = "http/protobuf" + processProtocolGRPC = "grpc" +) + +// ProcessConfig configures the process-wide OpenTelemetry runtime. +type ProcessConfig struct { + // ServiceName overrides the default OpenTelemetry service name. + ServiceName string + + // TracesExporter selects the external traces exporter. Supported values are + // `none` and `otlp`. + TracesExporter string + + // MetricsExporter selects the external metrics exporter. Supported values + // are `none` and `otlp`. + MetricsExporter string + + // TracesProtocol selects the OTLP traces protocol when TracesExporter is + // `otlp`. + TracesProtocol string + + // MetricsProtocol selects the OTLP metrics protocol when MetricsExporter is + // `otlp`. + MetricsProtocol string + + // StdoutTracesEnabled enables the additional stdout trace exporter used for + // local development and debugging. + StdoutTracesEnabled bool + + // StdoutMetricsEnabled enables the additional stdout metric exporter used + // for local development and debugging. + StdoutMetricsEnabled bool +} + +// Validate reports whether cfg contains a supported OpenTelemetry exporter +// configuration. +func (cfg ProcessConfig) Validate() error { + switch cfg.TracesExporter { + case processExporterNone, processExporterOTLP: + default: + return fmt.Errorf("unsupported traces exporter %q", cfg.TracesExporter) + } + + switch cfg.MetricsExporter { + case processExporterNone, processExporterOTLP: + default: + return fmt.Errorf("unsupported metrics exporter %q", cfg.MetricsExporter) + } + + if cfg.TracesProtocol != "" && cfg.TracesProtocol != processProtocolHTTPProtobuf && cfg.TracesProtocol != processProtocolGRPC { + return fmt.Errorf("unsupported OTLP traces protocol %q", cfg.TracesProtocol) + } + if cfg.MetricsProtocol != "" && cfg.MetricsProtocol != processProtocolHTTPProtobuf && cfg.MetricsProtocol != processProtocolGRPC { + return fmt.Errorf("unsupported OTLP metrics protocol %q", cfg.MetricsProtocol) + } + + return nil +} + +// SendEmailCodeOutcome identifies the coarse send-email-code result recorded +// by authsession metrics. +type SendEmailCodeOutcome string + +const ( + // SendEmailCodeOutcomeSent reports that the login code was handed off for + // delivery successfully. + SendEmailCodeOutcomeSent SendEmailCodeOutcome = "sent" + + // SendEmailCodeOutcomeSuppressed reports that outward send stayed + // success-shaped while actual delivery was skipped intentionally. + SendEmailCodeOutcomeSuppressed SendEmailCodeOutcome = "suppressed" + + // SendEmailCodeOutcomeThrottled reports that a fresh challenge was created + // but delivery was skipped because the resend cooldown was active. + SendEmailCodeOutcomeThrottled SendEmailCodeOutcome = "throttled" + + // SendEmailCodeOutcomeFailed reports that the send flow reached an explicit + // failure after a source-of-truth write. + SendEmailCodeOutcomeFailed SendEmailCodeOutcome = "failed" +) + +// IsKnown reports whether SendEmailCodeOutcome belongs to the stable +// authsession send-flow metric surface. +func (o SendEmailCodeOutcome) IsKnown() bool { + switch o { + case SendEmailCodeOutcomeSent, + SendEmailCodeOutcomeSuppressed, + SendEmailCodeOutcomeThrottled, + SendEmailCodeOutcomeFailed: + return true + default: + return false + } +} + +// SendEmailCodeReason identifies the low-cardinality send-flow reason recorded +// for suppressed, throttled, or failed outcomes. +type SendEmailCodeReason string + +const ( + // SendEmailCodeReasonBlocked reports that delivery was suppressed because + // user policy already marked the e-mail as blocked. + SendEmailCodeReasonBlocked SendEmailCodeReason = "blocked" + + // SendEmailCodeReasonMailSender reports that the delivery adapter itself + // suppressed or failed the send attempt. + SendEmailCodeReasonMailSender SendEmailCodeReason = "mail_sender" + + // SendEmailCodeReasonThrottled reports that delivery was skipped because the + // resend cooldown was active. + SendEmailCodeReasonThrottled SendEmailCodeReason = "throttled" +) + +// IsKnown reports whether SendEmailCodeReason belongs to the stable authsession +// send-flow metric surface. +func (r SendEmailCodeReason) IsKnown() bool { + switch r { + case "", + SendEmailCodeReasonBlocked, + SendEmailCodeReasonMailSender, + SendEmailCodeReasonThrottled: + return true + default: + return false + } +} + +// ConfirmEmailCodeOutcome identifies the coarse confirm-email-code result +// recorded by authsession metrics. +type ConfirmEmailCodeOutcome string + +const ( + // ConfirmEmailCodeOutcomeSuccess reports that a device session was created + // or idempotently recovered successfully. + ConfirmEmailCodeOutcomeSuccess ConfirmEmailCodeOutcome = "success" +) + +// Runtime owns the authsession OpenTelemetry providers and custom +// low-cardinality instruments. +type Runtime struct { + tracerProvider oteltrace.TracerProvider + meterProvider metric.MeterProvider + + shutdownMu sync.Mutex + shutdownDone bool + shutdownErr error + shutdownFns []func(context.Context) error + + publicHTTPRequests metric.Int64Counter + publicHTTPDuration metric.Float64Histogram + internalHTTPRequests metric.Int64Counter + internalHTTPDuration metric.Float64Histogram + sendEmailCodeAttempts metric.Int64Counter + confirmEmailCodeAttempts metric.Int64Counter + challengesCreated metric.Int64Counter + sessionsCreated metric.Int64Counter + sessionLimitRejections metric.Int64Counter + projectionPublishFailures metric.Int64Counter + userDirectoryOutcomes metric.Int64Counter + sessionsRevoked metric.Int64Counter +} + +// New constructs a lightweight telemetry runtime around meterProvider for +// tests and embedded use cases that do not need process-level exporter wiring. +func New(meterProvider metric.MeterProvider) (*Runtime, error) { + return NewWithProviders(meterProvider, nil) +} + +// NewWithProviders constructs a telemetry runtime around explicitly supplied +// meterProvider and tracerProvider values. +func NewWithProviders(meterProvider metric.MeterProvider, tracerProvider oteltrace.TracerProvider) (*Runtime, error) { + if meterProvider == nil { + meterProvider = otel.GetMeterProvider() + } + if tracerProvider == nil { + tracerProvider = otel.GetTracerProvider() + } + if meterProvider == nil { + return nil, errors.New("new authsession telemetry runtime: nil meter provider") + } + if tracerProvider == nil { + return nil, errors.New("new authsession telemetry runtime: nil tracer provider") + } + + return buildRuntime(meterProvider, tracerProvider, nil) +} + +// NewProcess constructs the process-wide authsession OpenTelemetry runtime from +// cfg, installs the resulting providers globally, and returns the runtime. +func NewProcess(ctx context.Context, cfg ProcessConfig, logger *zap.Logger) (*Runtime, error) { + return newProcess(ctx, cfg, logger, os.Stdout, os.Stdout) +} + +// TracerProvider returns the runtime tracer provider. +func (r *Runtime) TracerProvider() oteltrace.TracerProvider { + if r == nil || r.tracerProvider == nil { + return otel.GetTracerProvider() + } + + return r.tracerProvider +} + +// MeterProvider returns the runtime meter provider. +func (r *Runtime) MeterProvider() metric.MeterProvider { + if r == nil || r.meterProvider == nil { + return otel.GetMeterProvider() + } + + return r.meterProvider +} + +// Shutdown flushes and stops the configured telemetry providers. Shutdown is +// idempotent. +func (r *Runtime) Shutdown(ctx context.Context) error { + if r == nil { + return nil + } + + r.shutdownMu.Lock() + if r.shutdownDone { + err := r.shutdownErr + r.shutdownMu.Unlock() + return err + } + r.shutdownDone = true + shutdownFns := append([]func(context.Context) error(nil), r.shutdownFns...) + r.shutdownMu.Unlock() + + var joined error + for _, shutdownFn := range shutdownFns { + joined = errors.Join(joined, shutdownFn(ctx)) + } + + r.shutdownMu.Lock() + r.shutdownErr = joined + r.shutdownMu.Unlock() + + return joined +} + +// RecordPublicHTTPRequest records one public HTTP request outcome. +func (r *Runtime) RecordPublicHTTPRequest(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) { + if r == nil { + return + } + + options := metric.WithAttributes(attrs...) + r.publicHTTPRequests.Add(normalizeContext(ctx), 1, options) + r.publicHTTPDuration.Record(normalizeContext(ctx), duration.Seconds()*1000, options) +} + +// RecordInternalHTTPRequest records one trusted internal HTTP request outcome. +func (r *Runtime) RecordInternalHTTPRequest(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) { + if r == nil { + return + } + + options := metric.WithAttributes(attrs...) + r.internalHTTPRequests.Add(normalizeContext(ctx), 1, options) + r.internalHTTPDuration.Record(normalizeContext(ctx), duration.Seconds()*1000, options) +} + +// RecordSendEmailCode records one low-cardinality send-email-code outcome. +func (r *Runtime) RecordSendEmailCode(ctx context.Context, outcome SendEmailCodeOutcome, reason SendEmailCodeReason) { + if r == nil || !outcome.IsKnown() || !reason.IsKnown() { + return + } + + attrs := []attribute.KeyValue{ + attribute.String("outcome", string(outcome)), + } + if reason != "" { + attrs = append(attrs, attribute.String("reason", string(reason))) + } + + r.sendEmailCodeAttempts.Add(normalizeContext(ctx), 1, metric.WithAttributes(attrs...)) +} + +// RecordConfirmEmailCode records one low-cardinality confirm-email-code +// outcome. Success uses the stable value `success`; failures should pass the +// stable service/public error code. +func (r *Runtime) RecordConfirmEmailCode(ctx context.Context, outcome string) { + if r == nil || outcome == "" { + return + } + + r.confirmEmailCodeAttempts.Add( + normalizeContext(ctx), + 1, + metric.WithAttributes(attribute.String("outcome", outcome)), + ) +} + +// RecordChallengeCreated records one newly persisted challenge. +func (r *Runtime) RecordChallengeCreated(ctx context.Context) { + if r == nil { + return + } + + r.challengesCreated.Add(normalizeContext(ctx), 1) +} + +// RecordSessionCreated records one newly persisted device session. +func (r *Runtime) RecordSessionCreated(ctx context.Context) { + if r == nil { + return + } + + r.sessionsCreated.Add(normalizeContext(ctx), 1) +} + +// RecordSessionLimitRejection records one rejected confirmation caused by the +// active-session limit. +func (r *Runtime) RecordSessionLimitRejection(ctx context.Context) { + if r == nil { + return + } + + r.sessionLimitRejections.Add(normalizeContext(ctx), 1) +} + +// RecordProjectionPublishFailure records one exhausted projection publish +// failure for operation. +func (r *Runtime) RecordProjectionPublishFailure(ctx context.Context, operation string) { + if r == nil || strings.TrimSpace(operation) == "" { + return + } + + r.projectionPublishFailures.Add( + normalizeContext(ctx), + 1, + metric.WithAttributes(attribute.String("operation", operation)), + ) +} + +// RecordUserDirectoryOutcome records one user-directory boundary outcome for +// operation. +func (r *Runtime) RecordUserDirectoryOutcome(ctx context.Context, operation string, outcome string) { + if r == nil || strings.TrimSpace(operation) == "" || strings.TrimSpace(outcome) == "" { + return + } + + r.userDirectoryOutcomes.Add( + normalizeContext(ctx), + 1, + metric.WithAttributes( + attribute.String("operation", operation), + attribute.String("outcome", outcome), + ), + ) +} + +// RecordSessionRevocations records count revoked sessions for operation and a +// low-cardinality revoke-reason bucket. +func (r *Runtime) RecordSessionRevocations(ctx context.Context, operation string, reasonCode string, count int64) { + if r == nil || strings.TrimSpace(operation) == "" || count <= 0 { + return + } + + r.sessionsRevoked.Add( + normalizeContext(ctx), + count, + metric.WithAttributes( + attribute.String("operation", operation), + attribute.String("reason_bucket", revokeReasonBucket(reasonCode)), + ), + ) +} + +func newProcess(ctx context.Context, cfg ProcessConfig, logger *zap.Logger, stdoutTraceWriter io.Writer, stdoutMetricWriter io.Writer) (*Runtime, error) { + if ctx == nil { + return nil, errors.New("new authsession process telemetry: nil context") + } + if logger == nil { + logger = zap.NewNop() + } + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("new authsession process telemetry: %w", err) + } + + res, err := resource.New( + ctx, + resource.WithAttributes(attribute.String("service.name", cfg.ServiceName)), + ) + if err != nil { + return nil, fmt.Errorf("new authsession process telemetry: resource: %w", err) + } + + tracerProvider, err := newTracerProvider(ctx, res, cfg, stdoutTraceWriter) + if err != nil { + return nil, fmt.Errorf("new authsession process telemetry: tracer provider: %w", err) + } + + meterProvider, err := newMeterProvider(ctx, res, cfg, stdoutMetricWriter) + if err != nil { + return nil, fmt.Errorf("new authsession process telemetry: meter provider: %w", err) + } + + logger.Info( + "authsession telemetry configured", + zap.String("service_name", cfg.ServiceName), + zap.String("traces_exporter", cfg.TracesExporter), + zap.String("metrics_exporter", cfg.MetricsExporter), + zap.Bool("stdout_traces_enabled", cfg.StdoutTracesEnabled), + zap.Bool("stdout_metrics_enabled", cfg.StdoutMetricsEnabled), + ) + + otel.SetTracerProvider(tracerProvider) + otel.SetMeterProvider(meterProvider) + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + )) + + return buildRuntime( + meterProvider, + tracerProvider, + []func(context.Context) error{ + meterProvider.Shutdown, + tracerProvider.Shutdown, + }, + ) +} + +func buildRuntime(meterProvider metric.MeterProvider, tracerProvider oteltrace.TracerProvider, shutdownFns []func(context.Context) error) (*Runtime, error) { + meter := meterProvider.Meter(meterName) + + publicHTTPRequests, err := meter.Int64Counter("authsession.public_http.requests") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: public HTTP requests counter: %w", err) + } + publicHTTPDuration, err := meter.Float64Histogram("authsession.public_http.duration", metric.WithUnit("ms")) + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: public HTTP duration histogram: %w", err) + } + internalHTTPRequests, err := meter.Int64Counter("authsession.internal_http.requests") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: internal HTTP requests counter: %w", err) + } + internalHTTPDuration, err := meter.Float64Histogram("authsession.internal_http.duration", metric.WithUnit("ms")) + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: internal HTTP duration histogram: %w", err) + } + sendEmailCodeAttempts, err := meter.Int64Counter("authsession.send_email_code.attempts") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: send email code attempts counter: %w", err) + } + confirmEmailCodeAttempts, err := meter.Int64Counter("authsession.confirm_email_code.attempts") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: confirm email code attempts counter: %w", err) + } + challengesCreated, err := meter.Int64Counter("authsession.challenges.created") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: challenges created counter: %w", err) + } + sessionsCreated, err := meter.Int64Counter("authsession.sessions.created") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: sessions created counter: %w", err) + } + sessionLimitRejections, err := meter.Int64Counter("authsession.session_limit.rejections") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: session limit rejections counter: %w", err) + } + projectionPublishFailures, err := meter.Int64Counter("authsession.projection.publish_failures") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: projection publish failures counter: %w", err) + } + userDirectoryOutcomes, err := meter.Int64Counter("authsession.user_directory.outcomes") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: user directory outcomes counter: %w", err) + } + sessionsRevoked, err := meter.Int64Counter("authsession.sessions.revoked") + if err != nil { + return nil, fmt.Errorf("build authsession telemetry runtime: sessions revoked counter: %w", err) + } + + return &Runtime{ + tracerProvider: tracerProvider, + meterProvider: meterProvider, + shutdownFns: shutdownFns, + publicHTTPRequests: publicHTTPRequests, + publicHTTPDuration: publicHTTPDuration, + internalHTTPRequests: internalHTTPRequests, + internalHTTPDuration: internalHTTPDuration, + sendEmailCodeAttempts: sendEmailCodeAttempts, + confirmEmailCodeAttempts: confirmEmailCodeAttempts, + challengesCreated: challengesCreated, + sessionsCreated: sessionsCreated, + sessionLimitRejections: sessionLimitRejections, + projectionPublishFailures: projectionPublishFailures, + userDirectoryOutcomes: userDirectoryOutcomes, + sessionsRevoked: sessionsRevoked, + }, nil +} + +func newTracerProvider(ctx context.Context, res *resource.Resource, cfg ProcessConfig, stdoutWriter io.Writer) (*sdktrace.TracerProvider, error) { + options := []sdktrace.TracerProviderOption{sdktrace.WithResource(res)} + + if cfg.TracesExporter == processExporterOTLP { + exporter, err := newOTLPTraceExporter(ctx, cfg.TracesProtocol) + if err != nil { + return nil, err + } + options = append(options, sdktrace.WithBatcher(exporter)) + } + if cfg.StdoutTracesEnabled { + exporter, err := stdouttrace.New( + stdouttrace.WithPrettyPrint(), + stdouttrace.WithWriter(stdoutWriter), + ) + if err != nil { + return nil, err + } + options = append(options, sdktrace.WithBatcher(exporter)) + } + + return sdktrace.NewTracerProvider(options...), nil +} + +func newMeterProvider(ctx context.Context, res *resource.Resource, cfg ProcessConfig, stdoutWriter io.Writer) (*sdkmetric.MeterProvider, error) { + options := []sdkmetric.Option{sdkmetric.WithResource(res)} + + if cfg.MetricsExporter == processExporterOTLP { + exporter, err := newOTLPMetricExporter(ctx, cfg.MetricsProtocol) + if err != nil { + return nil, err + } + options = append(options, sdkmetric.WithReader(sdkmetric.NewPeriodicReader(exporter))) + } + if cfg.StdoutMetricsEnabled { + exporter, err := stdoutmetric.New( + stdoutmetric.WithPrettyPrint(), + stdoutmetric.WithWriter(stdoutWriter), + ) + if err != nil { + return nil, err + } + options = append(options, sdkmetric.WithReader(sdkmetric.NewPeriodicReader(exporter))) + } + + return sdkmetric.NewMeterProvider(options...), nil +} + +func newOTLPTraceExporter(ctx context.Context, protocol string) (sdktrace.SpanExporter, error) { + switch protocol { + case "", "http/protobuf": + return otlptracehttp.New(ctx) + case "grpc": + return otlptracegrpc.New(ctx) + default: + return nil, fmt.Errorf("unsupported OTLP traces protocol %q", protocol) + } +} + +func newOTLPMetricExporter(ctx context.Context, protocol string) (sdkmetric.Exporter, error) { + switch protocol { + case "", "http/protobuf": + return otlpmetrichttp.New(ctx) + case "grpc": + return otlpmetricgrpc.New(ctx) + default: + return nil, fmt.Errorf("unsupported OTLP metrics protocol %q", protocol) + } +} + +func revokeReasonBucket(reasonCode string) string { + switch strings.TrimSpace(reasonCode) { + case devicesession.RevokeReasonUserBlocked.String(): + return "user_blocked" + case "confirm_race_repair": + return "confirm_race_repair" + default: + return "custom" + } +} + +func normalizeContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + + return ctx +} diff --git a/authsession/internal/telemetry/runtime_test.go b/authsession/internal/telemetry/runtime_test.go new file mode 100644 index 0000000..fc6ae95 --- /dev/null +++ b/authsession/internal/telemetry/runtime_test.go @@ -0,0 +1,124 @@ +package telemetry + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.uber.org/zap" +) + +func TestNewProcessBuildsWithoutExporters(t *testing.T) { + runtime, err := newProcess(context.Background(), ProcessConfig{ + ServiceName: "galaxy-authsession-test", + TracesExporter: processExporterNone, + MetricsExporter: processExporterNone, + }, zap.NewNop(), ioDiscard{}, ioDiscard{}) + require.NoError(t, err) + + assert.NotNil(t, runtime.TracerProvider()) + assert.NotNil(t, runtime.MeterProvider()) + require.NoError(t, runtime.Shutdown(context.Background())) + require.NoError(t, runtime.Shutdown(context.Background())) +} + +func TestNewProcessBuildsWithStdoutExporters(t *testing.T) { + traceBuffer := &bytes.Buffer{} + metricBuffer := &bytes.Buffer{} + + runtime, err := newProcess(context.Background(), ProcessConfig{ + ServiceName: "galaxy-authsession-test", + TracesExporter: processExporterNone, + MetricsExporter: processExporterNone, + StdoutTracesEnabled: true, + StdoutMetricsEnabled: true, + }, zap.NewNop(), traceBuffer, metricBuffer) + require.NoError(t, err) + + ctx, span := runtime.TracerProvider().Tracer("test").Start(context.Background(), "public-request") + runtime.RecordSendEmailCode(ctx, SendEmailCodeOutcomeSent, "") + span.End() + + require.NoError(t, runtime.Shutdown(context.Background())) + assert.NotEmpty(t, traceBuffer.String()) + assert.NotEmpty(t, metricBuffer.String()) +} + +func TestNewPreservesBusinessMetrics(t *testing.T) { + reader := sdkmetric.NewManualReader() + meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + tracerProvider := sdktrace.NewTracerProvider() + + runtime, err := NewWithProviders(meterProvider, tracerProvider) + require.NoError(t, err) + + runtime.RecordSendEmailCode(context.Background(), SendEmailCodeOutcomeSuppressed, SendEmailCodeReasonBlocked) + runtime.RecordUserDirectoryOutcome(context.Background(), "ensure_user_by_email", "created") + runtime.RecordSessionRevocations(context.Background(), "block_user", "user_blocked", 2) + + assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{ + "outcome": "suppressed", + "reason": "blocked", + }, 1) + assertMetricCount(t, reader, "authsession.user_directory.outcomes", map[string]string{ + "operation": "ensure_user_by_email", + "outcome": "created", + }, 1) + assertMetricCount(t, reader, "authsession.sessions.revoked", map[string]string{ + "operation": "block_user", + "reason_bucket": "user_blocked", + }, 2) +} + +type ioDiscard struct{} + +func (ioDiscard) Write(p []byte) (int, error) { + return len(p), nil +} + +func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) { + t.Helper() + + var resourceMetrics metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &resourceMetrics)) + + for _, scopeMetrics := range resourceMetrics.ScopeMetrics { + for _, metric := range scopeMetrics.Metrics { + if metric.Name != metricName { + continue + } + + sum, ok := metric.Data.(metricdata.Sum[int64]) + require.True(t, ok) + + for _, point := range sum.DataPoints { + if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) { + assert.Equal(t, wantValue, point.Value) + return + } + } + } + } + + require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs) +} + +func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool { + if len(values) != len(want) { + return false + } + + for _, value := range values { + if want[string(value.Key)] != value.Value.AsString() { + return false + } + } + + return true +} diff --git a/authsession/internal/testkit/challenge_store.go b/authsession/internal/testkit/challenge_store.go new file mode 100644 index 0000000..7403477 --- /dev/null +++ b/authsession/internal/testkit/challenge_store.go @@ -0,0 +1,122 @@ +package testkit + +import ( + "context" + "errors" + "fmt" + "reflect" + "sync" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" +) + +// InMemoryChallengeStore is a deterministic map-backed ChallengeStore double +// suitable for service tests. +type InMemoryChallengeStore struct { + mu sync.Mutex + records map[common.ChallengeID]challenge.Challenge +} + +// Get returns the stored challenge for challengeID. +func (s *InMemoryChallengeStore) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) { + if err := ctx.Err(); err != nil { + return challenge.Challenge{}, err + } + if err := challengeID.Validate(); err != nil { + return challenge.Challenge{}, fmt.Errorf("get challenge: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + record, ok := s.records[challengeID] + if !ok { + return challenge.Challenge{}, fmt.Errorf("get challenge %q: %w", challengeID, ports.ErrNotFound) + } + + cloned, err := cloneChallenge(record) + if err != nil { + return challenge.Challenge{}, err + } + + return cloned, nil +} + +// Create stores record as a new challenge. +func (s *InMemoryChallengeStore) Create(ctx context.Context, record challenge.Challenge) error { + if err := ctx.Err(); err != nil { + return err + } + if err := record.Validate(); err != nil { + return fmt.Errorf("create challenge: %w", err) + } + + cloned, err := cloneChallenge(record) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.records == nil { + s.records = make(map[common.ChallengeID]challenge.Challenge) + } + if _, exists := s.records[record.ID]; exists { + return fmt.Errorf("create challenge %q: %w", record.ID, ports.ErrConflict) + } + + s.records[record.ID] = cloned + return nil +} + +// CompareAndSwap replaces previous with next when the currently stored +// challenge matches previous exactly. +func (s *InMemoryChallengeStore) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error { + if err := ctx.Err(); err != nil { + return err + } + if err := ports.ValidateComparableChallenges(previous, next); err != nil { + return fmt.Errorf("compare and swap challenge: %w", err) + } + + clonedPrevious, err := cloneChallenge(previous) + if err != nil { + return err + } + clonedNext, err := cloneChallenge(next) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + current, ok := s.records[previous.ID] + if !ok { + return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrNotFound) + } + if !reflect.DeepEqual(current, clonedPrevious) { + return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrConflict) + } + + s.records[next.ID] = clonedNext + return nil +} + +var _ ports.ChallengeStore = (*InMemoryChallengeStore)(nil) + +func mustGetChallenge(store *InMemoryChallengeStore, challengeID common.ChallengeID) challenge.Challenge { + record, err := store.Get(context.Background(), challengeID) + if err != nil { + panic(err) + } + + return record +} + +func isNotFound(err error) bool { + return errors.Is(err, ports.ErrNotFound) +} diff --git a/authsession/internal/testkit/challenge_store_test.go b/authsession/internal/testkit/challenge_store_test.go new file mode 100644 index 0000000..593b3ad --- /dev/null +++ b/authsession/internal/testkit/challenge_store_test.go @@ -0,0 +1,80 @@ +package testkit + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" +) + +func TestInMemoryChallengeStoreCreateAndGet(t *testing.T) { + t.Parallel() + + store := &InMemoryChallengeStore{} + record := challengeFixture() + + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + got, err := store.Get(context.Background(), record.ID) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if got.ID != record.ID { + require.Failf(t, "test failed", "Get().ID = %q, want %q", got.ID, record.ID) + } + if &got.CodeHash[0] == &record.CodeHash[0] { + require.FailNow(t, "Get() returned aliased code hash slice") + } +} + +func TestInMemoryChallengeStoreGetNotFound(t *testing.T) { + t.Parallel() + + store := &InMemoryChallengeStore{} + + _, err := store.Get(context.Background(), common.ChallengeID("missing")) + if !errors.Is(err, ports.ErrNotFound) { + require.Failf(t, "test failed", "Get() error = %v, want ErrNotFound", err) + } +} + +func TestInMemoryChallengeStoreCompareAndSwapConflict(t *testing.T) { + t.Parallel() + + store := &InMemoryChallengeStore{} + record := challengeFixture() + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + previous := record + previous.Attempts.Confirm = 1 + next := record + next.Status = challenge.StatusSent + next.DeliveryState = challenge.DeliverySent + + err := store.CompareAndSwap(context.Background(), previous, next) + if !errors.Is(err, ports.ErrConflict) { + require.Failf(t, "test failed", "CompareAndSwap() error = %v, want ErrConflict", err) + } +} + +func challengeFixture() challenge.Challenge { + timestamp := time.Unix(20, 0).UTC() + return challenge.Challenge{ + ID: common.ChallengeID("challenge-1"), + Email: common.Email("pilot@example.com"), + CodeHash: []byte("hash"), + Status: challenge.StatusPendingSend, + DeliveryState: challenge.DeliveryPending, + CreatedAt: timestamp, + ExpiresAt: timestamp.Add(10 * time.Minute), + } +} diff --git a/authsession/internal/testkit/clock.go b/authsession/internal/testkit/clock.go new file mode 100644 index 0000000..23a4967 --- /dev/null +++ b/authsession/internal/testkit/clock.go @@ -0,0 +1,15 @@ +package testkit + +import "time" + +// FixedClock is a deterministic Clock double that always returns the same +// instant. +type FixedClock struct { + // Time is the instant returned by Now. + Time time.Time +} + +// Now returns the configured instant. +func (c FixedClock) Now() time.Time { + return c.Time +} diff --git a/authsession/internal/testkit/clones.go b/authsession/internal/testkit/clones.go new file mode 100644 index 0000000..2f63365 --- /dev/null +++ b/authsession/internal/testkit/clones.go @@ -0,0 +1,130 @@ +package testkit + +import ( + "bytes" + "fmt" + "slices" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" +) + +func cloneChallenge(record challenge.Challenge) (challenge.Challenge, error) { + cloned := record + cloned.CodeHash = bytes.Clone(record.CodeHash) + cloned.Abuse = cloneAbuseMetadata(record.Abuse) + + if record.Confirmation != nil { + confirmation, err := cloneChallengeConfirmation(*record.Confirmation) + if err != nil { + return challenge.Challenge{}, err + } + cloned.Confirmation = &confirmation + } + + return cloned, nil +} + +func cloneChallengeConfirmation(value challenge.Confirmation) (challenge.Confirmation, error) { + cloned := value + + if value.ClientPublicKey.IsZero() { + cloned.ClientPublicKey = common.ClientPublicKey{} + return cloned, nil + } + + key, err := common.NewClientPublicKey(value.ClientPublicKey.PublicKey()) + if err != nil { + return challenge.Confirmation{}, fmt.Errorf("clone challenge confirmation client public key: %w", err) + } + cloned.ClientPublicKey = key + + return cloned, nil +} + +func cloneAbuseMetadata(value challenge.AbuseMetadata) challenge.AbuseMetadata { + cloned := value + if value.LastAttemptAt != nil { + lastAttemptAt := *value.LastAttemptAt + cloned.LastAttemptAt = &lastAttemptAt + } + + return cloned +} + +func cloneSession(record devicesession.Session) (devicesession.Session, error) { + cloned := record + + if !record.ClientPublicKey.IsZero() { + key, err := common.NewClientPublicKey(record.ClientPublicKey.PublicKey()) + if err != nil { + return devicesession.Session{}, fmt.Errorf("clone session client public key: %w", err) + } + cloned.ClientPublicKey = key + } + if record.Revocation != nil { + revocation := *record.Revocation + cloned.Revocation = &revocation + } + + return cloned, nil +} + +func cloneSessions(records []devicesession.Session) ([]devicesession.Session, error) { + cloned := make([]devicesession.Session, 0, len(records)) + for _, record := range records { + session, err := cloneSession(record) + if err != nil { + return nil, err + } + cloned = append(cloned, session) + } + + return cloned, nil +} + +func cloneProjectionSnapshot(snapshot gatewayprojection.Snapshot) gatewayprojection.Snapshot { + cloned := snapshot + if snapshot.RevokedAt != nil { + revokedAt := *snapshot.RevokedAt + cloned.RevokedAt = &revokedAt + } + + return cloned +} + +func sortSessionsNewestFirst(records []devicesession.Session) { + slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int { + switch { + case left.CreatedAt.Equal(right.CreatedAt): + return compareStrings(left.ID.String(), right.ID.String()) + case left.CreatedAt.After(right.CreatedAt): + return -1 + default: + return 1 + } + }) +} + +func compareStrings(left string, right string) int { + switch { + case left < right: + return -1 + case left > right: + return 1 + default: + return 0 + } +} + +func cloneTimePointer(value *time.Time) *time.Time { + if value == nil { + return nil + } + + cloned := *value + return &cloned +} diff --git a/authsession/internal/testkit/code_generator.go b/authsession/internal/testkit/code_generator.go new file mode 100644 index 0000000..ffe7266 --- /dev/null +++ b/authsession/internal/testkit/code_generator.go @@ -0,0 +1,35 @@ +package testkit + +import ( + "errors" + "strings" + + "galaxy/authsession/internal/ports" +) + +// FixedCodeGenerator is a deterministic CodeGenerator double that always +// returns the same code or error. +type FixedCodeGenerator struct { + // Code stores the fixed code returned by Generate when Err is nil. + Code string + + // Err is returned directly from Generate when set. + Err error +} + +// Generate returns the configured fixed code. +func (g FixedCodeGenerator) Generate() (string, error) { + if g.Err != nil { + return "", g.Err + } + switch { + case strings.TrimSpace(g.Code) == "": + return "", errors.New("fixed code generator code must not be empty") + case strings.TrimSpace(g.Code) != g.Code: + return "", errors.New("fixed code generator code must not contain surrounding whitespace") + default: + return g.Code, nil + } +} + +var _ ports.CodeGenerator = FixedCodeGenerator{} diff --git a/authsession/internal/testkit/code_hasher.go b/authsession/internal/testkit/code_hasher.go new file mode 100644 index 0000000..433ed7f --- /dev/null +++ b/authsession/internal/testkit/code_hasher.go @@ -0,0 +1,51 @@ +package testkit + +import ( + "crypto/sha256" + "crypto/subtle" + "errors" + "strings" + + "galaxy/authsession/internal/ports" +) + +// DeterministicCodeHasher is a deterministic CodeHasher double backed by +// SHA-256 for test stability. +type DeterministicCodeHasher struct{} + +// Hash returns the SHA-256 digest of code. +func (DeterministicCodeHasher) Hash(code string) ([]byte, error) { + if err := validateCode(code); err != nil { + return nil, err + } + + sum := sha256.Sum256([]byte(code)) + return sum[:], nil +} + +// Compare reports whether hash equals the deterministic hash of code. +func (h DeterministicCodeHasher) Compare(hash []byte, code string) (bool, error) { + if err := validateCode(code); err != nil { + return false, err + } + + expected, err := h.Hash(code) + if err != nil { + return false, err + } + + return subtle.ConstantTimeCompare(hash, expected) == 1, nil +} + +var _ ports.CodeHasher = DeterministicCodeHasher{} + +func validateCode(code string) error { + switch { + case strings.TrimSpace(code) == "": + return errors.New("code must not be empty") + case strings.TrimSpace(code) != code: + return errors.New("code must not contain surrounding whitespace") + default: + return nil + } +} diff --git a/authsession/internal/testkit/config_provider.go b/authsession/internal/testkit/config_provider.go new file mode 100644 index 0000000..ca4fc63 --- /dev/null +++ b/authsession/internal/testkit/config_provider.go @@ -0,0 +1,34 @@ +package testkit + +import ( + "context" + + "galaxy/authsession/internal/ports" +) + +// StaticConfigProvider is a deterministic ConfigProvider double that returns a +// preconfigured session-limit value or error. +type StaticConfigProvider struct { + // Config stores the configuration returned when Err is nil. + Config ports.SessionLimitConfig + + // Err is returned directly from LoadSessionLimit when set. + Err error +} + +// LoadSessionLimit returns the preconfigured session-limit result. +func (p StaticConfigProvider) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) { + if err := ctx.Err(); err != nil { + return ports.SessionLimitConfig{}, err + } + if p.Err != nil { + return ports.SessionLimitConfig{}, p.Err + } + if err := p.Config.Validate(); err != nil { + return ports.SessionLimitConfig{}, err + } + + return p.Config, nil +} + +var _ ports.ConfigProvider = StaticConfigProvider{} diff --git a/authsession/internal/testkit/doc.go b/authsession/internal/testkit/doc.go new file mode 100644 index 0000000..fe2647b --- /dev/null +++ b/authsession/internal/testkit/doc.go @@ -0,0 +1,4 @@ +// Package testkit provides deterministic in-memory doubles for auth/session +// service ports so later service tests can run without Redis, HTTP, or other +// external dependencies. +package testkit diff --git a/authsession/internal/testkit/id_generator.go b/authsession/internal/testkit/id_generator.go new file mode 100644 index 0000000..3edcc8c --- /dev/null +++ b/authsession/internal/testkit/id_generator.go @@ -0,0 +1,101 @@ +package testkit + +import ( + "fmt" + "sync" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" +) + +// SequenceIDGenerator is a deterministic IDGenerator double that consumes +// queued identifiers before falling back to monotonic generated ids. +type SequenceIDGenerator struct { + mu sync.Mutex + + // ChallengeIDs stores queued challenge identifiers returned by + // NewChallengeID before generated ids are used. + ChallengeIDs []common.ChallengeID + + // DeviceSessionIDs stores queued device-session identifiers returned by + // NewDeviceSessionID before generated ids are used. + DeviceSessionIDs []common.DeviceSessionID + + // ChallengeErr is returned directly from NewChallengeID when set. + ChallengeErr error + + // DeviceSessionErr is returned directly from NewDeviceSessionID when set. + DeviceSessionErr error + + ChallengePrefix string + DeviceSessionPrefix string + nextChallengeNumber int + nextSessionNumber int +} + +// NewChallengeID returns the next deterministic challenge identifier. +func (g *SequenceIDGenerator) NewChallengeID() (common.ChallengeID, error) { + if g.ChallengeErr != nil { + return "", g.ChallengeErr + } + + g.mu.Lock() + defer g.mu.Unlock() + + if len(g.ChallengeIDs) > 0 { + id := g.ChallengeIDs[0] + g.ChallengeIDs = g.ChallengeIDs[1:] + if err := id.Validate(); err != nil { + return "", err + } + return id, nil + } + + g.nextChallengeNumber++ + prefix := g.ChallengePrefix + if prefix == "" { + prefix = "challenge-" + } + + id := common.ChallengeID(fmt.Sprintf("%s%d", prefix, g.nextChallengeNumber)) + if err := id.Validate(); err != nil { + return "", err + } + + return id, nil +} + +// NewDeviceSessionID returns the next deterministic device-session +// identifier. +func (g *SequenceIDGenerator) NewDeviceSessionID() (common.DeviceSessionID, error) { + if g.DeviceSessionErr != nil { + return "", g.DeviceSessionErr + } + + g.mu.Lock() + defer g.mu.Unlock() + + if len(g.DeviceSessionIDs) > 0 { + id := g.DeviceSessionIDs[0] + g.DeviceSessionIDs = g.DeviceSessionIDs[1:] + if err := id.Validate(); err != nil { + return "", err + } + return id, nil + } + + g.nextSessionNumber++ + prefix := g.DeviceSessionPrefix + if prefix == "" { + prefix = "device-session-" + } + + id := common.DeviceSessionID(fmt.Sprintf("%s%d", prefix, g.nextSessionNumber)) + if err := id.Validate(); err != nil { + return "", err + } + + return id, nil +} + +var _ ports.IDGenerator = (*SequenceIDGenerator)(nil) diff --git a/authsession/internal/testkit/mail_sender.go b/authsession/internal/testkit/mail_sender.go new file mode 100644 index 0000000..5eb86c8 --- /dev/null +++ b/authsession/internal/testkit/mail_sender.go @@ -0,0 +1,73 @@ +package testkit + +import ( + "context" + "sync" + + "galaxy/authsession/internal/ports" +) + +// RecordingMailSender is a deterministic MailSender double that records every +// delivery request and returns preconfigured outcomes or errors. +type RecordingMailSender struct { + mu sync.Mutex + + // Results stores queued results consumed by SendLoginCode before + // DefaultResult is used. + Results []ports.SendLoginCodeResult + + // DefaultResult stores the result used when Results is empty. + DefaultResult ports.SendLoginCodeResult + + // Err is returned directly from SendLoginCode when set. + Err error + + recordedInputs []ports.SendLoginCodeInput +} + +// SendLoginCode records input and returns the next configured result. +func (s *RecordingMailSender) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) { + if err := ctx.Err(); err != nil { + return ports.SendLoginCodeResult{}, err + } + if err := input.Validate(); err != nil { + return ports.SendLoginCodeResult{}, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.recordedInputs = append(s.recordedInputs, input) + if s.Err != nil { + return ports.SendLoginCodeResult{}, s.Err + } + + if len(s.Results) > 0 { + result := s.Results[0] + s.Results = s.Results[1:] + if err := result.Validate(); err != nil { + return ports.SendLoginCodeResult{}, err + } + return result, nil + } + + result := s.DefaultResult + if result.Outcome == "" { + result.Outcome = ports.SendLoginCodeOutcomeSent + } + if err := result.Validate(); err != nil { + return ports.SendLoginCodeResult{}, err + } + + return result, nil +} + +// RecordedInputs returns a stable snapshot of every recorded mail request. +func (s *RecordingMailSender) RecordedInputs() []ports.SendLoginCodeInput { + s.mu.Lock() + defer s.mu.Unlock() + + return append([]ports.SendLoginCodeInput(nil), s.recordedInputs...) +} + +var _ ports.MailSender = (*RecordingMailSender)(nil) diff --git a/authsession/internal/testkit/projection_publisher.go b/authsession/internal/testkit/projection_publisher.go new file mode 100644 index 0000000..39a66bc --- /dev/null +++ b/authsession/internal/testkit/projection_publisher.go @@ -0,0 +1,62 @@ +package testkit + +import ( + "context" + "sync" + + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/ports" +) + +// RecordingProjectionPublisher is a deterministic +// GatewaySessionProjectionPublisher double that records every published +// snapshot. +type RecordingProjectionPublisher struct { + mu sync.Mutex + + // Err is returned directly from PublishSession when set. + Err error + + // Errors is an optional FIFO error script consumed before Err. Nil entries + // represent successful publish attempts. + Errors []error + + published []gatewayprojection.Snapshot +} + +// PublishSession records snapshot and returns the configured error, if any. +func (p *RecordingProjectionPublisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error { + if err := ctx.Err(); err != nil { + return err + } + if err := snapshot.Validate(); err != nil { + return err + } + + p.mu.Lock() + defer p.mu.Unlock() + + p.published = append(p.published, cloneProjectionSnapshot(snapshot)) + if len(p.Errors) > 0 { + err := p.Errors[0] + p.Errors = append([]error(nil), p.Errors[1:]...) + return err + } + + return p.Err +} + +// PublishedSnapshots returns a stable snapshot of every published projection. +func (p *RecordingProjectionPublisher) PublishedSnapshots() []gatewayprojection.Snapshot { + p.mu.Lock() + defer p.mu.Unlock() + + snapshots := make([]gatewayprojection.Snapshot, 0, len(p.published)) + for _, snapshot := range p.published { + snapshots = append(snapshots, cloneProjectionSnapshot(snapshot)) + } + + return snapshots +} + +var _ ports.GatewaySessionProjectionPublisher = (*RecordingProjectionPublisher)(nil) diff --git a/authsession/internal/testkit/projection_publisher_test.go b/authsession/internal/testkit/projection_publisher_test.go new file mode 100644 index 0000000..36e3707 --- /dev/null +++ b/authsession/internal/testkit/projection_publisher_test.go @@ -0,0 +1,48 @@ +package testkit + +import ( + "context" + "errors" + "testing" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/gatewayprojection" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRecordingProjectionPublisherConsumesScriptedErrorsAndRecordsAttempts(t *testing.T) { + t.Parallel() + + publisher := &RecordingProjectionPublisher{ + Errors: []error{errors.New("first publish failed"), nil}, + } + snapshot := projectionSnapshotFixture() + + err := publisher.PublishSession(context.Background(), snapshot) + require.Error(t, err) + + err = publisher.PublishSession(context.Background(), snapshot) + require.NoError(t, err) + + published := publisher.PublishedSnapshots() + require.Len(t, published, 2) + assert.Equal(t, snapshot.DeviceSessionID, published[0].DeviceSessionID) + assert.Equal(t, snapshot.DeviceSessionID, published[1].DeviceSessionID) + + published[0].ClientPublicKey = "mutated" + + stable := publisher.PublishedSnapshots() + require.Len(t, stable, 2) + assert.Equal(t, snapshot.ClientPublicKey, stable[0].ClientPublicKey) +} + +func projectionSnapshotFixture() gatewayprojection.Snapshot { + return gatewayprojection.Snapshot{ + DeviceSessionID: common.DeviceSessionID("device-session-1"), + UserID: common.UserID("user-1"), + ClientPublicKey: "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", + Status: gatewayprojection.StatusActive, + } +} diff --git a/authsession/internal/testkit/send_email_code_abuse.go b/authsession/internal/testkit/send_email_code_abuse.go new file mode 100644 index 0000000..1a4d9c6 --- /dev/null +++ b/authsession/internal/testkit/send_email_code_abuse.go @@ -0,0 +1,58 @@ +package testkit + +import ( + "context" + "fmt" + "sync" + "time" + + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" +) + +// InMemorySendEmailCodeAbuseProtector is a deterministic map-backed +// SendEmailCodeAbuseProtector double suitable for service tests. +type InMemorySendEmailCodeAbuseProtector struct { + mu sync.Mutex + + // Err is returned directly from CheckAndReserve when set. + Err error + + reservedUntil map[common.Email]time.Time +} + +// CheckAndReserve applies the fixed resend cooldown using input.Now as the +// authoritative decision timestamp. +func (p *InMemorySendEmailCodeAbuseProtector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) { + if err := ctx.Err(); err != nil { + return ports.SendEmailCodeAbuseResult{}, err + } + if err := input.Validate(); err != nil { + return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check send email code abuse: %w", err) + } + if p.Err != nil { + return ports.SendEmailCodeAbuseResult{}, p.Err + } + + p.mu.Lock() + defer p.mu.Unlock() + + if p.reservedUntil == nil { + p.reservedUntil = make(map[common.Email]time.Time) + } + + reservedUntil, exists := p.reservedUntil[input.Email] + if exists && input.Now.Before(reservedUntil) { + return ports.SendEmailCodeAbuseResult{ + Outcome: ports.SendEmailCodeAbuseOutcomeThrottled, + }, nil + } + + p.reservedUntil[input.Email] = input.Now.UTC().Add(challenge.ResendThrottleCooldown) + return ports.SendEmailCodeAbuseResult{ + Outcome: ports.SendEmailCodeAbuseOutcomeAllowed, + }, nil +} + +var _ ports.SendEmailCodeAbuseProtector = (*InMemorySendEmailCodeAbuseProtector)(nil) diff --git a/authsession/internal/testkit/send_email_code_abuse_test.go b/authsession/internal/testkit/send_email_code_abuse_test.go new file mode 100644 index 0000000..67b5203 --- /dev/null +++ b/authsession/internal/testkit/send_email_code_abuse_test.go @@ -0,0 +1,42 @@ +package testkit + +import ( + "context" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInMemorySendEmailCodeAbuseProtector(t *testing.T) { + t.Parallel() + + protector := &InMemorySendEmailCodeAbuseProtector{} + email := common.Email("pilot@example.com") + now := time.Unix(10, 0).UTC() + + result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now, + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome) + + result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now.Add(30 * time.Second), + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome) + + result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{ + Email: email, + Now: now.Add(time.Minute), + }) + require.NoError(t, err) + assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome) +} diff --git a/authsession/internal/testkit/session_store.go b/authsession/internal/testkit/session_store.go new file mode 100644 index 0000000..725a8cb --- /dev/null +++ b/authsession/internal/testkit/session_store.go @@ -0,0 +1,229 @@ +package testkit + +import ( + "context" + "fmt" + "slices" + "sync" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/ports" +) + +// InMemorySessionStore is a deterministic map-backed SessionStore double +// suitable for service tests. +type InMemorySessionStore struct { + mu sync.Mutex + records map[common.DeviceSessionID]devicesession.Session +} + +// Get returns the stored device session for deviceSessionID. +func (s *InMemorySessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { + if err := ctx.Err(); err != nil { + return devicesession.Session{}, err + } + if err := deviceSessionID.Validate(); err != nil { + return devicesession.Session{}, fmt.Errorf("get session: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + record, ok := s.records[deviceSessionID] + if !ok { + return devicesession.Session{}, fmt.Errorf("get session %q: %w", deviceSessionID, ports.ErrNotFound) + } + + cloned, err := cloneSession(record) + if err != nil { + return devicesession.Session{}, err + } + + return cloned, nil +} + +// ListByUserID returns every stored session for userID in newest-first order. +func (s *InMemorySessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if err := userID.Validate(); err != nil { + return nil, fmt.Errorf("list sessions by user id: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + var records []devicesession.Session + for _, record := range s.records { + if record.UserID == userID { + cloned, err := cloneSession(record) + if err != nil { + return nil, err + } + records = append(records, cloned) + } + } + sortSessionsNewestFirst(records) + + return records, nil +} + +// CountActiveByUserID returns the number of active sessions currently stored +// for userID. +func (s *InMemorySessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) { + if err := ctx.Err(); err != nil { + return 0, err + } + if err := userID.Validate(); err != nil { + return 0, fmt.Errorf("count active sessions by user id: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + count := 0 + for _, record := range s.records { + if record.UserID == userID && record.Status == devicesession.StatusActive { + count++ + } + } + + return count, nil +} + +// Create stores record as a new device session. +func (s *InMemorySessionStore) Create(ctx context.Context, record devicesession.Session) error { + if err := ctx.Err(); err != nil { + return err + } + if err := record.Validate(); err != nil { + return fmt.Errorf("create session: %w", err) + } + + cloned, err := cloneSession(record) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.records == nil { + s.records = make(map[common.DeviceSessionID]devicesession.Session) + } + if _, exists := s.records[record.ID]; exists { + return fmt.Errorf("create session %q: %w", record.ID, ports.ErrConflict) + } + + s.records[record.ID] = cloned + return nil +} + +// Revoke stores a revoked view of one target session. +func (s *InMemorySessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) { + if err := ctx.Err(); err != nil { + return ports.RevokeSessionResult{}, err + } + if err := input.Validate(); err != nil { + return ports.RevokeSessionResult{}, fmt.Errorf("revoke session: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + record, ok := s.records[input.DeviceSessionID] + if !ok { + return ports.RevokeSessionResult{}, fmt.Errorf("revoke session %q: %w", input.DeviceSessionID, ports.ErrNotFound) + } + + if record.Status == devicesession.StatusRevoked { + cloned, err := cloneSession(record) + if err != nil { + return ports.RevokeSessionResult{}, err + } + + result := ports.RevokeSessionResult{ + Outcome: ports.RevokeSessionOutcomeAlreadyRevoked, + Session: cloned, + } + if err := result.Validate(); err != nil { + return ports.RevokeSessionResult{}, err + } + + return result, nil + } + + record.Status = devicesession.StatusRevoked + revocation := input.Revocation + record.Revocation = &revocation + + cloned, err := cloneSession(record) + if err != nil { + return ports.RevokeSessionResult{}, err + } + s.records[input.DeviceSessionID] = cloned + + result := ports.RevokeSessionResult{ + Outcome: ports.RevokeSessionOutcomeRevoked, + Session: cloned, + } + if err := result.Validate(); err != nil { + return ports.RevokeSessionResult{}, err + } + + return result, nil +} + +// RevokeAllByUserID stores revoked views for all currently active sessions +// owned by input.UserID. +func (s *InMemorySessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) { + if err := ctx.Err(); err != nil { + return ports.RevokeUserSessionsResult{}, err + } + if err := input.Validate(); err != nil { + return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + var affected []devicesession.Session + for id, record := range s.records { + if record.UserID != input.UserID || record.Status != devicesession.StatusActive { + continue + } + + record.Status = devicesession.StatusRevoked + revocation := input.Revocation + record.Revocation = &revocation + + cloned, err := cloneSession(record) + if err != nil { + return ports.RevokeUserSessionsResult{}, err + } + s.records[id] = cloned + affected = append(affected, cloned) + } + + sortSessionsNewestFirst(affected) + + outcome := ports.RevokeUserSessionsOutcomeNoActiveSessions + if len(affected) > 0 { + outcome = ports.RevokeUserSessionsOutcomeRevoked + } + + result := ports.RevokeUserSessionsResult{ + Outcome: outcome, + UserID: input.UserID, + Sessions: slices.Clone(affected), + } + if err := result.Validate(); err != nil { + return ports.RevokeUserSessionsResult{}, err + } + + return result, nil +} + +var _ ports.SessionStore = (*InMemorySessionStore)(nil) diff --git a/authsession/internal/testkit/session_store_test.go b/authsession/internal/testkit/session_store_test.go new file mode 100644 index 0000000..20fdb53 --- /dev/null +++ b/authsession/internal/testkit/session_store_test.go @@ -0,0 +1,182 @@ +package testkit + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/ports" +) + +func TestInMemorySessionStoreCreateAndGet(t *testing.T) { + t.Parallel() + + store := &InMemorySessionStore{} + record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + got, err := store.Get(context.Background(), record.ID) + if err != nil { + require.Failf(t, "test failed", "Get() returned error: %v", err) + } + if got.ID != record.ID { + require.Failf(t, "test failed", "Get().ID = %q, want %q", got.ID, record.ID) + } +} + +func TestInMemorySessionStoreListByUserIDNewestFirst(t *testing.T) { + t.Parallel() + + store := &InMemorySessionStore{} + older := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + newer := activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()) + otherUser := activeSessionFixture("device-session-3", "user-2", time.Unix(30, 0).UTC()) + + for _, record := range []devicesession.Session{older, newer, otherUser} { + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + } + + got, err := store.ListByUserID(context.Background(), common.UserID("user-1")) + if err != nil { + require.Failf(t, "test failed", "ListByUserID() returned error: %v", err) + } + if len(got) != 2 { + require.Failf(t, "test failed", "ListByUserID() length = %d, want 2", len(got)) + } + if got[0].ID != newer.ID || got[1].ID != older.ID { + require.Failf(t, "test failed", "ListByUserID() order = [%q %q], want [%q %q]", got[0].ID, got[1].ID, newer.ID, older.ID) + } +} + +func TestInMemorySessionStoreCountActiveByUserID(t *testing.T) { + t.Parallel() + + store := &InMemorySessionStore{} + active := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + revoked := revokedSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()) + + for _, record := range []devicesession.Session{active, revoked} { + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + } + + got, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) + if err != nil { + require.Failf(t, "test failed", "CountActiveByUserID() returned error: %v", err) + } + if got != 1 { + require.Failf(t, "test failed", "CountActiveByUserID() = %d, want 1", got) + } +} + +func TestInMemorySessionStoreRevokeIsIdempotent(t *testing.T) { + t.Parallel() + + store := &InMemorySessionStore{} + record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + input := ports.RevokeSessionInput{ + DeviceSessionID: record.ID, + Revocation: devicesession.Revocation{ + At: time.Unix(30, 0).UTC(), + ReasonCode: devicesession.RevokeReasonLogoutAll, + ActorType: common.RevokeActorType("system"), + }, + } + + first, err := store.Revoke(context.Background(), input) + if err != nil { + require.Failf(t, "test failed", "first Revoke() returned error: %v", err) + } + if first.Outcome != ports.RevokeSessionOutcomeRevoked { + require.Failf(t, "test failed", "first Revoke() outcome = %q, want %q", first.Outcome, ports.RevokeSessionOutcomeRevoked) + } + + second, err := store.Revoke(context.Background(), input) + if err != nil { + require.Failf(t, "test failed", "second Revoke() returned error: %v", err) + } + if second.Outcome != ports.RevokeSessionOutcomeAlreadyRevoked { + require.Failf(t, "test failed", "second Revoke() outcome = %q, want %q", second.Outcome, ports.RevokeSessionOutcomeAlreadyRevoked) + } +} + +func TestInMemorySessionStoreRevokeAllNoActiveSessions(t *testing.T) { + t.Parallel() + + store := &InMemorySessionStore{} + record := revokedSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()) + if err := store.Create(context.Background(), record); err != nil { + require.Failf(t, "test failed", "Create() returned error: %v", err) + } + + input := ports.RevokeUserSessionsInput{ + UserID: common.UserID("user-1"), + Revocation: devicesession.Revocation{ + At: time.Unix(40, 0).UTC(), + ReasonCode: devicesession.RevokeReasonAdminRevoke, + ActorType: common.RevokeActorType("admin"), + }, + } + + result, err := store.RevokeAllByUserID(context.Background(), input) + if err != nil { + require.Failf(t, "test failed", "RevokeAllByUserID() returned error: %v", err) + } + if result.Outcome != ports.RevokeUserSessionsOutcomeNoActiveSessions { + require.Failf(t, "test failed", "RevokeAllByUserID() outcome = %q, want %q", result.Outcome, ports.RevokeUserSessionsOutcomeNoActiveSessions) + } + if len(result.Sessions) != 0 { + require.Failf(t, "test failed", "RevokeAllByUserID() session count = %d, want 0", len(result.Sessions)) + } +} + +func TestInMemorySessionStoreGetNotFound(t *testing.T) { + t.Parallel() + + store := &InMemorySessionStore{} + + _, err := store.Get(context.Background(), common.DeviceSessionID("missing")) + if !errors.Is(err, ports.ErrNotFound) { + require.Failf(t, "test failed", "Get() error = %v, want ErrNotFound", err) + } +} + +func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + key, err := common.NewClientPublicKey(make([]byte, 32)) + if err != nil { + panic(err) + } + + return devicesession.Session{ + ID: common.DeviceSessionID(deviceSessionID), + UserID: common.UserID(userID), + ClientPublicKey: key, + Status: devicesession.StatusActive, + CreatedAt: createdAt, + } +} + +func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { + record := activeSessionFixture(deviceSessionID, userID, createdAt) + record.Status = devicesession.StatusRevoked + record.Revocation = &devicesession.Revocation{ + At: createdAt.Add(time.Minute), + ReasonCode: devicesession.RevokeReasonDeviceLogout, + ActorType: common.RevokeActorType("user"), + } + return record +} diff --git a/authsession/internal/testkit/support_test.go b/authsession/internal/testkit/support_test.go new file mode 100644 index 0000000..4328107 --- /dev/null +++ b/authsession/internal/testkit/support_test.go @@ -0,0 +1,147 @@ +package testkit + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/ports" +) + +func TestStaticConfigProvider(t *testing.T) { + t.Parallel() + + limit := 4 + provider := StaticConfigProvider{ + Config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, + } + + got, err := provider.LoadSessionLimit(context.Background()) + if err != nil { + require.Failf(t, "test failed", "LoadSessionLimit() returned error: %v", err) + } + if got.ActiveSessionLimit == nil || *got.ActiveSessionLimit != limit { + require.Failf(t, "test failed", "LoadSessionLimit() = %+v, want limit %d", got, limit) + } +} + +func TestSequenceIDGenerator(t *testing.T) { + t.Parallel() + + generator := &SequenceIDGenerator{ + ChallengeIDs: []common.ChallengeID{"challenge-queue"}, + DeviceSessionIDs: []common.DeviceSessionID{"device-session-queue"}, + } + + challengeID, err := generator.NewChallengeID() + if err != nil { + require.Failf(t, "test failed", "NewChallengeID() returned error: %v", err) + } + if challengeID != common.ChallengeID("challenge-queue") { + require.Failf(t, "test failed", "NewChallengeID() = %q, want queued id", challengeID) + } + + deviceSessionID, err := generator.NewDeviceSessionID() + if err != nil { + require.Failf(t, "test failed", "NewDeviceSessionID() returned error: %v", err) + } + if deviceSessionID != common.DeviceSessionID("device-session-queue") { + require.Failf(t, "test failed", "NewDeviceSessionID() = %q, want queued id", deviceSessionID) + } +} + +func TestFixedCodeGenerator(t *testing.T) { + t.Parallel() + + generator := FixedCodeGenerator{Code: "123456"} + + got, err := generator.Generate() + if err != nil { + require.Failf(t, "test failed", "Generate() returned error: %v", err) + } + if got != "123456" { + require.Failf(t, "test failed", "Generate() = %q, want %q", got, "123456") + } +} + +func TestDeterministicCodeHasher(t *testing.T) { + t.Parallel() + + hasher := DeterministicCodeHasher{} + + hash, err := hasher.Hash("123456") + if err != nil { + require.Failf(t, "test failed", "Hash() returned error: %v", err) + } + + match, err := hasher.Compare(hash, "123456") + if err != nil { + require.Failf(t, "test failed", "Compare() returned error: %v", err) + } + if !match { + require.FailNow(t, "Compare() = false, want true") + } +} + +func TestRecordingMailSender(t *testing.T) { + t.Parallel() + + sender := &RecordingMailSender{ + Results: []ports.SendLoginCodeResult{ + {Outcome: ports.SendLoginCodeOutcomeSuppressed}, + }, + } + + result, err := sender.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ + Email: common.Email("pilot@example.com"), + Code: "654321", + }) + if err != nil { + require.Failf(t, "test failed", "SendLoginCode() returned error: %v", err) + } + if result.Outcome != ports.SendLoginCodeOutcomeSuppressed { + require.Failf(t, "test failed", "SendLoginCode().Outcome = %q, want %q", result.Outcome, ports.SendLoginCodeOutcomeSuppressed) + } + if len(sender.RecordedInputs()) != 1 { + require.Failf(t, "test failed", "RecordedInputs() length = %d, want 1", len(sender.RecordedInputs())) + } +} + +func TestRecordingProjectionPublisher(t *testing.T) { + t.Parallel() + + publisher := &RecordingProjectionPublisher{} + revokedAt := time.Unix(30, 0).UTC() + snapshot := gatewayprojection.Snapshot{ + DeviceSessionID: common.DeviceSessionID("device-session-1"), + UserID: common.UserID("user-1"), + ClientPublicKey: "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", + Status: gatewayprojection.StatusRevoked, + RevokedAt: &revokedAt, + RevokeReasonCode: common.RevokeReasonCode("logout_all"), + RevokeActorType: common.RevokeActorType("system"), + } + + if err := publisher.PublishSession(context.Background(), snapshot); err != nil { + require.Failf(t, "test failed", "PublishSession() returned error: %v", err) + } + if len(publisher.PublishedSnapshots()) != 1 { + require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(publisher.PublishedSnapshots())) + } +} + +func TestStaticConfigProviderReturnsConfiguredError(t *testing.T) { + t.Parallel() + + wantErr := errors.New("config failed") + provider := StaticConfigProvider{Err: wantErr} + + _, err := provider.LoadSessionLimit(context.Background()) + if !errors.Is(err, wantErr) { + require.Failf(t, "test failed", "LoadSessionLimit() error = %v, want %v", err, wantErr) + } +} diff --git a/authsession/internal/testkit/user_directory.go b/authsession/internal/testkit/user_directory.go new file mode 100644 index 0000000..e5e4f00 --- /dev/null +++ b/authsession/internal/testkit/user_directory.go @@ -0,0 +1,309 @@ +package testkit + +import ( + "context" + "fmt" + "sync" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" +) + +type userDirectoryEntry struct { + UserID common.UserID + BlockReasonCode userresolution.BlockReasonCode +} + +// InMemoryUserDirectory is a deterministic map-backed UserDirectory double +// suitable for service tests. +type InMemoryUserDirectory struct { + mu sync.Mutex + byEmail map[common.Email]userDirectoryEntry + emailByUserID map[common.UserID]common.Email + createdUserIDs []common.UserID + nextUserNumber int +} + +// ResolveByEmail returns the current resolution state for email without +// creating a new user. +func (d *InMemoryUserDirectory) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) { + if err := ctx.Err(); err != nil { + return userresolution.Result{}, err + } + if err := email.Validate(); err != nil { + return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + result, err := d.resolveLocked(email) + if err != nil { + return userresolution.Result{}, err + } + + return result, nil +} + +// ExistsByUserID reports whether userID currently identifies a stored user +// record. +func (d *InMemoryUserDirectory) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) { + if err := ctx.Err(); err != nil { + return false, err + } + if err := userID.Validate(); err != nil { + return false, fmt.Errorf("exists by user id: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + _, ok := d.emailByUserID[userID] + return ok, nil +} + +// EnsureUserByEmail returns an existing user for email, creates a new user +// when registration is allowed, or reports a blocked outcome. +func (d *InMemoryUserDirectory) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) { + if err := ctx.Err(); err != nil { + return ports.EnsureUserResult{}, err + } + if err := email.Validate(); err != nil { + return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + if d.byEmail == nil { + d.byEmail = make(map[common.Email]userDirectoryEntry) + } + if d.emailByUserID == nil { + d.emailByUserID = make(map[common.UserID]common.Email) + } + + entry, ok := d.byEmail[email] + if ok { + if !entry.BlockReasonCode.IsZero() { + result := ports.EnsureUserResult{ + Outcome: ports.EnsureUserOutcomeBlocked, + BlockReasonCode: entry.BlockReasonCode, + } + return result, result.Validate() + } + + result := ports.EnsureUserResult{ + Outcome: ports.EnsureUserOutcomeExisting, + UserID: entry.UserID, + } + return result, result.Validate() + } + + userID, err := d.nextCreatedUserIDLocked() + if err != nil { + return ports.EnsureUserResult{}, err + } + d.byEmail[email] = userDirectoryEntry{UserID: userID} + d.emailByUserID[userID] = email + + result := ports.EnsureUserResult{ + Outcome: ports.EnsureUserOutcomeCreated, + UserID: userID, + } + return result, result.Validate() +} + +// BlockByUserID applies a block state to the user identified by input.UserID. +func (d *InMemoryUserDirectory) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) { + if err := ctx.Err(); err != nil { + return ports.BlockUserResult{}, err + } + if err := input.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + email, ok := d.emailByUserID[input.UserID] + if !ok { + return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound) + } + entry := d.byEmail[email] + if !entry.BlockReasonCode.IsZero() { + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeAlreadyBlocked, + UserID: input.UserID, + } + return result, result.Validate() + } + + entry.BlockReasonCode = input.ReasonCode + d.byEmail[email] = entry + + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeBlocked, + UserID: input.UserID, + } + return result, result.Validate() +} + +// BlockByEmail applies a block state to input.Email even when no user record +// currently exists for that e-mail address. +func (d *InMemoryUserDirectory) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) { + if err := ctx.Err(); err != nil { + return ports.BlockUserResult{}, err + } + if err := input.Validate(); err != nil { + return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + if d.byEmail == nil { + d.byEmail = make(map[common.Email]userDirectoryEntry) + } + if d.emailByUserID == nil { + d.emailByUserID = make(map[common.UserID]common.Email) + } + + entry := d.byEmail[input.Email] + if !entry.BlockReasonCode.IsZero() { + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeAlreadyBlocked, + UserID: entry.UserID, + } + return result, result.Validate() + } + + entry.BlockReasonCode = input.ReasonCode + d.byEmail[input.Email] = entry + if !entry.UserID.IsZero() { + d.emailByUserID[entry.UserID] = input.Email + } + + result := ports.BlockUserResult{ + Outcome: ports.BlockUserOutcomeBlocked, + UserID: entry.UserID, + } + return result, result.Validate() +} + +// SeedExisting preloads one existing unblocked user record for service tests. +func (d *InMemoryUserDirectory) SeedExisting(email common.Email, userID common.UserID) error { + if err := email.Validate(); err != nil { + return fmt.Errorf("seed existing email: %w", err) + } + if err := userID.Validate(); err != nil { + return fmt.Errorf("seed existing user id: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + if d.byEmail == nil { + d.byEmail = make(map[common.Email]userDirectoryEntry) + } + if d.emailByUserID == nil { + d.emailByUserID = make(map[common.UserID]common.Email) + } + + d.byEmail[email] = userDirectoryEntry{UserID: userID} + d.emailByUserID[userID] = email + + return nil +} + +// SeedBlockedEmail preloads one blocked e-mail address that does not +// necessarily belong to an existing user record. +func (d *InMemoryUserDirectory) SeedBlockedEmail(email common.Email, reasonCode userresolution.BlockReasonCode) error { + if err := email.Validate(); err != nil { + return fmt.Errorf("seed blocked email: %w", err) + } + if err := reasonCode.Validate(); err != nil { + return fmt.Errorf("seed blocked email reason code: %w", err) + } + + d.mu.Lock() + defer d.mu.Unlock() + + if d.byEmail == nil { + d.byEmail = make(map[common.Email]userDirectoryEntry) + } + + d.byEmail[email] = userDirectoryEntry{BlockReasonCode: reasonCode} + return nil +} + +// SeedBlockedUser preloads one blocked existing user record for service tests. +func (d *InMemoryUserDirectory) SeedBlockedUser(email common.Email, userID common.UserID, reasonCode userresolution.BlockReasonCode) error { + if err := d.SeedExisting(email, userID); err != nil { + return err + } + + d.mu.Lock() + defer d.mu.Unlock() + + entry := d.byEmail[email] + entry.BlockReasonCode = reasonCode + d.byEmail[email] = entry + + return nil +} + +// QueueCreatedUserIDs appends deterministic user identifiers that +// EnsureUserByEmail will consume before falling back to generated ids. +func (d *InMemoryUserDirectory) QueueCreatedUserIDs(userIDs ...common.UserID) error { + for index, userID := range userIDs { + if err := userID.Validate(); err != nil { + return fmt.Errorf("queue created user id %d: %w", index, err) + } + } + + d.mu.Lock() + defer d.mu.Unlock() + + d.createdUserIDs = append(d.createdUserIDs, userIDs...) + return nil +} + +var _ ports.UserDirectory = (*InMemoryUserDirectory)(nil) + +func (d *InMemoryUserDirectory) resolveLocked(email common.Email) (userresolution.Result, error) { + entry, ok := d.byEmail[email] + if !ok { + result := userresolution.Result{Kind: userresolution.KindCreatable} + return result, result.Validate() + } + if !entry.BlockReasonCode.IsZero() { + result := userresolution.Result{ + Kind: userresolution.KindBlocked, + BlockReasonCode: entry.BlockReasonCode, + } + return result, result.Validate() + } + + result := userresolution.Result{ + Kind: userresolution.KindExisting, + UserID: entry.UserID, + } + return result, result.Validate() +} + +func (d *InMemoryUserDirectory) nextCreatedUserIDLocked() (common.UserID, error) { + if len(d.createdUserIDs) > 0 { + userID := d.createdUserIDs[0] + d.createdUserIDs = d.createdUserIDs[1:] + return userID, nil + } + + d.nextUserNumber++ + userID := common.UserID(fmt.Sprintf("user-%d", d.nextUserNumber)) + if err := userID.Validate(); err != nil { + return "", err + } + + return userID, nil +} diff --git a/authsession/internal/testkit/user_directory_test.go b/authsession/internal/testkit/user_directory_test.go new file mode 100644 index 0000000..bb620ea --- /dev/null +++ b/authsession/internal/testkit/user_directory_test.go @@ -0,0 +1,203 @@ +package testkit + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "testing" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" +) + +func TestInMemoryUserDirectoryResolveExistingCreatableAndBlocked(t *testing.T) { + t.Parallel() + + directory := &InMemoryUserDirectory{} + if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil { + require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err) + } + + tests := []struct { + name string + email common.Email + wantKind userresolution.Kind + }{ + {name: "existing", email: common.Email("existing@example.com"), wantKind: userresolution.KindExisting}, + {name: "creatable", email: common.Email("new@example.com"), wantKind: userresolution.KindCreatable}, + {name: "blocked", email: common.Email("blocked@example.com"), wantKind: userresolution.KindBlocked}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := directory.ResolveByEmail(context.Background(), tt.email) + if err != nil { + require.Failf(t, "test failed", "ResolveByEmail() returned error: %v", err) + } + if got.Kind != tt.wantKind { + require.Failf(t, "test failed", "ResolveByEmail().Kind = %q, want %q", got.Kind, tt.wantKind) + } + }) + } +} + +func TestInMemoryUserDirectoryEnsureUserExistingCreatedAndBlocked(t *testing.T) { + t.Parallel() + + directory := &InMemoryUserDirectory{} + if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + if err := directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil { + require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err) + } + if err := directory.QueueCreatedUserIDs(common.UserID("user-created")); err != nil { + require.Failf(t, "test failed", "QueueCreatedUserIDs() returned error: %v", err) + } + + tests := []struct { + name string + email common.Email + wantOutcome ports.EnsureUserOutcome + wantUserID common.UserID + }{ + { + name: "existing", + email: common.Email("existing@example.com"), + wantOutcome: ports.EnsureUserOutcomeExisting, + wantUserID: common.UserID("user-existing"), + }, + { + name: "created", + email: common.Email("created@example.com"), + wantOutcome: ports.EnsureUserOutcomeCreated, + wantUserID: common.UserID("user-created"), + }, + { + name: "blocked", + email: common.Email("blocked@example.com"), + wantOutcome: ports.EnsureUserOutcomeBlocked, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := directory.EnsureUserByEmail(context.Background(), tt.email) + if err != nil { + require.Failf(t, "test failed", "EnsureUserByEmail() returned error: %v", err) + } + if got.Outcome != tt.wantOutcome { + require.Failf(t, "test failed", "EnsureUserByEmail().Outcome = %q, want %q", got.Outcome, tt.wantOutcome) + } + if got.UserID != tt.wantUserID { + require.Failf(t, "test failed", "EnsureUserByEmail().UserID = %q, want %q", got.UserID, tt.wantUserID) + } + }) + } +} + +func TestInMemoryUserDirectoryExistsByUserID(t *testing.T) { + t.Parallel() + + directory := &InMemoryUserDirectory{} + if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + + exists, err := directory.ExistsByUserID(context.Background(), common.UserID("user-existing")) + if err != nil { + require.Failf(t, "test failed", "ExistsByUserID() returned error: %v", err) + } + if !exists { + require.FailNow(t, "ExistsByUserID() = false, want true") + } + + exists, err = directory.ExistsByUserID(context.Background(), common.UserID("missing")) + if err != nil { + require.Failf(t, "test failed", "ExistsByUserID() returned error: %v", err) + } + if exists { + require.FailNow(t, "ExistsByUserID() = true, want false") + } +} + +func TestInMemoryUserDirectoryBlockByEmail(t *testing.T) { + t.Parallel() + + directory := &InMemoryUserDirectory{} + result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{ + Email: common.Email("blocked@example.com"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + if err != nil { + require.Failf(t, "test failed", "BlockByEmail() returned error: %v", err) + } + if result.Outcome != ports.BlockUserOutcomeBlocked { + require.Failf(t, "test failed", "BlockByEmail().Outcome = %q, want %q", result.Outcome, ports.BlockUserOutcomeBlocked) + } + + resolution, err := directory.ResolveByEmail(context.Background(), common.Email("blocked@example.com")) + if err != nil { + require.Failf(t, "test failed", "ResolveByEmail() returned error: %v", err) + } + if resolution.Kind != userresolution.KindBlocked { + require.Failf(t, "test failed", "ResolveByEmail().Kind = %q, want %q", resolution.Kind, userresolution.KindBlocked) + } +} + +func TestInMemoryUserDirectoryBlockByUserID(t *testing.T) { + t.Parallel() + + directory := &InMemoryUserDirectory{} + if err := directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil { + require.Failf(t, "test failed", "SeedExisting() returned error: %v", err) + } + + result, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("user-1"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + if err != nil { + require.Failf(t, "test failed", "BlockByUserID() returned error: %v", err) + } + if result.Outcome != ports.BlockUserOutcomeBlocked { + require.Failf(t, "test failed", "BlockByUserID().Outcome = %q, want %q", result.Outcome, ports.BlockUserOutcomeBlocked) + } + + second, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("user-1"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + if err != nil { + require.Failf(t, "test failed", "second BlockByUserID() returned error: %v", err) + } + if second.Outcome != ports.BlockUserOutcomeAlreadyBlocked { + require.Failf(t, "test failed", "second BlockByUserID().Outcome = %q, want %q", second.Outcome, ports.BlockUserOutcomeAlreadyBlocked) + } +} + +func TestInMemoryUserDirectoryBlockByUserIDNotFound(t *testing.T) { + t.Parallel() + + directory := &InMemoryUserDirectory{} + + _, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{ + UserID: common.UserID("missing"), + ReasonCode: userresolution.BlockReasonCode("policy_block"), + }) + if !errors.Is(err, ports.ErrNotFound) { + require.Failf(t, "test failed", "BlockByUserID() error = %v, want ErrNotFound", err) + } +} diff --git a/authsession/mail_service_rest_compatibility_test.go b/authsession/mail_service_rest_compatibility_test.go new file mode 100644 index 0000000..c64689a --- /dev/null +++ b/authsession/mail_service_rest_compatibility_test.go @@ -0,0 +1,245 @@ +package authsession + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + mailadapter "galaxy/authsession/internal/adapters/mail" + "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/api/publichttp" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestMailServiceRESTCompatibilitySendEmailCodeSent(t *testing.T) { + t.Parallel() + + harness := newMailServiceRESTCompatibilityHarness(t, mailServiceRESTCompatibilityOptions{ + MailStatusCode: http.StatusOK, + MailResponse: `{"outcome":"sent"}`, + }) + + response := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, response.Body) + assert.Equal(t, 1, harness.mailServer.CallCount()) +} + +func TestMailServiceRESTCompatibilitySendEmailCodeSuppressed(t *testing.T) { + t.Parallel() + + harness := newMailServiceRESTCompatibilityHarness(t, mailServiceRESTCompatibilityOptions{ + MailStatusCode: http.StatusOK, + MailResponse: `{"outcome":"suppressed"}`, + }) + + response := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, response.Body) + assert.Equal(t, 1, harness.mailServer.CallCount()) +} + +func TestMailServiceRESTCompatibilitySendEmailCodeExplicitFailure(t *testing.T) { + t.Parallel() + + harness := newMailServiceRESTCompatibilityHarness(t, mailServiceRESTCompatibilityOptions{ + MailStatusCode: http.StatusServiceUnavailable, + MailResponse: `{"error":"temporary"}`, + }) + + response := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusServiceUnavailable, response.StatusCode) + assert.JSONEq(t, `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, response.Body) + assert.Equal(t, 1, harness.mailServer.CallCount()) +} + +func TestMailServiceRESTCompatibilityBlockedSendSkipsMailService(t *testing.T) { + t.Parallel() + + harness := newMailServiceRESTCompatibilityHarness(t, mailServiceRESTCompatibilityOptions{ + MailStatusCode: http.StatusOK, + MailResponse: `{"outcome":"sent"}`, + SeedBlockedEmail: true, + }) + + response := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, response.Body) + assert.Equal(t, 0, harness.mailServer.CallCount()) +} + +func TestMailServiceRESTCompatibilityThrottledSendSkipsMailService(t *testing.T) { + t.Parallel() + + harness := newMailServiceRESTCompatibilityHarness(t, mailServiceRESTCompatibilityOptions{ + MailStatusCode: http.StatusOK, + MailResponse: `{"outcome":"sent"}`, + AbuseProtector: &testkit.InMemorySendEmailCodeAbuseProtector{}, + }) + + first := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + second := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`) + + assert.Equal(t, http.StatusOK, first.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, first.Body) + assert.Equal(t, http.StatusOK, second.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-2"}`, second.Body) + assert.Equal(t, 1, harness.mailServer.CallCount()) +} + +type mailServiceRESTCompatibilityOptions struct { + MailStatusCode int + MailResponse string + SeedBlockedEmail bool + AbuseProtector *testkit.InMemorySendEmailCodeAbuseProtector +} + +type mailServiceRESTCompatibilityHarness struct { + publicBaseURL string + mailServer *mailServiceStubServer +} + +func newMailServiceRESTCompatibilityHarness(t *testing.T, options mailServiceRESTCompatibilityOptions) mailServiceRESTCompatibilityHarness { + t.Helper() + + challengeStore := &testkit.InMemoryChallengeStore{} + sessionStore := &testkit.InMemorySessionStore{} + userDirectory := &userservice.StubDirectory{} + if options.SeedBlockedEmail { + require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_blocked"))) + } + + mailServer := newMailServiceStubServer(options.MailStatusCode, options.MailResponse) + httpServer := httptest.NewServer(mailServer.Handler()) + t.Cleanup(httpServer.Close) + + mailSender, err := mailadapter.NewRESTClient(mailadapter.Config{ + BaseURL: httpServer.URL, + RequestTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, mailSender.Close()) + }) + + idGenerator := &testkit.SequenceIDGenerator{} + codeGenerator := testkit.FixedCodeGenerator{Code: "123456"} + codeHasher := testkit.DeterministicCodeHasher{} + clock := testkit.FixedClock{Time: time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)} + configProvider := testkit.StaticConfigProvider{} + projectionPublisher := &testkit.RecordingProjectionPublisher{} + + sendEmailCodeService, err := sendemailcode.NewWithObservability( + challengeStore, + userDirectory, + idGenerator, + codeGenerator, + codeHasher, + mailSender, + options.AbuseProtector, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + confirmEmailCodeService, err := confirmemailcode.NewWithObservability( + challengeStore, + sessionStore, + userDirectory, + configProvider, + projectionPublisher, + idGenerator, + codeHasher, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + publicCfg := publichttp.DefaultConfig() + publicCfg.Addr = gatewayCompatibilityFreeAddr(t) + publicServer, err := publichttp.NewServer(publicCfg, publichttp.Dependencies{ + SendEmailCode: sendEmailCodeService, + ConfirmEmailCode: confirmEmailCodeService, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + gatewayCompatibilityRunServer(t, publicServer.Run, publicServer.Shutdown, publicCfg.Addr) + + return mailServiceRESTCompatibilityHarness{ + publicBaseURL: "http://" + publicCfg.Addr, + mailServer: mailServer, + } +} + +type mailServiceStubServer struct { + mu sync.Mutex + statusCode int + response string + callCount int +} + +func newMailServiceStubServer(statusCode int, response string) *mailServiceStubServer { + return &mailServiceStubServer{ + statusCode: statusCode, + response: response, + } +} + +func (s *mailServiceStubServer) Handler() http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if request.Method != http.MethodPost || request.URL.Path != "/api/v1/internal/login-code-deliveries" { + http.NotFound(writer, request) + return + } + + s.mu.Lock() + s.callCount++ + s.mu.Unlock() + + decoder := json.NewDecoder(request.Body) + decoder.DisallowUnknownFields() + + var body struct { + Email string `json:"email"` + Code string `json:"code"` + } + if err := decoder.Decode(&body); err != nil { + http.Error(writer, err.Error(), http.StatusBadRequest) + return + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + if err == nil { + http.Error(writer, "unexpected trailing JSON input", http.StatusBadRequest) + return + } + http.Error(writer, err.Error(), http.StatusBadRequest) + return + } + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(s.statusCode) + _, _ = io.WriteString(writer, s.response) + }) +} + +func (s *mailServiceStubServer) CallCount() int { + s.mu.Lock() + defer s.mu.Unlock() + + return s.callCount +} diff --git a/authsession/production_hardening_concurrency_test.go b/authsession/production_hardening_concurrency_test.go new file mode 100644 index 0000000..72f6aef --- /dev/null +++ b/authsession/production_hardening_concurrency_test.go @@ -0,0 +1,334 @@ +package authsession + +import ( + "context" + "encoding/json" + "net/http" + "sync" + "testing" + "time" + + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/ports" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// gatedCreateSessionStore blocks the first target successful Create calls +// after they persist the session, which lets concurrency tests force overlap +// between confirm and competing revoke/block flows. +type gatedCreateSessionStore struct { + delegate ports.SessionStore + target int + + arrived chan common.DeviceSessionID + release chan struct{} + + mu sync.Mutex + seenCreates int + releaseOnce sync.Once +} + +// newGatedCreateSessionStore wraps delegate with deterministic post-create +// gating for the first target successful session creations. +func newGatedCreateSessionStore(delegate ports.SessionStore, target int) *gatedCreateSessionStore { + return &gatedCreateSessionStore{ + delegate: delegate, + target: target, + arrived: make(chan common.DeviceSessionID, target), + release: make(chan struct{}), + } +} + +// Create delegates persistence first and then blocks the first configured +// number of successful creations until Release is called. +func (s *gatedCreateSessionStore) Create(ctx context.Context, record devicesession.Session) error { + if err := s.delegate.Create(ctx, record); err != nil { + return err + } + + s.mu.Lock() + shouldGate := s.seenCreates < s.target + if shouldGate { + s.seenCreates++ + } + s.mu.Unlock() + + if !shouldGate { + return nil + } + + s.arrived <- record.ID + + select { + case <-s.release: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// WaitForCreates waits for count gated successful Create calls and returns the +// corresponding device session identifiers in arrival order. +func (s *gatedCreateSessionStore) WaitForCreates(t *testing.T, count int) []common.DeviceSessionID { + t.Helper() + + ids := make([]common.DeviceSessionID, 0, count) + timeout := time.After(5 * time.Second) + + for len(ids) < count { + select { + case id := <-s.arrived: + ids = append(ids, id) + case <-timeout: + require.FailNowf(t, "test failed", "timed out waiting for %d gated session creations", count) + } + } + + return ids +} + +// Release unblocks every gated Create call. +func (s *gatedCreateSessionStore) Release() { + s.releaseOnce.Do(func() { + close(s.release) + }) +} + +// Get delegates to the wrapped session store. +func (s *gatedCreateSessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { + return s.delegate.Get(ctx, deviceSessionID) +} + +// ListByUserID delegates to the wrapped session store. +func (s *gatedCreateSessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) { + return s.delegate.ListByUserID(ctx, userID) +} + +// CountActiveByUserID delegates to the wrapped session store. +func (s *gatedCreateSessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) { + return s.delegate.CountActiveByUserID(ctx, userID) +} + +// Revoke delegates to the wrapped session store. +func (s *gatedCreateSessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) { + return s.delegate.Revoke(ctx, input) +} + +// RevokeAllByUserID delegates to the wrapped session store. +func (s *gatedCreateSessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) { + return s.delegate.RevokeAllByUserID(ctx, input) +} + +var _ ports.SessionStore = (*gatedCreateSessionStore)(nil) + +func TestProductionHardeningConcurrentIdenticalConfirmsConvergeToOneActiveSession(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + var gate *gatedCreateSessionStore + app := newHardeningApp(t, env, hardeningAppOptions{ + SeedExistingUser: true, + WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore { + gate = newGatedCreateSessionStore(delegate, 2) + return gate + }, + }) + + challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail) + requestBody := map[string]string{ + "challenge_id": challengeID, + "code": code, + "client_public_key": gatewayCompatibilityClientPublicKey, + } + + responses := make([]gatewayCompatibilityHTTPResponse, 2) + start := make(chan struct{}) + + var requests sync.WaitGroup + requests.Add(2) + for index := range responses { + go func(index int) { + defer requests.Done() + <-start + responses[index] = gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", requestBody) + }(index) + } + + close(start) + createdIDs := gate.WaitForCreates(t, 2) + require.Len(t, createdIDs, 2) + assert.NotEqual(t, createdIDs[0], createdIDs[1]) + + gate.Release() + requests.Wait() + + var deviceSessionIDs []string + for _, response := range responses { + assert.Equal(t, http.StatusOK, response.StatusCode) + + var body struct { + DeviceSessionID string `json:"device_session_id"` + } + require.NoError(t, json.Unmarshal([]byte(response.Body), &body)) + deviceSessionIDs = append(deviceSessionIDs, body.DeviceSessionID) + } + require.Len(t, deviceSessionIDs, 2) + assert.Equal(t, deviceSessionIDs[0], deviceSessionIDs[1]) + + records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + require.Len(t, records, 2) + + activeCount := 0 + revokedCount := 0 + for _, record := range records { + switch record.Status { + case devicesession.StatusActive: + activeCount++ + assert.Equal(t, common.DeviceSessionID(deviceSessionIDs[0]), record.ID) + case devicesession.StatusRevoked: + revokedCount++ + require.NotNil(t, record.Revocation) + assert.Equal(t, common.RevokeReasonCode("confirm_race_repair"), record.Revocation.ReasonCode) + default: + require.Failf(t, "test failed", "unexpected final session status %q", record.Status) + } + } + assert.Equal(t, 1, activeCount) + assert.Equal(t, 1, revokedCount) + + cacheRecord := env.MustReadGatewayCacheRecord(t, deviceSessionIDs[0]) + assert.Equal(t, "active", cacheRecord.Status) +} + +func TestProductionHardeningConcurrentConfirmAndRevokeAllKeepProjectionConsistent(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + var gate *gatedCreateSessionStore + app := newHardeningApp(t, env, hardeningAppOptions{ + SeedExistingUser: true, + WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore { + gate = newGatedCreateSessionStore(delegate, 1) + return gate + }, + }) + + challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail) + confirmResponseCh := make(chan gatewayCompatibilityHTTPResponse, 1) + go func() { + confirmResponseCh <- gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": challengeID, + "code": code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + }() + + createdIDs := gate.WaitForCreates(t, 1) + sessionID := createdIDs[0].String() + + revokeAllResponse := gatewayCompatibilityPostJSON( + t, + app.internalBaseURL+"/api/v1/internal/users/user-1/sessions/revoke-all", + `{"reason_code":"logout_all","actor":{"type":"system"}}`, + ) + assert.Equal(t, http.StatusOK, revokeAllResponse.StatusCode) + assert.JSONEq(t, `{"outcome":"revoked","user_id":"user-1","affected_session_count":1,"affected_device_session_ids":["`+sessionID+`"]}`, revokeAllResponse.Body) + + gate.Release() + confirmResponse := <-confirmResponseCh + assert.Equal(t, http.StatusOK, confirmResponse.StatusCode) + + var confirmBody struct { + DeviceSessionID string `json:"device_session_id"` + } + require.NoError(t, json.Unmarshal([]byte(confirmResponse.Body), &confirmBody)) + assert.Equal(t, sessionID, confirmBody.DeviceSessionID) + + records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + require.Len(t, records, 1) + assert.Equal(t, devicesession.StatusRevoked, records[0].Status) + require.NotNil(t, records[0].Revocation) + assert.Equal(t, devicesession.RevokeReasonLogoutAll, records[0].Revocation.ReasonCode) + + cacheRecord := env.MustReadGatewayCacheRecord(t, sessionID) + assert.Equal(t, "revoked", cacheRecord.Status) + require.NotNil(t, cacheRecord.RevokedAtMS) +} + +func TestProductionHardeningConcurrentBlockUserAndConfirmDoNotLeakActiveSession(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + var gate *gatedCreateSessionStore + app := newHardeningApp(t, env, hardeningAppOptions{ + SeedExistingUser: true, + WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore { + gate = newGatedCreateSessionStore(delegate, 1) + return gate + }, + }) + + challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail) + initialAttempts := app.mailSender.RecordedAttempts() + require.Len(t, initialAttempts, 1) + + confirmResponseCh := make(chan gatewayCompatibilityHTTPResponse, 1) + go func() { + confirmResponseCh <- gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": challengeID, + "code": code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + }() + + createdIDs := gate.WaitForCreates(t, 1) + sessionID := createdIDs[0].String() + + blockResponse := gatewayCompatibilityPostJSON( + t, + app.internalBaseURL+"/api/v1/internal/user-blocks", + `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`, + ) + assert.Equal(t, http.StatusOK, blockResponse.StatusCode) + assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":1,"affected_device_session_ids":["`+sessionID+`"]}`, blockResponse.Body) + + gate.Release() + confirmResponse := <-confirmResponseCh + assert.Contains(t, []int{http.StatusOK, http.StatusForbidden}, confirmResponse.StatusCode) + + records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + require.Len(t, records, 1) + assert.Equal(t, devicesession.StatusRevoked, records[0].Status) + require.NotNil(t, records[0].Revocation) + assert.Equal(t, devicesession.RevokeReasonUserBlocked, records[0].Revocation.ReasonCode) + + cacheRecord := env.MustReadGatewayCacheRecord(t, sessionID) + assert.Equal(t, "revoked", cacheRecord.Status) + require.NotNil(t, cacheRecord.RevokedAtMS) + + followupSend := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", map[string]string{ + "email": gatewayCompatibilityEmail, + }) + assert.Equal(t, http.StatusOK, followupSend.StatusCode) + + var sendBody struct { + ChallengeID string `json:"challenge_id"` + } + require.NoError(t, json.Unmarshal([]byte(followupSend.Body), &sendBody)) + assert.NotEmpty(t, sendBody.ChallengeID) + assert.Len(t, app.mailSender.RecordedAttempts(), 1) + + followupConfirm := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": sendBody.ChallengeID, + "code": gatewayCompatibilityCode, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusForbidden, followupConfirm.StatusCode) + assert.JSONEq(t, `{"error":{"code":"blocked_by_policy","message":"authentication is blocked by policy"}}`, followupConfirm.Body) +} diff --git a/authsession/production_hardening_test.go b/authsession/production_hardening_test.go new file mode 100644 index 0000000..35e1546 --- /dev/null +++ b/authsession/production_hardening_test.go @@ -0,0 +1,837 @@ +package authsession + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "sync" + "testing" + "time" + + "galaxy/authsession/internal/adapters/mail" + "galaxy/authsession/internal/adapters/redis/challengestore" + "galaxy/authsession/internal/adapters/redis/configprovider" + "galaxy/authsession/internal/adapters/redis/projectionpublisher" + "galaxy/authsession/internal/adapters/redis/sessionstore" + "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/api/internalhttp" + "galaxy/authsession/internal/api/publichttp" + "galaxy/authsession/internal/domain/challenge" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/devicesession" + "galaxy/authsession/internal/domain/gatewayprojection" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/service/shared" + "galaxy/authsession/internal/testkit" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +const hardeningLargeSessionCount = 256 + +// hardeningEnvironment owns one reusable Redis-backed integration environment +// for Stage 22 tests. +type hardeningEnvironment struct { + redisAddr string + redisServer *miniredis.Miniredis + redisClient *redis.Client + now time.Time +} + +// newHardeningEnvironment starts one miniredis-backed environment on a stable +// local address so tests can restart Redis on the same endpoint when needed. +func newHardeningEnvironment(t *testing.T) *hardeningEnvironment { + t.Helper() + + env := &hardeningEnvironment{ + redisAddr: gatewayCompatibilityFreeAddr(t), + now: time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC), + } + env.startRedis(t) + + env.redisClient = redis.NewClient(&redis.Options{ + Addr: env.redisAddr, + Protocol: 2, + DisableIdentity: true, + }) + + t.Cleanup(func() { + env.Close() + }) + + return env +} + +// startRedis starts one miniredis instance on the environment's configured +// address. +func (e *hardeningEnvironment) startRedis(t *testing.T) { + t.Helper() + + if e.redisServer != nil { + require.Fail(t, "hardening environment redis already running") + } + + server := miniredis.NewMiniRedis() + require.NoError(t, server.StartAddr(e.redisAddr)) + e.redisServer = server +} + +// StopRedis stops the current Redis server and keeps the configured address +// reserved for later restart tests. +func (e *hardeningEnvironment) StopRedis() { + if e == nil || e.redisServer == nil { + return + } + + e.redisServer.Close() + e.redisServer = nil +} + +// RestartRedis starts a fresh Redis server on the same configured address. +func (e *hardeningEnvironment) RestartRedis(t *testing.T) { + t.Helper() + + e.StopRedis() + e.startRedis(t) +} + +// FastForward advances miniredis time to exercise TTL-based cleanup behavior. +func (e *hardeningEnvironment) FastForward(t *testing.T, duration time.Duration) { + t.Helper() + + require.NotNil(t, e.redisServer) + e.redisServer.FastForward(duration) +} + +// Close releases the Redis client and any still-running Redis server. +func (e *hardeningEnvironment) Close() { + if e == nil { + return + } + if e.redisClient != nil { + _ = e.redisClient.Close() + e.redisClient = nil + } + if e.redisServer != nil { + e.redisServer.Close() + e.redisServer = nil + } +} + +// GatewayCacheExists reports whether the gateway-compatible cache record for +// deviceSessionID is currently present in Redis. +func (e *hardeningEnvironment) GatewayCacheExists(ctx context.Context, deviceSessionID string) bool { + if e == nil || e.redisClient == nil { + return false + } + + _, err := e.redisClient.Get(ctx, gatewayCompatibilitySessionCacheKeyPrefix+deviceSessionID).Bytes() + return err == nil +} + +// MustReadGatewayCacheRecord reads one strict gateway-compatible cache record +// from Redis. +func (e *hardeningEnvironment) MustReadGatewayCacheRecord(t *testing.T, deviceSessionID string) gatewayCacheRecord { + t.Helper() + + payload, err := e.redisClient.Get(context.Background(), gatewayCompatibilitySessionCacheKeyPrefix+deviceSessionID).Bytes() + require.NoError(t, err) + + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.DisallowUnknownFields() + + var record gatewayCacheRecord + require.NoError(t, decoder.Decode(&record)) + + err = decoder.Decode(&struct{}{}) + require.ErrorIs(t, err, io.EOF) + + require.Equal(t, deviceSessionID, record.DeviceSessionID) + require.NotEmpty(t, record.UserID) + require.NotEmpty(t, record.ClientPublicKey) + require.Contains(t, []string{"active", "revoked"}, record.Status) + + return record +} + +// MustReadGatewaySessionEvents reads every gateway-compatible stream event for +// deviceSessionID from the shared session-events stream. +func (e *hardeningEnvironment) MustReadGatewaySessionEvents(t *testing.T, deviceSessionID string) []gatewaySessionEventRecord { + t.Helper() + + entries, err := e.redisClient.XRange(context.Background(), gatewayCompatibilitySessionEventsStream, "-", "+").Result() + require.NoError(t, err) + + records := make([]gatewaySessionEventRecord, 0, len(entries)) + for _, entry := range entries { + record := decodeGatewaySessionEvent(t, entry.Values) + if record.DeviceSessionID == deviceSessionID { + records = append(records, record) + } + } + require.NotEmpty(t, records) + + return records +} + +// hardeningAppOptions configures one runnable Stage-22 integration app. +type hardeningAppOptions struct { + SeedExistingUser bool + SeedBlockedEmail bool + SessionLimit *int + SeedSessions []devicesession.Session + PublisherErrors []error + WrapSessionStore func(ports.SessionStore) ports.SessionStore +} + +// hardeningApp owns one pair of real public and internal HTTP servers backed +// by real Redis adapters and seedable stub dependencies. +type hardeningApp struct { + publicBaseURL string + internalBaseURL string + + challengeStore *challengestore.Store + sessionStore *sessionstore.Store + configStore *configprovider.Store + publisher *projectionpublisher.Publisher + + mailSender *mail.StubSender + userDirectory *userservice.StubDirectory + + closeOnce sync.Once + closeFn func() +} + +// newHardeningApp builds and starts one real authsession HTTP pair over the +// shared hardening environment. +func newHardeningApp(t *testing.T, env *hardeningEnvironment, options hardeningAppOptions) *hardeningApp { + t.Helper() + + require.NotNil(t, env) + + if options.SessionLimit == nil { + require.NoError(t, env.redisClient.Del(context.Background(), gatewayCompatibilitySessionLimitKey).Err()) + } else { + env.redisServer.Set(gatewayCompatibilitySessionLimitKey, strconv.Itoa(*options.SessionLimit)) + } + + challengeStore, err := challengestore.New(challengestore.Config{ + Addr: env.redisAddr, + DB: 0, + KeyPrefix: gatewayCompatibilityChallengeKeyPrefix, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + + redisSessionStore, err := sessionstore.New(sessionstore.Config{ + Addr: env.redisAddr, + DB: 0, + SessionKeyPrefix: gatewayCompatibilitySessionKeyPrefix, + UserSessionsKeyPrefix: gatewayCompatibilityUserSessionsKeyPrefix, + UserActiveSessionsKeyPrefix: gatewayCompatibilityUserActiveKeyPrefix, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + + configStore, err := configprovider.New(configprovider.Config{ + Addr: env.redisAddr, + DB: 0, + SessionLimitKey: gatewayCompatibilitySessionLimitKey, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + + redisPublisher, err := projectionpublisher.New(projectionpublisher.Config{ + Addr: env.redisAddr, + DB: 0, + SessionCacheKeyPrefix: gatewayCompatibilitySessionCacheKeyPrefix, + SessionEventsStream: gatewayCompatibilitySessionEventsStream, + StreamMaxLen: gatewayCompatibilityStreamMaxLen, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + + userDirectory := &userservice.StubDirectory{} + if options.SeedBlockedEmail { + require.NoError(t, userDirectory.SeedBlockedEmail(common.Email(gatewayCompatibilityEmail), "policy_blocked")) + } + if options.SeedExistingUser { + require.NoError(t, userDirectory.SeedExisting(common.Email(gatewayCompatibilityEmail), common.UserID("user-1"))) + } + + for _, session := range options.SeedSessions { + require.NoError(t, redisSessionStore.Create(context.Background(), session)) + } + + publisherPort := ports.GatewaySessionProjectionPublisher(redisPublisher) + if len(options.PublisherErrors) > 0 { + publisherPort = &scriptedProjectionPublisher{ + delegate: redisPublisher, + errors: append([]error(nil), options.PublisherErrors...), + } + } + + sessionStorePort := ports.SessionStore(redisSessionStore) + if options.WrapSessionStore != nil { + sessionStorePort = options.WrapSessionStore(sessionStorePort) + } + + mailSender := &mail.StubSender{} + idGenerator := &testkit.SequenceIDGenerator{} + codeHasher := testkit.DeterministicCodeHasher{} + clock := testkit.FixedClock{Time: env.now} + + sendEmailCodeService, err := sendemailcode.NewWithObservability( + challengeStore, + userDirectory, + idGenerator, + testkit.FixedCodeGenerator{Code: gatewayCompatibilityCode}, + codeHasher, + mailSender, + nil, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + confirmEmailCodeService, err := confirmemailcode.NewWithObservability( + challengeStore, + sessionStorePort, + userDirectory, + configStore, + publisherPort, + idGenerator, + codeHasher, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + getSessionService, err := getsession.New(sessionStorePort) + require.NoError(t, err) + listUserSessionsService, err := listusersessions.New(sessionStorePort) + require.NoError(t, err) + revokeDeviceSessionService, err := revokedevicesession.NewWithObservability(sessionStorePort, publisherPort, clock, zap.NewNop(), nil) + require.NoError(t, err) + revokeAllUserSessionsService, err := revokeallusersessions.NewWithObservability(sessionStorePort, userDirectory, publisherPort, clock, zap.NewNop(), nil) + require.NoError(t, err) + blockUserService, err := blockuser.NewWithObservability(userDirectory, sessionStorePort, publisherPort, clock, zap.NewNop(), nil) + require.NoError(t, err) + + publicCfg := publichttp.DefaultConfig() + publicCfg.Addr = gatewayCompatibilityFreeAddr(t) + publicServer, err := publichttp.NewServer(publicCfg, publichttp.Dependencies{ + SendEmailCode: sendEmailCodeService, + ConfirmEmailCode: confirmEmailCodeService, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + internalCfg := internalhttp.DefaultConfig() + internalCfg.Addr = gatewayCompatibilityFreeAddr(t) + internalServer, err := internalhttp.NewServer(internalCfg, internalhttp.Dependencies{ + GetSession: getSessionService, + ListUserSessions: listUserSessionsService, + RevokeDeviceSession: revokeDeviceSessionService, + RevokeAllUserSessions: revokeAllUserSessionsService, + BlockUser: blockUserService, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + stopPublic := startHardeningServer(t, publicServer.Run, publicServer.Shutdown, publicCfg.Addr) + stopInternal := startHardeningServer(t, internalServer.Run, internalServer.Shutdown, internalCfg.Addr) + + app := &hardeningApp{ + publicBaseURL: "http://" + publicCfg.Addr, + internalBaseURL: "http://" + internalCfg.Addr, + challengeStore: challengeStore, + sessionStore: redisSessionStore, + configStore: configStore, + publisher: redisPublisher, + mailSender: mailSender, + userDirectory: userDirectory, + } + app.closeFn = func() { + stopPublic() + stopInternal() + assert.NoError(t, challengeStore.Close()) + assert.NoError(t, redisSessionStore.Close()) + assert.NoError(t, configStore.Close()) + assert.NoError(t, redisPublisher.Close()) + } + t.Cleanup(func() { + app.Close() + }) + + return app +} + +// Close stops the app servers and releases the real Redis adapters. +func (a *hardeningApp) Close() { + if a == nil { + return + } + + a.closeOnce.Do(func() { + if a.closeFn != nil { + a.closeFn() + } + }) +} + +// SendChallenge exercises the public send endpoint and returns the issued +// challenge identifier together with the cleartext code observed by the stub +// mail sender. +func (a *hardeningApp) SendChallenge(t *testing.T, email string) (string, string) { + t.Helper() + + response := gatewayCompatibilityPostJSONValue(t, a.publicBaseURL+"/api/v1/public/auth/send-email-code", map[string]string{ + "email": email, + }) + assert.Equal(t, http.StatusOK, response.StatusCode) + + var body struct { + ChallengeID string `json:"challenge_id"` + } + require.NoError(t, json.Unmarshal([]byte(response.Body), &body)) + + attempts := a.mailSender.RecordedAttempts() + require.NotEmpty(t, attempts) + + return body.ChallengeID, attempts[len(attempts)-1].Input.Code +} + +// CreateSessionThroughPublicFlow creates one active user session through the +// real public send and confirm handlers. +func (a *hardeningApp) CreateSessionThroughPublicFlow(t *testing.T) string { + t.Helper() + + challengeID, code := a.SendChallenge(t, gatewayCompatibilityEmail) + response := gatewayCompatibilityPostJSONValue(t, a.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": challengeID, + "code": code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusOK, response.StatusCode) + + var body struct { + DeviceSessionID string `json:"device_session_id"` + } + require.NoError(t, json.Unmarshal([]byte(response.Body), &body)) + + return body.DeviceSessionID +} + +// scriptedProjectionPublisher fails selected publish attempts before +// delegating to the real Redis projection publisher. +type scriptedProjectionPublisher struct { + mu sync.Mutex + + delegate ports.GatewaySessionProjectionPublisher + errors []error +} + +// PublishSession returns scripted errors first and delegates only after the +// script is exhausted. +func (p *scriptedProjectionPublisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error { + if err := ctx.Err(); err != nil { + return err + } + if err := snapshot.Validate(); err != nil { + return err + } + + p.mu.Lock() + if len(p.errors) > 0 { + err := p.errors[0] + p.errors = append([]error(nil), p.errors[1:]...) + p.mu.Unlock() + return err + } + p.mu.Unlock() + + return p.delegate.PublishSession(ctx, snapshot) +} + +var _ ports.GatewaySessionProjectionPublisher = (*scriptedProjectionPublisher)(nil) + +// startHardeningServer starts one HTTP server and returns a stop function that +// performs graceful shutdown exactly once. +func startHardeningServer( + t *testing.T, + run func(context.Context) error, + shutdown func(context.Context) error, + addr string, +) func() { + t.Helper() + + errCh := make(chan error, 1) + go func() { + errCh <- run(context.Background()) + }() + + gatewayCompatibilityWaitForTCP(t, addr) + + var once sync.Once + return func() { + once.Do(func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + assert.NoError(t, shutdown(shutdownCtx)) + assert.NoError(t, <-errCh) + }) + } +} + +// hardeningGetJSON sends one GET request and returns the captured response. +func hardeningGetJSON(t *testing.T, url string) gatewayCompatibilityHTTPResponse { + t.Helper() + + response, err := http.Get(url) + require.NoError(t, err) + defer response.Body.Close() + + payload, err := io.ReadAll(response.Body) + require.NoError(t, err) + + return gatewayCompatibilityHTTPResponse{ + StatusCode: response.StatusCode, + Body: string(payload), + } +} + +func TestProductionHardeningRedisReconnectRecoversOnSameLiveProcess(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + app := newHardeningApp(t, env, hardeningAppOptions{}) + + _, _ = app.SendChallenge(t, gatewayCompatibilityEmail) + + env.StopRedis() + + require.Eventually(t, func() bool { + response := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", map[string]string{ + "email": gatewayCompatibilityEmail, + }) + return response.StatusCode == http.StatusServiceUnavailable + }, 5*time.Second, 50*time.Millisecond) + + env.RestartRedis(t) + + require.Eventually(t, func() bool { + response := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", map[string]string{ + "email": gatewayCompatibilityEmail, + }) + return response.StatusCode == http.StatusOK + }, 5*time.Second, 50*time.Millisecond) +} + +func TestProductionHardeningConfirmRetryRepairsProjectionAfterProcessRestart(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + publishErr := errors.New("hardening publish failure") + + failingApp := newHardeningApp(t, env, hardeningAppOptions{ + PublisherErrors: repeatHardeningError(publishErr, shared.MaxProjectionPublishAttempts), + }) + + challengeID, code := failingApp.SendChallenge(t, gatewayCompatibilityEmail) + firstConfirm := gatewayCompatibilityPostJSONValue(t, failingApp.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": challengeID, + "code": code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusServiceUnavailable, firstConfirm.StatusCode) + assert.False(t, env.GatewayCacheExists(context.Background(), "device-session-1")) + + failingApp.Close() + + healthyApp := newHardeningApp(t, env, hardeningAppOptions{}) + secondConfirm := gatewayCompatibilityPostJSONValue(t, healthyApp.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": challengeID, + "code": code, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusOK, secondConfirm.StatusCode) + + var body struct { + DeviceSessionID string `json:"device_session_id"` + } + require.NoError(t, json.Unmarshal([]byte(secondConfirm.Body), &body)) + assert.Equal(t, "device-session-1", body.DeviceSessionID) + + record := env.MustReadGatewayCacheRecord(t, body.DeviceSessionID) + assert.Equal(t, gatewayCacheRecord{ + DeviceSessionID: "device-session-1", + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "active", + }, record) +} + +func TestProductionHardeningRepeatedRevokeRepairsProjectionAfterProcessRestart(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + createApp := newHardeningApp(t, env, hardeningAppOptions{SeedExistingUser: true}) + sessionID := createApp.CreateSessionThroughPublicFlow(t) + createApp.Close() + + publishErr := errors.New("hardening publish failure") + failingApp := newHardeningApp(t, env, hardeningAppOptions{ + SeedExistingUser: true, + PublisherErrors: repeatHardeningError(publishErr, shared.MaxProjectionPublishAttempts), + }) + + firstRevoke := gatewayCompatibilityPostJSON( + t, + failingApp.internalBaseURL+"/api/v1/internal/sessions/"+sessionID+"/revoke", + `{"reason_code":"admin_revoke","actor":{"type":"system"}}`, + ) + assert.Equal(t, http.StatusServiceUnavailable, firstRevoke.StatusCode) + + activeRecord := env.MustReadGatewayCacheRecord(t, sessionID) + assert.Equal(t, "active", activeRecord.Status) + + failingApp.Close() + + healthyApp := newHardeningApp(t, env, hardeningAppOptions{SeedExistingUser: true}) + secondRevoke := gatewayCompatibilityPostJSON( + t, + healthyApp.internalBaseURL+"/api/v1/internal/sessions/"+sessionID+"/revoke", + `{"reason_code":"admin_revoke","actor":{"type":"system"}}`, + ) + assert.Equal(t, http.StatusOK, secondRevoke.StatusCode) + assert.JSONEq(t, `{"outcome":"already_revoked","device_session_id":"`+sessionID+`","affected_session_count":0}`, secondRevoke.Body) + + revokedRecord := env.MustReadGatewayCacheRecord(t, sessionID) + require.NotNil(t, revokedRecord.RevokedAtMS) + assert.Equal(t, "revoked", revokedRecord.Status) +} + +func TestProductionHardeningRepeatedRevokeAllRepairsProjectionAfterProcessRestart(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + createApp := newHardeningApp(t, env, hardeningAppOptions{SeedExistingUser: true}) + firstSessionID := createApp.CreateSessionThroughPublicFlow(t) + secondSessionID := createApp.CreateSessionThroughPublicFlow(t) + createApp.Close() + + publishErr := errors.New("hardening publish failure") + failingApp := newHardeningApp(t, env, hardeningAppOptions{ + SeedExistingUser: true, + PublisherErrors: repeatHardeningError(publishErr, shared.MaxProjectionPublishAttempts), + }) + + firstRevokeAll := gatewayCompatibilityPostJSON( + t, + failingApp.internalBaseURL+"/api/v1/internal/users/user-1/sessions/revoke-all", + `{"reason_code":"logout_all","actor":{"type":"system"}}`, + ) + assert.Equal(t, http.StatusServiceUnavailable, firstRevokeAll.StatusCode) + + assert.Equal(t, "active", env.MustReadGatewayCacheRecord(t, firstSessionID).Status) + assert.Equal(t, "active", env.MustReadGatewayCacheRecord(t, secondSessionID).Status) + + failingApp.Close() + + healthyApp := newHardeningApp(t, env, hardeningAppOptions{SeedExistingUser: true}) + secondRevokeAll := gatewayCompatibilityPostJSON( + t, + healthyApp.internalBaseURL+"/api/v1/internal/users/user-1/sessions/revoke-all", + `{"reason_code":"logout_all","actor":{"type":"system"}}`, + ) + assert.Equal(t, http.StatusOK, secondRevokeAll.StatusCode) + assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-1","affected_session_count":0,"affected_device_session_ids":[]}`, secondRevokeAll.Body) + + firstRecord := env.MustReadGatewayCacheRecord(t, firstSessionID) + secondRecord := env.MustReadGatewayCacheRecord(t, secondSessionID) + require.NotNil(t, firstRecord.RevokedAtMS) + require.NotNil(t, secondRecord.RevokedAtMS) + assert.Equal(t, "revoked", firstRecord.Status) + assert.Equal(t, "revoked", secondRecord.Status) +} + +func TestProductionHardeningDuplicatePublishKeepsGatewayCacheCanonical(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + publisher, err := projectionpublisher.New(projectionpublisher.Config{ + Addr: env.redisAddr, + DB: 0, + SessionCacheKeyPrefix: gatewayCompatibilitySessionCacheKeyPrefix, + SessionEventsStream: gatewayCompatibilitySessionEventsStream, + StreamMaxLen: gatewayCompatibilityStreamMaxLen, + OperationTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + defer func() { + assert.NoError(t, publisher.Close()) + }() + + snapshot := gatewayprojection.Snapshot{ + DeviceSessionID: common.DeviceSessionID("device-session-1"), + UserID: common.UserID("user-1"), + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: gatewayprojection.StatusActive, + } + require.NoError(t, snapshot.Validate()) + + require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) + require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) + + record := env.MustReadGatewayCacheRecord(t, "device-session-1") + assert.Equal(t, gatewayCacheRecord{ + DeviceSessionID: "device-session-1", + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "active", + }, record) + + events := env.MustReadGatewaySessionEvents(t, "device-session-1") + require.Len(t, events, 2) + assert.Equal(t, gatewaySessionEventRecord{ + DeviceSessionID: "device-session-1", + UserID: "user-1", + ClientPublicKey: gatewayCompatibilityClientPublicKey, + Status: "active", + }, events[0]) + assert.Equal(t, events[0], events[1]) +} + +func TestProductionHardeningExpiredChallengeReturnsExpiredDuringGraceAndNotFoundAfterGC(t *testing.T) { + t.Parallel() + + env := newHardeningEnvironment(t) + app := newHardeningApp(t, env, hardeningAppOptions{}) + + hasher := testkit.DeterministicCodeHasher{} + codeHash, err := hasher.Hash(gatewayCompatibilityCode) + require.NoError(t, err) + + record := challenge.Challenge{ + ID: common.ChallengeID("challenge-expired"), + Email: common.Email(gatewayCompatibilityEmail), + CodeHash: codeHash, + Status: challenge.StatusSent, + DeliveryState: challenge.DeliverySent, + CreatedAt: env.now.Add(-2 * time.Minute), + ExpiresAt: env.now.Add(-time.Second), + } + require.NoError(t, record.Validate()) + require.NoError(t, app.challengeStore.Create(context.Background(), record)) + + firstConfirm := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "challenge-expired", + "code": gatewayCompatibilityCode, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusGone, firstConfirm.StatusCode) + assert.JSONEq(t, `{"error":{"code":"challenge_expired","message":"challenge expired"}}`, firstConfirm.Body) + + env.FastForward(t, 5*time.Minute+time.Second) + + secondConfirm := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": "challenge-expired", + "code": gatewayCompatibilityCode, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + assert.Equal(t, http.StatusNotFound, secondConfirm.StatusCode) + assert.JSONEq(t, `{"error":{"code":"challenge_not_found","message":"challenge not found"}}`, secondConfirm.Body) +} + +func TestProductionHardeningLargeUserSessionListAndRevokeAllStayStable(t *testing.T) { + t.Parallel() + + sessions := make([]devicesession.Session, 0, hardeningLargeSessionCount) + for index := 0; index < hardeningLargeSessionCount; index++ { + sessions = append(sessions, gatewayCompatibilityActiveSession( + t, + fmt.Sprintf("bulk-session-%03d", index+1), + "user-1", + gatewayCompatibilityClientPublicKey, + time.Date(2026, 4, 5, 10, 0, index, 0, time.UTC), + )) + } + + env := newHardeningEnvironment(t) + app := newHardeningApp(t, env, hardeningAppOptions{ + SeedExistingUser: true, + SeedSessions: sessions, + }) + + listResponse := hardeningGetJSON(t, app.internalBaseURL+"/api/v1/internal/users/user-1/sessions") + assert.Equal(t, http.StatusOK, listResponse.StatusCode) + + var listBody struct { + Sessions []struct { + DeviceSessionID string `json:"device_session_id"` + Status string `json:"status"` + } `json:"sessions"` + } + require.NoError(t, json.Unmarshal([]byte(listResponse.Body), &listBody)) + require.Len(t, listBody.Sessions, hardeningLargeSessionCount) + assert.Equal(t, "bulk-session-256", listBody.Sessions[0].DeviceSessionID) + assert.Equal(t, "bulk-session-001", listBody.Sessions[len(listBody.Sessions)-1].DeviceSessionID) + for _, session := range listBody.Sessions { + assert.Equal(t, "active", session.Status) + } + + revokeResponse := gatewayCompatibilityPostJSON( + t, + app.internalBaseURL+"/api/v1/internal/users/user-1/sessions/revoke-all", + `{"reason_code":"logout_all","actor":{"type":"system"}}`, + ) + assert.Equal(t, http.StatusOK, revokeResponse.StatusCode) + + var revokeBody struct { + Outcome string `json:"outcome"` + UserID string `json:"user_id"` + AffectedSessionCount int `json:"affected_session_count"` + AffectedDeviceSessionIDs []string `json:"affected_device_session_ids"` + } + require.NoError(t, json.Unmarshal([]byte(revokeResponse.Body), &revokeBody)) + assert.Equal(t, "revoked", revokeBody.Outcome) + assert.Equal(t, "user-1", revokeBody.UserID) + assert.Equal(t, hardeningLargeSessionCount, revokeBody.AffectedSessionCount) + require.Len(t, revokeBody.AffectedDeviceSessionIDs, hardeningLargeSessionCount) + assert.Equal(t, "bulk-session-256", revokeBody.AffectedDeviceSessionIDs[0]) + assert.Equal(t, "bulk-session-001", revokeBody.AffectedDeviceSessionIDs[len(revokeBody.AffectedDeviceSessionIDs)-1]) + + activeCount, err := app.sessionStore.CountActiveByUserID(context.Background(), common.UserID("user-1")) + require.NoError(t, err) + assert.Zero(t, activeCount) +} + +// repeatHardeningError builds a stable FIFO error script for retry-oriented +// publisher hardening tests. +func repeatHardeningError(err error, count int) []error { + script := make([]error, 0, count) + for index := 0; index < count; index++ { + script = append(script, err) + } + + return script +} diff --git a/authsession/storage_boundary_test.go b/authsession/storage_boundary_test.go new file mode 100644 index 0000000..0c1e717 --- /dev/null +++ b/authsession/storage_boundary_test.go @@ -0,0 +1,72 @@ +package authsession + +import ( + "fmt" + "go/parser" + "go/token" + "io/fs" + "path/filepath" + "runtime" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestProductionCoreStaysStorageAgnostic(t *testing.T) { + t.Parallel() + + root := authsessionRootDir(t) + for _, relativeDir := range []string{ + filepath.Join("internal", "domain"), + filepath.Join("internal", "service"), + filepath.Join("internal", "ports"), + } { + checkStorageAgnosticImports(t, filepath.Join(root, relativeDir)) + } +} + +func authsessionRootDir(t *testing.T) string { + t.Helper() + + _, thisFile, _, ok := runtime.Caller(0) + require.True(t, ok, "runtime.Caller failed") + + return filepath.Dir(thisFile) +} + +func checkStorageAgnosticImports(t *testing.T, dir string) { + t.Helper() + + fileSet := token.NewFileSet() + err := filepath.WalkDir(dir, func(path string, entry fs.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if entry.IsDir() { + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + file, err := parser.ParseFile(fileSet, path, nil, parser.ImportsOnly) + if err != nil { + return err + } + + for _, importSpec := range file.Imports { + importPath, err := strconv.Unquote(importSpec.Path.Value) + if err != nil { + return err + } + if importPath == "github.com/redis/go-redis/v9" || strings.Contains(importPath, "internal/adapters/redis") { + return fmt.Errorf("storage-specific import %q found in %s", importPath, path) + } + } + + return nil + }) + require.NoError(t, err) +} diff --git a/authsession/user_service_rest_compatibility_test.go b/authsession/user_service_rest_compatibility_test.go new file mode 100644 index 0000000..d47b44d --- /dev/null +++ b/authsession/user_service_rest_compatibility_test.go @@ -0,0 +1,445 @@ +package authsession + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "galaxy/authsession/internal/adapters/mail" + "galaxy/authsession/internal/adapters/userservice" + "galaxy/authsession/internal/api/internalhttp" + "galaxy/authsession/internal/api/publichttp" + "galaxy/authsession/internal/domain/common" + "galaxy/authsession/internal/domain/userresolution" + "galaxy/authsession/internal/ports" + "galaxy/authsession/internal/service/blockuser" + "galaxy/authsession/internal/service/confirmemailcode" + "galaxy/authsession/internal/service/getsession" + "galaxy/authsession/internal/service/listusersessions" + "galaxy/authsession/internal/service/revokeallusersessions" + "galaxy/authsession/internal/service/revokedevicesession" + "galaxy/authsession/internal/service/sendemailcode" + "galaxy/authsession/internal/testkit" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +const userServiceRESTCompatibilityCode = "123456" + +func TestUserServiceRESTCompatibilityPublicSendUsesResolveByEmailOutcomes(t *testing.T) { + t.Parallel() + + harness := newUserServiceRESTCompatibilityHarness(t) + require.NoError(t, harness.directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing"))) + require.NoError(t, harness.directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_blocked"))) + + existing := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"existing@example.com"}`) + creatable := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"creatable@example.com"}`) + blocked := gatewayCompatibilityPostJSON(t, harness.publicBaseURL+"/api/v1/public/auth/send-email-code", `{"email":"blocked@example.com"}`) + + assert.Equal(t, http.StatusOK, existing.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, existing.Body) + assert.Equal(t, http.StatusOK, creatable.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-2"}`, creatable.Body) + assert.Equal(t, http.StatusOK, blocked.StatusCode) + assert.JSONEq(t, `{"challenge_id":"challenge-3"}`, blocked.Body) + + attempts := harness.mailSender.RecordedAttempts() + require.Len(t, attempts, 2) + assert.Equal(t, common.Email("existing@example.com"), attempts[0].Input.Email) + assert.Equal(t, common.Email("creatable@example.com"), attempts[1].Input.Email) +} + +func TestUserServiceRESTCompatibilityPublicConfirmUsesEnsureOutcomes(t *testing.T) { + t.Parallel() + + harness := newUserServiceRESTCompatibilityHarness(t) + require.NoError(t, harness.directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing"))) + require.NoError(t, harness.directory.QueueCreatedUserIDs(common.UserID("user-created"))) + require.NoError(t, harness.directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_blocked"))) + + existingChallengeID := harness.sendChallengeID(t, "existing@example.com") + createdChallengeID := harness.sendChallengeID(t, "created@example.com") + blockedChallengeID := harness.sendChallengeID(t, "blocked@example.com") + + existing := gatewayCompatibilityPostJSONValue(t, harness.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": existingChallengeID, + "code": userServiceRESTCompatibilityCode, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + created := gatewayCompatibilityPostJSONValue(t, harness.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": createdChallengeID, + "code": userServiceRESTCompatibilityCode, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + blocked := gatewayCompatibilityPostJSONValue(t, harness.publicBaseURL+"/api/v1/public/auth/confirm-email-code", map[string]string{ + "challenge_id": blockedChallengeID, + "code": userServiceRESTCompatibilityCode, + "client_public_key": gatewayCompatibilityClientPublicKey, + }) + + assert.Equal(t, http.StatusOK, existing.StatusCode) + assert.JSONEq(t, `{"device_session_id":"device-session-1"}`, existing.Body) + assert.Equal(t, http.StatusOK, created.StatusCode) + assert.JSONEq(t, `{"device_session_id":"device-session-2"}`, created.Body) + assert.Equal(t, http.StatusForbidden, blocked.StatusCode) + assert.JSONEq(t, `{"error":{"code":"blocked_by_policy","message":"authentication is blocked by policy"}}`, blocked.Body) + + existingSession, err := harness.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1")) + require.NoError(t, err) + assert.Equal(t, common.UserID("user-existing"), existingSession.UserID) + + createdSession, err := harness.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-2")) + require.NoError(t, err) + assert.Equal(t, common.UserID("user-created"), createdSession.UserID) +} + +func TestUserServiceRESTCompatibilityInternalRevokeAllUsesExistsByUserID(t *testing.T) { + t.Parallel() + + harness := newUserServiceRESTCompatibilityHarness(t) + require.NoError(t, harness.directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + existing := gatewayCompatibilityPostJSON( + t, + harness.internalBaseURL+"/api/v1/internal/users/user-1/sessions/revoke-all", + `{"reason_code":"logout_all","actor":{"type":"system"}}`, + ) + missing := gatewayCompatibilityPostJSON( + t, + harness.internalBaseURL+"/api/v1/internal/users/missing-user/sessions/revoke-all", + `{"reason_code":"logout_all","actor":{"type":"system"}}`, + ) + + assert.Equal(t, http.StatusOK, existing.StatusCode) + assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-1","affected_session_count":0,"affected_device_session_ids":[]}`, existing.Body) + assert.Equal(t, http.StatusNotFound, missing.StatusCode) + assert.JSONEq(t, `{"error":{"code":"subject_not_found","message":"subject not found"}}`, missing.Body) +} + +func TestUserServiceRESTCompatibilityInternalBlockUserUsesRESTClient(t *testing.T) { + t.Parallel() + + t.Run("block by user id", func(t *testing.T) { + t.Parallel() + + harness := newUserServiceRESTCompatibilityHarness(t) + require.NoError(t, harness.directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))) + + response := gatewayCompatibilityPostJSON( + t, + harness.internalBaseURL+"/api/v1/internal/user-blocks", + `{"user_id":"user-1","reason_code":"policy_blocked","actor":{"type":"admin"}}`, + ) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"user_id","subject_value":"user-1","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body) + }) + + t.Run("block by email", func(t *testing.T) { + t.Parallel() + + harness := newUserServiceRESTCompatibilityHarness(t) + + response := gatewayCompatibilityPostJSON( + t, + harness.internalBaseURL+"/api/v1/internal/user-blocks", + `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`, + ) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body) + }) +} + +type userServiceRESTCompatibilityHarness struct { + publicBaseURL string + internalBaseURL string + mailSender *mail.StubSender + sessionStore *testkit.InMemorySessionStore + directory *userservice.StubDirectory +} + +func newUserServiceRESTCompatibilityHarness(t *testing.T) userServiceRESTCompatibilityHarness { + t.Helper() + + challengeStore := &testkit.InMemoryChallengeStore{} + sessionStore := &testkit.InMemorySessionStore{} + directory := &userservice.StubDirectory{} + + userServiceServer := httptest.NewServer(newUserServiceStubHandler(directory)) + t.Cleanup(userServiceServer.Close) + + userDirectory, err := userservice.NewRESTClient(userservice.Config{ + BaseURL: userServiceServer.URL, + RequestTimeout: 250 * time.Millisecond, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, userDirectory.Close()) + }) + + configProvider := testkit.StaticConfigProvider{} + publisher := &testkit.RecordingProjectionPublisher{} + mailSender := &mail.StubSender{} + idGenerator := &testkit.SequenceIDGenerator{} + codeGenerator := testkit.FixedCodeGenerator{Code: userServiceRESTCompatibilityCode} + codeHasher := testkit.DeterministicCodeHasher{} + clock := testkit.FixedClock{Time: time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)} + + sendEmailCodeService, err := sendemailcode.NewWithObservability( + challengeStore, + userDirectory, + idGenerator, + codeGenerator, + codeHasher, + mailSender, + nil, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + confirmEmailCodeService, err := confirmemailcode.NewWithObservability( + challengeStore, + sessionStore, + userDirectory, + configProvider, + publisher, + idGenerator, + codeHasher, + clock, + zap.NewNop(), + nil, + ) + require.NoError(t, err) + + getSessionService, err := getsession.New(sessionStore) + require.NoError(t, err) + listUserSessionsService, err := listusersessions.New(sessionStore) + require.NoError(t, err) + revokeDeviceSessionService, err := revokedevicesession.NewWithObservability(sessionStore, publisher, clock, zap.NewNop(), nil) + require.NoError(t, err) + revokeAllUserSessionsService, err := revokeallusersessions.NewWithObservability(sessionStore, userDirectory, publisher, clock, zap.NewNop(), nil) + require.NoError(t, err) + blockUserService, err := blockuser.NewWithObservability(userDirectory, sessionStore, publisher, clock, zap.NewNop(), nil) + require.NoError(t, err) + + publicCfg := publichttp.DefaultConfig() + publicCfg.Addr = gatewayCompatibilityFreeAddr(t) + publicServer, err := publichttp.NewServer(publicCfg, publichttp.Dependencies{ + SendEmailCode: sendEmailCodeService, + ConfirmEmailCode: confirmEmailCodeService, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + internalCfg := internalhttp.DefaultConfig() + internalCfg.Addr = gatewayCompatibilityFreeAddr(t) + internalServer, err := internalhttp.NewServer(internalCfg, internalhttp.Dependencies{ + GetSession: getSessionService, + ListUserSessions: listUserSessionsService, + RevokeDeviceSession: revokeDeviceSessionService, + RevokeAllUserSessions: revokeAllUserSessionsService, + BlockUser: blockUserService, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + gatewayCompatibilityRunServer(t, publicServer.Run, publicServer.Shutdown, publicCfg.Addr) + gatewayCompatibilityRunServer(t, internalServer.Run, internalServer.Shutdown, internalCfg.Addr) + + return userServiceRESTCompatibilityHarness{ + publicBaseURL: "http://" + publicCfg.Addr, + internalBaseURL: "http://" + internalCfg.Addr, + mailSender: mailSender, + sessionStore: sessionStore, + directory: directory, + } +} + +func (h userServiceRESTCompatibilityHarness) sendChallengeID(t *testing.T, email string) string { + t.Helper() + + response := gatewayCompatibilityPostJSON(t, h.publicBaseURL+"/api/v1/public/auth/send-email-code", fmt.Sprintf(`{"email":"%s"}`, email)) + assert.Equal(t, http.StatusOK, response.StatusCode) + + var body struct { + ChallengeID string `json:"challenge_id"` + } + require.NoError(t, json.Unmarshal([]byte(response.Body), &body)) + require.NotEmpty(t, body.ChallengeID) + + return body.ChallengeID +} + +func newUserServiceStubHandler(directory *userservice.StubDirectory) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + switch { + case request.Method == http.MethodPost && request.URL.Path == "/api/v1/internal/user-resolutions/by-email": + var input struct { + Email string `json:"email"` + } + if !decodeUserServiceStubRequest(writer, request, &input) { + return + } + + result, err := directory.ResolveByEmail(request.Context(), common.Email(input.Email)) + if err != nil { + writeUserServiceStubError(writer, http.StatusInternalServerError, err) + return + } + + response := map[string]any{"kind": result.Kind} + if !result.UserID.IsZero() { + response["user_id"] = result.UserID.String() + } + if !result.BlockReasonCode.IsZero() { + response["block_reason_code"] = result.BlockReasonCode.String() + } + writeUserServiceStubJSON(writer, http.StatusOK, response) + case request.Method == http.MethodGet && strings.HasPrefix(request.URL.Path, "/api/v1/internal/users/") && strings.HasSuffix(request.URL.Path, "/exists"): + userIDValue := strings.TrimSuffix(strings.TrimPrefix(request.URL.Path, "/api/v1/internal/users/"), "/exists") + userIDValue, err := url.PathUnescape(userIDValue) + if err != nil { + writeUserServiceStubError(writer, http.StatusBadRequest, err) + return + } + + exists, err := directory.ExistsByUserID(request.Context(), common.UserID(userIDValue)) + if err != nil { + writeUserServiceStubError(writer, http.StatusInternalServerError, err) + return + } + + writeUserServiceStubJSON(writer, http.StatusOK, map[string]bool{"exists": exists}) + case request.Method == http.MethodPost && request.URL.Path == "/api/v1/internal/users/ensure-by-email": + var input struct { + Email string `json:"email"` + } + if !decodeUserServiceStubRequest(writer, request, &input) { + return + } + + result, err := directory.EnsureUserByEmail(request.Context(), common.Email(input.Email)) + if err != nil { + writeUserServiceStubError(writer, http.StatusInternalServerError, err) + return + } + + response := map[string]any{"outcome": result.Outcome} + if !result.UserID.IsZero() { + response["user_id"] = result.UserID.String() + } + if !result.BlockReasonCode.IsZero() { + response["block_reason_code"] = result.BlockReasonCode.String() + } + writeUserServiceStubJSON(writer, http.StatusOK, response) + case request.Method == http.MethodPost && strings.HasPrefix(request.URL.Path, "/api/v1/internal/users/") && strings.HasSuffix(request.URL.Path, "/block"): + userIDValue := strings.TrimSuffix(strings.TrimPrefix(request.URL.Path, "/api/v1/internal/users/"), "/block") + userIDValue, err := url.PathUnescape(userIDValue) + if err != nil { + writeUserServiceStubError(writer, http.StatusBadRequest, err) + return + } + + var input struct { + ReasonCode string `json:"reason_code"` + } + if !decodeUserServiceStubRequest(writer, request, &input) { + return + } + + result, err := directory.BlockByUserID(request.Context(), ports.BlockUserByIDInput{ + UserID: common.UserID(userIDValue), + ReasonCode: userresolution.BlockReasonCode(input.ReasonCode), + }) + if err != nil { + if errors.Is(err, ports.ErrNotFound) { + writeUserServiceStubJSON(writer, http.StatusNotFound, map[string]string{"error": "not found"}) + return + } + writeUserServiceStubError(writer, http.StatusInternalServerError, err) + return + } + + response := map[string]any{"outcome": result.Outcome} + if !result.UserID.IsZero() { + response["user_id"] = result.UserID.String() + } + writeUserServiceStubJSON(writer, http.StatusOK, response) + case request.Method == http.MethodPost && request.URL.Path == "/api/v1/internal/user-blocks/by-email": + var input struct { + Email string `json:"email"` + ReasonCode string `json:"reason_code"` + } + if !decodeUserServiceStubRequest(writer, request, &input) { + return + } + + result, err := directory.BlockByEmail(request.Context(), ports.BlockUserByEmailInput{ + Email: common.Email(input.Email), + ReasonCode: userresolution.BlockReasonCode(input.ReasonCode), + }) + if err != nil { + writeUserServiceStubError(writer, http.StatusInternalServerError, err) + return + } + + response := map[string]any{"outcome": result.Outcome} + if !result.UserID.IsZero() { + response["user_id"] = result.UserID.String() + } + writeUserServiceStubJSON(writer, http.StatusOK, response) + default: + http.NotFound(writer, request) + } + }) +} + +func decodeUserServiceStubRequest(writer http.ResponseWriter, request *http.Request, target any) bool { + decoder := json.NewDecoder(request.Body) + decoder.DisallowUnknownFields() + + if err := decoder.Decode(target); err != nil { + writeUserServiceStubError(writer, http.StatusBadRequest, err) + return false + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + if err == nil { + writeUserServiceStubError(writer, http.StatusBadRequest, errors.New("unexpected trailing JSON input")) + return false + } + writeUserServiceStubError(writer, http.StatusBadRequest, err) + return false + } + + return true +} + +func writeUserServiceStubJSON(writer http.ResponseWriter, statusCode int, value any) { + payload, err := json.Marshal(value) + if err != nil { + writeUserServiceStubError(writer, http.StatusInternalServerError, err) + return + } + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(statusCode) + _, _ = writer.Write(payload) +} + +func writeUserServiceStubError(writer http.ResponseWriter, statusCode int, err error) { + http.Error(writer, err.Error(), statusCode) +} diff --git a/client/go.mod b/client/go.mod index a4fa64f..15f4e51 100644 --- a/client/go.mod +++ b/client/go.mod @@ -11,7 +11,7 @@ require ( require ( fyne.io/systray v1.12.0 // indirect github.com/BurntSushi/toml v1.6.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fredbi/uri v1.1.1 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect @@ -31,16 +31,16 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect github.com/nicksnyder/go-i18n/v2 v2.6.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/rymdport/portal v0.4.2 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/yuin/goldmark v1.7.16 // indirect golang.org/x/image v0.36.0 // indirect - golang.org/x/net v0.50.0 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/client/go.sum b/client/go.sum index 0f202ed..d63ecdb 100644 --- a/client/go.sum +++ b/client/go.sum @@ -5,8 +5,7 @@ fyne.io/systray v1.12.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8= @@ -57,10 +56,8 @@ github.com/nicksnyder/go-i18n/v2 v2.6.1 h1:JDEJraFsQE17Dut9HFDHzCoAWGEQJom5s0TRd github.com/nicksnyder/go-i18n/v2 v2.6.1/go.mod h1:Vee0/9RD3Quc/NmwEjzzD7VTZ+Ir7QbXocrkhOzmUKA= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= @@ -75,12 +72,9 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/image v0.36.0 h1:Iknbfm1afbgtwPTmHnS2gTM/6PPZfH+z2EFuOkSbqwc= golang.org/x/image v0.36.0/go.mod h1:YsWD2TyyGKiIX1kZlu9QfKIsQ4nAAK9bdgdrIsE7xy4= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/game/go.mod b/game/go.mod index c4b882e..fbc7c1f 100644 --- a/game/go.mod +++ b/game/go.mod @@ -3,7 +3,7 @@ module galaxy/game go 1.26.0 require ( - github.com/gin-gonic/gin v1.11.0 + github.com/gin-gonic/gin v1.12.0 github.com/go-playground/validator/v10 v10.30.1 github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.11.1 @@ -14,7 +14,7 @@ require ( github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -23,23 +23,23 @@ require ( github.com/goccy/go-yaml v1.19.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect - github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect - go.uber.org/mock v0.6.0 // indirect + go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect golang.org/x/arch v0.24.0 // indirect - golang.org/x/crypto v0.48.0 // indirect - golang.org/x/net v0.50.0 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect + golang.org/x/crypto v0.49.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/game/go.sum b/game/go.sum index d4150ce..724011a 100644 --- a/game/go.sum +++ b/game/go.sum @@ -6,16 +6,14 @@ github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiD github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= -github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= -github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= +github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -52,14 +50,13 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -75,19 +72,16 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/gateway/PLAN.md b/gateway/PLAN.md index fc66e88..a00c3a7 100644 --- a/gateway/PLAN.md +++ b/gateway/PLAN.md @@ -1,5 +1,9 @@ # Edge Gateway Implementation Plan +This plan has been already implemented and stays here for historical reasons. + +It should NOT be threated as source of truth for service functionality. + ## Summary This plan breaks implementation into small, reviewable phases. diff --git a/gateway/docs/examples.md b/gateway/docs/examples.md index 8651347..2dc3fc9 100644 --- a/gateway/docs/examples.md +++ b/gateway/docs/examples.md @@ -127,7 +127,7 @@ gateway:session:device-session-123 Example session snapshot entry: ```bash -redis-cli XADD gateway:session-events '*' \ +redis-cli XADD gateway:session_events '*' \ device_session_id device-session-123 \ user_id user-123 \ client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \ @@ -137,7 +137,7 @@ redis-cli XADD gateway:session-events '*' \ Revocation entry: ```bash -redis-cli XADD gateway:session-events '*' \ +redis-cli XADD gateway:session_events '*' \ device_session_id device-session-123 \ user_id user-123 \ client_public_key 11qYAYdk8v3K6Yw8QK6ZlQ2nP4Wm8Cq5g1H0K8vT9no= \ diff --git a/gateway/go.mod b/gateway/go.mod index f8afb6d..a215cbb 100644 --- a/gateway/go.mod +++ b/gateway/go.mod @@ -38,7 +38,7 @@ require ( github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect github.com/gin-contrib/sse v1.1.0 // indirect @@ -68,7 +68,7 @@ require ( github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect @@ -82,17 +82,17 @@ require ( go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 // indirect - go.opentelemetry.io/proto/otlp v1.9.0 // indirect + go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.10.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/arch v0.24.0 // indirect - golang.org/x/crypto v0.48.0 // indirect + golang.org/x/crypto v0.49.0 // indirect golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 // indirect - golang.org/x/net v0.51.0 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/gateway/go.sum b/gateway/go.sum index bbb0ce6..4b3657e 100644 --- a/gateway/go.sum +++ b/gateway/go.sum @@ -29,8 +29,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= @@ -114,8 +114,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= @@ -187,8 +187,7 @@ go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9 go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= -go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= -go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -205,25 +204,19 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 h1:SbTAbRFnd5kjQXbczszQ0hdk3ctwYf3qBNH9jIsGclE= golang.org/x/exp v0.0.0-20250813145105-42675adae3e6/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= -golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= -golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0= -google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= diff --git a/go.work b/go.work index 4ca55b0..781e8b8 100644 --- a/go.work +++ b/go.work @@ -1,6 +1,7 @@ go 1.26.0 use ( + ./authsession ./client ./game ./gateway diff --git a/go.work.sum b/go.work.sum index 105382c..d4bef3c 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,7 +1,11 @@ buf.build/go/hyperpb v0.1.3/go.mod h1:IHXAM5qnS0/Fsnd7/HGDghFNvUET646WoHmq1FDZXIE= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= github.com/akavel/rsrc v0.10.2/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= +github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= +github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b/go.mod h1:fvzegU4vN3H1qMT+8wDmzjAcDONcgo2/SZ/TyfdUOFs= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= @@ -18,37 +22,26 @@ github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiD github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= -github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= -github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= -github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= -github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/jackmordaunt/icns/v2 v2.2.6/go.mod h1:DqlVnR5iafSphrId7aSD06r3jg0KRC9V6lEBBp504ZQ= github.com/jordanlewis/gcassert v0.0.0-20250430164644-389ef753e22e/go.mod h1:ZybsQk6DWyN5t7An1MuPm1gtSZ1xDaTXS9ZjIOxvQrk= github.com/josephspurrier/goversioninfo v1.4.0/go.mod h1:JWzv5rKQr+MmW+LvM412ToT/IkYDZjaclF2pKDss8IY= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lucor/goinfo v0.9.0/go.mod h1:L6m6tN5Rlova5Z83h1ZaKsMP1iiaoZ9vGTNzu5QKOD4= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mcuadros/go-version v0.0.0-20190830083331-035f6764e8d2/go.mod h1:76rfSfYPWj01Z85hUf/ituArm797mNKcvINh1OlsZKo= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c h1:7ACFcSaQsrWtrH4WHHfUqE1C+f8r2uv8KGaW0jTNjus= -github.com/oasdiff/yaml v0.0.0-20260313112342-a3ea61cb4d4c/go.mod h1:JKox4Gszkxt57kj27u7rvi7IFoIULvCZHUsBTUmQM/s= -github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b h1:vivRhVUAa9t1q0Db4ZmezBP8pWQWnXHFokZj0AOea2g= -github.com/oasdiff/yaml3 v0.0.0-20260224194419-61cd415a242b/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= -github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= -github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= @@ -57,17 +50,33 @@ github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJL github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/timandy/routine v1.1.6/go.mod h1:kXslgIosdY8LW0byTyPnenDgn4/azt2euufAq9rK51w= github.com/urfave/cli/v2 v2.4.0/go.mod h1:NX9W0zmTvedE5oDoOMs2RTC8RvdK98NTYZE5LbaEYPg= -github.com/woodsbury/decimal128 v1.3.0 h1:8pffMNWIlC0O5vbyHWFZAt5yWvWcrHA+3ovIIjVWss0= -github.com/woodsbury/decimal128 v1.3.0/go.mod h1:C5UTmyTjW3JftjUFzOVhC20BEQa2a4ZKOB5I6Zjb+ds= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opentelemetry.io/contrib/detectors/gcp v1.39.0/go.mod h1:t/OGqzHBa5v6RHZwrDBJ2OirWc+4q/w2fTbLZwAKjTk= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 h1:MdKucPl/HbzckWWEisiNqMPhRrAOQX8r4jTuGr636gk= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0/go.mod h1:RolT8tWtfHcjajEH5wFIZ4Dgh5jpPdFXYV9pTAk/qjc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 h1:H7O6RlGOMTizyl3R08Kn5pdM06bnH8oscSj7o11tmLA= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0/go.mod h1:mBFWu/WOVDkWWsR7Tx7h6EpQB8wsv7P0Yrh0Pb7othc= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0 h1:lSZHgNHfbmQTPfuTmWVkEu8J8qXaQwuV30pjCcAUvP8= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.42.0/go.mod h1:so9ounLcuoRDu033MW/E0AD4hhUjVqswrMF5FoZlBcw= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= @@ -85,6 +94,7 @@ golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -95,15 +105,18 @@ golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -132,6 +145,7 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -142,6 +156,7 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= @@ -152,13 +167,20 @@ golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0 golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/tools/go/vcs v0.1.0-deprecated/go.mod h1:zUrvATBAvEI9535oC0yWYsLsHIV4Z7g63sNPVMtuBy8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/api v0.0.0-20260120221211-b8f7ae30c516/go.mod h1:p3MLuOwURrGBRoEyFHBT3GjUwaCQVKeNqqWxlcISGdw= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401001100-f93e5f3e9f0f/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= diff --git a/pkg/connector/go.mod b/pkg/connector/go.mod index dae7b36..aa79029 100644 --- a/pkg/connector/go.mod +++ b/pkg/connector/go.mod @@ -5,10 +5,10 @@ go 1.26.0 require github.com/stretchr/testify v1.11.1 require ( - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/connector/go.sum b/pkg/connector/go.sum index 240dea8..b679674 100644 --- a/pkg/connector/go.sum +++ b/pkg/connector/go.sum @@ -1,8 +1,8 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/error/go.mod b/pkg/error/go.mod index 16a2dfd..a348ca2 100644 --- a/pkg/error/go.mod +++ b/pkg/error/go.mod @@ -5,10 +5,10 @@ go 1.26.0 require github.com/stretchr/testify v1.11.1 require ( - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/error/go.sum b/pkg/error/go.sum index 85e31c2..309ca2e 100644 --- a/pkg/error/go.sum +++ b/pkg/error/go.sum @@ -1,10 +1,8 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/util/go.mod b/pkg/util/go.mod index 3522a7c..bf8d6b0 100644 --- a/pkg/util/go.mod +++ b/pkg/util/go.mod @@ -5,14 +5,14 @@ go 1.26.0 require ( github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.11.1 - golang.org/x/sys v0.41.0 + golang.org/x/sys v0.42.0 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/util/go.sum b/pkg/util/go.sum index d3f9635..9645866 100644 --- a/pkg/util/go.sum +++ b/pkg/util/go.sum @@ -1,10 +1,10 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=