package userstore import ( "context" "errors" "fmt" "time" "galaxy/user/internal/domain/policy" "galaxy/user/internal/ports" "github.com/redis/go-redis/v9" ) // ApplySanction atomically creates one new active sanction record. func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("apply sanction in redis: %w", err) } recordPayload, err := marshalSanctionRecord(input.NewRecord) if err != nil { return fmt.Errorf("apply sanction in redis: %w", err) } recordKey := store.keyspace.SanctionRecord(input.NewRecord.RecordID) historyKey := store.keyspace.SanctionHistory(input.NewRecord.UserID) activeKey := store.keyspace.ActiveSanction(input.NewRecord.UserID, input.NewRecord.SanctionCode) snapshotKey := store.keyspace.EntitlementSnapshot(input.NewRecord.UserID) watchedKeys := append( []string{recordKey, historyKey, activeKey, snapshotKey}, store.activeSanctionWatchKeys(input.NewRecord.UserID)..., ) operationCtx, cancel, err := store.operationContext(ctx, "apply sanction in redis") if err != nil { return err } defer cancel() watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error { if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil { return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err) } if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil { return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err) } snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.NewRecord.UserID) if err != nil { return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err) } activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewRecord.UserID) if err != nil { return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err) } activeSanctionCodes[input.NewRecord.SanctionCode] = struct{}{} _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { pipe.Set(operationCtx, recordKey, recordPayload, 0) pipe.ZAdd(operationCtx, historyKey, redis.Z{ Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()), Member: input.NewRecord.RecordID.String(), }) setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt) store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeSanctionCodes) store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewRecord.UserID, snapshot.IsPaid, activeSanctionCodes) return nil }) if err != nil { return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err) } return nil }, watchedKeys...) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // RemoveSanction atomically removes one active sanction record. func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("remove sanction in redis: %w", err) } updatedPayload, err := marshalSanctionRecord(input.UpdatedRecord) if err != nil { return fmt.Errorf("remove sanction in redis: %w", err) } recordKey := store.keyspace.SanctionRecord(input.ExpectedActiveRecord.RecordID) activeKey := store.keyspace.ActiveSanction(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.SanctionCode) snapshotKey := store.keyspace.EntitlementSnapshot(input.ExpectedActiveRecord.UserID) watchedKeys := append( []string{recordKey, activeKey, snapshotKey}, store.activeSanctionWatchKeys(input.ExpectedActiveRecord.UserID)..., ) operationCtx, cancel, err := store.operationContext(ctx, "remove sanction in redis") if err != nil { return err } defer cancel() watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error { activeRecordID, err := store.loadActiveSanctionRecordID(operationCtx, tx, activeKey) if err != nil { return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } if activeRecordID != input.ExpectedActiveRecord.RecordID { return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict) } storedRecord, err := store.loadSanctionRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID) if err != nil { return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } if !equalSanctionRecords(storedRecord, input.ExpectedActiveRecord) { return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict) } snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedActiveRecord.UserID) if err != nil { return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID) if err != nil { return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } delete(activeSanctionCodes, input.ExpectedActiveRecord.SanctionCode) _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { pipe.Set(operationCtx, recordKey, updatedPayload, 0) pipe.Del(operationCtx, activeKey) store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeSanctionCodes) store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, snapshot.IsPaid, activeSanctionCodes) return nil }) if err != nil { return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } return nil }, watchedKeys...) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // SetLimit atomically creates or replaces one active limit record. func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("set limit in redis: %w", err) } newRecordPayload, err := marshalLimitRecord(input.NewRecord) if err != nil { return fmt.Errorf("set limit in redis: %w", err) } newRecordKey := store.keyspace.LimitRecord(input.NewRecord.RecordID) historyKey := store.keyspace.LimitHistory(input.NewRecord.UserID) activeKey := store.keyspace.ActiveLimit(input.NewRecord.UserID, input.NewRecord.LimitCode) watchedKeys := append( []string{newRecordKey, historyKey, activeKey}, store.activeLimitWatchKeys(input.NewRecord.UserID)..., ) operationCtx, cancel, err := store.operationContext(ctx, "set limit in redis") if err != nil { return err } defer cancel() if input.ExpectedActiveRecord != nil { watchedKeys = append(watchedKeys, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID)) } watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error { if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err) } var updatedPayload []byte if input.ExpectedActiveRecord == nil { if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err) } } else { activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey) if err != nil { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err) } if activeRecordID != input.ExpectedActiveRecord.RecordID { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict) } storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID) if err != nil { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err) } if !equalLimitRecords(storedRecord, *input.ExpectedActiveRecord) { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict) } updatedPayload, err = marshalLimitRecord(*input.UpdatedActiveRecord) if err != nil { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err) } } activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.NewRecord.UserID) if err != nil { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err) } activeLimitCodes[input.NewRecord.LimitCode] = struct{}{} _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { if input.ExpectedActiveRecord != nil { pipe.Set(operationCtx, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID), updatedPayload, 0) } pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0) pipe.ZAdd(operationCtx, historyKey, redis.Z{ Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()), Member: input.NewRecord.RecordID.String(), }) setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt) store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeLimitCodes) return nil }) if err != nil { return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err) } return nil }, watchedKeys...) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // RemoveLimit atomically removes one active limit record. func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("remove limit in redis: %w", err) } updatedPayload, err := marshalLimitRecord(input.UpdatedRecord) if err != nil { return fmt.Errorf("remove limit in redis: %w", err) } recordKey := store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID) activeKey := store.keyspace.ActiveLimit(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.LimitCode) watchedKeys := append( []string{recordKey, activeKey}, store.activeLimitWatchKeys(input.ExpectedActiveRecord.UserID)..., ) operationCtx, cancel, err := store.operationContext(ctx, "remove limit in redis") if err != nil { return err } defer cancel() watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error { activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey) if err != nil { return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } if activeRecordID != input.ExpectedActiveRecord.RecordID { return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict) } storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID) if err != nil { return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } if !equalLimitRecords(storedRecord, input.ExpectedActiveRecord) { return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict) } activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID) if err != nil { return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } delete(activeLimitCodes, input.ExpectedActiveRecord.LimitCode) _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { pipe.Set(operationCtx, recordKey, updatedPayload, 0) pipe.Del(operationCtx, activeKey) store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeLimitCodes) return nil }) if err != nil { return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err) } return nil }, watchedKeys...) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict) case watchErr != nil: return watchErr default: return nil } } func (store *Store) loadActiveSanctionRecordID( ctx context.Context, getter bytesGetter, key string, ) (policy.SanctionRecordID, error) { value, err := getter.Get(ctx, key).Result() switch { case errors.Is(err, redis.Nil): return "", ports.ErrNotFound case err != nil: return "", err } recordID := policy.SanctionRecordID(value) if err := recordID.Validate(); err != nil { return "", fmt.Errorf("active sanction record id: %w", err) } return recordID, nil } func (store *Store) loadActiveLimitRecordID( ctx context.Context, getter bytesGetter, key string, ) (policy.LimitRecordID, error) { value, err := getter.Get(ctx, key).Result() switch { case errors.Is(err, redis.Nil): return "", ports.ErrNotFound case err != nil: return "", err } recordID := policy.LimitRecordID(value) if err := recordID.Validate(); err != nil { return "", fmt.Errorf("active limit record id: %w", err) } return recordID, nil } func setActiveSlot( pipe redis.Pipeliner, ctx context.Context, key string, recordID string, expiresAt *time.Time, ) { pipe.Set(ctx, key, recordID, 0) if expiresAt != nil { pipe.PExpireAt(ctx, key, expiresAt.UTC()) } } func equalSanctionRecords(left policy.SanctionRecord, right policy.SanctionRecord) bool { return left.RecordID == right.RecordID && left.UserID == right.UserID && left.SanctionCode == right.SanctionCode && left.Scope == right.Scope && left.ReasonCode == right.ReasonCode && left.Actor == right.Actor && left.AppliedAt.Equal(right.AppliedAt) && equalOptionalTime(left.ExpiresAt, right.ExpiresAt) && equalOptionalTime(left.RemovedAt, right.RemovedAt) && left.RemovedBy == right.RemovedBy && left.RemovedReasonCode == right.RemovedReasonCode } func equalLimitRecords(left policy.LimitRecord, right policy.LimitRecord) bool { return left.RecordID == right.RecordID && left.UserID == right.UserID && left.LimitCode == right.LimitCode && left.Value == right.Value && left.ReasonCode == right.ReasonCode && left.Actor == right.Actor && left.AppliedAt.Equal(right.AppliedAt) && equalOptionalTime(left.ExpiresAt, right.ExpiresAt) && equalOptionalTime(left.RemovedAt, right.RemovedAt) && left.RemovedBy == right.RemovedBy && left.RemovedReasonCode == right.RemovedReasonCode } // PolicyLifecycleStore adapts Store to the existing PolicyLifecycleStore // port. type PolicyLifecycleStore struct { store *Store } // PolicyLifecycle returns one adapter that exposes the atomic policy-lifecycle // store port over Store. func (store *Store) PolicyLifecycle() *PolicyLifecycleStore { if store == nil { return nil } return &PolicyLifecycleStore{store: store} } // ApplySanction atomically creates one new active sanction record. func (adapter *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error { return adapter.store.ApplySanction(ctx, input) } // RemoveSanction atomically removes one active sanction record. func (adapter *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error { return adapter.store.RemoveSanction(ctx, input) } // SetLimit atomically creates or replaces one active limit record. func (adapter *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error { return adapter.store.SetLimit(ctx, input) } // RemoveLimit atomically removes one active limit record. func (adapter *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error { return adapter.store.RemoveLimit(ctx, input) } var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)