package auth import ( "context" "database/sql" "errors" "fmt" "time" "galaxy/backend/internal/postgres/jet/backend/model" "galaxy/backend/internal/postgres/jet/backend/table" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" ) // Challenge mirrors a row in `backend.auth_challenges` enriched with the // PreferredLanguage column added by migration 00002. The CodeHash slice // is the raw bcrypt hash; verifyCode wraps the comparison. type Challenge struct { ChallengeID uuid.UUID Email string CodeHash []byte Attempts int32 CreatedAt time.Time ExpiresAt time.Time ConsumedAt *time.Time PreferredLanguage string } // Session mirrors a row in `backend.device_sessions`. The // ClientPublicKey slice is the raw 32-byte Ed25519 key; the handler // layer is responsible for base64 encoding/decoding on the wire. type Session struct { DeviceSessionID uuid.UUID UserID uuid.UUID Status string ClientPublicKey []byte CreatedAt time.Time RevokedAt *time.Time LastSeenAt *time.Time } // SessionStatusActive and SessionStatusRevoked enumerate the values // auth writes. The CHECK constraint on `device_sessions.status` also // allows 'blocked', which the user package emits when applying a // `permanent_block` sanction. const ( SessionStatusActive = "active" SessionStatusRevoked = "revoked" ) // Store is the Postgres-backed query surface for `backend.auth_challenges`, // `backend.device_sessions` and the read-side `backend.accounts` lookup // auth needs to detect permanently-blocked emails. type Store struct { db *sql.DB } // NewStore constructs a Store wrapping db. func NewStore(db *sql.DB) *Store { return &Store{db: db} } // challengeColumns lists the projection used by every read of // `auth_challenges`. The order matches model.AuthChallenges field order // inside QueryContext destination scans. func challengeColumns() postgres.ColumnList { return postgres.ColumnList{ table.AuthChallenges.ChallengeID, table.AuthChallenges.Email, table.AuthChallenges.CodeHash, table.AuthChallenges.Attempts, table.AuthChallenges.CreatedAt, table.AuthChallenges.ExpiresAt, table.AuthChallenges.ConsumedAt, table.AuthChallenges.PreferredLanguage, } } // sessionColumns lists the projection used by every read of // `device_sessions`. func sessionColumns() postgres.ColumnList { return postgres.ColumnList{ table.DeviceSessions.DeviceSessionID, table.DeviceSessions.UserID, table.DeviceSessions.ClientPublicKey, table.DeviceSessions.Status, table.DeviceSessions.CreatedAt, table.DeviceSessions.RevokedAt, table.DeviceSessions.LastSeenAt, } } // IsEmailPermanentlyBlocked reports whether email maps to a live // `accounts` row whose permanent_block column is true. The lookup is // case-sensitive: callers are expected to pass an already-normalised // (lowercase, trimmed) email. // // A non-existent account returns (false, nil) — the auth flow treats // such emails as eligible for fresh registration. func (s *Store) IsEmailPermanentlyBlocked(ctx context.Context, email string) (bool, error) { stmt := postgres.SELECT(table.Accounts.PermanentBlock). FROM(table.Accounts). WHERE( table.Accounts.Email.EQ(postgres.String(email)). AND(table.Accounts.DeletedAt.IS_NULL()), ). LIMIT(1) var row model.Accounts if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return false, nil } return false, fmt.Errorf("auth store: query permanent_block for %q: %w", email, err) } return row.PermanentBlock, nil } // LatestUnconsumedChallenge returns the most recently issued // un-consumed, non-expired challenge for email created at or after // since. Returns sql.ErrNoRows when no such challenge exists. The // throttle path uses this method to reuse the existing challenge_id // rather than emit a fresh row. func (s *Store) LatestUnconsumedChallenge(ctx context.Context, email string, since time.Time) (Challenge, error) { stmt := postgres.SELECT(challengeColumns()). FROM(table.AuthChallenges). WHERE( table.AuthChallenges.Email.EQ(postgres.String(email)). AND(table.AuthChallenges.ConsumedAt.IS_NULL()). AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())). AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))), ). ORDER_BY(table.AuthChallenges.CreatedAt.DESC()). LIMIT(1) var row model.AuthChallenges if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return Challenge{}, sql.ErrNoRows } return Challenge{}, err } return modelToChallenge(row), nil } // CountRecentChallenges returns the number of un-consumed, non-expired // challenges issued for email at or after since. Used by the throttle // gate in SendEmailCode. func (s *Store) CountRecentChallenges(ctx context.Context, email string, since time.Time) (int, error) { stmt := postgres.SELECT(postgres.COUNT(postgres.STAR).AS("count")). FROM(table.AuthChallenges). WHERE( table.AuthChallenges.Email.EQ(postgres.String(email)). AND(table.AuthChallenges.ConsumedAt.IS_NULL()). AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())). AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))), ) var dest struct { Count int64 `alias:"count"` } if err := stmt.QueryContext(ctx, s.db, &dest); err != nil { return 0, fmt.Errorf("auth store: count recent challenges: %w", err) } return int(dest.Count), nil } // InsertChallenge persists a fresh `auth_challenges` row. The caller // owns the primary-key, the bcrypt hash, the expires_at timestamp and // the captured locale. created_at and attempts default at the schema // level. func (s *Store) InsertChallenge(ctx context.Context, c Challenge) error { stmt := table.AuthChallenges.INSERT( table.AuthChallenges.ChallengeID, table.AuthChallenges.Email, table.AuthChallenges.CodeHash, table.AuthChallenges.ExpiresAt, table.AuthChallenges.PreferredLanguage, ).VALUES(c.ChallengeID, c.Email, c.CodeHash, c.ExpiresAt, c.PreferredLanguage) if _, err := stmt.ExecContext(ctx, s.db); err != nil { return fmt.Errorf("auth store: insert challenge: %w", err) } return nil } // LoadAndIncrementChallenge atomically locks the challenge row, // validates that it is still un-consumed and non-expired, and increments // its `attempts` counter. The returned Challenge carries the // post-increment counter so the caller can compare it against the // configured ceiling without a second query. // // Returns ErrChallengeNotFound when the row does not exist, has been // consumed, or has expired. Any other error is wrapped with the auth // store prefix. func (s *Store) LoadAndIncrementChallenge(ctx context.Context, challengeID uuid.UUID) (Challenge, error) { var loaded Challenge err := withTx(ctx, s.db, func(tx *sql.Tx) error { selectStmt := postgres.SELECT(challengeColumns()). FROM(table.AuthChallenges). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))). FOR(postgres.UPDATE()) var row model.AuthChallenges if err := selectStmt.QueryContext(ctx, tx, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return ErrChallengeNotFound } return err } loaded = modelToChallenge(row) if loaded.ConsumedAt != nil { return ErrChallengeNotFound } if !loaded.ExpiresAt.After(time.Now()) { return ErrChallengeNotFound } updateStmt := table.AuthChallenges. UPDATE(table.AuthChallenges.Attempts). SET(table.AuthChallenges.Attempts.ADD(postgres.Int(1))). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))) if _, err := updateStmt.ExecContext(ctx, tx); err != nil { return err } loaded.Attempts++ return nil }) if err != nil { if errors.Is(err, ErrChallengeNotFound) { return Challenge{}, err } return Challenge{}, fmt.Errorf("auth store: load and increment challenge: %w", err) } return loaded, nil } // MarkConsumedAndInsertSession atomically: // // 1. Locks the challenge row. // 2. Validates that it is still un-consumed and non-expired. // 3. Sets consumed_at = now(). // 4. Inserts the supplied Session into device_sessions with status = // 'active'. // // The two writes are committed together so a single challenge yields at // most one device session even under concurrent confirm-email-code // callers. // // Returns ErrChallengeNotFound when the challenge has been consumed (by // a concurrent caller) or has expired in the gap between the // LoadAndIncrementChallenge call and this one. func (s *Store) MarkConsumedAndInsertSession(ctx context.Context, challengeID uuid.UUID, session Session) error { err := withTx(ctx, s.db, func(tx *sql.Tx) error { lockStmt := postgres.SELECT(table.AuthChallenges.ConsumedAt, table.AuthChallenges.ExpiresAt). FROM(table.AuthChallenges). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))). FOR(postgres.UPDATE()) var locked model.AuthChallenges if err := lockStmt.QueryContext(ctx, tx, &locked); err != nil { if errors.Is(err, qrm.ErrNoRows) { return ErrChallengeNotFound } return err } if locked.ConsumedAt != nil || !locked.ExpiresAt.After(time.Now()) { return ErrChallengeNotFound } consumeStmt := table.AuthChallenges. UPDATE(table.AuthChallenges.ConsumedAt). SET(postgres.NOW()). WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))) if _, err := consumeStmt.ExecContext(ctx, tx); err != nil { return err } insertStmt := table.DeviceSessions.INSERT( table.DeviceSessions.DeviceSessionID, table.DeviceSessions.UserID, table.DeviceSessions.ClientPublicKey, table.DeviceSessions.Status, ).VALUES(session.DeviceSessionID, session.UserID, session.ClientPublicKey, SessionStatusActive) if _, err := insertStmt.ExecContext(ctx, tx); err != nil { return err } return nil }) if err != nil { if errors.Is(err, ErrChallengeNotFound) { return err } return fmt.Errorf("auth store: mark consumed and insert session: %w", err) } return nil } // ListActiveSessions loads every row from device_sessions whose status // is 'active'. Cache.Warm calls this at process boot. func (s *Store) ListActiveSessions(ctx context.Context) ([]Session, error) { stmt := postgres.SELECT(sessionColumns()). FROM(table.DeviceSessions). WHERE(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))) var rows []model.DeviceSessions if err := stmt.QueryContext(ctx, s.db, &rows); err != nil { return nil, fmt.Errorf("auth store: list active sessions: %w", err) } out := make([]Session, 0, len(rows)) for _, row := range rows { out = append(out, modelToSession(row)) } return out, nil } // LoadSession returns the row for deviceSessionID regardless of status. // Returns ErrSessionNotFound on missing row. func (s *Store) LoadSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, error) { stmt := postgres.SELECT(sessionColumns()). FROM(table.DeviceSessions). WHERE(table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID))). LIMIT(1) var row model.DeviceSessions if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return Session{}, ErrSessionNotFound } return Session{}, fmt.Errorf("auth store: load session %s: %w", deviceSessionID, err) } return modelToSession(row), nil } // RevokeSession transitions an active row to status='revoked' and // returns the row as it stands after the update. The boolean reports // whether the UPDATE actually changed a row — false means the row was // already revoked or did not exist; the auth Service then falls back to // LoadSession for idempotent-revoke responses. func (s *Store) RevokeSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, bool, error) { stmt := table.DeviceSessions. UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt). SET(postgres.String(SessionStatusRevoked), postgres.NOW()). WHERE( table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID)). AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))), ). RETURNING(sessionColumns()) var row model.DeviceSessions if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return Session{}, false, nil } return Session{}, false, fmt.Errorf("auth store: revoke session %s: %w", deviceSessionID, err) } return modelToSession(row), true, nil } // RevokeAllForUser transitions every active row for userID to // status='revoked' and returns the rows as they stand after the update. // An empty slice with a nil error is returned when the user owned no // active sessions; the caller must treat that as a successful idempotent // revoke (the API surface returns revoked_count=0 in that case). func (s *Store) RevokeAllForUser(ctx context.Context, userID uuid.UUID) ([]Session, error) { stmt := table.DeviceSessions. UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt). SET(postgres.String(SessionStatusRevoked), postgres.NOW()). WHERE( table.DeviceSessions.UserID.EQ(postgres.UUID(userID)). AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))), ). RETURNING(sessionColumns()) var rows []model.DeviceSessions if err := stmt.QueryContext(ctx, s.db, &rows); err != nil { return nil, fmt.Errorf("auth store: revoke all for user %s: %w", userID, err) } out := make([]Session, 0, len(rows)) for _, row := range rows { out = append(out, modelToSession(row)) } return out, nil } // modelToChallenge projects a generated model row into the public // Challenge struct. Pointer fields are copied so callers cannot mutate // the underlying scan buffer. func modelToChallenge(row model.AuthChallenges) Challenge { c := Challenge{ ChallengeID: row.ChallengeID, Email: row.Email, CodeHash: row.CodeHash, Attempts: row.Attempts, CreatedAt: row.CreatedAt, ExpiresAt: row.ExpiresAt, PreferredLanguage: row.PreferredLanguage, } if row.ConsumedAt != nil { t := *row.ConsumedAt c.ConsumedAt = &t } return c } // modelToSession projects a generated model row into the public Session // struct. func modelToSession(row model.DeviceSessions) Session { s := Session{ DeviceSessionID: row.DeviceSessionID, UserID: row.UserID, Status: row.Status, ClientPublicKey: row.ClientPublicKey, CreatedAt: row.CreatedAt, } if row.RevokedAt != nil { t := *row.RevokedAt s.RevokedAt = &t } if row.LastSeenAt != nil { t := *row.LastSeenAt s.LastSeenAt = &t } return s } // withTx wraps fn in a Postgres transaction. fn's return value // determines commit (nil) vs rollback (non-nil). Rollback errors are // swallowed when fn already returned an error, since the latter is more // actionable. func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error { tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("auth store: begin tx: %w", err) } if err := fn(tx); err != nil { _ = tx.Rollback() return err } if err := tx.Commit(); err != nil { return fmt.Errorf("auth store: commit tx: %w", err) } return nil }