package userstore import ( "context" "database/sql" "errors" "fmt" "time" pgtable "galaxy/user/internal/adapters/postgres/jet/user/table" "galaxy/user/internal/domain/common" "galaxy/user/internal/domain/policy" "galaxy/user/internal/ports" pg "github.com/go-jet/jet/v2/postgres" ) // sanctionSelectColumns is the canonical SELECT list for sanction_records, // matching scanSanction's column order. var sanctionSelectColumns = pg.ColumnList{ pgtable.SanctionRecords.RecordID, pgtable.SanctionRecords.UserID, pgtable.SanctionRecords.SanctionCode, pgtable.SanctionRecords.Scope, pgtable.SanctionRecords.ReasonCode, pgtable.SanctionRecords.ActorType, pgtable.SanctionRecords.ActorID, pgtable.SanctionRecords.AppliedAt, pgtable.SanctionRecords.ExpiresAt, pgtable.SanctionRecords.RemovedAt, pgtable.SanctionRecords.RemovedByType, pgtable.SanctionRecords.RemovedByID, pgtable.SanctionRecords.RemovedReasonCode, } // limitSelectColumns is the canonical SELECT list for limit_records, matching // scanLimit's column order. var limitSelectColumns = pg.ColumnList{ pgtable.LimitRecords.RecordID, pgtable.LimitRecords.UserID, pgtable.LimitRecords.LimitCode, pgtable.LimitRecords.Value, pgtable.LimitRecords.ReasonCode, pgtable.LimitRecords.ActorType, pgtable.LimitRecords.ActorID, pgtable.LimitRecords.AppliedAt, pgtable.LimitRecords.ExpiresAt, pgtable.LimitRecords.RemovedAt, pgtable.LimitRecords.RemovedByType, pgtable.LimitRecords.RemovedByID, pgtable.LimitRecords.RemovedReasonCode, } // CreateSanction stores one new sanction history record. func (store *Store) CreateSanction(ctx context.Context, record policy.SanctionRecord) error { if err := record.Validate(); err != nil { return fmt.Errorf("create sanction in postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "create sanction in postgres") if err != nil { return err } defer cancel() return insertSanctionRecord(operationCtx, store.db, record) } func insertSanctionRecord(ctx context.Context, q queryer, record policy.SanctionRecord) error { stmt := pgtable.SanctionRecords.INSERT( pgtable.SanctionRecords.RecordID, pgtable.SanctionRecords.UserID, pgtable.SanctionRecords.SanctionCode, pgtable.SanctionRecords.Scope, pgtable.SanctionRecords.ReasonCode, pgtable.SanctionRecords.ActorType, pgtable.SanctionRecords.ActorID, pgtable.SanctionRecords.AppliedAt, pgtable.SanctionRecords.ExpiresAt, pgtable.SanctionRecords.RemovedAt, pgtable.SanctionRecords.RemovedByType, pgtable.SanctionRecords.RemovedByID, pgtable.SanctionRecords.RemovedReasonCode, ).VALUES( record.RecordID.String(), record.UserID.String(), string(record.SanctionCode), record.Scope.String(), record.ReasonCode.String(), record.Actor.Type.String(), nullableActorID(record.Actor.ID), record.AppliedAt.UTC(), nullableTime(record.ExpiresAt), nullableTime(record.RemovedAt), nullableActorType(record.RemovedBy.Type), nullableActorID(record.RemovedBy.ID), nullableReasonCode(record.RemovedReasonCode), ) query, args := stmt.Sql() _, err := q.ExecContext(ctx, query, args...) if err == nil { return nil } if isUniqueViolation(err) { return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, ports.ErrConflict) } return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, err) } // GetSanctionByRecordID returns the sanction history record identified by // recordID. func (store *Store) GetSanctionByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) { if err := recordID.Validate(); err != nil { return policy.SanctionRecord{}, fmt.Errorf("get sanction from postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "get sanction from postgres") if err != nil { return policy.SanctionRecord{}, err } defer cancel() stmt := pg.SELECT(sanctionSelectColumns). FROM(pgtable.SanctionRecords). WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(recordID.String()))) query, args := stmt.Sql() row := store.db.QueryRowContext(operationCtx, query, args...) record, err := scanSanctionRow(row) switch { case errors.Is(err, ports.ErrNotFound): return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, ports.ErrNotFound) case err != nil: return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, err) } return record, nil } // ListSanctionsByUserID returns every sanction history record owned by // userID, ordered by applied_at ascending. func (store *Store) ListSanctionsByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) { if err := userID.Validate(); err != nil { return nil, fmt.Errorf("list sanctions from postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "list sanctions from postgres") if err != nil { return nil, err } defer cancel() stmt := pg.SELECT(sanctionSelectColumns). FROM(pgtable.SanctionRecords). WHERE(pgtable.SanctionRecords.UserID.EQ(pg.String(userID.String()))). ORDER_BY(pgtable.SanctionRecords.AppliedAt.ASC(), pgtable.SanctionRecords.RecordID.ASC()) query, args := stmt.Sql() rows, err := store.db.QueryContext(operationCtx, query, args...) if err != nil { return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err) } defer func() { _ = rows.Close() }() out := make([]policy.SanctionRecord, 0) for rows.Next() { record, err := scanSanction(rows) if err != nil { return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err) } out = append(out, record) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err) } return out, nil } // UpdateSanction replaces one stored sanction history record. The matched // row is identified by record_id; ports.ErrNotFound is returned when no row // matches. func (store *Store) UpdateSanction(ctx context.Context, record policy.SanctionRecord) error { if err := record.Validate(); err != nil { return fmt.Errorf("update sanction in postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "update sanction in postgres") if err != nil { return err } defer cancel() return updateSanctionRecordTx(operationCtx, store.db, record) } func updateSanctionRecordTx(ctx context.Context, q queryer, record policy.SanctionRecord) error { stmt := pgtable.SanctionRecords.UPDATE( pgtable.SanctionRecords.UserID, pgtable.SanctionRecords.SanctionCode, pgtable.SanctionRecords.Scope, pgtable.SanctionRecords.ReasonCode, pgtable.SanctionRecords.ActorType, pgtable.SanctionRecords.ActorID, pgtable.SanctionRecords.AppliedAt, pgtable.SanctionRecords.ExpiresAt, pgtable.SanctionRecords.RemovedAt, pgtable.SanctionRecords.RemovedByType, pgtable.SanctionRecords.RemovedByID, pgtable.SanctionRecords.RemovedReasonCode, ).SET( record.UserID.String(), string(record.SanctionCode), record.Scope.String(), record.ReasonCode.String(), record.Actor.Type.String(), nullableActorID(record.Actor.ID), record.AppliedAt.UTC(), nullableTime(record.ExpiresAt), nullableTime(record.RemovedAt), nullableActorType(record.RemovedBy.Type), nullableActorID(record.RemovedBy.ID), nullableReasonCode(record.RemovedReasonCode), ).WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(record.RecordID.String()))) query, args := stmt.Sql() res, err := q.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err) } if rows == 0 { return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, ports.ErrNotFound) } return nil } func scanSanctionRow(row *sql.Row) (policy.SanctionRecord, error) { record, err := scanSanction(row) if errors.Is(err, sql.ErrNoRows) { return policy.SanctionRecord{}, ports.ErrNotFound } return record, err } func scanSanction(row scannableRow) (policy.SanctionRecord, error) { var ( recordID string userID string code string scope string reason string actorType string actorID *string appliedAt time.Time expiresAt *time.Time removedAt *time.Time rmByType *string rmByID *string rmReason *string ) if err := row.Scan( &recordID, &userID, &code, &scope, &reason, &actorType, &actorID, &appliedAt, &expiresAt, &removedAt, &rmByType, &rmByID, &rmReason, ); err != nil { return policy.SanctionRecord{}, err } record := policy.SanctionRecord{ RecordID: policy.SanctionRecordID(recordID), UserID: common.UserID(userID), SanctionCode: policy.SanctionCode(code), Scope: common.Scope(scope), ReasonCode: common.ReasonCode(reason), Actor: common.ActorRef{Type: common.ActorType(actorType)}, AppliedAt: appliedAt.UTC(), ExpiresAt: timeFromNullable(expiresAt), RemovedAt: timeFromNullable(removedAt), } if actorID != nil { record.Actor.ID = common.ActorID(*actorID) } if rmByType != nil { record.RemovedBy.Type = common.ActorType(*rmByType) } if rmByID != nil { record.RemovedBy.ID = common.ActorID(*rmByID) } if rmReason != nil { record.RemovedReasonCode = common.ReasonCode(*rmReason) } return record, nil } // CreateLimit stores one new limit history record. func (store *Store) CreateLimit(ctx context.Context, record policy.LimitRecord) error { if err := record.Validate(); err != nil { return fmt.Errorf("create limit in postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "create limit in postgres") if err != nil { return err } defer cancel() return insertLimitRecord(operationCtx, store.db, record) } func insertLimitRecord(ctx context.Context, q queryer, record policy.LimitRecord) error { stmt := pgtable.LimitRecords.INSERT( pgtable.LimitRecords.RecordID, pgtable.LimitRecords.UserID, pgtable.LimitRecords.LimitCode, pgtable.LimitRecords.Value, pgtable.LimitRecords.ReasonCode, pgtable.LimitRecords.ActorType, pgtable.LimitRecords.ActorID, pgtable.LimitRecords.AppliedAt, pgtable.LimitRecords.ExpiresAt, pgtable.LimitRecords.RemovedAt, pgtable.LimitRecords.RemovedByType, pgtable.LimitRecords.RemovedByID, pgtable.LimitRecords.RemovedReasonCode, ).VALUES( record.RecordID.String(), record.UserID.String(), string(record.LimitCode), record.Value, record.ReasonCode.String(), record.Actor.Type.String(), nullableActorID(record.Actor.ID), record.AppliedAt.UTC(), nullableTime(record.ExpiresAt), nullableTime(record.RemovedAt), nullableActorType(record.RemovedBy.Type), nullableActorID(record.RemovedBy.ID), nullableReasonCode(record.RemovedReasonCode), ) query, args := stmt.Sql() _, err := q.ExecContext(ctx, query, args...) if err == nil { return nil } if isUniqueViolation(err) { return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, ports.ErrConflict) } return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, err) } // GetLimitByRecordID returns the limit history record identified by recordID. func (store *Store) GetLimitByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) { if err := recordID.Validate(); err != nil { return policy.LimitRecord{}, fmt.Errorf("get limit from postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "get limit from postgres") if err != nil { return policy.LimitRecord{}, err } defer cancel() stmt := pg.SELECT(limitSelectColumns). FROM(pgtable.LimitRecords). WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(recordID.String()))) query, args := stmt.Sql() row := store.db.QueryRowContext(operationCtx, query, args...) record, err := scanLimitRow(row) switch { case errors.Is(err, ports.ErrNotFound): return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, ports.ErrNotFound) case err != nil: return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, err) } return record, nil } // ListLimitsByUserID returns every limit history record owned by userID, // ordered by applied_at ascending. func (store *Store) ListLimitsByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) { if err := userID.Validate(); err != nil { return nil, fmt.Errorf("list limits from postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "list limits from postgres") if err != nil { return nil, err } defer cancel() stmt := pg.SELECT(limitSelectColumns). FROM(pgtable.LimitRecords). WHERE(pgtable.LimitRecords.UserID.EQ(pg.String(userID.String()))). ORDER_BY(pgtable.LimitRecords.AppliedAt.ASC(), pgtable.LimitRecords.RecordID.ASC()) query, args := stmt.Sql() rows, err := store.db.QueryContext(operationCtx, query, args...) if err != nil { return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err) } defer func() { _ = rows.Close() }() out := make([]policy.LimitRecord, 0) for rows.Next() { record, err := scanLimit(rows) if err != nil { return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err) } out = append(out, record) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err) } return out, nil } // UpdateLimit replaces one stored limit history record. func (store *Store) UpdateLimit(ctx context.Context, record policy.LimitRecord) error { if err := record.Validate(); err != nil { return fmt.Errorf("update limit in postgres: %w", err) } operationCtx, cancel, err := store.operationContext(ctx, "update limit in postgres") if err != nil { return err } defer cancel() return updateLimitRecordTx(operationCtx, store.db, record) } func updateLimitRecordTx(ctx context.Context, q queryer, record policy.LimitRecord) error { stmt := pgtable.LimitRecords.UPDATE( pgtable.LimitRecords.UserID, pgtable.LimitRecords.LimitCode, pgtable.LimitRecords.Value, pgtable.LimitRecords.ReasonCode, pgtable.LimitRecords.ActorType, pgtable.LimitRecords.ActorID, pgtable.LimitRecords.AppliedAt, pgtable.LimitRecords.ExpiresAt, pgtable.LimitRecords.RemovedAt, pgtable.LimitRecords.RemovedByType, pgtable.LimitRecords.RemovedByID, pgtable.LimitRecords.RemovedReasonCode, ).SET( record.UserID.String(), string(record.LimitCode), record.Value, record.ReasonCode.String(), record.Actor.Type.String(), nullableActorID(record.Actor.ID), record.AppliedAt.UTC(), nullableTime(record.ExpiresAt), nullableTime(record.RemovedAt), nullableActorType(record.RemovedBy.Type), nullableActorID(record.RemovedBy.ID), nullableReasonCode(record.RemovedReasonCode), ).WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(record.RecordID.String()))) query, args := stmt.Sql() res, err := q.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err) } if rows == 0 { return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, ports.ErrNotFound) } return nil } func scanLimitRow(row *sql.Row) (policy.LimitRecord, error) { record, err := scanLimit(row) if errors.Is(err, sql.ErrNoRows) { return policy.LimitRecord{}, ports.ErrNotFound } return record, err } func scanLimit(row scannableRow) (policy.LimitRecord, error) { var ( recordID string userID string code string value int reason string actorType string actorID *string appliedAt time.Time expiresAt *time.Time removedAt *time.Time rmByType *string rmByID *string rmReason *string ) if err := row.Scan( &recordID, &userID, &code, &value, &reason, &actorType, &actorID, &appliedAt, &expiresAt, &removedAt, &rmByType, &rmByID, &rmReason, ); err != nil { return policy.LimitRecord{}, err } record := policy.LimitRecord{ RecordID: policy.LimitRecordID(recordID), UserID: common.UserID(userID), LimitCode: policy.LimitCode(code), Value: value, ReasonCode: common.ReasonCode(reason), Actor: common.ActorRef{Type: common.ActorType(actorType)}, AppliedAt: appliedAt.UTC(), ExpiresAt: timeFromNullable(expiresAt), RemovedAt: timeFromNullable(removedAt), } if actorID != nil { record.Actor.ID = common.ActorID(*actorID) } if rmByType != nil { record.RemovedBy.Type = common.ActorType(*rmByType) } if rmByID != nil { record.RemovedBy.ID = common.ActorID(*rmByID) } if rmReason != nil { record.RemovedReasonCode = common.ReasonCode(*rmReason) } return record, nil } // ApplySanction inserts the new sanction history row and points // sanction_active at it. Re-applying the same code while another active // record exists returns ports.ErrConflict. func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("apply sanction in postgres: %w", err) } return store.withTx(ctx, "apply sanction in postgres", func(ctx context.Context, tx *sql.Tx) error { if err := insertSanctionRecord(ctx, tx, input.NewRecord); err != nil { return err } stmt := pgtable.SanctionActive.INSERT( pgtable.SanctionActive.UserID, pgtable.SanctionActive.SanctionCode, pgtable.SanctionActive.RecordID, ).VALUES( input.NewRecord.UserID.String(), string(input.NewRecord.SanctionCode), input.NewRecord.RecordID.String(), ) query, args := stmt.Sql() if _, err := tx.ExecContext(ctx, query, args...); err != nil { if isUniqueViolation(err) { return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict) } return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, err) } return nil }) } // RemoveSanction updates the existing sanction record with remove metadata // and clears the sanction_active row that pointed at it. func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("remove sanction in postgres: %w", err) } return store.withTx(ctx, "remove sanction in postgres", func(ctx context.Context, tx *sql.Tx) error { if err := lockSanctionMatching(ctx, tx, input.ExpectedActiveRecord); err != nil { return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err) } if err := updateSanctionRecordTx(ctx, tx, input.UpdatedRecord); err != nil { return err } stmt := pgtable.SanctionActive.DELETE(). WHERE(pg.AND( pgtable.SanctionActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())), pgtable.SanctionActive.SanctionCode.EQ(pg.String(string(input.ExpectedActiveRecord.SanctionCode))), pgtable.SanctionActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())), )) query, args := stmt.Sql() res, err := tx.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err) } if rows == 0 { return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict) } return nil }) } // SetLimit creates a new active limit (or replaces one) for the user. When // ExpectedActiveRecord is nil the call must succeed only if no active row // exists for (user_id, limit_code); otherwise the existing record is // updated with remove metadata and superseded by NewRecord. func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("set limit in postgres: %w", err) } return store.withTx(ctx, "set limit in postgres", func(ctx context.Context, tx *sql.Tx) error { if input.ExpectedActiveRecord != nil { if err := lockLimitMatching(ctx, tx, *input.ExpectedActiveRecord); err != nil { return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err) } if err := updateLimitRecordTx(ctx, tx, *input.UpdatedActiveRecord); err != nil { return err } } else { probe := pg.SELECT(pgtable.LimitActive.RecordID). FROM(pgtable.LimitActive). WHERE(pg.AND( pgtable.LimitActive.UserID.EQ(pg.String(input.NewRecord.UserID.String())), pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.NewRecord.LimitCode))), )). FOR(pg.UPDATE()) probeQuery, probeArgs := probe.Sql() row := tx.QueryRowContext(ctx, probeQuery, probeArgs...) var marker string if err := row.Scan(&marker); err == nil { return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict) } else if !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err) } } if err := insertLimitRecord(ctx, tx, input.NewRecord); err != nil { return err } upsert := pgtable.LimitActive.INSERT( pgtable.LimitActive.UserID, pgtable.LimitActive.LimitCode, pgtable.LimitActive.RecordID, pgtable.LimitActive.Value, ).VALUES( input.NewRecord.UserID.String(), string(input.NewRecord.LimitCode), input.NewRecord.RecordID.String(), input.NewRecord.Value, ).ON_CONFLICT(pgtable.LimitActive.UserID, pgtable.LimitActive.LimitCode).DO_UPDATE( pg.SET( pgtable.LimitActive.RecordID.SET(pgtable.LimitActive.EXCLUDED.RecordID), pgtable.LimitActive.Value.SET(pgtable.LimitActive.EXCLUDED.Value), ), ) upsertQuery, upsertArgs := upsert.Sql() if _, err := tx.ExecContext(ctx, upsertQuery, upsertArgs...); err != nil { return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err) } return nil }) } // RemoveLimit updates the limit record with remove metadata and removes the // active row that referenced it. func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("remove limit in postgres: %w", err) } return store.withTx(ctx, "remove limit in postgres", func(ctx context.Context, tx *sql.Tx) error { if err := lockLimitMatching(ctx, tx, input.ExpectedActiveRecord); err != nil { return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err) } if err := updateLimitRecordTx(ctx, tx, input.UpdatedRecord); err != nil { return err } stmt := pgtable.LimitActive.DELETE(). WHERE(pg.AND( pgtable.LimitActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())), pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.ExpectedActiveRecord.LimitCode))), pgtable.LimitActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())), )) query, args := stmt.Sql() res, err := tx.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err) } if rows == 0 { return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict) } return nil }) } func lockSanctionMatching(ctx context.Context, tx *sql.Tx, expected policy.SanctionRecord) error { stmt := pg.SELECT(sanctionSelectColumns). FROM(pgtable.SanctionRecords). WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(expected.RecordID.String()))). FOR(pg.UPDATE()) query, args := stmt.Sql() row := tx.QueryRowContext(ctx, query, args...) current, err := scanSanctionRow(row) switch { case errors.Is(err, ports.ErrNotFound): return ports.ErrNotFound case err != nil: return err } if !sanctionsEqual(current, expected) { return ports.ErrConflict } return nil } func lockLimitMatching(ctx context.Context, tx *sql.Tx, expected policy.LimitRecord) error { stmt := pg.SELECT(limitSelectColumns). FROM(pgtable.LimitRecords). WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(expected.RecordID.String()))). FOR(pg.UPDATE()) query, args := stmt.Sql() row := tx.QueryRowContext(ctx, query, args...) current, err := scanLimitRow(row) switch { case errors.Is(err, ports.ErrNotFound): return ports.ErrNotFound case err != nil: return err } if !limitsEqual(current, expected) { return ports.ErrConflict } return nil } func sanctionsEqual(left policy.SanctionRecord, right policy.SanctionRecord) bool { if left.RecordID != right.RecordID || left.UserID != right.UserID || left.SanctionCode != right.SanctionCode || left.Scope != right.Scope || left.ReasonCode != right.ReasonCode || left.Actor != right.Actor || left.RemovedBy != right.RemovedBy || left.RemovedReasonCode != right.RemovedReasonCode { return false } if !left.AppliedAt.Equal(right.AppliedAt) { return false } if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) { return false } return optionalTimeEqual(left.RemovedAt, right.RemovedAt) } func limitsEqual(left policy.LimitRecord, right policy.LimitRecord) bool { if left.RecordID != right.RecordID || left.UserID != right.UserID || left.LimitCode != right.LimitCode || left.Value != right.Value || left.ReasonCode != right.ReasonCode || left.Actor != right.Actor || left.RemovedBy != right.RemovedBy || left.RemovedReasonCode != right.RemovedReasonCode { return false } if !left.AppliedAt.Equal(right.AppliedAt) { return false } if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) { return false } return optionalTimeEqual(left.RemovedAt, right.RemovedAt) } // SanctionStore adapts Store to the SanctionStore port. type SanctionStore struct{ store *Store } // Sanctions returns one adapter that exposes the sanction store port. func (store *Store) Sanctions() *SanctionStore { if store == nil { return nil } return &SanctionStore{store: store} } // Create stores one new sanction history record. func (a *SanctionStore) Create(ctx context.Context, record policy.SanctionRecord) error { return a.store.CreateSanction(ctx, record) } // GetByRecordID returns the sanction record identified by recordID. func (a *SanctionStore) GetByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) { return a.store.GetSanctionByRecordID(ctx, recordID) } // ListByUserID returns every sanction record owned by userID. func (a *SanctionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) { return a.store.ListSanctionsByUserID(ctx, userID) } // Update replaces one stored sanction record. func (a *SanctionStore) Update(ctx context.Context, record policy.SanctionRecord) error { return a.store.UpdateSanction(ctx, record) } var _ ports.SanctionStore = (*SanctionStore)(nil) // LimitStore adapts Store to the LimitStore port. type LimitStore struct{ store *Store } // Limits returns one adapter that exposes the limit store port. func (store *Store) Limits() *LimitStore { if store == nil { return nil } return &LimitStore{store: store} } // Create stores one new limit history record. func (a *LimitStore) Create(ctx context.Context, record policy.LimitRecord) error { return a.store.CreateLimit(ctx, record) } // GetByRecordID returns the limit record identified by recordID. func (a *LimitStore) GetByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) { return a.store.GetLimitByRecordID(ctx, recordID) } // ListByUserID returns every limit record owned by userID. func (a *LimitStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) { return a.store.ListLimitsByUserID(ctx, userID) } // Update replaces one stored limit record. func (a *LimitStore) Update(ctx context.Context, record policy.LimitRecord) error { return a.store.UpdateLimit(ctx, record) } var _ ports.LimitStore = (*LimitStore)(nil) // PolicyLifecycleStore adapts Store to the PolicyLifecycleStore port. type PolicyLifecycleStore struct{ store *Store } // PolicyLifecycle returns one adapter that exposes the policy-lifecycle // store port. func (store *Store) PolicyLifecycle() *PolicyLifecycleStore { if store == nil { return nil } return &PolicyLifecycleStore{store: store} } // ApplySanction atomically creates one new active sanction record. func (a *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error { return a.store.ApplySanction(ctx, input) } // RemoveSanction atomically removes one active sanction record. func (a *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error { return a.store.RemoveSanction(ctx, input) } // SetLimit atomically creates or replaces one active limit record. func (a *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error { return a.store.SetLimit(ctx, input) } // RemoveLimit atomically removes one active limit record. func (a *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error { return a.store.RemoveLimit(ctx, input) } var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)