446 lines
16 KiB
Go
446 lines
16 KiB
Go
package userstore
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"galaxy/user/internal/domain/policy"
|
|
"galaxy/user/internal/ports"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// ApplySanction atomically creates one new active sanction record.
|
|
func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("apply sanction in redis: %w", err)
|
|
}
|
|
|
|
recordPayload, err := marshalSanctionRecord(input.NewRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("apply sanction in redis: %w", err)
|
|
}
|
|
|
|
recordKey := store.keyspace.SanctionRecord(input.NewRecord.RecordID)
|
|
historyKey := store.keyspace.SanctionHistory(input.NewRecord.UserID)
|
|
activeKey := store.keyspace.ActiveSanction(input.NewRecord.UserID, input.NewRecord.SanctionCode)
|
|
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewRecord.UserID)
|
|
watchedKeys := append(
|
|
[]string{recordKey, historyKey, activeKey, snapshotKey},
|
|
store.activeSanctionWatchKeys(input.NewRecord.UserID)...,
|
|
)
|
|
|
|
operationCtx, cancel, err := store.operationContext(ctx, "apply sanction in redis")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
|
|
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
|
if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil {
|
|
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
|
|
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.NewRecord.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewRecord.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
activeSanctionCodes[input.NewRecord.SanctionCode] = struct{}{}
|
|
|
|
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
|
pipe.Set(operationCtx, recordKey, recordPayload, 0)
|
|
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
|
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
|
|
Member: input.NewRecord.RecordID.String(),
|
|
})
|
|
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
|
|
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeSanctionCodes)
|
|
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
|
|
return nil
|
|
}, watchedKeys...)
|
|
|
|
switch {
|
|
case errors.Is(watchErr, redis.TxFailedErr):
|
|
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
|
case watchErr != nil:
|
|
return watchErr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// RemoveSanction atomically removes one active sanction record.
|
|
func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("remove sanction in redis: %w", err)
|
|
}
|
|
|
|
updatedPayload, err := marshalSanctionRecord(input.UpdatedRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction in redis: %w", err)
|
|
}
|
|
|
|
recordKey := store.keyspace.SanctionRecord(input.ExpectedActiveRecord.RecordID)
|
|
activeKey := store.keyspace.ActiveSanction(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.SanctionCode)
|
|
snapshotKey := store.keyspace.EntitlementSnapshot(input.ExpectedActiveRecord.UserID)
|
|
watchedKeys := append(
|
|
[]string{recordKey, activeKey, snapshotKey},
|
|
store.activeSanctionWatchKeys(input.ExpectedActiveRecord.UserID)...,
|
|
)
|
|
|
|
operationCtx, cancel, err := store.operationContext(ctx, "remove sanction in redis")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
|
|
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
|
activeRecordID, err := store.loadActiveSanctionRecordID(operationCtx, tx, activeKey)
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
if activeRecordID != input.ExpectedActiveRecord.RecordID {
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
|
}
|
|
|
|
storedRecord, err := store.loadSanctionRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
if !equalSanctionRecords(storedRecord, input.ExpectedActiveRecord) {
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
|
}
|
|
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedActiveRecord.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
delete(activeSanctionCodes, input.ExpectedActiveRecord.SanctionCode)
|
|
|
|
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
|
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
|
|
pipe.Del(operationCtx, activeKey)
|
|
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeSanctionCodes)
|
|
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
|
|
return nil
|
|
}, watchedKeys...)
|
|
|
|
switch {
|
|
case errors.Is(watchErr, redis.TxFailedErr):
|
|
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
|
case watchErr != nil:
|
|
return watchErr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// SetLimit atomically creates or replaces one active limit record.
|
|
func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("set limit in redis: %w", err)
|
|
}
|
|
|
|
newRecordPayload, err := marshalLimitRecord(input.NewRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("set limit in redis: %w", err)
|
|
}
|
|
|
|
newRecordKey := store.keyspace.LimitRecord(input.NewRecord.RecordID)
|
|
historyKey := store.keyspace.LimitHistory(input.NewRecord.UserID)
|
|
activeKey := store.keyspace.ActiveLimit(input.NewRecord.UserID, input.NewRecord.LimitCode)
|
|
watchedKeys := append(
|
|
[]string{newRecordKey, historyKey, activeKey},
|
|
store.activeLimitWatchKeys(input.NewRecord.UserID)...,
|
|
)
|
|
|
|
operationCtx, cancel, err := store.operationContext(ctx, "set limit in redis")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
if input.ExpectedActiveRecord != nil {
|
|
watchedKeys = append(watchedKeys, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID))
|
|
}
|
|
|
|
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
|
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
|
|
var updatedPayload []byte
|
|
if input.ExpectedActiveRecord == nil {
|
|
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
} else {
|
|
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
|
|
if err != nil {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
if activeRecordID != input.ExpectedActiveRecord.RecordID {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
|
}
|
|
|
|
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
|
|
if err != nil {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
if !equalLimitRecords(storedRecord, *input.ExpectedActiveRecord) {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
|
}
|
|
|
|
updatedPayload, err = marshalLimitRecord(*input.UpdatedActiveRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
}
|
|
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.NewRecord.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
activeLimitCodes[input.NewRecord.LimitCode] = struct{}{}
|
|
|
|
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
|
if input.ExpectedActiveRecord != nil {
|
|
pipe.Set(operationCtx, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID), updatedPayload, 0)
|
|
}
|
|
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
|
|
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
|
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
|
|
Member: input.NewRecord.RecordID.String(),
|
|
})
|
|
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
|
|
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeLimitCodes)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
|
}
|
|
|
|
return nil
|
|
}, watchedKeys...)
|
|
|
|
switch {
|
|
case errors.Is(watchErr, redis.TxFailedErr):
|
|
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
|
case watchErr != nil:
|
|
return watchErr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// RemoveLimit atomically removes one active limit record.
|
|
func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("remove limit in redis: %w", err)
|
|
}
|
|
|
|
updatedPayload, err := marshalLimitRecord(input.UpdatedRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("remove limit in redis: %w", err)
|
|
}
|
|
|
|
recordKey := store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID)
|
|
activeKey := store.keyspace.ActiveLimit(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.LimitCode)
|
|
watchedKeys := append(
|
|
[]string{recordKey, activeKey},
|
|
store.activeLimitWatchKeys(input.ExpectedActiveRecord.UserID)...,
|
|
)
|
|
|
|
operationCtx, cancel, err := store.operationContext(ctx, "remove limit in redis")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cancel()
|
|
|
|
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
|
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
|
|
if err != nil {
|
|
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
if activeRecordID != input.ExpectedActiveRecord.RecordID {
|
|
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
|
}
|
|
|
|
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
|
|
if err != nil {
|
|
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
if !equalLimitRecords(storedRecord, input.ExpectedActiveRecord) {
|
|
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
|
}
|
|
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
delete(activeLimitCodes, input.ExpectedActiveRecord.LimitCode)
|
|
|
|
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
|
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
|
|
pipe.Del(operationCtx, activeKey)
|
|
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeLimitCodes)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
|
}
|
|
|
|
return nil
|
|
}, watchedKeys...)
|
|
|
|
switch {
|
|
case errors.Is(watchErr, redis.TxFailedErr):
|
|
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
|
case watchErr != nil:
|
|
return watchErr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (store *Store) loadActiveSanctionRecordID(
|
|
ctx context.Context,
|
|
getter bytesGetter,
|
|
key string,
|
|
) (policy.SanctionRecordID, error) {
|
|
value, err := getter.Get(ctx, key).Result()
|
|
switch {
|
|
case errors.Is(err, redis.Nil):
|
|
return "", ports.ErrNotFound
|
|
case err != nil:
|
|
return "", err
|
|
}
|
|
|
|
recordID := policy.SanctionRecordID(value)
|
|
if err := recordID.Validate(); err != nil {
|
|
return "", fmt.Errorf("active sanction record id: %w", err)
|
|
}
|
|
|
|
return recordID, nil
|
|
}
|
|
|
|
func (store *Store) loadActiveLimitRecordID(
|
|
ctx context.Context,
|
|
getter bytesGetter,
|
|
key string,
|
|
) (policy.LimitRecordID, error) {
|
|
value, err := getter.Get(ctx, key).Result()
|
|
switch {
|
|
case errors.Is(err, redis.Nil):
|
|
return "", ports.ErrNotFound
|
|
case err != nil:
|
|
return "", err
|
|
}
|
|
|
|
recordID := policy.LimitRecordID(value)
|
|
if err := recordID.Validate(); err != nil {
|
|
return "", fmt.Errorf("active limit record id: %w", err)
|
|
}
|
|
|
|
return recordID, nil
|
|
}
|
|
|
|
func setActiveSlot(
|
|
pipe redis.Pipeliner,
|
|
ctx context.Context,
|
|
key string,
|
|
recordID string,
|
|
expiresAt *time.Time,
|
|
) {
|
|
pipe.Set(ctx, key, recordID, 0)
|
|
if expiresAt != nil {
|
|
pipe.PExpireAt(ctx, key, expiresAt.UTC())
|
|
}
|
|
}
|
|
|
|
func equalSanctionRecords(left policy.SanctionRecord, right policy.SanctionRecord) bool {
|
|
return left.RecordID == right.RecordID &&
|
|
left.UserID == right.UserID &&
|
|
left.SanctionCode == right.SanctionCode &&
|
|
left.Scope == right.Scope &&
|
|
left.ReasonCode == right.ReasonCode &&
|
|
left.Actor == right.Actor &&
|
|
left.AppliedAt.Equal(right.AppliedAt) &&
|
|
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
|
|
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
|
|
left.RemovedBy == right.RemovedBy &&
|
|
left.RemovedReasonCode == right.RemovedReasonCode
|
|
}
|
|
|
|
func equalLimitRecords(left policy.LimitRecord, right policy.LimitRecord) bool {
|
|
return left.RecordID == right.RecordID &&
|
|
left.UserID == right.UserID &&
|
|
left.LimitCode == right.LimitCode &&
|
|
left.Value == right.Value &&
|
|
left.ReasonCode == right.ReasonCode &&
|
|
left.Actor == right.Actor &&
|
|
left.AppliedAt.Equal(right.AppliedAt) &&
|
|
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
|
|
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
|
|
left.RemovedBy == right.RemovedBy &&
|
|
left.RemovedReasonCode == right.RemovedReasonCode
|
|
}
|
|
|
|
// PolicyLifecycleStore adapts Store to the existing PolicyLifecycleStore
|
|
// port.
|
|
type PolicyLifecycleStore struct {
|
|
store *Store
|
|
}
|
|
|
|
// PolicyLifecycle returns one adapter that exposes the atomic policy-lifecycle
|
|
// store port over Store.
|
|
func (store *Store) PolicyLifecycle() *PolicyLifecycleStore {
|
|
if store == nil {
|
|
return nil
|
|
}
|
|
|
|
return &PolicyLifecycleStore{store: store}
|
|
}
|
|
|
|
// ApplySanction atomically creates one new active sanction record.
|
|
func (adapter *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
|
|
return adapter.store.ApplySanction(ctx, input)
|
|
}
|
|
|
|
// RemoveSanction atomically removes one active sanction record.
|
|
func (adapter *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
|
|
return adapter.store.RemoveSanction(ctx, input)
|
|
}
|
|
|
|
// SetLimit atomically creates or replaces one active limit record.
|
|
func (adapter *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
|
|
return adapter.store.SetLimit(ctx, input)
|
|
}
|
|
|
|
// RemoveLimit atomically removes one active limit record.
|
|
func (adapter *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
|
|
return adapter.store.RemoveLimit(ctx, input)
|
|
}
|
|
|
|
var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)
|