871 lines
29 KiB
Go
871 lines
29 KiB
Go
package userstore
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
|
|
"galaxy/user/internal/domain/common"
|
|
"galaxy/user/internal/domain/policy"
|
|
"galaxy/user/internal/ports"
|
|
|
|
pg "github.com/go-jet/jet/v2/postgres"
|
|
)
|
|
|
|
// sanctionSelectColumns is the canonical SELECT list for sanction_records,
|
|
// matching scanSanction's column order.
|
|
var sanctionSelectColumns = pg.ColumnList{
|
|
pgtable.SanctionRecords.RecordID,
|
|
pgtable.SanctionRecords.UserID,
|
|
pgtable.SanctionRecords.SanctionCode,
|
|
pgtable.SanctionRecords.Scope,
|
|
pgtable.SanctionRecords.ReasonCode,
|
|
pgtable.SanctionRecords.ActorType,
|
|
pgtable.SanctionRecords.ActorID,
|
|
pgtable.SanctionRecords.AppliedAt,
|
|
pgtable.SanctionRecords.ExpiresAt,
|
|
pgtable.SanctionRecords.RemovedAt,
|
|
pgtable.SanctionRecords.RemovedByType,
|
|
pgtable.SanctionRecords.RemovedByID,
|
|
pgtable.SanctionRecords.RemovedReasonCode,
|
|
}
|
|
|
|
// limitSelectColumns is the canonical SELECT list for limit_records, matching
|
|
// scanLimit's column order.
|
|
var limitSelectColumns = pg.ColumnList{
|
|
pgtable.LimitRecords.RecordID,
|
|
pgtable.LimitRecords.UserID,
|
|
pgtable.LimitRecords.LimitCode,
|
|
pgtable.LimitRecords.Value,
|
|
pgtable.LimitRecords.ReasonCode,
|
|
pgtable.LimitRecords.ActorType,
|
|
pgtable.LimitRecords.ActorID,
|
|
pgtable.LimitRecords.AppliedAt,
|
|
pgtable.LimitRecords.ExpiresAt,
|
|
pgtable.LimitRecords.RemovedAt,
|
|
pgtable.LimitRecords.RemovedByType,
|
|
pgtable.LimitRecords.RemovedByID,
|
|
pgtable.LimitRecords.RemovedReasonCode,
|
|
}
|
|
|
|
// CreateSanction stores one new sanction history record.
|
|
func (store *Store) CreateSanction(ctx context.Context, record policy.SanctionRecord) error {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("create sanction in postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "create sanction in postgres")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
return insertSanctionRecord(operationCtx, store.db, record)
|
|
}
|
|
|
|
func insertSanctionRecord(ctx context.Context, q queryer, record policy.SanctionRecord) error {
|
|
stmt := pgtable.SanctionRecords.INSERT(
|
|
pgtable.SanctionRecords.RecordID,
|
|
pgtable.SanctionRecords.UserID,
|
|
pgtable.SanctionRecords.SanctionCode,
|
|
pgtable.SanctionRecords.Scope,
|
|
pgtable.SanctionRecords.ReasonCode,
|
|
pgtable.SanctionRecords.ActorType,
|
|
pgtable.SanctionRecords.ActorID,
|
|
pgtable.SanctionRecords.AppliedAt,
|
|
pgtable.SanctionRecords.ExpiresAt,
|
|
pgtable.SanctionRecords.RemovedAt,
|
|
pgtable.SanctionRecords.RemovedByType,
|
|
pgtable.SanctionRecords.RemovedByID,
|
|
pgtable.SanctionRecords.RemovedReasonCode,
|
|
).VALUES(
|
|
record.RecordID.String(),
|
|
record.UserID.String(),
|
|
string(record.SanctionCode),
|
|
record.Scope.String(),
|
|
record.ReasonCode.String(),
|
|
record.Actor.Type.String(),
|
|
nullableActorID(record.Actor.ID),
|
|
record.AppliedAt.UTC(),
|
|
nullableTime(record.ExpiresAt),
|
|
nullableTime(record.RemovedAt),
|
|
nullableActorType(record.RemovedBy.Type),
|
|
nullableActorID(record.RemovedBy.ID),
|
|
nullableReasonCode(record.RemovedReasonCode),
|
|
)
|
|
|
|
query, args := stmt.Sql()
|
|
_, err := q.ExecContext(ctx, query, args...)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if isUniqueViolation(err) {
|
|
return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, ports.ErrConflict)
|
|
}
|
|
return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
|
|
// GetSanctionByRecordID returns the sanction history record identified by
|
|
// recordID.
|
|
func (store *Store) GetSanctionByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) {
|
|
if err := recordID.Validate(); err != nil {
|
|
return policy.SanctionRecord{}, fmt.Errorf("get sanction from postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "get sanction from postgres")
|
|
if err != nil {
|
|
return policy.SanctionRecord{}, err
|
|
}
|
|
defer cancel()
|
|
|
|
stmt := pg.SELECT(sanctionSelectColumns).
|
|
FROM(pgtable.SanctionRecords).
|
|
WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(recordID.String())))
|
|
|
|
query, args := stmt.Sql()
|
|
row := store.db.QueryRowContext(operationCtx, query, args...)
|
|
record, err := scanSanctionRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, ports.ErrNotFound)
|
|
case err != nil:
|
|
return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, err)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// ListSanctionsByUserID returns every sanction history record owned by
|
|
// userID, ordered by applied_at ascending.
|
|
func (store *Store) ListSanctionsByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) {
|
|
if err := userID.Validate(); err != nil {
|
|
return nil, fmt.Errorf("list sanctions from postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "list sanctions from postgres")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer cancel()
|
|
|
|
stmt := pg.SELECT(sanctionSelectColumns).
|
|
FROM(pgtable.SanctionRecords).
|
|
WHERE(pgtable.SanctionRecords.UserID.EQ(pg.String(userID.String()))).
|
|
ORDER_BY(pgtable.SanctionRecords.AppliedAt.ASC(), pgtable.SanctionRecords.RecordID.ASC())
|
|
|
|
query, args := stmt.Sql()
|
|
rows, err := store.db.QueryContext(operationCtx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
out := make([]policy.SanctionRecord, 0)
|
|
for rows.Next() {
|
|
record, err := scanSanction(rows)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
|
|
}
|
|
out = append(out, record)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// UpdateSanction replaces one stored sanction history record. The matched
|
|
// row is identified by record_id; ports.ErrNotFound is returned when no row
|
|
// matches.
|
|
func (store *Store) UpdateSanction(ctx context.Context, record policy.SanctionRecord) error {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("update sanction in postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "update sanction in postgres")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
return updateSanctionRecordTx(operationCtx, store.db, record)
|
|
}
|
|
|
|
func updateSanctionRecordTx(ctx context.Context, q queryer, record policy.SanctionRecord) error {
|
|
stmt := pgtable.SanctionRecords.UPDATE(
|
|
pgtable.SanctionRecords.UserID,
|
|
pgtable.SanctionRecords.SanctionCode,
|
|
pgtable.SanctionRecords.Scope,
|
|
pgtable.SanctionRecords.ReasonCode,
|
|
pgtable.SanctionRecords.ActorType,
|
|
pgtable.SanctionRecords.ActorID,
|
|
pgtable.SanctionRecords.AppliedAt,
|
|
pgtable.SanctionRecords.ExpiresAt,
|
|
pgtable.SanctionRecords.RemovedAt,
|
|
pgtable.SanctionRecords.RemovedByType,
|
|
pgtable.SanctionRecords.RemovedByID,
|
|
pgtable.SanctionRecords.RemovedReasonCode,
|
|
).SET(
|
|
record.UserID.String(),
|
|
string(record.SanctionCode),
|
|
record.Scope.String(),
|
|
record.ReasonCode.String(),
|
|
record.Actor.Type.String(),
|
|
nullableActorID(record.Actor.ID),
|
|
record.AppliedAt.UTC(),
|
|
nullableTime(record.ExpiresAt),
|
|
nullableTime(record.RemovedAt),
|
|
nullableActorType(record.RemovedBy.Type),
|
|
nullableActorID(record.RemovedBy.ID),
|
|
nullableReasonCode(record.RemovedReasonCode),
|
|
).WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(record.RecordID.String())))
|
|
|
|
query, args := stmt.Sql()
|
|
res, err := q.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, ports.ErrNotFound)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanSanctionRow(row *sql.Row) (policy.SanctionRecord, error) {
|
|
record, err := scanSanction(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return policy.SanctionRecord{}, ports.ErrNotFound
|
|
}
|
|
return record, err
|
|
}
|
|
|
|
func scanSanction(row scannableRow) (policy.SanctionRecord, error) {
|
|
var (
|
|
recordID string
|
|
userID string
|
|
code string
|
|
scope string
|
|
reason string
|
|
actorType string
|
|
actorID *string
|
|
appliedAt time.Time
|
|
expiresAt *time.Time
|
|
removedAt *time.Time
|
|
rmByType *string
|
|
rmByID *string
|
|
rmReason *string
|
|
)
|
|
if err := row.Scan(
|
|
&recordID, &userID, &code, &scope, &reason,
|
|
&actorType, &actorID, &appliedAt,
|
|
&expiresAt, &removedAt,
|
|
&rmByType, &rmByID, &rmReason,
|
|
); err != nil {
|
|
return policy.SanctionRecord{}, err
|
|
}
|
|
record := policy.SanctionRecord{
|
|
RecordID: policy.SanctionRecordID(recordID),
|
|
UserID: common.UserID(userID),
|
|
SanctionCode: policy.SanctionCode(code),
|
|
Scope: common.Scope(scope),
|
|
ReasonCode: common.ReasonCode(reason),
|
|
Actor: common.ActorRef{Type: common.ActorType(actorType)},
|
|
AppliedAt: appliedAt.UTC(),
|
|
ExpiresAt: timeFromNullable(expiresAt),
|
|
RemovedAt: timeFromNullable(removedAt),
|
|
}
|
|
if actorID != nil {
|
|
record.Actor.ID = common.ActorID(*actorID)
|
|
}
|
|
if rmByType != nil {
|
|
record.RemovedBy.Type = common.ActorType(*rmByType)
|
|
}
|
|
if rmByID != nil {
|
|
record.RemovedBy.ID = common.ActorID(*rmByID)
|
|
}
|
|
if rmReason != nil {
|
|
record.RemovedReasonCode = common.ReasonCode(*rmReason)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// CreateLimit stores one new limit history record.
|
|
func (store *Store) CreateLimit(ctx context.Context, record policy.LimitRecord) error {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("create limit in postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "create limit in postgres")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
return insertLimitRecord(operationCtx, store.db, record)
|
|
}
|
|
|
|
func insertLimitRecord(ctx context.Context, q queryer, record policy.LimitRecord) error {
|
|
stmt := pgtable.LimitRecords.INSERT(
|
|
pgtable.LimitRecords.RecordID,
|
|
pgtable.LimitRecords.UserID,
|
|
pgtable.LimitRecords.LimitCode,
|
|
pgtable.LimitRecords.Value,
|
|
pgtable.LimitRecords.ReasonCode,
|
|
pgtable.LimitRecords.ActorType,
|
|
pgtable.LimitRecords.ActorID,
|
|
pgtable.LimitRecords.AppliedAt,
|
|
pgtable.LimitRecords.ExpiresAt,
|
|
pgtable.LimitRecords.RemovedAt,
|
|
pgtable.LimitRecords.RemovedByType,
|
|
pgtable.LimitRecords.RemovedByID,
|
|
pgtable.LimitRecords.RemovedReasonCode,
|
|
).VALUES(
|
|
record.RecordID.String(),
|
|
record.UserID.String(),
|
|
string(record.LimitCode),
|
|
record.Value,
|
|
record.ReasonCode.String(),
|
|
record.Actor.Type.String(),
|
|
nullableActorID(record.Actor.ID),
|
|
record.AppliedAt.UTC(),
|
|
nullableTime(record.ExpiresAt),
|
|
nullableTime(record.RemovedAt),
|
|
nullableActorType(record.RemovedBy.Type),
|
|
nullableActorID(record.RemovedBy.ID),
|
|
nullableReasonCode(record.RemovedReasonCode),
|
|
)
|
|
|
|
query, args := stmt.Sql()
|
|
_, err := q.ExecContext(ctx, query, args...)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if isUniqueViolation(err) {
|
|
return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, ports.ErrConflict)
|
|
}
|
|
return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
|
|
// GetLimitByRecordID returns the limit history record identified by recordID.
|
|
func (store *Store) GetLimitByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) {
|
|
if err := recordID.Validate(); err != nil {
|
|
return policy.LimitRecord{}, fmt.Errorf("get limit from postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "get limit from postgres")
|
|
if err != nil {
|
|
return policy.LimitRecord{}, err
|
|
}
|
|
defer cancel()
|
|
|
|
stmt := pg.SELECT(limitSelectColumns).
|
|
FROM(pgtable.LimitRecords).
|
|
WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(recordID.String())))
|
|
|
|
query, args := stmt.Sql()
|
|
row := store.db.QueryRowContext(operationCtx, query, args...)
|
|
record, err := scanLimitRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, ports.ErrNotFound)
|
|
case err != nil:
|
|
return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, err)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// ListLimitsByUserID returns every limit history record owned by userID,
|
|
// ordered by applied_at ascending.
|
|
func (store *Store) ListLimitsByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) {
|
|
if err := userID.Validate(); err != nil {
|
|
return nil, fmt.Errorf("list limits from postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "list limits from postgres")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer cancel()
|
|
|
|
stmt := pg.SELECT(limitSelectColumns).
|
|
FROM(pgtable.LimitRecords).
|
|
WHERE(pgtable.LimitRecords.UserID.EQ(pg.String(userID.String()))).
|
|
ORDER_BY(pgtable.LimitRecords.AppliedAt.ASC(), pgtable.LimitRecords.RecordID.ASC())
|
|
|
|
query, args := stmt.Sql()
|
|
rows, err := store.db.QueryContext(operationCtx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
out := make([]policy.LimitRecord, 0)
|
|
for rows.Next() {
|
|
record, err := scanLimit(rows)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
|
|
}
|
|
out = append(out, record)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// UpdateLimit replaces one stored limit history record.
|
|
func (store *Store) UpdateLimit(ctx context.Context, record policy.LimitRecord) error {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("update limit in postgres: %w", err)
|
|
}
|
|
operationCtx, cancel, err := store.operationContext(ctx, "update limit in postgres")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
return updateLimitRecordTx(operationCtx, store.db, record)
|
|
}
|
|
|
|
func updateLimitRecordTx(ctx context.Context, q queryer, record policy.LimitRecord) error {
|
|
stmt := pgtable.LimitRecords.UPDATE(
|
|
pgtable.LimitRecords.UserID,
|
|
pgtable.LimitRecords.LimitCode,
|
|
pgtable.LimitRecords.Value,
|
|
pgtable.LimitRecords.ReasonCode,
|
|
pgtable.LimitRecords.ActorType,
|
|
pgtable.LimitRecords.ActorID,
|
|
pgtable.LimitRecords.AppliedAt,
|
|
pgtable.LimitRecords.ExpiresAt,
|
|
pgtable.LimitRecords.RemovedAt,
|
|
pgtable.LimitRecords.RemovedByType,
|
|
pgtable.LimitRecords.RemovedByID,
|
|
pgtable.LimitRecords.RemovedReasonCode,
|
|
).SET(
|
|
record.UserID.String(),
|
|
string(record.LimitCode),
|
|
record.Value,
|
|
record.ReasonCode.String(),
|
|
record.Actor.Type.String(),
|
|
nullableActorID(record.Actor.ID),
|
|
record.AppliedAt.UTC(),
|
|
nullableTime(record.ExpiresAt),
|
|
nullableTime(record.RemovedAt),
|
|
nullableActorType(record.RemovedBy.Type),
|
|
nullableActorID(record.RemovedBy.ID),
|
|
nullableReasonCode(record.RemovedReasonCode),
|
|
).WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(record.RecordID.String())))
|
|
|
|
query, args := stmt.Sql()
|
|
res, err := q.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err)
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, ports.ErrNotFound)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanLimitRow(row *sql.Row) (policy.LimitRecord, error) {
|
|
record, err := scanLimit(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return policy.LimitRecord{}, ports.ErrNotFound
|
|
}
|
|
return record, err
|
|
}
|
|
|
|
func scanLimit(row scannableRow) (policy.LimitRecord, error) {
|
|
var (
|
|
recordID string
|
|
userID string
|
|
code string
|
|
value int
|
|
reason string
|
|
actorType string
|
|
actorID *string
|
|
appliedAt time.Time
|
|
expiresAt *time.Time
|
|
removedAt *time.Time
|
|
rmByType *string
|
|
rmByID *string
|
|
rmReason *string
|
|
)
|
|
if err := row.Scan(
|
|
&recordID, &userID, &code, &value, &reason,
|
|
&actorType, &actorID, &appliedAt,
|
|
&expiresAt, &removedAt,
|
|
&rmByType, &rmByID, &rmReason,
|
|
); err != nil {
|
|
return policy.LimitRecord{}, err
|
|
}
|
|
record := policy.LimitRecord{
|
|
RecordID: policy.LimitRecordID(recordID),
|
|
UserID: common.UserID(userID),
|
|
LimitCode: policy.LimitCode(code),
|
|
Value: value,
|
|
ReasonCode: common.ReasonCode(reason),
|
|
Actor: common.ActorRef{Type: common.ActorType(actorType)},
|
|
AppliedAt: appliedAt.UTC(),
|
|
ExpiresAt: timeFromNullable(expiresAt),
|
|
RemovedAt: timeFromNullable(removedAt),
|
|
}
|
|
if actorID != nil {
|
|
record.Actor.ID = common.ActorID(*actorID)
|
|
}
|
|
if rmByType != nil {
|
|
record.RemovedBy.Type = common.ActorType(*rmByType)
|
|
}
|
|
if rmByID != nil {
|
|
record.RemovedBy.ID = common.ActorID(*rmByID)
|
|
}
|
|
if rmReason != nil {
|
|
record.RemovedReasonCode = common.ReasonCode(*rmReason)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// ApplySanction inserts the new sanction history row and points
|
|
// sanction_active at it. Re-applying the same code while another active
|
|
// record exists returns ports.ErrConflict.
|
|
func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("apply sanction in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "apply sanction in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if err := insertSanctionRecord(ctx, tx, input.NewRecord); err != nil {
|
|
return err
|
|
}
|
|
stmt := pgtable.SanctionActive.INSERT(
|
|
pgtable.SanctionActive.UserID,
|
|
pgtable.SanctionActive.SanctionCode,
|
|
pgtable.SanctionActive.RecordID,
|
|
).VALUES(
|
|
input.NewRecord.UserID.String(),
|
|
string(input.NewRecord.SanctionCode),
|
|
input.NewRecord.RecordID.String(),
|
|
)
|
|
query, args := stmt.Sql()
|
|
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
|
if isUniqueViolation(err) {
|
|
return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict)
|
|
}
|
|
return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// RemoveSanction updates the existing sanction record with remove metadata
|
|
// and clears the sanction_active row that pointed at it.
|
|
func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("remove sanction in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "remove sanction in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if err := lockSanctionMatching(ctx, tx, input.ExpectedActiveRecord); err != nil {
|
|
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
|
|
}
|
|
if err := updateSanctionRecordTx(ctx, tx, input.UpdatedRecord); err != nil {
|
|
return err
|
|
}
|
|
stmt := pgtable.SanctionActive.DELETE().
|
|
WHERE(pg.AND(
|
|
pgtable.SanctionActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())),
|
|
pgtable.SanctionActive.SanctionCode.EQ(pg.String(string(input.ExpectedActiveRecord.SanctionCode))),
|
|
pgtable.SanctionActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())),
|
|
))
|
|
query, args := stmt.Sql()
|
|
res, err := tx.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// SetLimit creates a new active limit (or replaces one) for the user. When
|
|
// ExpectedActiveRecord is nil the call must succeed only if no active row
|
|
// exists for (user_id, limit_code); otherwise the existing record is
|
|
// updated with remove metadata and superseded by NewRecord.
|
|
func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("set limit in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "set limit in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if input.ExpectedActiveRecord != nil {
|
|
if err := lockLimitMatching(ctx, tx, *input.ExpectedActiveRecord); err != nil {
|
|
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
|
|
}
|
|
if err := updateLimitRecordTx(ctx, tx, *input.UpdatedActiveRecord); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
probe := pg.SELECT(pgtable.LimitActive.RecordID).
|
|
FROM(pgtable.LimitActive).
|
|
WHERE(pg.AND(
|
|
pgtable.LimitActive.UserID.EQ(pg.String(input.NewRecord.UserID.String())),
|
|
pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.NewRecord.LimitCode))),
|
|
)).
|
|
FOR(pg.UPDATE())
|
|
probeQuery, probeArgs := probe.Sql()
|
|
row := tx.QueryRowContext(ctx, probeQuery, probeArgs...)
|
|
var marker string
|
|
if err := row.Scan(&marker); err == nil {
|
|
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict)
|
|
} else if !errors.Is(err, sql.ErrNoRows) {
|
|
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
|
|
}
|
|
}
|
|
|
|
if err := insertLimitRecord(ctx, tx, input.NewRecord); err != nil {
|
|
return err
|
|
}
|
|
|
|
upsert := pgtable.LimitActive.INSERT(
|
|
pgtable.LimitActive.UserID,
|
|
pgtable.LimitActive.LimitCode,
|
|
pgtable.LimitActive.RecordID,
|
|
pgtable.LimitActive.Value,
|
|
).VALUES(
|
|
input.NewRecord.UserID.String(),
|
|
string(input.NewRecord.LimitCode),
|
|
input.NewRecord.RecordID.String(),
|
|
input.NewRecord.Value,
|
|
).ON_CONFLICT(pgtable.LimitActive.UserID, pgtable.LimitActive.LimitCode).DO_UPDATE(
|
|
pg.SET(
|
|
pgtable.LimitActive.RecordID.SET(pgtable.LimitActive.EXCLUDED.RecordID),
|
|
pgtable.LimitActive.Value.SET(pgtable.LimitActive.EXCLUDED.Value),
|
|
),
|
|
)
|
|
upsertQuery, upsertArgs := upsert.Sql()
|
|
if _, err := tx.ExecContext(ctx, upsertQuery, upsertArgs...); err != nil {
|
|
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// RemoveLimit updates the limit record with remove metadata and removes the
|
|
// active row that referenced it.
|
|
func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("remove limit in postgres: %w", err)
|
|
}
|
|
return store.withTx(ctx, "remove limit in postgres", func(ctx context.Context, tx *sql.Tx) error {
|
|
if err := lockLimitMatching(ctx, tx, input.ExpectedActiveRecord); err != nil {
|
|
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
|
|
}
|
|
if err := updateLimitRecordTx(ctx, tx, input.UpdatedRecord); err != nil {
|
|
return err
|
|
}
|
|
stmt := pgtable.LimitActive.DELETE().
|
|
WHERE(pg.AND(
|
|
pgtable.LimitActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())),
|
|
pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.ExpectedActiveRecord.LimitCode))),
|
|
pgtable.LimitActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())),
|
|
))
|
|
query, args := stmt.Sql()
|
|
res, err := tx.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func lockSanctionMatching(ctx context.Context, tx *sql.Tx, expected policy.SanctionRecord) error {
|
|
stmt := pg.SELECT(sanctionSelectColumns).
|
|
FROM(pgtable.SanctionRecords).
|
|
WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
|
|
FOR(pg.UPDATE())
|
|
|
|
query, args := stmt.Sql()
|
|
row := tx.QueryRowContext(ctx, query, args...)
|
|
current, err := scanSanctionRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return ports.ErrNotFound
|
|
case err != nil:
|
|
return err
|
|
}
|
|
if !sanctionsEqual(current, expected) {
|
|
return ports.ErrConflict
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func lockLimitMatching(ctx context.Context, tx *sql.Tx, expected policy.LimitRecord) error {
|
|
stmt := pg.SELECT(limitSelectColumns).
|
|
FROM(pgtable.LimitRecords).
|
|
WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
|
|
FOR(pg.UPDATE())
|
|
|
|
query, args := stmt.Sql()
|
|
row := tx.QueryRowContext(ctx, query, args...)
|
|
current, err := scanLimitRow(row)
|
|
switch {
|
|
case errors.Is(err, ports.ErrNotFound):
|
|
return ports.ErrNotFound
|
|
case err != nil:
|
|
return err
|
|
}
|
|
if !limitsEqual(current, expected) {
|
|
return ports.ErrConflict
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sanctionsEqual(left policy.SanctionRecord, right policy.SanctionRecord) bool {
|
|
if left.RecordID != right.RecordID ||
|
|
left.UserID != right.UserID ||
|
|
left.SanctionCode != right.SanctionCode ||
|
|
left.Scope != right.Scope ||
|
|
left.ReasonCode != right.ReasonCode ||
|
|
left.Actor != right.Actor ||
|
|
left.RemovedBy != right.RemovedBy ||
|
|
left.RemovedReasonCode != right.RemovedReasonCode {
|
|
return false
|
|
}
|
|
if !left.AppliedAt.Equal(right.AppliedAt) {
|
|
return false
|
|
}
|
|
if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) {
|
|
return false
|
|
}
|
|
return optionalTimeEqual(left.RemovedAt, right.RemovedAt)
|
|
}
|
|
|
|
func limitsEqual(left policy.LimitRecord, right policy.LimitRecord) bool {
|
|
if left.RecordID != right.RecordID ||
|
|
left.UserID != right.UserID ||
|
|
left.LimitCode != right.LimitCode ||
|
|
left.Value != right.Value ||
|
|
left.ReasonCode != right.ReasonCode ||
|
|
left.Actor != right.Actor ||
|
|
left.RemovedBy != right.RemovedBy ||
|
|
left.RemovedReasonCode != right.RemovedReasonCode {
|
|
return false
|
|
}
|
|
if !left.AppliedAt.Equal(right.AppliedAt) {
|
|
return false
|
|
}
|
|
if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) {
|
|
return false
|
|
}
|
|
return optionalTimeEqual(left.RemovedAt, right.RemovedAt)
|
|
}
|
|
|
|
// SanctionStore adapts Store to the SanctionStore port.
|
|
type SanctionStore struct{ store *Store }
|
|
|
|
// Sanctions returns one adapter that exposes the sanction store port.
|
|
func (store *Store) Sanctions() *SanctionStore {
|
|
if store == nil {
|
|
return nil
|
|
}
|
|
return &SanctionStore{store: store}
|
|
}
|
|
|
|
// Create stores one new sanction history record.
|
|
func (a *SanctionStore) Create(ctx context.Context, record policy.SanctionRecord) error {
|
|
return a.store.CreateSanction(ctx, record)
|
|
}
|
|
|
|
// GetByRecordID returns the sanction record identified by recordID.
|
|
func (a *SanctionStore) GetByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) {
|
|
return a.store.GetSanctionByRecordID(ctx, recordID)
|
|
}
|
|
|
|
// ListByUserID returns every sanction record owned by userID.
|
|
func (a *SanctionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) {
|
|
return a.store.ListSanctionsByUserID(ctx, userID)
|
|
}
|
|
|
|
// Update replaces one stored sanction record.
|
|
func (a *SanctionStore) Update(ctx context.Context, record policy.SanctionRecord) error {
|
|
return a.store.UpdateSanction(ctx, record)
|
|
}
|
|
|
|
var _ ports.SanctionStore = (*SanctionStore)(nil)
|
|
|
|
// LimitStore adapts Store to the LimitStore port.
|
|
type LimitStore struct{ store *Store }
|
|
|
|
// Limits returns one adapter that exposes the limit store port.
|
|
func (store *Store) Limits() *LimitStore {
|
|
if store == nil {
|
|
return nil
|
|
}
|
|
return &LimitStore{store: store}
|
|
}
|
|
|
|
// Create stores one new limit history record.
|
|
func (a *LimitStore) Create(ctx context.Context, record policy.LimitRecord) error {
|
|
return a.store.CreateLimit(ctx, record)
|
|
}
|
|
|
|
// GetByRecordID returns the limit record identified by recordID.
|
|
func (a *LimitStore) GetByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) {
|
|
return a.store.GetLimitByRecordID(ctx, recordID)
|
|
}
|
|
|
|
// ListByUserID returns every limit record owned by userID.
|
|
func (a *LimitStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) {
|
|
return a.store.ListLimitsByUserID(ctx, userID)
|
|
}
|
|
|
|
// Update replaces one stored limit record.
|
|
func (a *LimitStore) Update(ctx context.Context, record policy.LimitRecord) error {
|
|
return a.store.UpdateLimit(ctx, record)
|
|
}
|
|
|
|
var _ ports.LimitStore = (*LimitStore)(nil)
|
|
|
|
// PolicyLifecycleStore adapts Store to the PolicyLifecycleStore port.
|
|
type PolicyLifecycleStore struct{ store *Store }
|
|
|
|
// PolicyLifecycle returns one adapter that exposes the policy-lifecycle
|
|
// store port.
|
|
func (store *Store) PolicyLifecycle() *PolicyLifecycleStore {
|
|
if store == nil {
|
|
return nil
|
|
}
|
|
return &PolicyLifecycleStore{store: store}
|
|
}
|
|
|
|
// ApplySanction atomically creates one new active sanction record.
|
|
func (a *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
|
|
return a.store.ApplySanction(ctx, input)
|
|
}
|
|
|
|
// RemoveSanction atomically removes one active sanction record.
|
|
func (a *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
|
|
return a.store.RemoveSanction(ctx, input)
|
|
}
|
|
|
|
// SetLimit atomically creates or replaces one active limit record.
|
|
func (a *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
|
|
return a.store.SetLimit(ctx, input)
|
|
}
|
|
|
|
// RemoveLimit atomically removes one active limit record.
|
|
func (a *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
|
|
return a.store.RemoveLimit(ctx, input)
|
|
}
|
|
|
|
var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)
|