445 lines
15 KiB
Go
445 lines
15 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"galaxy/backend/internal/postgres/jet/backend/model"
|
|
"galaxy/backend/internal/postgres/jet/backend/table"
|
|
|
|
"github.com/go-jet/jet/v2/postgres"
|
|
"github.com/go-jet/jet/v2/qrm"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// Challenge mirrors a row in `backend.auth_challenges` enriched with the
|
|
// PreferredLanguage column added by migration 00002. The CodeHash slice
|
|
// is the raw bcrypt hash; verifyCode wraps the comparison.
|
|
type Challenge struct {
|
|
ChallengeID uuid.UUID
|
|
Email string
|
|
CodeHash []byte
|
|
Attempts int32
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
ConsumedAt *time.Time
|
|
PreferredLanguage string
|
|
}
|
|
|
|
// Session mirrors a row in `backend.device_sessions`. The
|
|
// ClientPublicKey slice is the raw 32-byte Ed25519 key; the handler
|
|
// layer is responsible for base64 encoding/decoding on the wire.
|
|
type Session struct {
|
|
DeviceSessionID uuid.UUID
|
|
UserID uuid.UUID
|
|
Status string
|
|
ClientPublicKey []byte
|
|
CreatedAt time.Time
|
|
RevokedAt *time.Time
|
|
LastSeenAt *time.Time
|
|
}
|
|
|
|
// SessionStatusActive and SessionStatusRevoked enumerate the values
|
|
// auth writes. The CHECK constraint on `device_sessions.status` also
|
|
// allows 'blocked', which the user package emits when applying a
|
|
// `permanent_block` sanction.
|
|
const (
|
|
SessionStatusActive = "active"
|
|
SessionStatusRevoked = "revoked"
|
|
)
|
|
|
|
// Store is the Postgres-backed query surface for `backend.auth_challenges`,
|
|
// `backend.device_sessions` and the read-side `backend.accounts` lookup
|
|
// auth needs to detect permanently-blocked emails.
|
|
type Store struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewStore constructs a Store wrapping db.
|
|
func NewStore(db *sql.DB) *Store {
|
|
return &Store{db: db}
|
|
}
|
|
|
|
// challengeColumns lists the projection used by every read of
|
|
// `auth_challenges`. The order matches model.AuthChallenges field order
|
|
// inside QueryContext destination scans.
|
|
func challengeColumns() postgres.ColumnList {
|
|
return postgres.ColumnList{
|
|
table.AuthChallenges.ChallengeID,
|
|
table.AuthChallenges.Email,
|
|
table.AuthChallenges.CodeHash,
|
|
table.AuthChallenges.Attempts,
|
|
table.AuthChallenges.CreatedAt,
|
|
table.AuthChallenges.ExpiresAt,
|
|
table.AuthChallenges.ConsumedAt,
|
|
table.AuthChallenges.PreferredLanguage,
|
|
}
|
|
}
|
|
|
|
// sessionColumns lists the projection used by every read of
|
|
// `device_sessions`.
|
|
func sessionColumns() postgres.ColumnList {
|
|
return postgres.ColumnList{
|
|
table.DeviceSessions.DeviceSessionID,
|
|
table.DeviceSessions.UserID,
|
|
table.DeviceSessions.ClientPublicKey,
|
|
table.DeviceSessions.Status,
|
|
table.DeviceSessions.CreatedAt,
|
|
table.DeviceSessions.RevokedAt,
|
|
table.DeviceSessions.LastSeenAt,
|
|
}
|
|
}
|
|
|
|
// IsEmailPermanentlyBlocked reports whether email maps to a live
|
|
// `accounts` row whose permanent_block column is true. The lookup is
|
|
// case-sensitive: callers are expected to pass an already-normalised
|
|
// (lowercase, trimmed) email.
|
|
//
|
|
// A non-existent account returns (false, nil) — the auth flow treats
|
|
// such emails as eligible for fresh registration.
|
|
func (s *Store) IsEmailPermanentlyBlocked(ctx context.Context, email string) (bool, error) {
|
|
stmt := postgres.SELECT(table.Accounts.PermanentBlock).
|
|
FROM(table.Accounts).
|
|
WHERE(
|
|
table.Accounts.Email.EQ(postgres.String(email)).
|
|
AND(table.Accounts.DeletedAt.IS_NULL()),
|
|
).
|
|
LIMIT(1)
|
|
|
|
var row model.Accounts
|
|
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("auth store: query permanent_block for %q: %w", email, err)
|
|
}
|
|
return row.PermanentBlock, nil
|
|
}
|
|
|
|
// LatestUnconsumedChallenge returns the most recently issued
|
|
// un-consumed, non-expired challenge for email created at or after
|
|
// since. Returns sql.ErrNoRows when no such challenge exists. The
|
|
// throttle path uses this method to reuse the existing challenge_id
|
|
// rather than emit a fresh row.
|
|
func (s *Store) LatestUnconsumedChallenge(ctx context.Context, email string, since time.Time) (Challenge, error) {
|
|
stmt := postgres.SELECT(challengeColumns()).
|
|
FROM(table.AuthChallenges).
|
|
WHERE(
|
|
table.AuthChallenges.Email.EQ(postgres.String(email)).
|
|
AND(table.AuthChallenges.ConsumedAt.IS_NULL()).
|
|
AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())).
|
|
AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))),
|
|
).
|
|
ORDER_BY(table.AuthChallenges.CreatedAt.DESC()).
|
|
LIMIT(1)
|
|
|
|
var row model.AuthChallenges
|
|
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
return Challenge{}, sql.ErrNoRows
|
|
}
|
|
return Challenge{}, err
|
|
}
|
|
return modelToChallenge(row), nil
|
|
}
|
|
|
|
// CountRecentChallenges returns the number of un-consumed, non-expired
|
|
// challenges issued for email at or after since. Used by the throttle
|
|
// gate in SendEmailCode.
|
|
func (s *Store) CountRecentChallenges(ctx context.Context, email string, since time.Time) (int, error) {
|
|
stmt := postgres.SELECT(postgres.COUNT(postgres.STAR).AS("count")).
|
|
FROM(table.AuthChallenges).
|
|
WHERE(
|
|
table.AuthChallenges.Email.EQ(postgres.String(email)).
|
|
AND(table.AuthChallenges.ConsumedAt.IS_NULL()).
|
|
AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())).
|
|
AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))),
|
|
)
|
|
|
|
var dest struct {
|
|
Count int64 `alias:"count"`
|
|
}
|
|
if err := stmt.QueryContext(ctx, s.db, &dest); err != nil {
|
|
return 0, fmt.Errorf("auth store: count recent challenges: %w", err)
|
|
}
|
|
return int(dest.Count), nil
|
|
}
|
|
|
|
// InsertChallenge persists a fresh `auth_challenges` row. The caller
|
|
// owns the primary-key, the bcrypt hash, the expires_at timestamp and
|
|
// the captured locale. created_at and attempts default at the schema
|
|
// level.
|
|
func (s *Store) InsertChallenge(ctx context.Context, c Challenge) error {
|
|
stmt := table.AuthChallenges.INSERT(
|
|
table.AuthChallenges.ChallengeID,
|
|
table.AuthChallenges.Email,
|
|
table.AuthChallenges.CodeHash,
|
|
table.AuthChallenges.ExpiresAt,
|
|
table.AuthChallenges.PreferredLanguage,
|
|
).VALUES(c.ChallengeID, c.Email, c.CodeHash, c.ExpiresAt, c.PreferredLanguage)
|
|
|
|
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
|
|
return fmt.Errorf("auth store: insert challenge: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LoadAndIncrementChallenge atomically locks the challenge row,
|
|
// validates that it is still un-consumed and non-expired, and increments
|
|
// its `attempts` counter. The returned Challenge carries the
|
|
// post-increment counter so the caller can compare it against the
|
|
// configured ceiling without a second query.
|
|
//
|
|
// Returns ErrChallengeNotFound when the row does not exist, has been
|
|
// consumed, or has expired. Any other error is wrapped with the auth
|
|
// store prefix.
|
|
func (s *Store) LoadAndIncrementChallenge(ctx context.Context, challengeID uuid.UUID) (Challenge, error) {
|
|
var loaded Challenge
|
|
err := withTx(ctx, s.db, func(tx *sql.Tx) error {
|
|
selectStmt := postgres.SELECT(challengeColumns()).
|
|
FROM(table.AuthChallenges).
|
|
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))).
|
|
FOR(postgres.UPDATE())
|
|
|
|
var row model.AuthChallenges
|
|
if err := selectStmt.QueryContext(ctx, tx, &row); err != nil {
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
return ErrChallengeNotFound
|
|
}
|
|
return err
|
|
}
|
|
loaded = modelToChallenge(row)
|
|
if loaded.ConsumedAt != nil {
|
|
return ErrChallengeNotFound
|
|
}
|
|
if !loaded.ExpiresAt.After(time.Now()) {
|
|
return ErrChallengeNotFound
|
|
}
|
|
updateStmt := table.AuthChallenges.
|
|
UPDATE(table.AuthChallenges.Attempts).
|
|
SET(table.AuthChallenges.Attempts.ADD(postgres.Int(1))).
|
|
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID)))
|
|
if _, err := updateStmt.ExecContext(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
loaded.Attempts++
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, ErrChallengeNotFound) {
|
|
return Challenge{}, err
|
|
}
|
|
return Challenge{}, fmt.Errorf("auth store: load and increment challenge: %w", err)
|
|
}
|
|
return loaded, nil
|
|
}
|
|
|
|
// MarkConsumedAndInsertSession atomically:
|
|
//
|
|
// 1. Locks the challenge row.
|
|
// 2. Validates that it is still un-consumed and non-expired.
|
|
// 3. Sets consumed_at = now().
|
|
// 4. Inserts the supplied Session into device_sessions with status =
|
|
// 'active'.
|
|
//
|
|
// The two writes are committed together so a single challenge yields at
|
|
// most one device session even under concurrent confirm-email-code
|
|
// callers.
|
|
//
|
|
// Returns ErrChallengeNotFound when the challenge has been consumed (by
|
|
// a concurrent caller) or has expired in the gap between the
|
|
// LoadAndIncrementChallenge call and this one.
|
|
func (s *Store) MarkConsumedAndInsertSession(ctx context.Context, challengeID uuid.UUID, session Session) error {
|
|
err := withTx(ctx, s.db, func(tx *sql.Tx) error {
|
|
lockStmt := postgres.SELECT(table.AuthChallenges.ConsumedAt, table.AuthChallenges.ExpiresAt).
|
|
FROM(table.AuthChallenges).
|
|
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))).
|
|
FOR(postgres.UPDATE())
|
|
|
|
var locked model.AuthChallenges
|
|
if err := lockStmt.QueryContext(ctx, tx, &locked); err != nil {
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
return ErrChallengeNotFound
|
|
}
|
|
return err
|
|
}
|
|
if locked.ConsumedAt != nil || !locked.ExpiresAt.After(time.Now()) {
|
|
return ErrChallengeNotFound
|
|
}
|
|
consumeStmt := table.AuthChallenges.
|
|
UPDATE(table.AuthChallenges.ConsumedAt).
|
|
SET(postgres.NOW()).
|
|
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID)))
|
|
if _, err := consumeStmt.ExecContext(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
insertStmt := table.DeviceSessions.INSERT(
|
|
table.DeviceSessions.DeviceSessionID,
|
|
table.DeviceSessions.UserID,
|
|
table.DeviceSessions.ClientPublicKey,
|
|
table.DeviceSessions.Status,
|
|
).VALUES(session.DeviceSessionID, session.UserID, session.ClientPublicKey, SessionStatusActive)
|
|
if _, err := insertStmt.ExecContext(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, ErrChallengeNotFound) {
|
|
return err
|
|
}
|
|
return fmt.Errorf("auth store: mark consumed and insert session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ListActiveSessions loads every row from device_sessions whose status
|
|
// is 'active'. Cache.Warm calls this at process boot.
|
|
func (s *Store) ListActiveSessions(ctx context.Context) ([]Session, error) {
|
|
stmt := postgres.SELECT(sessionColumns()).
|
|
FROM(table.DeviceSessions).
|
|
WHERE(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive)))
|
|
|
|
var rows []model.DeviceSessions
|
|
if err := stmt.QueryContext(ctx, s.db, &rows); err != nil {
|
|
return nil, fmt.Errorf("auth store: list active sessions: %w", err)
|
|
}
|
|
out := make([]Session, 0, len(rows))
|
|
for _, row := range rows {
|
|
out = append(out, modelToSession(row))
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// LoadSession returns the row for deviceSessionID regardless of status.
|
|
// Returns ErrSessionNotFound on missing row.
|
|
func (s *Store) LoadSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, error) {
|
|
stmt := postgres.SELECT(sessionColumns()).
|
|
FROM(table.DeviceSessions).
|
|
WHERE(table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID))).
|
|
LIMIT(1)
|
|
|
|
var row model.DeviceSessions
|
|
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
return Session{}, ErrSessionNotFound
|
|
}
|
|
return Session{}, fmt.Errorf("auth store: load session %s: %w", deviceSessionID, err)
|
|
}
|
|
return modelToSession(row), nil
|
|
}
|
|
|
|
// RevokeSession transitions an active row to status='revoked' and
|
|
// returns the row as it stands after the update. The boolean reports
|
|
// whether the UPDATE actually changed a row — false means the row was
|
|
// already revoked or did not exist; the auth Service then falls back to
|
|
// LoadSession for idempotent-revoke responses.
|
|
func (s *Store) RevokeSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, bool, error) {
|
|
stmt := table.DeviceSessions.
|
|
UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt).
|
|
SET(postgres.String(SessionStatusRevoked), postgres.NOW()).
|
|
WHERE(
|
|
table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID)).
|
|
AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))),
|
|
).
|
|
RETURNING(sessionColumns())
|
|
|
|
var row model.DeviceSessions
|
|
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
return Session{}, false, nil
|
|
}
|
|
return Session{}, false, fmt.Errorf("auth store: revoke session %s: %w", deviceSessionID, err)
|
|
}
|
|
return modelToSession(row), true, nil
|
|
}
|
|
|
|
// RevokeAllForUser transitions every active row for userID to
|
|
// status='revoked' and returns the rows as they stand after the update.
|
|
// An empty slice with a nil error is returned when the user owned no
|
|
// active sessions; the caller must treat that as a successful idempotent
|
|
// revoke (the API surface returns revoked_count=0 in that case).
|
|
func (s *Store) RevokeAllForUser(ctx context.Context, userID uuid.UUID) ([]Session, error) {
|
|
stmt := table.DeviceSessions.
|
|
UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt).
|
|
SET(postgres.String(SessionStatusRevoked), postgres.NOW()).
|
|
WHERE(
|
|
table.DeviceSessions.UserID.EQ(postgres.UUID(userID)).
|
|
AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))),
|
|
).
|
|
RETURNING(sessionColumns())
|
|
|
|
var rows []model.DeviceSessions
|
|
if err := stmt.QueryContext(ctx, s.db, &rows); err != nil {
|
|
return nil, fmt.Errorf("auth store: revoke all for user %s: %w", userID, err)
|
|
}
|
|
out := make([]Session, 0, len(rows))
|
|
for _, row := range rows {
|
|
out = append(out, modelToSession(row))
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// modelToChallenge projects a generated model row into the public
|
|
// Challenge struct. Pointer fields are copied so callers cannot mutate
|
|
// the underlying scan buffer.
|
|
func modelToChallenge(row model.AuthChallenges) Challenge {
|
|
c := Challenge{
|
|
ChallengeID: row.ChallengeID,
|
|
Email: row.Email,
|
|
CodeHash: row.CodeHash,
|
|
Attempts: row.Attempts,
|
|
CreatedAt: row.CreatedAt,
|
|
ExpiresAt: row.ExpiresAt,
|
|
PreferredLanguage: row.PreferredLanguage,
|
|
}
|
|
if row.ConsumedAt != nil {
|
|
t := *row.ConsumedAt
|
|
c.ConsumedAt = &t
|
|
}
|
|
return c
|
|
}
|
|
|
|
// modelToSession projects a generated model row into the public Session
|
|
// struct.
|
|
func modelToSession(row model.DeviceSessions) Session {
|
|
s := Session{
|
|
DeviceSessionID: row.DeviceSessionID,
|
|
UserID: row.UserID,
|
|
Status: row.Status,
|
|
ClientPublicKey: row.ClientPublicKey,
|
|
CreatedAt: row.CreatedAt,
|
|
}
|
|
if row.RevokedAt != nil {
|
|
t := *row.RevokedAt
|
|
s.RevokedAt = &t
|
|
}
|
|
if row.LastSeenAt != nil {
|
|
t := *row.LastSeenAt
|
|
s.LastSeenAt = &t
|
|
}
|
|
return s
|
|
}
|
|
|
|
// withTx wraps fn in a Postgres transaction. fn's return value
|
|
// determines commit (nil) vs rollback (non-nil). Rollback errors are
|
|
// swallowed when fn already returned an error, since the latter is more
|
|
// actionable.
|
|
func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("auth store: begin tx: %w", err)
|
|
}
|
|
if err := fn(tx); err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("auth store: commit tx: %w", err)
|
|
}
|
|
return nil
|
|
}
|