Files
galaxy-game/user/internal/adapters/postgres/userstore/policy_store.go
T
2026-04-26 20:34:39 +02:00

871 lines
29 KiB
Go

package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
pg "github.com/go-jet/jet/v2/postgres"
)
// sanctionSelectColumns is the canonical SELECT list for sanction_records,
// matching scanSanction's column order.
var sanctionSelectColumns = pg.ColumnList{
pgtable.SanctionRecords.RecordID,
pgtable.SanctionRecords.UserID,
pgtable.SanctionRecords.SanctionCode,
pgtable.SanctionRecords.Scope,
pgtable.SanctionRecords.ReasonCode,
pgtable.SanctionRecords.ActorType,
pgtable.SanctionRecords.ActorID,
pgtable.SanctionRecords.AppliedAt,
pgtable.SanctionRecords.ExpiresAt,
pgtable.SanctionRecords.RemovedAt,
pgtable.SanctionRecords.RemovedByType,
pgtable.SanctionRecords.RemovedByID,
pgtable.SanctionRecords.RemovedReasonCode,
}
// limitSelectColumns is the canonical SELECT list for limit_records, matching
// scanLimit's column order.
var limitSelectColumns = pg.ColumnList{
pgtable.LimitRecords.RecordID,
pgtable.LimitRecords.UserID,
pgtable.LimitRecords.LimitCode,
pgtable.LimitRecords.Value,
pgtable.LimitRecords.ReasonCode,
pgtable.LimitRecords.ActorType,
pgtable.LimitRecords.ActorID,
pgtable.LimitRecords.AppliedAt,
pgtable.LimitRecords.ExpiresAt,
pgtable.LimitRecords.RemovedAt,
pgtable.LimitRecords.RemovedByType,
pgtable.LimitRecords.RemovedByID,
pgtable.LimitRecords.RemovedReasonCode,
}
// CreateSanction stores one new sanction history record.
func (store *Store) CreateSanction(ctx context.Context, record policy.SanctionRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create sanction in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "create sanction in postgres")
if err != nil {
return err
}
defer cancel()
return insertSanctionRecord(operationCtx, store.db, record)
}
func insertSanctionRecord(ctx context.Context, q queryer, record policy.SanctionRecord) error {
stmt := pgtable.SanctionRecords.INSERT(
pgtable.SanctionRecords.RecordID,
pgtable.SanctionRecords.UserID,
pgtable.SanctionRecords.SanctionCode,
pgtable.SanctionRecords.Scope,
pgtable.SanctionRecords.ReasonCode,
pgtable.SanctionRecords.ActorType,
pgtable.SanctionRecords.ActorID,
pgtable.SanctionRecords.AppliedAt,
pgtable.SanctionRecords.ExpiresAt,
pgtable.SanctionRecords.RemovedAt,
pgtable.SanctionRecords.RemovedByType,
pgtable.SanctionRecords.RemovedByID,
pgtable.SanctionRecords.RemovedReasonCode,
).VALUES(
record.RecordID.String(),
record.UserID.String(),
string(record.SanctionCode),
record.Scope.String(),
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
)
query, args := stmt.Sql()
_, err := q.ExecContext(ctx, query, args...)
if err == nil {
return nil
}
if isUniqueViolation(err) {
return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, ports.ErrConflict)
}
return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, err)
}
// GetSanctionByRecordID returns the sanction history record identified by
// recordID.
func (store *Store) GetSanctionByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) {
if err := recordID.Validate(); err != nil {
return policy.SanctionRecord{}, fmt.Errorf("get sanction from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get sanction from postgres")
if err != nil {
return policy.SanctionRecord{}, err
}
defer cancel()
stmt := pg.SELECT(sanctionSelectColumns).
FROM(pgtable.SanctionRecords).
WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(recordID.String())))
query, args := stmt.Sql()
row := store.db.QueryRowContext(operationCtx, query, args...)
record, err := scanSanctionRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, ports.ErrNotFound)
case err != nil:
return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, err)
}
return record, nil
}
// ListSanctionsByUserID returns every sanction history record owned by
// userID, ordered by applied_at ascending.
func (store *Store) ListSanctionsByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list sanctions from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list sanctions from postgres")
if err != nil {
return nil, err
}
defer cancel()
stmt := pg.SELECT(sanctionSelectColumns).
FROM(pgtable.SanctionRecords).
WHERE(pgtable.SanctionRecords.UserID.EQ(pg.String(userID.String()))).
ORDER_BY(pgtable.SanctionRecords.AppliedAt.ASC(), pgtable.SanctionRecords.RecordID.ASC())
query, args := stmt.Sql()
rows, err := store.db.QueryContext(operationCtx, query, args...)
if err != nil {
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
}
defer func() { _ = rows.Close() }()
out := make([]policy.SanctionRecord, 0)
for rows.Next() {
record, err := scanSanction(rows)
if err != nil {
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
}
out = append(out, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
}
return out, nil
}
// UpdateSanction replaces one stored sanction history record. The matched
// row is identified by record_id; ports.ErrNotFound is returned when no row
// matches.
func (store *Store) UpdateSanction(ctx context.Context, record policy.SanctionRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update sanction in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "update sanction in postgres")
if err != nil {
return err
}
defer cancel()
return updateSanctionRecordTx(operationCtx, store.db, record)
}
func updateSanctionRecordTx(ctx context.Context, q queryer, record policy.SanctionRecord) error {
stmt := pgtable.SanctionRecords.UPDATE(
pgtable.SanctionRecords.UserID,
pgtable.SanctionRecords.SanctionCode,
pgtable.SanctionRecords.Scope,
pgtable.SanctionRecords.ReasonCode,
pgtable.SanctionRecords.ActorType,
pgtable.SanctionRecords.ActorID,
pgtable.SanctionRecords.AppliedAt,
pgtable.SanctionRecords.ExpiresAt,
pgtable.SanctionRecords.RemovedAt,
pgtable.SanctionRecords.RemovedByType,
pgtable.SanctionRecords.RemovedByID,
pgtable.SanctionRecords.RemovedReasonCode,
).SET(
record.UserID.String(),
string(record.SanctionCode),
record.Scope.String(),
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
).WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(record.RecordID.String())))
query, args := stmt.Sql()
res, err := q.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, ports.ErrNotFound)
}
return nil
}
func scanSanctionRow(row *sql.Row) (policy.SanctionRecord, error) {
record, err := scanSanction(row)
if errors.Is(err, sql.ErrNoRows) {
return policy.SanctionRecord{}, ports.ErrNotFound
}
return record, err
}
func scanSanction(row scannableRow) (policy.SanctionRecord, error) {
var (
recordID string
userID string
code string
scope string
reason string
actorType string
actorID *string
appliedAt time.Time
expiresAt *time.Time
removedAt *time.Time
rmByType *string
rmByID *string
rmReason *string
)
if err := row.Scan(
&recordID, &userID, &code, &scope, &reason,
&actorType, &actorID, &appliedAt,
&expiresAt, &removedAt,
&rmByType, &rmByID, &rmReason,
); err != nil {
return policy.SanctionRecord{}, err
}
record := policy.SanctionRecord{
RecordID: policy.SanctionRecordID(recordID),
UserID: common.UserID(userID),
SanctionCode: policy.SanctionCode(code),
Scope: common.Scope(scope),
ReasonCode: common.ReasonCode(reason),
Actor: common.ActorRef{Type: common.ActorType(actorType)},
AppliedAt: appliedAt.UTC(),
ExpiresAt: timeFromNullable(expiresAt),
RemovedAt: timeFromNullable(removedAt),
}
if actorID != nil {
record.Actor.ID = common.ActorID(*actorID)
}
if rmByType != nil {
record.RemovedBy.Type = common.ActorType(*rmByType)
}
if rmByID != nil {
record.RemovedBy.ID = common.ActorID(*rmByID)
}
if rmReason != nil {
record.RemovedReasonCode = common.ReasonCode(*rmReason)
}
return record, nil
}
// CreateLimit stores one new limit history record.
func (store *Store) CreateLimit(ctx context.Context, record policy.LimitRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create limit in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "create limit in postgres")
if err != nil {
return err
}
defer cancel()
return insertLimitRecord(operationCtx, store.db, record)
}
func insertLimitRecord(ctx context.Context, q queryer, record policy.LimitRecord) error {
stmt := pgtable.LimitRecords.INSERT(
pgtable.LimitRecords.RecordID,
pgtable.LimitRecords.UserID,
pgtable.LimitRecords.LimitCode,
pgtable.LimitRecords.Value,
pgtable.LimitRecords.ReasonCode,
pgtable.LimitRecords.ActorType,
pgtable.LimitRecords.ActorID,
pgtable.LimitRecords.AppliedAt,
pgtable.LimitRecords.ExpiresAt,
pgtable.LimitRecords.RemovedAt,
pgtable.LimitRecords.RemovedByType,
pgtable.LimitRecords.RemovedByID,
pgtable.LimitRecords.RemovedReasonCode,
).VALUES(
record.RecordID.String(),
record.UserID.String(),
string(record.LimitCode),
record.Value,
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
)
query, args := stmt.Sql()
_, err := q.ExecContext(ctx, query, args...)
if err == nil {
return nil
}
if isUniqueViolation(err) {
return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, ports.ErrConflict)
}
return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, err)
}
// GetLimitByRecordID returns the limit history record identified by recordID.
func (store *Store) GetLimitByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) {
if err := recordID.Validate(); err != nil {
return policy.LimitRecord{}, fmt.Errorf("get limit from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get limit from postgres")
if err != nil {
return policy.LimitRecord{}, err
}
defer cancel()
stmt := pg.SELECT(limitSelectColumns).
FROM(pgtable.LimitRecords).
WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(recordID.String())))
query, args := stmt.Sql()
row := store.db.QueryRowContext(operationCtx, query, args...)
record, err := scanLimitRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, ports.ErrNotFound)
case err != nil:
return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, err)
}
return record, nil
}
// ListLimitsByUserID returns every limit history record owned by userID,
// ordered by applied_at ascending.
func (store *Store) ListLimitsByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list limits from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list limits from postgres")
if err != nil {
return nil, err
}
defer cancel()
stmt := pg.SELECT(limitSelectColumns).
FROM(pgtable.LimitRecords).
WHERE(pgtable.LimitRecords.UserID.EQ(pg.String(userID.String()))).
ORDER_BY(pgtable.LimitRecords.AppliedAt.ASC(), pgtable.LimitRecords.RecordID.ASC())
query, args := stmt.Sql()
rows, err := store.db.QueryContext(operationCtx, query, args...)
if err != nil {
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
}
defer func() { _ = rows.Close() }()
out := make([]policy.LimitRecord, 0)
for rows.Next() {
record, err := scanLimit(rows)
if err != nil {
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
}
out = append(out, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
}
return out, nil
}
// UpdateLimit replaces one stored limit history record.
func (store *Store) UpdateLimit(ctx context.Context, record policy.LimitRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update limit in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "update limit in postgres")
if err != nil {
return err
}
defer cancel()
return updateLimitRecordTx(operationCtx, store.db, record)
}
func updateLimitRecordTx(ctx context.Context, q queryer, record policy.LimitRecord) error {
stmt := pgtable.LimitRecords.UPDATE(
pgtable.LimitRecords.UserID,
pgtable.LimitRecords.LimitCode,
pgtable.LimitRecords.Value,
pgtable.LimitRecords.ReasonCode,
pgtable.LimitRecords.ActorType,
pgtable.LimitRecords.ActorID,
pgtable.LimitRecords.AppliedAt,
pgtable.LimitRecords.ExpiresAt,
pgtable.LimitRecords.RemovedAt,
pgtable.LimitRecords.RemovedByType,
pgtable.LimitRecords.RemovedByID,
pgtable.LimitRecords.RemovedReasonCode,
).SET(
record.UserID.String(),
string(record.LimitCode),
record.Value,
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
).WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(record.RecordID.String())))
query, args := stmt.Sql()
res, err := q.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, ports.ErrNotFound)
}
return nil
}
func scanLimitRow(row *sql.Row) (policy.LimitRecord, error) {
record, err := scanLimit(row)
if errors.Is(err, sql.ErrNoRows) {
return policy.LimitRecord{}, ports.ErrNotFound
}
return record, err
}
func scanLimit(row scannableRow) (policy.LimitRecord, error) {
var (
recordID string
userID string
code string
value int
reason string
actorType string
actorID *string
appliedAt time.Time
expiresAt *time.Time
removedAt *time.Time
rmByType *string
rmByID *string
rmReason *string
)
if err := row.Scan(
&recordID, &userID, &code, &value, &reason,
&actorType, &actorID, &appliedAt,
&expiresAt, &removedAt,
&rmByType, &rmByID, &rmReason,
); err != nil {
return policy.LimitRecord{}, err
}
record := policy.LimitRecord{
RecordID: policy.LimitRecordID(recordID),
UserID: common.UserID(userID),
LimitCode: policy.LimitCode(code),
Value: value,
ReasonCode: common.ReasonCode(reason),
Actor: common.ActorRef{Type: common.ActorType(actorType)},
AppliedAt: appliedAt.UTC(),
ExpiresAt: timeFromNullable(expiresAt),
RemovedAt: timeFromNullable(removedAt),
}
if actorID != nil {
record.Actor.ID = common.ActorID(*actorID)
}
if rmByType != nil {
record.RemovedBy.Type = common.ActorType(*rmByType)
}
if rmByID != nil {
record.RemovedBy.ID = common.ActorID(*rmByID)
}
if rmReason != nil {
record.RemovedReasonCode = common.ReasonCode(*rmReason)
}
return record, nil
}
// ApplySanction inserts the new sanction history row and points
// sanction_active at it. Re-applying the same code while another active
// record exists returns ports.ErrConflict.
func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("apply sanction in postgres: %w", err)
}
return store.withTx(ctx, "apply sanction in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := insertSanctionRecord(ctx, tx, input.NewRecord); err != nil {
return err
}
stmt := pgtable.SanctionActive.INSERT(
pgtable.SanctionActive.UserID,
pgtable.SanctionActive.SanctionCode,
pgtable.SanctionActive.RecordID,
).VALUES(
input.NewRecord.UserID.String(),
string(input.NewRecord.SanctionCode),
input.NewRecord.RecordID.String(),
)
query, args := stmt.Sql()
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
if isUniqueViolation(err) {
return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict)
}
return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, err)
}
return nil
})
}
// RemoveSanction updates the existing sanction record with remove metadata
// and clears the sanction_active row that pointed at it.
func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove sanction in postgres: %w", err)
}
return store.withTx(ctx, "remove sanction in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockSanctionMatching(ctx, tx, input.ExpectedActiveRecord); err != nil {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if err := updateSanctionRecordTx(ctx, tx, input.UpdatedRecord); err != nil {
return err
}
stmt := pgtable.SanctionActive.DELETE().
WHERE(pg.AND(
pgtable.SanctionActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())),
pgtable.SanctionActive.SanctionCode.EQ(pg.String(string(input.ExpectedActiveRecord.SanctionCode))),
pgtable.SanctionActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())),
))
query, args := stmt.Sql()
res, err := tx.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict)
}
return nil
})
}
// SetLimit creates a new active limit (or replaces one) for the user. When
// ExpectedActiveRecord is nil the call must succeed only if no active row
// exists for (user_id, limit_code); otherwise the existing record is
// updated with remove metadata and superseded by NewRecord.
func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("set limit in postgres: %w", err)
}
return store.withTx(ctx, "set limit in postgres", func(ctx context.Context, tx *sql.Tx) error {
if input.ExpectedActiveRecord != nil {
if err := lockLimitMatching(ctx, tx, *input.ExpectedActiveRecord); err != nil {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
}
if err := updateLimitRecordTx(ctx, tx, *input.UpdatedActiveRecord); err != nil {
return err
}
} else {
probe := pg.SELECT(pgtable.LimitActive.RecordID).
FROM(pgtable.LimitActive).
WHERE(pg.AND(
pgtable.LimitActive.UserID.EQ(pg.String(input.NewRecord.UserID.String())),
pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.NewRecord.LimitCode))),
)).
FOR(pg.UPDATE())
probeQuery, probeArgs := probe.Sql()
row := tx.QueryRowContext(ctx, probeQuery, probeArgs...)
var marker string
if err := row.Scan(&marker); err == nil {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict)
} else if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
}
}
if err := insertLimitRecord(ctx, tx, input.NewRecord); err != nil {
return err
}
upsert := pgtable.LimitActive.INSERT(
pgtable.LimitActive.UserID,
pgtable.LimitActive.LimitCode,
pgtable.LimitActive.RecordID,
pgtable.LimitActive.Value,
).VALUES(
input.NewRecord.UserID.String(),
string(input.NewRecord.LimitCode),
input.NewRecord.RecordID.String(),
input.NewRecord.Value,
).ON_CONFLICT(pgtable.LimitActive.UserID, pgtable.LimitActive.LimitCode).DO_UPDATE(
pg.SET(
pgtable.LimitActive.RecordID.SET(pgtable.LimitActive.EXCLUDED.RecordID),
pgtable.LimitActive.Value.SET(pgtable.LimitActive.EXCLUDED.Value),
),
)
upsertQuery, upsertArgs := upsert.Sql()
if _, err := tx.ExecContext(ctx, upsertQuery, upsertArgs...); err != nil {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
}
return nil
})
}
// RemoveLimit updates the limit record with remove metadata and removes the
// active row that referenced it.
func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove limit in postgres: %w", err)
}
return store.withTx(ctx, "remove limit in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockLimitMatching(ctx, tx, input.ExpectedActiveRecord); err != nil {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if err := updateLimitRecordTx(ctx, tx, input.UpdatedRecord); err != nil {
return err
}
stmt := pgtable.LimitActive.DELETE().
WHERE(pg.AND(
pgtable.LimitActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())),
pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.ExpectedActiveRecord.LimitCode))),
pgtable.LimitActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())),
))
query, args := stmt.Sql()
res, err := tx.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict)
}
return nil
})
}
func lockSanctionMatching(ctx context.Context, tx *sql.Tx, expected policy.SanctionRecord) error {
stmt := pg.SELECT(sanctionSelectColumns).
FROM(pgtable.SanctionRecords).
WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
FOR(pg.UPDATE())
query, args := stmt.Sql()
row := tx.QueryRowContext(ctx, query, args...)
current, err := scanSanctionRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return ports.ErrNotFound
case err != nil:
return err
}
if !sanctionsEqual(current, expected) {
return ports.ErrConflict
}
return nil
}
func lockLimitMatching(ctx context.Context, tx *sql.Tx, expected policy.LimitRecord) error {
stmt := pg.SELECT(limitSelectColumns).
FROM(pgtable.LimitRecords).
WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
FOR(pg.UPDATE())
query, args := stmt.Sql()
row := tx.QueryRowContext(ctx, query, args...)
current, err := scanLimitRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return ports.ErrNotFound
case err != nil:
return err
}
if !limitsEqual(current, expected) {
return ports.ErrConflict
}
return nil
}
func sanctionsEqual(left policy.SanctionRecord, right policy.SanctionRecord) bool {
if left.RecordID != right.RecordID ||
left.UserID != right.UserID ||
left.SanctionCode != right.SanctionCode ||
left.Scope != right.Scope ||
left.ReasonCode != right.ReasonCode ||
left.Actor != right.Actor ||
left.RemovedBy != right.RemovedBy ||
left.RemovedReasonCode != right.RemovedReasonCode {
return false
}
if !left.AppliedAt.Equal(right.AppliedAt) {
return false
}
if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) {
return false
}
return optionalTimeEqual(left.RemovedAt, right.RemovedAt)
}
func limitsEqual(left policy.LimitRecord, right policy.LimitRecord) bool {
if left.RecordID != right.RecordID ||
left.UserID != right.UserID ||
left.LimitCode != right.LimitCode ||
left.Value != right.Value ||
left.ReasonCode != right.ReasonCode ||
left.Actor != right.Actor ||
left.RemovedBy != right.RemovedBy ||
left.RemovedReasonCode != right.RemovedReasonCode {
return false
}
if !left.AppliedAt.Equal(right.AppliedAt) {
return false
}
if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) {
return false
}
return optionalTimeEqual(left.RemovedAt, right.RemovedAt)
}
// SanctionStore adapts Store to the SanctionStore port.
type SanctionStore struct{ store *Store }
// Sanctions returns one adapter that exposes the sanction store port.
func (store *Store) Sanctions() *SanctionStore {
if store == nil {
return nil
}
return &SanctionStore{store: store}
}
// Create stores one new sanction history record.
func (a *SanctionStore) Create(ctx context.Context, record policy.SanctionRecord) error {
return a.store.CreateSanction(ctx, record)
}
// GetByRecordID returns the sanction record identified by recordID.
func (a *SanctionStore) GetByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) {
return a.store.GetSanctionByRecordID(ctx, recordID)
}
// ListByUserID returns every sanction record owned by userID.
func (a *SanctionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) {
return a.store.ListSanctionsByUserID(ctx, userID)
}
// Update replaces one stored sanction record.
func (a *SanctionStore) Update(ctx context.Context, record policy.SanctionRecord) error {
return a.store.UpdateSanction(ctx, record)
}
var _ ports.SanctionStore = (*SanctionStore)(nil)
// LimitStore adapts Store to the LimitStore port.
type LimitStore struct{ store *Store }
// Limits returns one adapter that exposes the limit store port.
func (store *Store) Limits() *LimitStore {
if store == nil {
return nil
}
return &LimitStore{store: store}
}
// Create stores one new limit history record.
func (a *LimitStore) Create(ctx context.Context, record policy.LimitRecord) error {
return a.store.CreateLimit(ctx, record)
}
// GetByRecordID returns the limit record identified by recordID.
func (a *LimitStore) GetByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) {
return a.store.GetLimitByRecordID(ctx, recordID)
}
// ListByUserID returns every limit record owned by userID.
func (a *LimitStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) {
return a.store.ListLimitsByUserID(ctx, userID)
}
// Update replaces one stored limit record.
func (a *LimitStore) Update(ctx context.Context, record policy.LimitRecord) error {
return a.store.UpdateLimit(ctx, record)
}
var _ ports.LimitStore = (*LimitStore)(nil)
// PolicyLifecycleStore adapts Store to the PolicyLifecycleStore port.
type PolicyLifecycleStore struct{ store *Store }
// PolicyLifecycle returns one adapter that exposes the policy-lifecycle
// store port.
func (store *Store) PolicyLifecycle() *PolicyLifecycleStore {
if store == nil {
return nil
}
return &PolicyLifecycleStore{store: store}
}
// ApplySanction atomically creates one new active sanction record.
func (a *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
return a.store.ApplySanction(ctx, input)
}
// RemoveSanction atomically removes one active sanction record.
func (a *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
return a.store.RemoveSanction(ctx, input)
}
// SetLimit atomically creates or replaces one active limit record.
func (a *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
return a.store.SetLimit(ctx, input)
}
// RemoveLimit atomically removes one active limit record.
func (a *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
return a.store.RemoveLimit(ctx, input)
}
var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)