feat: use postgres
This commit is contained in:
@@ -0,0 +1,729 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
pg "github.com/go-jet/jet/v2/postgres"
|
||||
)
|
||||
|
||||
// entitlementPeriodSelectColumns is the canonical SELECT list for
|
||||
// entitlement_records, matching scanEntitlementPeriod's column order.
|
||||
var entitlementPeriodSelectColumns = pg.ColumnList{
|
||||
pgtable.EntitlementRecords.RecordID,
|
||||
pgtable.EntitlementRecords.UserID,
|
||||
pgtable.EntitlementRecords.PlanCode,
|
||||
pgtable.EntitlementRecords.Source,
|
||||
pgtable.EntitlementRecords.ActorType,
|
||||
pgtable.EntitlementRecords.ActorID,
|
||||
pgtable.EntitlementRecords.ReasonCode,
|
||||
pgtable.EntitlementRecords.StartsAt,
|
||||
pgtable.EntitlementRecords.EndsAt,
|
||||
pgtable.EntitlementRecords.CreatedAt,
|
||||
pgtable.EntitlementRecords.ClosedAt,
|
||||
pgtable.EntitlementRecords.ClosedByType,
|
||||
pgtable.EntitlementRecords.ClosedByID,
|
||||
pgtable.EntitlementRecords.ClosedReasonCode,
|
||||
}
|
||||
|
||||
// entitlementSnapshotSelectColumns is the canonical SELECT list for
|
||||
// entitlement_snapshots, matching scanEntitlementSnapshotRow's column order.
|
||||
var entitlementSnapshotSelectColumns = pg.ColumnList{
|
||||
pgtable.EntitlementSnapshots.UserID,
|
||||
pgtable.EntitlementSnapshots.PlanCode,
|
||||
pgtable.EntitlementSnapshots.IsPaid,
|
||||
pgtable.EntitlementSnapshots.StartsAt,
|
||||
pgtable.EntitlementSnapshots.EndsAt,
|
||||
pgtable.EntitlementSnapshots.Source,
|
||||
pgtable.EntitlementSnapshots.ActorType,
|
||||
pgtable.EntitlementSnapshots.ActorID,
|
||||
pgtable.EntitlementSnapshots.ReasonCode,
|
||||
pgtable.EntitlementSnapshots.UpdatedAt,
|
||||
}
|
||||
|
||||
// CreateEntitlementRecord stores one new entitlement period history record.
|
||||
// The unique key is record_id; a duplicate record_id returns
|
||||
// ports.ErrConflict.
|
||||
func (store *Store) CreateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create entitlement record in postgres: %w", err)
|
||||
}
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "create entitlement record in postgres")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
return insertEntitlementPeriod(operationCtx, store.db, record)
|
||||
}
|
||||
|
||||
// GetEntitlementRecordByID returns the entitlement period record identified
|
||||
// by recordID.
|
||||
func (store *Store) GetEntitlementRecordByID(ctx context.Context, recordID entitlement.EntitlementRecordID) (entitlement.PeriodRecord, error) {
|
||||
if err := recordID.Validate(); err != nil {
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record from postgres: %w", err)
|
||||
}
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement record from postgres")
|
||||
if err != nil {
|
||||
return entitlement.PeriodRecord{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
stmt := pg.SELECT(entitlementPeriodSelectColumns).
|
||||
FROM(pgtable.EntitlementRecords).
|
||||
WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(recordID.String())))
|
||||
|
||||
query, args := stmt.Sql()
|
||||
row := store.db.QueryRowContext(operationCtx, query, args...)
|
||||
record, err := scanEntitlementPeriodRow(row)
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record %q from postgres: %w", recordID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record %q from postgres: %w", recordID, err)
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// ListEntitlementRecordsByUserID returns every entitlement period record
|
||||
// owned by userID, ordered by created_at ascending so historical replay is
|
||||
// deterministic.
|
||||
func (store *Store) ListEntitlementRecordsByUserID(ctx context.Context, userID common.UserID) ([]entitlement.PeriodRecord, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("list entitlement records from postgres: %w", err)
|
||||
}
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "list entitlement records from postgres")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
stmt := pg.SELECT(entitlementPeriodSelectColumns).
|
||||
FROM(pgtable.EntitlementRecords).
|
||||
WHERE(pgtable.EntitlementRecords.UserID.EQ(pg.String(userID.String()))).
|
||||
ORDER_BY(pgtable.EntitlementRecords.CreatedAt.ASC(), pgtable.EntitlementRecords.RecordID.ASC())
|
||||
|
||||
query, args := stmt.Sql()
|
||||
rows, err := store.db.QueryContext(operationCtx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := make([]entitlement.PeriodRecord, 0)
|
||||
for rows.Next() {
|
||||
record, err := scanEntitlementPeriodRows(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
|
||||
}
|
||||
out = append(out, record)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// UpdateEntitlementRecord replaces one stored entitlement period record. The
|
||||
// statement matches by record_id; ports.ErrNotFound is returned when the
|
||||
// record does not exist.
|
||||
func (store *Store) UpdateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("update entitlement record in postgres: %w", err)
|
||||
}
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "update entitlement record in postgres")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
rows, err := updateEntitlementPeriod(operationCtx, store.db, record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update entitlement record %q in postgres: %w", record.RecordID, err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("update entitlement record %q in postgres: %w", record.RecordID, ports.ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateEntitlementPeriod(ctx context.Context, q queryer, record entitlement.PeriodRecord) (int64, error) {
|
||||
stmt := pgtable.EntitlementRecords.UPDATE(
|
||||
pgtable.EntitlementRecords.PlanCode,
|
||||
pgtable.EntitlementRecords.Source,
|
||||
pgtable.EntitlementRecords.ActorType,
|
||||
pgtable.EntitlementRecords.ActorID,
|
||||
pgtable.EntitlementRecords.ReasonCode,
|
||||
pgtable.EntitlementRecords.StartsAt,
|
||||
pgtable.EntitlementRecords.EndsAt,
|
||||
pgtable.EntitlementRecords.CreatedAt,
|
||||
pgtable.EntitlementRecords.ClosedAt,
|
||||
pgtable.EntitlementRecords.ClosedByType,
|
||||
pgtable.EntitlementRecords.ClosedByID,
|
||||
pgtable.EntitlementRecords.ClosedReasonCode,
|
||||
).SET(
|
||||
string(record.PlanCode),
|
||||
record.Source.String(),
|
||||
record.Actor.Type.String(),
|
||||
nullableActorID(record.Actor.ID),
|
||||
record.ReasonCode.String(),
|
||||
record.StartsAt.UTC(),
|
||||
nullableTime(record.EndsAt),
|
||||
record.CreatedAt.UTC(),
|
||||
nullableTime(record.ClosedAt),
|
||||
nullableActorType(record.ClosedBy.Type),
|
||||
nullableActorID(record.ClosedBy.ID),
|
||||
nullableReasonCode(record.ClosedReasonCode),
|
||||
).WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(record.RecordID.String())))
|
||||
|
||||
query, args := stmt.Sql()
|
||||
res, err := q.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func insertEntitlementPeriod(ctx context.Context, q queryer, record entitlement.PeriodRecord) error {
|
||||
stmt := pgtable.EntitlementRecords.INSERT(
|
||||
pgtable.EntitlementRecords.RecordID,
|
||||
pgtable.EntitlementRecords.UserID,
|
||||
pgtable.EntitlementRecords.PlanCode,
|
||||
pgtable.EntitlementRecords.Source,
|
||||
pgtable.EntitlementRecords.ActorType,
|
||||
pgtable.EntitlementRecords.ActorID,
|
||||
pgtable.EntitlementRecords.ReasonCode,
|
||||
pgtable.EntitlementRecords.StartsAt,
|
||||
pgtable.EntitlementRecords.EndsAt,
|
||||
pgtable.EntitlementRecords.CreatedAt,
|
||||
pgtable.EntitlementRecords.ClosedAt,
|
||||
pgtable.EntitlementRecords.ClosedByType,
|
||||
pgtable.EntitlementRecords.ClosedByID,
|
||||
pgtable.EntitlementRecords.ClosedReasonCode,
|
||||
).VALUES(
|
||||
record.RecordID.String(),
|
||||
record.UserID.String(),
|
||||
string(record.PlanCode),
|
||||
record.Source.String(),
|
||||
record.Actor.Type.String(),
|
||||
nullableActorID(record.Actor.ID),
|
||||
record.ReasonCode.String(),
|
||||
record.StartsAt.UTC(),
|
||||
nullableTime(record.EndsAt),
|
||||
record.CreatedAt.UTC(),
|
||||
nullableTime(record.ClosedAt),
|
||||
nullableActorType(record.ClosedBy.Type),
|
||||
nullableActorID(record.ClosedBy.ID),
|
||||
nullableReasonCode(record.ClosedReasonCode),
|
||||
)
|
||||
|
||||
query, args := stmt.Sql()
|
||||
_, err := q.ExecContext(ctx, query, args...)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if isUniqueViolation(err) {
|
||||
return fmt.Errorf("create entitlement record %q in postgres: %w", record.RecordID, ports.ErrConflict)
|
||||
}
|
||||
return fmt.Errorf("create entitlement record %q in postgres: %w", record.RecordID, err)
|
||||
}
|
||||
|
||||
// scannableRow abstracts *sql.Row and *sql.Rows so the row-scanner can be
|
||||
// shared by single-row and iterating callers.
|
||||
type scannableRow interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanEntitlementPeriodRow(row *sql.Row) (entitlement.PeriodRecord, error) {
|
||||
record, err := scanEntitlementPeriod(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return entitlement.PeriodRecord{}, ports.ErrNotFound
|
||||
}
|
||||
return record, err
|
||||
}
|
||||
|
||||
func scanEntitlementPeriodRows(rows *sql.Rows) (entitlement.PeriodRecord, error) {
|
||||
return scanEntitlementPeriod(rows)
|
||||
}
|
||||
|
||||
func scanEntitlementPeriod(row scannableRow) (entitlement.PeriodRecord, error) {
|
||||
var (
|
||||
recordID string
|
||||
userID string
|
||||
planCode string
|
||||
source string
|
||||
actorType string
|
||||
actorID *string
|
||||
reasonCode string
|
||||
startsAt time.Time
|
||||
endsAt *time.Time
|
||||
createdAt time.Time
|
||||
closedAt *time.Time
|
||||
closedByType *string
|
||||
closedByID *string
|
||||
closedReason *string
|
||||
)
|
||||
if err := row.Scan(
|
||||
&recordID, &userID, &planCode, &source,
|
||||
&actorType, &actorID, &reasonCode,
|
||||
&startsAt, &endsAt, &createdAt,
|
||||
&closedAt, &closedByType, &closedByID, &closedReason,
|
||||
); err != nil {
|
||||
return entitlement.PeriodRecord{}, err
|
||||
}
|
||||
record := entitlement.PeriodRecord{
|
||||
RecordID: entitlement.EntitlementRecordID(recordID),
|
||||
UserID: common.UserID(userID),
|
||||
PlanCode: entitlement.PlanCode(planCode),
|
||||
Source: common.Source(source),
|
||||
Actor: common.ActorRef{Type: common.ActorType(actorType)},
|
||||
ReasonCode: common.ReasonCode(reasonCode),
|
||||
StartsAt: startsAt.UTC(),
|
||||
EndsAt: timeFromNullable(endsAt),
|
||||
CreatedAt: createdAt.UTC(),
|
||||
ClosedAt: timeFromNullable(closedAt),
|
||||
}
|
||||
if actorID != nil {
|
||||
record.Actor.ID = common.ActorID(*actorID)
|
||||
}
|
||||
if closedByType != nil {
|
||||
record.ClosedBy.Type = common.ActorType(*closedByType)
|
||||
}
|
||||
if closedByID != nil {
|
||||
record.ClosedBy.ID = common.ActorID(*closedByID)
|
||||
}
|
||||
if closedReason != nil {
|
||||
record.ClosedReasonCode = common.ReasonCode(*closedReason)
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetEntitlementByUserID returns the current entitlement snapshot for userID.
|
||||
func (store *Store) GetEntitlementByUserID(ctx context.Context, userID common.UserID) (entitlement.CurrentSnapshot, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot from postgres: %w", err)
|
||||
}
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement snapshot from postgres")
|
||||
if err != nil {
|
||||
return entitlement.CurrentSnapshot{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
stmt := pg.SELECT(entitlementSnapshotSelectColumns).
|
||||
FROM(pgtable.EntitlementSnapshots).
|
||||
WHERE(pgtable.EntitlementSnapshots.UserID.EQ(pg.String(userID.String())))
|
||||
|
||||
query, args := stmt.Sql()
|
||||
row := store.db.QueryRowContext(operationCtx, query, args...)
|
||||
record, err := scanEntitlementSnapshotRow(row)
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot for %q from postgres: %w", userID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot for %q from postgres: %w", userID, err)
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// PutEntitlement stores the current entitlement snapshot for record.UserID.
|
||||
// It is an UPSERT so the runtime path can call it on creation and on
|
||||
// replacement uniformly.
|
||||
func (store *Store) PutEntitlement(ctx context.Context, record entitlement.CurrentSnapshot) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("put entitlement snapshot in postgres: %w", err)
|
||||
}
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "put entitlement snapshot in postgres")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
return upsertEntitlementSnapshot(operationCtx, store.db, record)
|
||||
}
|
||||
|
||||
func upsertEntitlementSnapshot(ctx context.Context, q queryer, record entitlement.CurrentSnapshot) error {
|
||||
stmt := pgtable.EntitlementSnapshots.INSERT(
|
||||
pgtable.EntitlementSnapshots.UserID,
|
||||
pgtable.EntitlementSnapshots.PlanCode,
|
||||
pgtable.EntitlementSnapshots.IsPaid,
|
||||
pgtable.EntitlementSnapshots.StartsAt,
|
||||
pgtable.EntitlementSnapshots.EndsAt,
|
||||
pgtable.EntitlementSnapshots.Source,
|
||||
pgtable.EntitlementSnapshots.ActorType,
|
||||
pgtable.EntitlementSnapshots.ActorID,
|
||||
pgtable.EntitlementSnapshots.ReasonCode,
|
||||
pgtable.EntitlementSnapshots.UpdatedAt,
|
||||
).VALUES(
|
||||
record.UserID.String(),
|
||||
string(record.PlanCode),
|
||||
record.IsPaid,
|
||||
record.StartsAt.UTC(),
|
||||
nullableTime(record.EndsAt),
|
||||
record.Source.String(),
|
||||
record.Actor.Type.String(),
|
||||
nullableActorID(record.Actor.ID),
|
||||
record.ReasonCode.String(),
|
||||
record.UpdatedAt.UTC(),
|
||||
).ON_CONFLICT(pgtable.EntitlementSnapshots.UserID).DO_UPDATE(
|
||||
pg.SET(
|
||||
pgtable.EntitlementSnapshots.PlanCode.SET(pgtable.EntitlementSnapshots.EXCLUDED.PlanCode),
|
||||
pgtable.EntitlementSnapshots.IsPaid.SET(pgtable.EntitlementSnapshots.EXCLUDED.IsPaid),
|
||||
pgtable.EntitlementSnapshots.StartsAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.StartsAt),
|
||||
pgtable.EntitlementSnapshots.EndsAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.EndsAt),
|
||||
pgtable.EntitlementSnapshots.Source.SET(pgtable.EntitlementSnapshots.EXCLUDED.Source),
|
||||
pgtable.EntitlementSnapshots.ActorType.SET(pgtable.EntitlementSnapshots.EXCLUDED.ActorType),
|
||||
pgtable.EntitlementSnapshots.ActorID.SET(pgtable.EntitlementSnapshots.EXCLUDED.ActorID),
|
||||
pgtable.EntitlementSnapshots.ReasonCode.SET(pgtable.EntitlementSnapshots.EXCLUDED.ReasonCode),
|
||||
pgtable.EntitlementSnapshots.UpdatedAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.UpdatedAt),
|
||||
),
|
||||
)
|
||||
|
||||
query, args := stmt.Sql()
|
||||
if _, err := q.ExecContext(ctx, query, args...); err != nil {
|
||||
return fmt.Errorf("upsert entitlement snapshot for %q in postgres: %w", record.UserID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanEntitlementSnapshotRow(row *sql.Row) (entitlement.CurrentSnapshot, error) {
|
||||
var (
|
||||
userID string
|
||||
planCode string
|
||||
isPaid bool
|
||||
startsAt time.Time
|
||||
endsAt *time.Time
|
||||
source string
|
||||
actorType string
|
||||
actorID *string
|
||||
reasonCode string
|
||||
updatedAt time.Time
|
||||
)
|
||||
err := row.Scan(
|
||||
&userID, &planCode, &isPaid,
|
||||
&startsAt, &endsAt,
|
||||
&source, &actorType, &actorID, &reasonCode,
|
||||
&updatedAt,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return entitlement.CurrentSnapshot{}, ports.ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return entitlement.CurrentSnapshot{}, err
|
||||
}
|
||||
record := entitlement.CurrentSnapshot{
|
||||
UserID: common.UserID(userID),
|
||||
PlanCode: entitlement.PlanCode(planCode),
|
||||
IsPaid: isPaid,
|
||||
StartsAt: startsAt.UTC(),
|
||||
EndsAt: timeFromNullable(endsAt),
|
||||
Source: common.Source(source),
|
||||
Actor: common.ActorRef{Type: common.ActorType(actorType)},
|
||||
ReasonCode: common.ReasonCode(reasonCode),
|
||||
UpdatedAt: updatedAt.UTC(),
|
||||
}
|
||||
if actorID != nil {
|
||||
record.Actor.ID = common.ActorID(*actorID)
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GrantEntitlement atomically closes the current free period, inserts the
|
||||
// new paid period, and replaces the snapshot.
|
||||
func (store *Store) GrantEntitlement(ctx context.Context, input ports.GrantEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("grant entitlement in postgres: %w", err)
|
||||
}
|
||||
return store.withTx(ctx, "grant entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
||||
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
|
||||
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if err := lockPeriodMatching(ctx, tx, input.ExpectedCurrentRecord); err != nil {
|
||||
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.ExpectedCurrentRecord.RecordID, err)
|
||||
}
|
||||
if err := updateEntitlementPeriodTx(ctx, tx, input.UpdatedCurrentRecord); err != nil {
|
||||
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.UpdatedCurrentRecord.RecordID, err)
|
||||
}
|
||||
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ExtendEntitlement atomically appends a new paid history segment and
|
||||
// replaces the snapshot.
|
||||
func (store *Store) ExtendEntitlement(ctx context.Context, input ports.ExtendEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("extend entitlement in postgres: %w", err)
|
||||
}
|
||||
return store.withTx(ctx, "extend entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
||||
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
|
||||
return fmt.Errorf("extend entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RevokeEntitlement atomically closes the current paid period, inserts a new
|
||||
// free period, and replaces the snapshot.
|
||||
func (store *Store) RevokeEntitlement(ctx context.Context, input ports.RevokeEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke entitlement in postgres: %w", err)
|
||||
}
|
||||
return store.withTx(ctx, "revoke entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
||||
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
|
||||
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if err := lockPeriodMatching(ctx, tx, input.ExpectedCurrentRecord); err != nil {
|
||||
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.ExpectedCurrentRecord.RecordID, err)
|
||||
}
|
||||
if err := updateEntitlementPeriodTx(ctx, tx, input.UpdatedCurrentRecord); err != nil {
|
||||
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.UpdatedCurrentRecord.RecordID, err)
|
||||
}
|
||||
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RepairExpiredEntitlement atomically replaces an expired finite paid
|
||||
// snapshot with a materialised free state.
|
||||
func (store *Store) RepairExpiredEntitlement(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("repair expired entitlement in postgres: %w", err)
|
||||
}
|
||||
return store.withTx(ctx, "repair expired entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
||||
if err := lockSnapshotMatching(ctx, tx, input.ExpectedExpiredSnapshot); err != nil {
|
||||
return fmt.Errorf("repair expired entitlement for %q in postgres: %w", input.ExpectedExpiredSnapshot.UserID, err)
|
||||
}
|
||||
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// lockSnapshotMatching loads the current snapshot under FOR UPDATE and
|
||||
// verifies it matches expected. Mismatches surface as ports.ErrConflict so
|
||||
// optimistic-replacement callers can retry.
|
||||
func lockSnapshotMatching(ctx context.Context, tx *sql.Tx, expected entitlement.CurrentSnapshot) error {
|
||||
stmt := pg.SELECT(entitlementSnapshotSelectColumns).
|
||||
FROM(pgtable.EntitlementSnapshots).
|
||||
WHERE(pgtable.EntitlementSnapshots.UserID.EQ(pg.String(expected.UserID.String()))).
|
||||
FOR(pg.UPDATE())
|
||||
|
||||
query, args := stmt.Sql()
|
||||
row := tx.QueryRowContext(ctx, query, args...)
|
||||
current, err := scanEntitlementSnapshotRow(row)
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return ports.ErrNotFound
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
if !snapshotsEqual(current, expected) {
|
||||
return ports.ErrConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lockPeriodMatching(ctx context.Context, tx *sql.Tx, expected entitlement.PeriodRecord) error {
|
||||
stmt := pg.SELECT(entitlementPeriodSelectColumns).
|
||||
FROM(pgtable.EntitlementRecords).
|
||||
WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
|
||||
FOR(pg.UPDATE())
|
||||
|
||||
query, args := stmt.Sql()
|
||||
row := tx.QueryRowContext(ctx, query, args...)
|
||||
current, err := scanEntitlementPeriodRow(row)
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return ports.ErrNotFound
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
if !periodsEqual(current, expected) {
|
||||
return ports.ErrConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateEntitlementPeriodTx(ctx context.Context, tx *sql.Tx, record entitlement.PeriodRecord) error {
|
||||
rows, err := updateEntitlementPeriod(ctx, tx, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return ports.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func snapshotsEqual(left entitlement.CurrentSnapshot, right entitlement.CurrentSnapshot) bool {
|
||||
if left.UserID != right.UserID ||
|
||||
left.PlanCode != right.PlanCode ||
|
||||
left.IsPaid != right.IsPaid ||
|
||||
left.Source != right.Source ||
|
||||
left.Actor != right.Actor ||
|
||||
left.ReasonCode != right.ReasonCode {
|
||||
return false
|
||||
}
|
||||
if !left.StartsAt.Equal(right.StartsAt) || !left.UpdatedAt.Equal(right.UpdatedAt) {
|
||||
return false
|
||||
}
|
||||
return optionalTimeEqual(left.EndsAt, right.EndsAt)
|
||||
}
|
||||
|
||||
func periodsEqual(left entitlement.PeriodRecord, right entitlement.PeriodRecord) bool {
|
||||
if left.RecordID != right.RecordID ||
|
||||
left.UserID != right.UserID ||
|
||||
left.PlanCode != right.PlanCode ||
|
||||
left.Source != right.Source ||
|
||||
left.Actor != right.Actor ||
|
||||
left.ReasonCode != right.ReasonCode ||
|
||||
left.ClosedBy != right.ClosedBy ||
|
||||
left.ClosedReasonCode != right.ClosedReasonCode {
|
||||
return false
|
||||
}
|
||||
if !left.StartsAt.Equal(right.StartsAt) || !left.CreatedAt.Equal(right.CreatedAt) {
|
||||
return false
|
||||
}
|
||||
if !optionalTimeEqual(left.EndsAt, right.EndsAt) {
|
||||
return false
|
||||
}
|
||||
return optionalTimeEqual(left.ClosedAt, right.ClosedAt)
|
||||
}
|
||||
|
||||
func optionalTimeEqual(left *time.Time, right *time.Time) bool {
|
||||
switch {
|
||||
case left == nil && right == nil:
|
||||
return true
|
||||
case left == nil || right == nil:
|
||||
return false
|
||||
default:
|
||||
return left.Equal(*right)
|
||||
}
|
||||
}
|
||||
|
||||
// EntitlementSnapshotStore adapts Store to the EntitlementSnapshotStore port.
|
||||
type EntitlementSnapshotStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// EntitlementSnapshots returns one adapter that exposes the entitlement-
|
||||
// snapshot store port over Store.
|
||||
func (store *Store) EntitlementSnapshots() *EntitlementSnapshotStore {
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
return &EntitlementSnapshotStore{store: store}
|
||||
}
|
||||
|
||||
// GetByUserID returns the current entitlement snapshot for userID.
|
||||
func (adapter *EntitlementSnapshotStore) GetByUserID(ctx context.Context, userID common.UserID) (entitlement.CurrentSnapshot, error) {
|
||||
return adapter.store.GetEntitlementByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// Put stores the current entitlement snapshot for record.UserID.
|
||||
func (adapter *EntitlementSnapshotStore) Put(ctx context.Context, record entitlement.CurrentSnapshot) error {
|
||||
return adapter.store.PutEntitlement(ctx, record)
|
||||
}
|
||||
|
||||
var _ ports.EntitlementSnapshotStore = (*EntitlementSnapshotStore)(nil)
|
||||
|
||||
// EntitlementHistoryStore adapts Store to the EntitlementHistoryStore port.
|
||||
type EntitlementHistoryStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// EntitlementHistory returns one adapter that exposes the entitlement
|
||||
// history store port over Store.
|
||||
func (store *Store) EntitlementHistory() *EntitlementHistoryStore {
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
return &EntitlementHistoryStore{store: store}
|
||||
}
|
||||
|
||||
// Create stores one new entitlement history record.
|
||||
func (adapter *EntitlementHistoryStore) Create(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
return adapter.store.CreateEntitlementRecord(ctx, record)
|
||||
}
|
||||
|
||||
// GetByRecordID returns the entitlement history record identified by
|
||||
// recordID.
|
||||
func (adapter *EntitlementHistoryStore) GetByRecordID(ctx context.Context, recordID entitlement.EntitlementRecordID) (entitlement.PeriodRecord, error) {
|
||||
return adapter.store.GetEntitlementRecordByID(ctx, recordID)
|
||||
}
|
||||
|
||||
// ListByUserID returns every entitlement history record owned by userID.
|
||||
func (adapter *EntitlementHistoryStore) ListByUserID(ctx context.Context, userID common.UserID) ([]entitlement.PeriodRecord, error) {
|
||||
return adapter.store.ListEntitlementRecordsByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// Update replaces one stored entitlement history record.
|
||||
func (adapter *EntitlementHistoryStore) Update(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
return adapter.store.UpdateEntitlementRecord(ctx, record)
|
||||
}
|
||||
|
||||
var _ ports.EntitlementHistoryStore = (*EntitlementHistoryStore)(nil)
|
||||
|
||||
// EntitlementLifecycleStore adapts Store to the EntitlementLifecycleStore
|
||||
// port.
|
||||
type EntitlementLifecycleStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// EntitlementLifecycle returns one adapter that exposes the entitlement
|
||||
// lifecycle store port over Store.
|
||||
func (store *Store) EntitlementLifecycle() *EntitlementLifecycleStore {
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
return &EntitlementLifecycleStore{store: store}
|
||||
}
|
||||
|
||||
// Grant atomically closes the current free period and starts a new paid
|
||||
// period.
|
||||
func (adapter *EntitlementLifecycleStore) Grant(ctx context.Context, input ports.GrantEntitlementInput) error {
|
||||
return adapter.store.GrantEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
// Extend appends a paid history segment.
|
||||
func (adapter *EntitlementLifecycleStore) Extend(ctx context.Context, input ports.ExtendEntitlementInput) error {
|
||||
return adapter.store.ExtendEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
// Revoke closes the current paid period and starts a fresh free period.
|
||||
func (adapter *EntitlementLifecycleStore) Revoke(ctx context.Context, input ports.RevokeEntitlementInput) error {
|
||||
return adapter.store.RevokeEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
// RepairExpired replaces an expired finite paid snapshot with a free state.
|
||||
func (adapter *EntitlementLifecycleStore) RepairExpired(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
|
||||
return adapter.store.RepairExpiredEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
var _ ports.EntitlementLifecycleStore = (*EntitlementLifecycleStore)(nil)
|
||||
Reference in New Issue
Block a user