package notification import ( "context" "database/sql" "encoding/json" "errors" "fmt" "strings" "time" "galaxy/backend/internal/postgres/jet/backend/model" "galaxy/backend/internal/postgres/jet/backend/table" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" ) // Store is the Postgres-backed query surface for notifications, // notification_routes, notification_dead_letters, and // notification_malformed_intents. All queries are built through go-jet // against the generated table bindings under // `backend/internal/postgres/jet/backend/table`. type Store struct { db *sql.DB } // NewStore constructs a Store wrapping db. func NewStore(db *sql.DB) *Store { return &Store{db: db} } // BeginTx exposes the transaction handle to the worker so the // claim-dispatch-mark cycle stays within a single commit boundary. func (s *Store) BeginTx(ctx context.Context) (*sql.Tx, error) { return s.db.BeginTx(ctx, nil) } // RouteSeed describes one freshly-materialised route destined for an // `INSERT INTO notification_routes` inside InsertNotification. type RouteSeed struct { RouteID uuid.UUID Channel string Status string MaxAttempts int32 NextAttemptAt *time.Time ResolvedEmail string ResolvedLocale string UserID *uuid.UUID DeviceSessionID *uuid.UUID SkippedAt *time.Time LastError string } // InsertNotificationArgs aggregates the inputs to InsertNotification. type InsertNotificationArgs struct { NotificationID uuid.UUID Kind string IdempotencyKey string UserID *uuid.UUID Payload map[string]any Routes []RouteSeed } // InsertNotification persists a notification row together with its // route rows in a single transaction. The (kind, idempotency_key) // UNIQUE constraint serves the idempotency contract: the second // caller observes inserted=false and the existing notification_id is // returned. On the duplicate path no route rows are inserted and the // transaction rolls back so an orphan notification cannot exist. func (s *Store) InsertNotification(ctx context.Context, args InsertNotificationArgs) (uuid.UUID, bool, error) { payload, err := encodePayload(args.Payload) if err != nil { return uuid.Nil, false, fmt.Errorf("encode payload: %w", err) } var ( storedID uuid.UUID inserted bool ) err = withTx(ctx, s.db, func(tx *sql.Tx) error { insertStmt := table.Notifications.INSERT( table.Notifications.NotificationID, table.Notifications.Kind, table.Notifications.IdempotencyKey, table.Notifications.UserID, table.Notifications.Payload, ).VALUES( args.NotificationID, args.Kind, args.IdempotencyKey, args.UserID, string(payload), ). ON_CONFLICT(table.Notifications.Kind, table.Notifications.IdempotencyKey). DO_NOTHING(). RETURNING(table.Notifications.NotificationID) var freshRow model.Notifications err := insertStmt.QueryContext(ctx, tx, &freshRow) switch { case errors.Is(err, qrm.ErrNoRows): // Idempotent re-submit. Look up the existing row id and bail. lookupStmt := postgres.SELECT(table.Notifications.NotificationID). FROM(table.Notifications). WHERE( table.Notifications.Kind.EQ(postgres.String(args.Kind)). AND(table.Notifications.IdempotencyKey.EQ(postgres.String(args.IdempotencyKey))), ). LIMIT(1) var existing model.Notifications if scanErr := lookupStmt.QueryContext(ctx, tx, &existing); scanErr != nil { return fmt.Errorf("lookup existing notification: %w", scanErr) } storedID = existing.NotificationID return errIdempotentNoop case err != nil: return fmt.Errorf("insert notification: %w", err) } storedID = freshRow.NotificationID inserted = true for _, r := range args.Routes { routeStmt := table.NotificationRoutes.INSERT( table.NotificationRoutes.RouteID, table.NotificationRoutes.NotificationID, table.NotificationRoutes.Channel, table.NotificationRoutes.Status, table.NotificationRoutes.MaxAttempts, table.NotificationRoutes.NextAttemptAt, table.NotificationRoutes.ResolvedEmail, table.NotificationRoutes.ResolvedLocale, table.NotificationRoutes.LastError, table.NotificationRoutes.SkippedAt, ).VALUES( r.RouteID, args.NotificationID, r.Channel, r.Status, r.MaxAttempts, r.NextAttemptAt, r.ResolvedEmail, r.ResolvedLocale, r.LastError, r.SkippedAt, ) if _, err := routeStmt.ExecContext(ctx, tx); err != nil { return fmt.Errorf("insert route %s: %w", r.RouteID, err) } } return nil }) if errors.Is(err, errIdempotentNoop) { return storedID, false, nil } if err != nil { return uuid.Nil, false, err } return storedID, inserted, nil } // errIdempotentNoop tells withTx to roll back the transaction without // surfacing an error to the caller. It must never escape this package. var errIdempotentNoop = errors.New("notification store: idempotent noop") // MarkRoutePublished flips a route to status='published', clears the // retry schedule, stamps published_at and last_attempt_at, and clears // last_error. func (s *Store) MarkRoutePublished(ctx context.Context, tx *sql.Tx, routeID uuid.UUID, at time.Time) error { r := table.NotificationRoutes stmt := r.UPDATE(). SET( r.Status.SET(postgres.String(RouteStatusPublished)), r.Attempts.SET(r.Attempts.ADD(postgres.Int(1))), r.LastAttemptAt.SET(postgres.TimestampzT(at)), r.PublishedAt.SET(postgres.TimestampzT(at)), r.NextAttemptAt.SET(postgres.TimestampzExp(postgres.NULL)), r.LastError.SET(postgres.String("")), r.UpdatedAt.SET(postgres.TimestampzT(at)), ). WHERE(r.RouteID.EQ(postgres.UUID(routeID))) if _, err := stmt.ExecContext(ctx, tx); err != nil { return fmt.Errorf("mark route published: %w", err) } return nil } // ScheduleRouteRetry flips a route to status='retrying', bumps // attempts, arms next_attempt_at, and stamps the diagnostic message. func (s *Store) ScheduleRouteRetry(ctx context.Context, tx *sql.Tx, routeID uuid.UUID, at time.Time, nextAt time.Time, errMsg string) error { r := table.NotificationRoutes stmt := r.UPDATE(). SET( r.Status.SET(postgres.String(RouteStatusRetrying)), r.Attempts.SET(r.Attempts.ADD(postgres.Int(1))), r.LastAttemptAt.SET(postgres.TimestampzT(at)), r.NextAttemptAt.SET(postgres.TimestampzT(nextAt)), r.LastError.SET(postgres.String(errMsg)), r.UpdatedAt.SET(postgres.TimestampzT(at)), ). WHERE(r.RouteID.EQ(postgres.UUID(routeID))) if _, err := stmt.ExecContext(ctx, tx); err != nil { return fmt.Errorf("schedule route retry: %w", err) } return nil } // MarkRouteDeadLettered moves the route to the terminal `dead_lettered` // state and inserts a notification_dead_letters row under the same // transaction. func (s *Store) MarkRouteDeadLettered(ctx context.Context, tx *sql.Tx, notificationID, routeID uuid.UUID, at time.Time, reason string) error { r := table.NotificationRoutes updateStmt := r.UPDATE(). SET( r.Status.SET(postgres.String(RouteStatusDeadLettered)), r.Attempts.SET(r.Attempts.ADD(postgres.Int(1))), r.LastAttemptAt.SET(postgres.TimestampzT(at)), r.NextAttemptAt.SET(postgres.TimestampzExp(postgres.NULL)), r.DeadLetteredAt.SET(postgres.TimestampzT(at)), r.LastError.SET(postgres.String(reason)), r.UpdatedAt.SET(postgres.TimestampzT(at)), ). WHERE(r.RouteID.EQ(postgres.UUID(routeID))) if _, err := updateStmt.ExecContext(ctx, tx); err != nil { return fmt.Errorf("mark route dead-lettered: %w", err) } dl := table.NotificationDeadLetters insertStmt := dl.INSERT( dl.DeadLetterID, dl.NotificationID, dl.RouteID, dl.ArchivedAt, dl.Reason, ).VALUES(uuid.New(), notificationID, routeID, at, reason) if _, err := insertStmt.ExecContext(ctx, tx); err != nil { return fmt.Errorf("insert notification dead-letter: %w", err) } return nil } // ClaimedRoute bundles a locked route row with its parent notification // so the worker has every field it needs in one trip. type ClaimedRoute struct { Route Route Notification Notification } // ClaimDueRoutes locks up to `limit` due routes with FOR UPDATE SKIP // LOCKED, joins the parent notification to surface kind/payload, and // returns them. exclude is the list of route_ids already handled in // the current tick — they are filtered out so the same row cannot // chew through MaxAttempts inside a single tick when its retry // schedule lands at <= now(). func (s *Store) ClaimDueRoutes(ctx context.Context, tx *sql.Tx, limit int, exclude ...uuid.UUID) ([]ClaimedRoute, error) { r := table.NotificationRoutes n := table.Notifications condition := r.Status.IN(postgres.String(RouteStatusPending), postgres.String(RouteStatusRetrying)). AND(r.NextAttemptAt.IS_NULL().OR(r.NextAttemptAt.LT_EQ(postgres.NOW()))) if len(exclude) > 0 { excludeExprs := make([]postgres.Expression, 0, len(exclude)) for _, id := range exclude { excludeExprs = append(excludeExprs, postgres.UUID(id)) } condition = condition.AND(r.RouteID.NOT_IN(excludeExprs...)) } stmt := postgres.SELECT( r.AllColumns, n.Kind, n.IdempotencyKey, n.UserID, n.Payload, n.CreatedAt, ). FROM(r.INNER_JOIN(n, n.NotificationID.EQ(r.NotificationID))). WHERE(condition). ORDER_BY(postgres.COALESCE(r.NextAttemptAt, r.CreatedAt).ASC()). LIMIT(int64(limit)). FOR(postgres.UPDATE().OF(r).SKIP_LOCKED()) var rows []struct { model.NotificationRoutes Notifications struct { Kind string IdempotencyKey string UserID *uuid.UUID Payload *string CreatedAt time.Time } } if err := stmt.QueryContext(ctx, tx, &rows); err != nil { return nil, fmt.Errorf("claim due routes: %w", err) } out := make([]ClaimedRoute, 0, len(rows)) for _, row := range rows { route := modelToRoute(row.NotificationRoutes) route.UserID = row.Notifications.UserID notif := Notification{ NotificationID: row.NotificationRoutes.NotificationID, Kind: row.Notifications.Kind, IdempotencyKey: row.Notifications.IdempotencyKey, UserID: row.Notifications.UserID, CreatedAt: row.Notifications.CreatedAt, } decoded, err := decodePayload(payloadBytesFromPtr(row.Notifications.Payload)) if err != nil { return nil, fmt.Errorf("decode notification payload: %w", err) } notif.Payload = decoded out = append(out, ClaimedRoute{Route: route, Notification: notif}) } return out, nil } // ListNotificationsResult bundles a page of notifications and the // total-row count. Layout mirrors `mail.AdminListDeliveriesPage`. type ListNotificationsResult struct { Items []Notification Total int64 } // ListNotifications returns the page newest-first. func (s *Store) ListNotifications(ctx context.Context, offset, limit int) (ListNotificationsResult, error) { total, err := countAll(ctx, s.db, table.Notifications) if err != nil { return ListNotificationsResult{}, fmt.Errorf("count notifications: %w", err) } n := table.Notifications stmt := postgres.SELECT( n.NotificationID, n.Kind, n.IdempotencyKey, n.UserID, n.Payload, n.CreatedAt, ). FROM(n). ORDER_BY(n.CreatedAt.DESC(), n.NotificationID.DESC()). LIMIT(int64(limit)).OFFSET(int64(offset)) var rows []model.Notifications if err := stmt.QueryContext(ctx, s.db, &rows); err != nil { return ListNotificationsResult{}, fmt.Errorf("list notifications: %w", err) } items := make([]Notification, 0, len(rows)) for _, row := range rows { notif, err := modelToNotification(row) if err != nil { return ListNotificationsResult{}, err } items = append(items, notif) } return ListNotificationsResult{Items: items, Total: total}, nil } // GetNotification loads a notification by primary key. The sentinel // ErrNotificationNotFound is returned when no row matches. func (s *Store) GetNotification(ctx context.Context, id uuid.UUID) (Notification, error) { n := table.Notifications stmt := postgres.SELECT( n.NotificationID, n.Kind, n.IdempotencyKey, n.UserID, n.Payload, n.CreatedAt, ). FROM(n). WHERE(n.NotificationID.EQ(postgres.UUID(id))). LIMIT(1) var row model.Notifications if err := stmt.QueryContext(ctx, s.db, &row); err != nil { if errors.Is(err, qrm.ErrNoRows) { return Notification{}, ErrNotificationNotFound } return Notification{}, fmt.Errorf("get notification: %w", err) } return modelToNotification(row) } // ListDeadLettersResult bundles a page of dead-letters and the total // row count. type ListDeadLettersResult struct { Items []DeadLetter Total int64 } // ListDeadLetters returns the dead-letter page newest-first. func (s *Store) ListDeadLetters(ctx context.Context, offset, limit int) (ListDeadLettersResult, error) { total, err := countAll(ctx, s.db, table.NotificationDeadLetters) if err != nil { return ListDeadLettersResult{}, fmt.Errorf("count dead-letters: %w", err) } dl := table.NotificationDeadLetters stmt := postgres.SELECT( dl.DeadLetterID, dl.NotificationID, dl.RouteID, dl.ArchivedAt, dl.Reason, ). FROM(dl). ORDER_BY(dl.ArchivedAt.DESC(), dl.DeadLetterID.DESC()). LIMIT(int64(limit)).OFFSET(int64(offset)) var rows []model.NotificationDeadLetters if err := stmt.QueryContext(ctx, s.db, &rows); err != nil { return ListDeadLettersResult{}, fmt.Errorf("list dead-letters: %w", err) } items := make([]DeadLetter, 0, len(rows)) for _, row := range rows { items = append(items, DeadLetter{ DeadLetterID: row.DeadLetterID, NotificationID: row.NotificationID, RouteID: row.RouteID, ArchivedAt: row.ArchivedAt, Reason: row.Reason, }) } return ListDeadLettersResult{Items: items, Total: total}, nil } // ListMalformedResult bundles a page of malformed intents and the // total row count. type ListMalformedResult struct { Items []MalformedIntent Total int64 } // ListMalformed returns the malformed page newest-first. func (s *Store) ListMalformed(ctx context.Context, offset, limit int) (ListMalformedResult, error) { total, err := countAll(ctx, s.db, table.NotificationMalformedIntents) if err != nil { return ListMalformedResult{}, fmt.Errorf("count malformed intents: %w", err) } m := table.NotificationMalformedIntents stmt := postgres.SELECT(m.ID, m.ReceivedAt, m.Payload, m.Reason). FROM(m). ORDER_BY(m.ReceivedAt.DESC(), m.ID.DESC()). LIMIT(int64(limit)).OFFSET(int64(offset)) var rows []model.NotificationMalformedIntents if err := stmt.QueryContext(ctx, s.db, &rows); err != nil { return ListMalformedResult{}, fmt.Errorf("list malformed intents: %w", err) } items := make([]MalformedIntent, 0, len(rows)) for _, row := range rows { decoded, err := decodePayload([]byte(row.Payload)) if err != nil { return ListMalformedResult{}, fmt.Errorf("decode malformed payload: %w", err) } items = append(items, MalformedIntent{ ID: row.ID, ReceivedAt: row.ReceivedAt, Payload: decoded, Reason: row.Reason, }) } return ListMalformedResult{Items: items, Total: total}, nil } // InsertMalformed records a producer-supplied intent that failed // validation. The payload is best-effort JSON-encoded by the caller; // the row never blocks the producer. func (s *Store) InsertMalformed(ctx context.Context, payload map[string]any, reason string) error { encoded, err := encodePayload(payload) if err != nil { return fmt.Errorf("encode malformed payload: %w", err) } m := table.NotificationMalformedIntents stmt := m.INSERT(m.ID, m.Payload, m.Reason). VALUES(uuid.New(), string(encoded), reason) if _, err := stmt.ExecContext(ctx, s.db); err != nil { return fmt.Errorf("insert malformed intent: %w", err) } return nil } // SkipPendingRoutesForUser flips every pending or retrying route owned // by userID to status='skipped'. The `OnUserDeleted` cascade calls it so // the worker stops trying to deliver notifications to a vanished // account; published rows are kept as audit trail. func (s *Store) SkipPendingRoutesForUser(ctx context.Context, userID uuid.UUID, at time.Time) (int64, error) { r := table.NotificationRoutes n := table.Notifications notifSubquery := postgres.SELECT(n.NotificationID). FROM(n). WHERE(n.UserID.EQ(postgres.UUID(userID))) stmt := r.UPDATE(). SET( r.Status.SET(postgres.String(RouteStatusSkipped)), r.NextAttemptAt.SET(postgres.TimestampzExp(postgres.NULL)), r.SkippedAt.SET(postgres.TimestampzT(at)), r.UpdatedAt.SET(postgres.TimestampzT(at)), r.LastError.SET(postgres.String("recipient soft-deleted")), ). WHERE( r.Status.IN(postgres.String(RouteStatusPending), postgres.String(RouteStatusRetrying)). AND(r.NotificationID.IN(notifSubquery)), ) res, err := stmt.ExecContext(ctx, s.db) if err != nil { return 0, fmt.Errorf("skip pending routes: %w", err) } affected, err := res.RowsAffected() if err != nil { return 0, fmt.Errorf("rows affected: %w", err) } return affected, nil } // withTx wraps fn in a Postgres transaction. fn's return value // determines commit (nil) vs rollback (non-nil). Rollback errors are // swallowed when fn already returned an error, since the latter is // more actionable. func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error { tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("notification store: begin tx: %w", err) } if err := fn(tx); err != nil { _ = tx.Rollback() return err } if err := tx.Commit(); err != nil { return fmt.Errorf("notification store: commit tx: %w", err) } return nil } // modelToRoute projects a generated model row onto the public Route // struct (without the user-id which lives on the parent notification). func modelToRoute(row model.NotificationRoutes) Route { r := Route{ RouteID: row.RouteID, NotificationID: row.NotificationID, Channel: row.Channel, Status: row.Status, Attempts: row.Attempts, MaxAttempts: row.MaxAttempts, LastError: row.LastError, ResolvedEmail: row.ResolvedEmail, ResolvedLocale: row.ResolvedLocale, CreatedAt: row.CreatedAt, UpdatedAt: row.UpdatedAt, } if row.NextAttemptAt != nil { t := *row.NextAttemptAt r.NextAttemptAt = &t } if row.LastAttemptAt != nil { t := *row.LastAttemptAt r.LastAttemptAt = &t } if row.PublishedAt != nil { t := *row.PublishedAt r.PublishedAt = &t } if row.DeadLetteredAt != nil { t := *row.DeadLetteredAt r.DeadLetteredAt = &t } if row.SkippedAt != nil { t := *row.SkippedAt r.SkippedAt = &t } return r } // modelToNotification decodes a generated model row into the public // Notification struct, including the JSON payload. func modelToNotification(row model.Notifications) (Notification, error) { decoded, err := decodePayload(payloadBytesFromPtr(row.Payload)) if err != nil { return Notification{}, fmt.Errorf("decode payload: %w", err) } return Notification{ NotificationID: row.NotificationID, Kind: row.Kind, IdempotencyKey: row.IdempotencyKey, UserID: row.UserID, Payload: decoded, CreatedAt: row.CreatedAt, }, nil } // payloadBytesFromPtr converts the nullable string from the generated // jsonb-as-text model into the byte slice expected by decodePayload. func payloadBytesFromPtr(p *string) []byte { if p == nil { return nil } return []byte(*p) } // encodePayload renders a map[string]any to JSON for storage in // jsonb columns. A nil map encodes as JSON null; this is harmless on // the read path because decodePayload returns nil for it. func encodePayload(payload map[string]any) ([]byte, error) { if payload == nil { return []byte("null"), nil } return json.Marshal(payload) } // decodePayload parses a jsonb column back into the producer's map. // A NULL or empty buffer round-trips to nil. func decodePayload(buf []byte) (map[string]any, error) { if len(buf) == 0 || strings.EqualFold(strings.TrimSpace(string(buf)), "null") { return nil, nil } out := map[string]any{} if err := json.Unmarshal(buf, &out); err != nil { return nil, err } return out, nil } // countAll runs `SELECT COUNT(*) FROM ` through jet and returns // the result. The destination uses an alias-tagged scalar so QRM can // map the un-prefixed alias produced by AS("count"). func countAll(ctx context.Context, db qrm.DB, tbl postgres.ReadableTable) (int64, error) { stmt := postgres.SELECT(postgres.COUNT(postgres.STAR).AS("count")).FROM(tbl) var dest struct { Count int64 `alias:"count"` } if err := stmt.QueryContext(ctx, db, &dest); err != nil { return 0, err } return dest.Count, nil }