package auth import ( "context" "database/sql" "errors" "fmt" "time" "galaxy/backend/internal/postgres/jet/backend/model" "galaxy/backend/internal/postgres/jet/backend/table" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" ) // Challenge mirrors a row in `backend.auth_challenges` enriched with the // PreferredLanguage column added by migration 00002. The CodeHash slice // is the raw bcrypt hash; verifyCode wraps the comparison. type Challenge struct { ChallengeID uuid.UUID Email string CodeHash []byte Attempts int32 CreatedAt time.Time ExpiresAt time.Time ConsumedAt *time.Time PreferredLanguage string } // Session mirrors a row in `backend.device_sessions`. The // ClientPublicKey slice is the raw 32-byte Ed25519 key; the handler // layer is responsible for base64 encoding/decoding on the wire. type Session struct { DeviceSessionID uuid.UUID UserID uuid.UUID Status string ClientPublicKey []byte CreatedAt time.Time RevokedAt *time.Time LastSeenAt *time.Time } // SessionStatusActive and SessionStatusRevoked enumerate the values // auth writes. The CHECK constraint on `device_sessions.status` also // allows 'blocked', which the user package emits when applying a // `permanent_block` sanction. const ( SessionStatusActive = "active" SessionStatusRevoked = "revoked" ) // Store is the Postgres-backed query surface for `backend.auth_challenges`, // `backend.device_sessions` and the read-side `backend.accounts` lookup // auth needs to detect permanently-blocked emails. type Store struct { db *sql.DB } // NewStore constructs a Store wrapping db. func NewStore(db *sql.DB) *Store { return &Store{db: db} } // challengeColumns lists the projection used by every read of // `auth_challenges`. The order matches model.AuthChallenges field order // inside QueryContext destination scans. func challengeColumns() postgres.ColumnList { return postgres.ColumnList{ table.AuthChallenges.ChallengeID, table.AuthChallenges.Email, table.AuthChallenges.CodeHash, table.AuthChallenges.Attempts, table.AuthChallenges.CreatedAt, table.AuthChallenges.ExpiresAt, table.AuthChallenges.ConsumedAt, table.AuthChallenges.PreferredLanguage, } } // sessionColumns lists the projection used by every read of // `device_sessions`. func sessionColumns() postgres.ColumnList { return postgres.ColumnList{ table.DeviceSessions.DeviceSessionID, table.DeviceSessions.UserID, table.DeviceSessions.ClientPublicKey, table.DeviceSessions.Status, table.DeviceSessions.CreatedAt, table.DeviceSessions.RevokedAt, table.DeviceSessions.LastSeenAt, } } // IsEmailPermanentlyBlocked reports whether email maps to a live // `accounts` row whose permanent_block column is true. The lookup is // case-sensitive: callers are expected to pass an already-normalised // (lowercase, trimmed) email. // // A non-existent account returns (false, nil) — the auth flow treats // such emails as eligible for fresh registration. func (s *Store) IsEmailPermanentlyBlocked(ctx context.Context, email string) (bool, error) { stmt := postgres.SELECT(table.Accounts.PermanentBlock). FROM(table.Accounts). WHERE( table.Accounts.Email.EQ(postgres.String(email)). AND(table.Accounts.DeletedAt.IS_NULL()), ). LIMIT(1) var row model.Accounts if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return false, nil } return false, fmt.Errorf("auth store: query permanent_block for %q: %w", email, err) } return row.PermanentBlock, nil } // LatestUnconsumedChallenge returns the most recently issued // un-consumed, non-expired challenge for email created at or after // since. Returns sql.ErrNoRows when no such challenge exists. The // throttle path uses this method to reuse the existing challenge_id // rather than emit a fresh row. func (s *Store) LatestUnconsumedChallenge(ctx context.Context, email string, since time.Time) (Challenge, error) { stmt := postgres.SELECT(challengeColumns()). FROM(table.AuthChallenges). WHERE( table.AuthChallenges.Email.EQ(postgres.String(email)). AND(table.AuthChallenges.ConsumedAt.IS_NULL()). AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())). AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))), ). ORDER_BY(table.AuthChallenges.CreatedAt.DESC()). LIMIT(1) var row model.AuthChallenges if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return Challenge{}, sql.ErrNoRows } return Challenge{}, err } return modelToChallenge(row), nil } // CountRecentChallenges returns the number of un-consumed, non-expired // challenges issued for email at or after since. Used by the throttle // gate in SendEmailCode. func (s *Store) CountRecentChallenges(ctx context.Context, email string, since time.Time) (int, error) { stmt := postgres.SELECT(postgres.COUNT(postgres.STAR).AS("count")). FROM(table.AuthChallenges). WHERE( table.AuthChallenges.Email.EQ(postgres.String(email)). AND(table.AuthChallenges.ConsumedAt.IS_NULL()). AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())). AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))), ) var dest struct { Count int64 `alias:"count"` } if err := stmt.QueryContext(ctx, s.db, &dest); err != nil { return 0, fmt.Errorf("auth store: count recent challenges: %w", err) } return int(dest.Count), nil } // InsertChallenge persists a fresh `auth_challenges` row. The caller // owns the primary-key, the bcrypt hash, the expires_at timestamp and // the captured locale. created_at and attempts default at the schema // level. func (s *Store) InsertChallenge(ctx context.Context, c Challenge) error { stmt := table.AuthChallenges.INSERT( table.AuthChallenges.ChallengeID, table.AuthChallenges.Email, table.AuthChallenges.CodeHash, table.AuthChallenges.ExpiresAt, table.AuthChallenges.PreferredLanguage, ).VALUES(c.ChallengeID, c.Email, c.CodeHash, c.ExpiresAt, c.PreferredLanguage) if _, err := stmt.ExecContext(ctx, s.db); err != nil { return fmt.Errorf("auth store: insert challenge: %w", err) } return nil } // LoadAndIncrementChallenge atomically locks the challenge row, // validates that it is still un-consumed and non-expired, and increments // its `attempts` counter. The returned Challenge carries the // post-increment counter so the caller can compare it against the // configured ceiling without a second query. // // Returns ErrChallengeNotFound when the row does not exist, has been // consumed, or has expired. Any other error is wrapped with the auth // store prefix. func (s *Store) LoadAndIncrementChallenge(ctx context.Context, challengeID uuid.UUID) (Challenge, error) { var loaded Challenge err := withTx(ctx, s.db, func(tx *sql.Tx) error { selectStmt := postgres.SELECT(challengeColumns()). FROM(table.AuthChallenges). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))). FOR(postgres.UPDATE()) var row model.AuthChallenges if err := selectStmt.QueryContext(ctx, tx, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return ErrChallengeNotFound } return err } loaded = modelToChallenge(row) if loaded.ConsumedAt != nil { return ErrChallengeNotFound } if !loaded.ExpiresAt.After(time.Now()) { return ErrChallengeNotFound } updateStmt := table.AuthChallenges. UPDATE(table.AuthChallenges.Attempts). SET(table.AuthChallenges.Attempts.ADD(postgres.Int(1))). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))) if _, err := updateStmt.ExecContext(ctx, tx); err != nil { return err } loaded.Attempts++ return nil }) if err != nil { if errors.Is(err, ErrChallengeNotFound) { return Challenge{}, err } return Challenge{}, fmt.Errorf("auth store: load and increment challenge: %w", err) } return loaded, nil } // MarkConsumedAndInsertSession atomically: // // 1. Locks the challenge row. // 2. Validates that it is still un-consumed and non-expired. // 3. Sets consumed_at = now(). // 4. Inserts the supplied Session into device_sessions with status = // 'active'. // // The two writes are committed together so a single challenge yields at // most one device session even under concurrent confirm-email-code // callers. // // Returns ErrChallengeNotFound when the challenge has been consumed (by // a concurrent caller) or has expired in the gap between the // LoadAndIncrementChallenge call and this one. func (s *Store) MarkConsumedAndInsertSession(ctx context.Context, challengeID uuid.UUID, session Session) error { err := withTx(ctx, s.db, func(tx *sql.Tx) error { lockStmt := postgres.SELECT(table.AuthChallenges.ConsumedAt, table.AuthChallenges.ExpiresAt). FROM(table.AuthChallenges). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))). FOR(postgres.UPDATE()) var locked model.AuthChallenges if err := lockStmt.QueryContext(ctx, tx, &locked); err != nil { if errors.Is(err, qrm.ErrNoRows) { return ErrChallengeNotFound } return err } if locked.ConsumedAt != nil || !locked.ExpiresAt.After(time.Now()) { return ErrChallengeNotFound } consumeStmt := table.AuthChallenges. UPDATE(table.AuthChallenges.ConsumedAt). SET(postgres.NOW()). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))) if _, err := consumeStmt.ExecContext(ctx, tx); err != nil { return err } insertStmt := table.DeviceSessions.INSERT( table.DeviceSessions.DeviceSessionID, table.DeviceSessions.UserID, table.DeviceSessions.ClientPublicKey, table.DeviceSessions.Status, ).VALUES(session.DeviceSessionID, session.UserID, session.ClientPublicKey, SessionStatusActive) if _, err := insertStmt.ExecContext(ctx, tx); err != nil { return err } return nil }) if err != nil { if errors.Is(err, ErrChallengeNotFound) { return err } return fmt.Errorf("auth store: mark consumed and insert session: %w", err) } return nil } // ListActiveSessions loads every row from device_sessions whose status // is 'active'. Cache.Warm calls this at process boot. func (s *Store) ListActiveSessions(ctx context.Context) ([]Session, error) { stmt := postgres.SELECT(sessionColumns()). FROM(table.DeviceSessions). WHERE(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))) var rows []model.DeviceSessions if err := stmt.QueryContext(ctx, s.db, &rows); err != nil { return nil, fmt.Errorf("auth store: list active sessions: %w", err) } out := make([]Session, 0, len(rows)) for _, row := range rows { out = append(out, modelToSession(row)) } return out, nil } // LoadSession returns the row for deviceSessionID regardless of status. // Returns ErrSessionNotFound on missing row. func (s *Store) LoadSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, error) { stmt := postgres.SELECT(sessionColumns()). FROM(table.DeviceSessions). WHERE(table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID))). LIMIT(1) var row model.DeviceSessions if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return Session{}, ErrSessionNotFound } return Session{}, fmt.Errorf("auth store: load session %s: %w", deviceSessionID, err) } return modelToSession(row), nil } // TouchSessionLastSeen sets `last_seen_at` to at on the row keyed by // deviceSessionID. The UPDATE is gated by `status='active'` so a // revoked or absent row reports ErrSessionNotFound. Returns the post- // update row so the cache can be refreshed without a second read. func (s *Store) TouchSessionLastSeen(ctx context.Context, deviceSessionID uuid.UUID, at time.Time) (Session, error) { stmt := table.DeviceSessions. UPDATE(table.DeviceSessions.LastSeenAt). SET(postgres.TimestampzT(at)). WHERE( table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID)). AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))), ). RETURNING(sessionColumns()) var row model.DeviceSessions if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return Session{}, ErrSessionNotFound } return Session{}, fmt.Errorf("auth store: touch last_seen %s: %w", deviceSessionID, err) } return modelToSession(row), nil } // RevokeSession transitions an active row to status='revoked' and // inserts the matching audit row into session_revocations atomically // inside one transaction. The boolean reports whether the UPDATE // actually changed a row — false means the row was already revoked or // did not exist, in which case no audit row is written and the auth // Service falls back to LoadSession for the idempotent-revoke // response. func (s *Store) RevokeSession(ctx context.Context, deviceSessionID uuid.UUID, rc RevokeContext, at time.Time) (Session, bool, error) { var ( revoked Session ok bool ) err := withTx(ctx, s.db, func(tx *sql.Tx) error { updateStmt := table.DeviceSessions. UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt). SET(postgres.String(SessionStatusRevoked), postgres.TimestampzT(at)). WHERE( table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID)). AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))), ). RETURNING(sessionColumns()) var row model.DeviceSessions if err := updateStmt.QueryContext(ctx, tx, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return nil } return err } revoked = modelToSession(row) ok = true return insertRevocationTx(ctx, tx, deviceSessionID, revoked.UserID, rc, at) }) if err != nil { return Session{}, false, fmt.Errorf("auth store: revoke session %s: %w", deviceSessionID, err) } return revoked, ok, nil } // RevokeAllForUser transitions every active row for userID to // status='revoked', writes one session_revocations row per revoked // session, and returns the rows as they stand after the update. The // UPDATE and the audit inserts run inside one transaction. An empty // slice with a nil error is returned when the user owned no active // sessions; the caller treats that as a successful idempotent revoke // (the API surface returns revoked_count=0). func (s *Store) RevokeAllForUser(ctx context.Context, userID uuid.UUID, rc RevokeContext, at time.Time) ([]Session, error) { var out []Session err := withTx(ctx, s.db, func(tx *sql.Tx) error { updateStmt := table.DeviceSessions. UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt). SET(postgres.String(SessionStatusRevoked), postgres.TimestampzT(at)). WHERE( table.DeviceSessions.UserID.EQ(postgres.UUID(userID)). AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))), ). RETURNING(sessionColumns()) var rows []model.DeviceSessions if err := updateStmt.QueryContext(ctx, tx, &rows); err != nil { return err } out = make([]Session, 0, len(rows)) for _, row := range rows { sess := modelToSession(row) out = append(out, sess) if err := insertRevocationTx(ctx, tx, sess.DeviceSessionID, sess.UserID, rc, at); err != nil { return err } } return nil }) if err != nil { return nil, fmt.Errorf("auth store: revoke all for user %s: %w", userID, err) } return out, nil } // insertRevocationTx writes a single audit row inside an existing // transaction. Callers are expected to mint a fresh revocation_id per // row; collisions are not retried because revocation_id is a uuid.New // in the only call sites. func insertRevocationTx(ctx context.Context, tx *sql.Tx, deviceSessionID, userID uuid.UUID, rc RevokeContext, at time.Time) error { actorUserID, actorUsername, err := revokeContextToColumns(rc) if err != nil { return err } stmt := table.SessionRevocations.INSERT( table.SessionRevocations.RevocationID, table.SessionRevocations.DeviceSessionID, table.SessionRevocations.UserID, table.SessionRevocations.ActorKind, table.SessionRevocations.ActorUserID, table.SessionRevocations.ActorUsername, table.SessionRevocations.Reason, table.SessionRevocations.RevokedAt, ).VALUES(uuid.New(), deviceSessionID, userID, string(rc.ActorKind), actorUserID, actorUsername, rc.Reason, at) if _, err := stmt.ExecContext(ctx, tx); err != nil { return fmt.Errorf("insert session_revocations: %w", err) } return nil } // revokeContextToColumns splits RevokeContext.ActorID into the // (actor_user_id, actor_username) pair persisted by session_revocations. // User-driven kinds parse ActorID as a UUID; admin-driven kinds keep it // as the operator username. Empty ActorID lands as NULL/NULL. func revokeContextToColumns(rc RevokeContext) (any, any, error) { if rc.ActorID == "" { return nil, nil, nil } switch rc.ActorKind { case ActorKindUserSelf, ActorKindSoftDeleteUser: uid, err := uuid.Parse(rc.ActorID) if err != nil { return nil, nil, fmt.Errorf("auth store: actor_id %q is not a uuid: %w", rc.ActorID, err) } return uid, nil, nil case ActorKindAdminSanction, ActorKindSoftDeleteAdmin: return nil, rc.ActorID, nil default: return nil, nil, nil } } // modelToChallenge projects a generated model row into the public // Challenge struct. Pointer fields are copied so callers cannot mutate // the underlying scan buffer. func modelToChallenge(row model.AuthChallenges) Challenge { c := Challenge{ ChallengeID: row.ChallengeID, Email: row.Email, CodeHash: row.CodeHash, Attempts: row.Attempts, CreatedAt: row.CreatedAt, ExpiresAt: row.ExpiresAt, PreferredLanguage: row.PreferredLanguage, } if row.ConsumedAt != nil { t := *row.ConsumedAt c.ConsumedAt = &t } return c } // modelToSession projects a generated model row into the public Session // struct. func modelToSession(row model.DeviceSessions) Session { s := Session{ DeviceSessionID: row.DeviceSessionID, UserID: row.UserID, Status: row.Status, ClientPublicKey: row.ClientPublicKey, CreatedAt: row.CreatedAt, } if row.RevokedAt != nil { t := *row.RevokedAt s.RevokedAt = &t } if row.LastSeenAt != nil { t := *row.LastSeenAt s.LastSeenAt = &t } return s } // withTx wraps fn in a Postgres transaction. fn's return value // determines commit (nil) vs rollback (non-nil). Rollback errors are // swallowed when fn already returned an error, since the latter is more // actionable. func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error { tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("auth store: begin tx: %w", err) } if err := fn(tx); err != nil { _ = tx.Rollback() return err } if err := tx.Commit(); err != nil { return fmt.Errorf("auth store: commit tx: %w", err) } return nil }