730 lines
26 KiB
Go
730 lines
26 KiB
Go
package userstore
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
|
|
"galaxy/user/internal/domain/common"
|
|
"galaxy/user/internal/domain/entitlement"
|
|
"galaxy/user/internal/ports"
|
|
|
|
pg "github.com/go-jet/jet/v2/postgres"
|
|
)
|
|
|
|
// entitlementPeriodSelectColumns is the canonical SELECT list for
|
|
// entitlement_records, matching scanEntitlementPeriod's column order.
|
|
var entitlementPeriodSelectColumns = pg.ColumnList{
|
|
pgtable.EntitlementRecords.RecordID,
|
|
pgtable.EntitlementRecords.UserID,
|
|
pgtable.EntitlementRecords.PlanCode,
|
|
pgtable.EntitlementRecords.Source,
|
|
pgtable.EntitlementRecords.ActorType,
|
|
pgtable.EntitlementRecords.ActorID,
|
|
pgtable.EntitlementRecords.ReasonCode,
|
|
pgtable.EntitlementRecords.StartsAt,
|
|
pgtable.EntitlementRecords.EndsAt,
|
|
pgtable.EntitlementRecords.CreatedAt,
|
|
pgtable.EntitlementRecords.ClosedAt,
|
|
pgtable.EntitlementRecords.ClosedByType,
|
|
pgtable.EntitlementRecords.ClosedByID,
|
|
pgtable.EntitlementRecords.ClosedReasonCode,
|
|
}
|
|
|
|
// entitlementSnapshotSelectColumns is the canonical SELECT list for
|
|
// entitlement_snapshots, matching scanEntitlementSnapshotRow's column order.
|
|
var entitlementSnapshotSelectColumns = pg.ColumnList{
|
|
pgtable.EntitlementSnapshots.UserID,
|
|
pgtable.EntitlementSnapshots.PlanCode,
|
|
pgtable.EntitlementSnapshots.IsPaid,
|
|
pgtable.EntitlementSnapshots.StartsAt,
|
|
pgtable.EntitlementSnapshots.EndsAt,
|
|
pgtable.EntitlementSnapshots.Source,
|
|
pgtable.EntitlementSnapshots.ActorType,
|
|
pgtable.EntitlementSnapshots.ActorID,
|
|
pgtable.EntitlementSnapshots.ReasonCode,
|
|
pgtable.EntitlementSnapshots.UpdatedAt,
|
|
}
|
|
|
|
// CreateEntitlementRecord stores one new entitlement period history record.
|
|
// The unique key is record_id; a duplicate record_id returns
|
|
// ports.ErrConflict.
|
|
func (store *Store) CreateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("create entitlement record in postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "create entitlement record in postgres")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
return insertEntitlementPeriod(operationCtx, store.db, record)
|
|
}
|
|
|
|
// GetEntitlementRecordByID returns the entitlement period record identified
|
|
// by recordID.
|
|
func (store *Store) GetEntitlementRecordByID(ctx context.Context, recordID entitlement.EntitlementRecordID) (entitlement.PeriodRecord, error) {
|
|
if err := recordID.Validate(); err != nil {
|
|
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record from postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement record from postgres")
|
|
if err != nil {
|
|
return entitlement.PeriodRecord{}, err
|
|
}
|
|
defer cancel()
|
|
|
|
stmt := pg.SELECT(entitlementPeriodSelectColumns).
|
|
FROM(pgtable.EntitlementRecords).
|
|
WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(recordID.String())))
|
|
|
|
query, args := stmt.Sql()
|
|
row := store.db.QueryRowContext(operationCtx, query, args...)
|
|
record, err := scanEntitlementPeriodRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record %q from postgres: %w", recordID, ports.ErrNotFound)
|
|
case err != nil:
|
|
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record %q from postgres: %w", recordID, err)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// ListEntitlementRecordsByUserID returns every entitlement period record
|
|
// owned by userID, ordered by created_at ascending so historical replay is
|
|
// deterministic.
|
|
func (store *Store) ListEntitlementRecordsByUserID(ctx context.Context, userID common.UserID) ([]entitlement.PeriodRecord, error) {
|
|
if err := userID.Validate(); err != nil {
|
|
return nil, fmt.Errorf("list entitlement records from postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "list entitlement records from postgres")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer cancel()
|
|
|
|
stmt := pg.SELECT(entitlementPeriodSelectColumns).
|
|
FROM(pgtable.EntitlementRecords).
|
|
WHERE(pgtable.EntitlementRecords.UserID.EQ(pg.String(userID.String()))).
|
|
ORDER_BY(pgtable.EntitlementRecords.CreatedAt.ASC(), pgtable.EntitlementRecords.RecordID.ASC())
|
|
|
|
query, args := stmt.Sql()
|
|
rows, err := store.db.QueryContext(operationCtx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
out := make([]entitlement.PeriodRecord, 0)
|
|
for rows.Next() {
|
|
record, err := scanEntitlementPeriodRows(rows)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
|
|
}
|
|
out = append(out, record)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// UpdateEntitlementRecord replaces one stored entitlement period record. The
|
|
// statement matches by record_id; ports.ErrNotFound is returned when the
|
|
// record does not exist.
|
|
func (store *Store) UpdateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("update entitlement record in postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "update entitlement record in postgres")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
|
|
rows, err := updateEntitlementPeriod(operationCtx, store.db, record)
|
|
if err != nil {
|
|
return fmt.Errorf("update entitlement record %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("update entitlement record %q in postgres: %w", record.RecordID, ports.ErrNotFound)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func updateEntitlementPeriod(ctx context.Context, q queryer, record entitlement.PeriodRecord) (int64, error) {
|
|
stmt := pgtable.EntitlementRecords.UPDATE(
|
|
pgtable.EntitlementRecords.PlanCode,
|
|
pgtable.EntitlementRecords.Source,
|
|
pgtable.EntitlementRecords.ActorType,
|
|
pgtable.EntitlementRecords.ActorID,
|
|
pgtable.EntitlementRecords.ReasonCode,
|
|
pgtable.EntitlementRecords.StartsAt,
|
|
pgtable.EntitlementRecords.EndsAt,
|
|
pgtable.EntitlementRecords.CreatedAt,
|
|
pgtable.EntitlementRecords.ClosedAt,
|
|
pgtable.EntitlementRecords.ClosedByType,
|
|
pgtable.EntitlementRecords.ClosedByID,
|
|
pgtable.EntitlementRecords.ClosedReasonCode,
|
|
).SET(
|
|
string(record.PlanCode),
|
|
record.Source.String(),
|
|
record.Actor.Type.String(),
|
|
nullableActorID(record.Actor.ID),
|
|
record.ReasonCode.String(),
|
|
record.StartsAt.UTC(),
|
|
nullableTime(record.EndsAt),
|
|
record.CreatedAt.UTC(),
|
|
nullableTime(record.ClosedAt),
|
|
nullableActorType(record.ClosedBy.Type),
|
|
nullableActorID(record.ClosedBy.ID),
|
|
nullableReasonCode(record.ClosedReasonCode),
|
|
).WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(record.RecordID.String())))
|
|
|
|
query, args := stmt.Sql()
|
|
res, err := q.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return res.RowsAffected()
|
|
}
|
|
|
|
func insertEntitlementPeriod(ctx context.Context, q queryer, record entitlement.PeriodRecord) error {
|
|
stmt := pgtable.EntitlementRecords.INSERT(
|
|
pgtable.EntitlementRecords.RecordID,
|
|
pgtable.EntitlementRecords.UserID,
|
|
pgtable.EntitlementRecords.PlanCode,
|
|
pgtable.EntitlementRecords.Source,
|
|
pgtable.EntitlementRecords.ActorType,
|
|
pgtable.EntitlementRecords.ActorID,
|
|
pgtable.EntitlementRecords.ReasonCode,
|
|
pgtable.EntitlementRecords.StartsAt,
|
|
pgtable.EntitlementRecords.EndsAt,
|
|
pgtable.EntitlementRecords.CreatedAt,
|
|
pgtable.EntitlementRecords.ClosedAt,
|
|
pgtable.EntitlementRecords.ClosedByType,
|
|
pgtable.EntitlementRecords.ClosedByID,
|
|
pgtable.EntitlementRecords.ClosedReasonCode,
|
|
).VALUES(
|
|
record.RecordID.String(),
|
|
record.UserID.String(),
|
|
string(record.PlanCode),
|
|
record.Source.String(),
|
|
record.Actor.Type.String(),
|
|
nullableActorID(record.Actor.ID),
|
|
record.ReasonCode.String(),
|
|
record.StartsAt.UTC(),
|
|
nullableTime(record.EndsAt),
|
|
record.CreatedAt.UTC(),
|
|
nullableTime(record.ClosedAt),
|
|
nullableActorType(record.ClosedBy.Type),
|
|
nullableActorID(record.ClosedBy.ID),
|
|
nullableReasonCode(record.ClosedReasonCode),
|
|
)
|
|
|
|
query, args := stmt.Sql()
|
|
_, err := q.ExecContext(ctx, query, args...)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if isUniqueViolation(err) {
|
|
return fmt.Errorf("create entitlement record %q in postgres: %w", record.RecordID, ports.ErrConflict)
|
|
}
|
|
return fmt.Errorf("create entitlement record %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
|
|
// scannableRow abstracts *sql.Row and *sql.Rows so the row-scanner can be
|
|
// shared by single-row and iterating callers.
|
|
type scannableRow interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanEntitlementPeriodRow(row *sql.Row) (entitlement.PeriodRecord, error) {
|
|
record, err := scanEntitlementPeriod(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return entitlement.PeriodRecord{}, ports.ErrNotFound
|
|
}
|
|
return record, err
|
|
}
|
|
|
|
func scanEntitlementPeriodRows(rows *sql.Rows) (entitlement.PeriodRecord, error) {
|
|
return scanEntitlementPeriod(rows)
|
|
}
|
|
|
|
func scanEntitlementPeriod(row scannableRow) (entitlement.PeriodRecord, error) {
|
|
var (
|
|
recordID string
|
|
userID string
|
|
planCode string
|
|
source string
|
|
actorType string
|
|
actorID *string
|
|
reasonCode string
|
|
startsAt time.Time
|
|
endsAt *time.Time
|
|
createdAt time.Time
|
|
closedAt *time.Time
|
|
closedByType *string
|
|
closedByID *string
|
|
closedReason *string
|
|
)
|
|
if err := row.Scan(
|
|
&recordID, &userID, &planCode, &source,
|
|
&actorType, &actorID, &reasonCode,
|
|
&startsAt, &endsAt, &createdAt,
|
|
&closedAt, &closedByType, &closedByID, &closedReason,
|
|
); err != nil {
|
|
return entitlement.PeriodRecord{}, err
|
|
}
|
|
record := entitlement.PeriodRecord{
|
|
RecordID: entitlement.EntitlementRecordID(recordID),
|
|
UserID: common.UserID(userID),
|
|
PlanCode: entitlement.PlanCode(planCode),
|
|
Source: common.Source(source),
|
|
Actor: common.ActorRef{Type: common.ActorType(actorType)},
|
|
ReasonCode: common.ReasonCode(reasonCode),
|
|
StartsAt: startsAt.UTC(),
|
|
EndsAt: timeFromNullable(endsAt),
|
|
CreatedAt: createdAt.UTC(),
|
|
ClosedAt: timeFromNullable(closedAt),
|
|
}
|
|
if actorID != nil {
|
|
record.Actor.ID = common.ActorID(*actorID)
|
|
}
|
|
if closedByType != nil {
|
|
record.ClosedBy.Type = common.ActorType(*closedByType)
|
|
}
|
|
if closedByID != nil {
|
|
record.ClosedBy.ID = common.ActorID(*closedByID)
|
|
}
|
|
if closedReason != nil {
|
|
record.ClosedReasonCode = common.ReasonCode(*closedReason)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// GetEntitlementByUserID returns the current entitlement snapshot for userID.
|
|
func (store *Store) GetEntitlementByUserID(ctx context.Context, userID common.UserID) (entitlement.CurrentSnapshot, error) {
|
|
if err := userID.Validate(); err != nil {
|
|
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot from postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement snapshot from postgres")
|
|
if err != nil {
|
|
return entitlement.CurrentSnapshot{}, err
|
|
}
|
|
defer cancel()
|
|
|
|
stmt := pg.SELECT(entitlementSnapshotSelectColumns).
|
|
FROM(pgtable.EntitlementSnapshots).
|
|
WHERE(pgtable.EntitlementSnapshots.UserID.EQ(pg.String(userID.String())))
|
|
|
|
query, args := stmt.Sql()
|
|
row := store.db.QueryRowContext(operationCtx, query, args...)
|
|
record, err := scanEntitlementSnapshotRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot for %q from postgres: %w", userID, ports.ErrNotFound)
|
|
case err != nil:
|
|
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot for %q from postgres: %w", userID, err)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// PutEntitlement stores the current entitlement snapshot for record.UserID.
|
|
// It is an UPSERT so the runtime path can call it on creation and on
|
|
// replacement uniformly.
|
|
func (store *Store) PutEntitlement(ctx context.Context, record entitlement.CurrentSnapshot) error {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("put entitlement snapshot in postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "put entitlement snapshot in postgres")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
return upsertEntitlementSnapshot(operationCtx, store.db, record)
|
|
}
|
|
|
|
func upsertEntitlementSnapshot(ctx context.Context, q queryer, record entitlement.CurrentSnapshot) error {
|
|
stmt := pgtable.EntitlementSnapshots.INSERT(
|
|
pgtable.EntitlementSnapshots.UserID,
|
|
pgtable.EntitlementSnapshots.PlanCode,
|
|
pgtable.EntitlementSnapshots.IsPaid,
|
|
pgtable.EntitlementSnapshots.StartsAt,
|
|
pgtable.EntitlementSnapshots.EndsAt,
|
|
pgtable.EntitlementSnapshots.Source,
|
|
pgtable.EntitlementSnapshots.ActorType,
|
|
pgtable.EntitlementSnapshots.ActorID,
|
|
pgtable.EntitlementSnapshots.ReasonCode,
|
|
pgtable.EntitlementSnapshots.UpdatedAt,
|
|
).VALUES(
|
|
record.UserID.String(),
|
|
string(record.PlanCode),
|
|
record.IsPaid,
|
|
record.StartsAt.UTC(),
|
|
nullableTime(record.EndsAt),
|
|
record.Source.String(),
|
|
record.Actor.Type.String(),
|
|
nullableActorID(record.Actor.ID),
|
|
record.ReasonCode.String(),
|
|
record.UpdatedAt.UTC(),
|
|
).ON_CONFLICT(pgtable.EntitlementSnapshots.UserID).DO_UPDATE(
|
|
pg.SET(
|
|
pgtable.EntitlementSnapshots.PlanCode.SET(pgtable.EntitlementSnapshots.EXCLUDED.PlanCode),
|
|
pgtable.EntitlementSnapshots.IsPaid.SET(pgtable.EntitlementSnapshots.EXCLUDED.IsPaid),
|
|
pgtable.EntitlementSnapshots.StartsAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.StartsAt),
|
|
pgtable.EntitlementSnapshots.EndsAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.EndsAt),
|
|
pgtable.EntitlementSnapshots.Source.SET(pgtable.EntitlementSnapshots.EXCLUDED.Source),
|
|
pgtable.EntitlementSnapshots.ActorType.SET(pgtable.EntitlementSnapshots.EXCLUDED.ActorType),
|
|
pgtable.EntitlementSnapshots.ActorID.SET(pgtable.EntitlementSnapshots.EXCLUDED.ActorID),
|
|
pgtable.EntitlementSnapshots.ReasonCode.SET(pgtable.EntitlementSnapshots.EXCLUDED.ReasonCode),
|
|
pgtable.EntitlementSnapshots.UpdatedAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.UpdatedAt),
|
|
),
|
|
)
|
|
|
|
query, args := stmt.Sql()
|
|
if _, err := q.ExecContext(ctx, query, args...); err != nil {
|
|
return fmt.Errorf("upsert entitlement snapshot for %q in postgres: %w", record.UserID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanEntitlementSnapshotRow(row *sql.Row) (entitlement.CurrentSnapshot, error) {
|
|
var (
|
|
userID string
|
|
planCode string
|
|
isPaid bool
|
|
startsAt time.Time
|
|
endsAt *time.Time
|
|
source string
|
|
actorType string
|
|
actorID *string
|
|
reasonCode string
|
|
updatedAt time.Time
|
|
)
|
|
err := row.Scan(
|
|
&userID, &planCode, &isPaid,
|
|
&startsAt, &endsAt,
|
|
&source, &actorType, &actorID, &reasonCode,
|
|
&updatedAt,
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return entitlement.CurrentSnapshot{}, ports.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return entitlement.CurrentSnapshot{}, err
|
|
}
|
|
record := entitlement.CurrentSnapshot{
|
|
UserID: common.UserID(userID),
|
|
PlanCode: entitlement.PlanCode(planCode),
|
|
IsPaid: isPaid,
|
|
StartsAt: startsAt.UTC(),
|
|
EndsAt: timeFromNullable(endsAt),
|
|
Source: common.Source(source),
|
|
Actor: common.ActorRef{Type: common.ActorType(actorType)},
|
|
ReasonCode: common.ReasonCode(reasonCode),
|
|
UpdatedAt: updatedAt.UTC(),
|
|
}
|
|
if actorID != nil {
|
|
record.Actor.ID = common.ActorID(*actorID)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// GrantEntitlement atomically closes the current free period, inserts the
|
|
// new paid period, and replaces the snapshot.
|
|
func (store *Store) GrantEntitlement(ctx context.Context, input ports.GrantEntitlementInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("grant entitlement in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "grant entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
|
|
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
|
}
|
|
if err := lockPeriodMatching(ctx, tx, input.ExpectedCurrentRecord); err != nil {
|
|
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.ExpectedCurrentRecord.RecordID, err)
|
|
}
|
|
if err := updateEntitlementPeriodTx(ctx, tx, input.UpdatedCurrentRecord); err != nil {
|
|
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.UpdatedCurrentRecord.RecordID, err)
|
|
}
|
|
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
|
return err
|
|
}
|
|
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// ExtendEntitlement atomically appends a new paid history segment and
|
|
// replaces the snapshot.
|
|
func (store *Store) ExtendEntitlement(ctx context.Context, input ports.ExtendEntitlementInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("extend entitlement in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "extend entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
|
|
return fmt.Errorf("extend entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
|
}
|
|
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
|
return err
|
|
}
|
|
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// RevokeEntitlement atomically closes the current paid period, inserts a new
|
|
// free period, and replaces the snapshot.
|
|
func (store *Store) RevokeEntitlement(ctx context.Context, input ports.RevokeEntitlementInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("revoke entitlement in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "revoke entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
|
|
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
|
}
|
|
if err := lockPeriodMatching(ctx, tx, input.ExpectedCurrentRecord); err != nil {
|
|
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.ExpectedCurrentRecord.RecordID, err)
|
|
}
|
|
if err := updateEntitlementPeriodTx(ctx, tx, input.UpdatedCurrentRecord); err != nil {
|
|
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.UpdatedCurrentRecord.RecordID, err)
|
|
}
|
|
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
|
return err
|
|
}
|
|
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// RepairExpiredEntitlement atomically replaces an expired finite paid
|
|
// snapshot with a materialised free state.
|
|
func (store *Store) RepairExpiredEntitlement(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("repair expired entitlement in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "repair expired entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if err := lockSnapshotMatching(ctx, tx, input.ExpectedExpiredSnapshot); err != nil {
|
|
return fmt.Errorf("repair expired entitlement for %q in postgres: %w", input.ExpectedExpiredSnapshot.UserID, err)
|
|
}
|
|
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
|
|
return err
|
|
}
|
|
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// lockSnapshotMatching loads the current snapshot under FOR UPDATE and
|
|
// verifies it matches expected. Mismatches surface as ports.ErrConflict so
|
|
// optimistic-replacement callers can retry.
|
|
func lockSnapshotMatching(ctx context.Context, tx *sql.Tx, expected entitlement.CurrentSnapshot) error {
|
|
stmt := pg.SELECT(entitlementSnapshotSelectColumns).
|
|
FROM(pgtable.EntitlementSnapshots).
|
|
WHERE(pgtable.EntitlementSnapshots.UserID.EQ(pg.String(expected.UserID.String()))).
|
|
FOR(pg.UPDATE())
|
|
|
|
query, args := stmt.Sql()
|
|
row := tx.QueryRowContext(ctx, query, args...)
|
|
current, err := scanEntitlementSnapshotRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return ports.ErrNotFound
|
|
case err != nil:
|
|
return err
|
|
}
|
|
if !snapshotsEqual(current, expected) {
|
|
return ports.ErrConflict
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func lockPeriodMatching(ctx context.Context, tx *sql.Tx, expected entitlement.PeriodRecord) error {
|
|
stmt := pg.SELECT(entitlementPeriodSelectColumns).
|
|
FROM(pgtable.EntitlementRecords).
|
|
WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
|
|
FOR(pg.UPDATE())
|
|
|
|
query, args := stmt.Sql()
|
|
row := tx.QueryRowContext(ctx, query, args...)
|
|
current, err := scanEntitlementPeriodRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return ports.ErrNotFound
|
|
case err != nil:
|
|
return err
|
|
}
|
|
if !periodsEqual(current, expected) {
|
|
return ports.ErrConflict
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func updateEntitlementPeriodTx(ctx context.Context, tx *sql.Tx, record entitlement.PeriodRecord) error {
|
|
rows, err := updateEntitlementPeriod(ctx, tx, record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows == 0 {
|
|
return ports.ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func snapshotsEqual(left entitlement.CurrentSnapshot, right entitlement.CurrentSnapshot) bool {
|
|
if left.UserID != right.UserID ||
|
|
left.PlanCode != right.PlanCode ||
|
|
left.IsPaid != right.IsPaid ||
|
|
left.Source != right.Source ||
|
|
left.Actor != right.Actor ||
|
|
left.ReasonCode != right.ReasonCode {
|
|
return false
|
|
}
|
|
if !left.StartsAt.Equal(right.StartsAt) || !left.UpdatedAt.Equal(right.UpdatedAt) {
|
|
return false
|
|
}
|
|
return optionalTimeEqual(left.EndsAt, right.EndsAt)
|
|
}
|
|
|
|
func periodsEqual(left entitlement.PeriodRecord, right entitlement.PeriodRecord) bool {
|
|
if left.RecordID != right.RecordID ||
|
|
left.UserID != right.UserID ||
|
|
left.PlanCode != right.PlanCode ||
|
|
left.Source != right.Source ||
|
|
left.Actor != right.Actor ||
|
|
left.ReasonCode != right.ReasonCode ||
|
|
left.ClosedBy != right.ClosedBy ||
|
|
left.ClosedReasonCode != right.ClosedReasonCode {
|
|
return false
|
|
}
|
|
if !left.StartsAt.Equal(right.StartsAt) || !left.CreatedAt.Equal(right.CreatedAt) {
|
|
return false
|
|
}
|
|
if !optionalTimeEqual(left.EndsAt, right.EndsAt) {
|
|
return false
|
|
}
|
|
return optionalTimeEqual(left.ClosedAt, right.ClosedAt)
|
|
}
|
|
|
|
func optionalTimeEqual(left *time.Time, right *time.Time) bool {
|
|
switch {
|
|
case left == nil && right == nil:
|
|
return true
|
|
case left == nil || right == nil:
|
|
return false
|
|
default:
|
|
return left.Equal(*right)
|
|
}
|
|
}
|
|
|
|
// EntitlementSnapshotStore adapts Store to the EntitlementSnapshotStore port.
|
|
type EntitlementSnapshotStore struct {
|
|
store *Store
|
|
}
|
|
|
|
// EntitlementSnapshots returns one adapter that exposes the entitlement-
|
|
// snapshot store port over Store.
|
|
func (store *Store) EntitlementSnapshots() *EntitlementSnapshotStore {
|
|
if store == nil {
|
|
return nil
|
|
}
|
|
return &EntitlementSnapshotStore{store: store}
|
|
}
|
|
|
|
// GetByUserID returns the current entitlement snapshot for userID.
|
|
func (adapter *EntitlementSnapshotStore) GetByUserID(ctx context.Context, userID common.UserID) (entitlement.CurrentSnapshot, error) {
|
|
return adapter.store.GetEntitlementByUserID(ctx, userID)
|
|
}
|
|
|
|
// Put stores the current entitlement snapshot for record.UserID.
|
|
func (adapter *EntitlementSnapshotStore) Put(ctx context.Context, record entitlement.CurrentSnapshot) error {
|
|
return adapter.store.PutEntitlement(ctx, record)
|
|
}
|
|
|
|
var _ ports.EntitlementSnapshotStore = (*EntitlementSnapshotStore)(nil)
|
|
|
|
// EntitlementHistoryStore adapts Store to the EntitlementHistoryStore port.
|
|
type EntitlementHistoryStore struct {
|
|
store *Store
|
|
}
|
|
|
|
// EntitlementHistory returns one adapter that exposes the entitlement
|
|
// history store port over Store.
|
|
func (store *Store) EntitlementHistory() *EntitlementHistoryStore {
|
|
if store == nil {
|
|
return nil
|
|
}
|
|
return &EntitlementHistoryStore{store: store}
|
|
}
|
|
|
|
// Create stores one new entitlement history record.
|
|
func (adapter *EntitlementHistoryStore) Create(ctx context.Context, record entitlement.PeriodRecord) error {
|
|
return adapter.store.CreateEntitlementRecord(ctx, record)
|
|
}
|
|
|
|
// GetByRecordID returns the entitlement history record identified by
|
|
// recordID.
|
|
func (adapter *EntitlementHistoryStore) GetByRecordID(ctx context.Context, recordID entitlement.EntitlementRecordID) (entitlement.PeriodRecord, error) {
|
|
return adapter.store.GetEntitlementRecordByID(ctx, recordID)
|
|
}
|
|
|
|
// ListByUserID returns every entitlement history record owned by userID.
|
|
func (adapter *EntitlementHistoryStore) ListByUserID(ctx context.Context, userID common.UserID) ([]entitlement.PeriodRecord, error) {
|
|
return adapter.store.ListEntitlementRecordsByUserID(ctx, userID)
|
|
}
|
|
|
|
// Update replaces one stored entitlement history record.
|
|
func (adapter *EntitlementHistoryStore) Update(ctx context.Context, record entitlement.PeriodRecord) error {
|
|
return adapter.store.UpdateEntitlementRecord(ctx, record)
|
|
}
|
|
|
|
var _ ports.EntitlementHistoryStore = (*EntitlementHistoryStore)(nil)
|
|
|
|
// EntitlementLifecycleStore adapts Store to the EntitlementLifecycleStore
|
|
// port.
|
|
type EntitlementLifecycleStore struct {
|
|
store *Store
|
|
}
|
|
|
|
// EntitlementLifecycle returns one adapter that exposes the entitlement
|
|
// lifecycle store port over Store.
|
|
func (store *Store) EntitlementLifecycle() *EntitlementLifecycleStore {
|
|
if store == nil {
|
|
return nil
|
|
}
|
|
return &EntitlementLifecycleStore{store: store}
|
|
}
|
|
|
|
// Grant atomically closes the current free period and starts a new paid
|
|
// period.
|
|
func (adapter *EntitlementLifecycleStore) Grant(ctx context.Context, input ports.GrantEntitlementInput) error {
|
|
return adapter.store.GrantEntitlement(ctx, input)
|
|
}
|
|
|
|
// Extend appends a paid history segment.
|
|
func (adapter *EntitlementLifecycleStore) Extend(ctx context.Context, input ports.ExtendEntitlementInput) error {
|
|
return adapter.store.ExtendEntitlement(ctx, input)
|
|
}
|
|
|
|
// Revoke closes the current paid period and starts a fresh free period.
|
|
func (adapter *EntitlementLifecycleStore) Revoke(ctx context.Context, input ports.RevokeEntitlementInput) error {
|
|
return adapter.store.RevokeEntitlement(ctx, input)
|
|
}
|
|
|
|
// RepairExpired replaces an expired finite paid snapshot with a free state.
|
|
func (adapter *EntitlementLifecycleStore) RepairExpired(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
|
|
return adapter.store.RepairExpiredEntitlement(ctx, input)
|
|
}
|
|
|
|
var _ ports.EntitlementLifecycleStore = (*EntitlementLifecycleStore)(nil)
|