feat: user service

This commit is contained in:
Ilia Denisov
2026-04-10 19:05:02 +02:00
committed by GitHub
parent 710bad712e
commit 23ffcb7535
140 changed files with 33418 additions and 952 deletions
@@ -0,0 +1,215 @@
package userstore
import (
"context"
"errors"
"galaxy/user/internal/adapters/redisstate"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
var knownSanctionCodes = []policy.SanctionCode{
policy.SanctionCodeLoginBlock,
policy.SanctionCodePrivateGameCreateBlock,
policy.SanctionCodePrivateGameManageBlock,
policy.SanctionCodeGameJoinBlock,
policy.SanctionCodeProfileUpdateBlock,
}
var knownLimitCodes = []policy.LimitCode{
policy.LimitCodeMaxOwnedPrivateGames,
policy.LimitCodeMaxPendingPublicApplications,
policy.LimitCodeMaxActiveGameMemberships,
}
var knownEligibilityMarkers = []policy.EligibilityMarker{
policy.EligibilityMarkerCanLogin,
policy.EligibilityMarkerCanCreatePrivateGame,
policy.EligibilityMarkerCanManagePrivateGame,
policy.EligibilityMarkerCanJoinGame,
policy.EligibilityMarkerCanUpdateProfile,
}
func (store *Store) addCreatedAtIndex(
pipe redis.Pipeliner,
ctx context.Context,
record account.UserAccount,
) {
pipe.ZAdd(ctx, store.keyspace.CreatedAtIndex(), redis.Z{
Score: redisstate.CreatedAtScore(record.CreatedAt),
Member: record.UserID.String(),
})
}
func (store *Store) syncDeclaredCountryIndex(
pipe redis.Pipeliner,
ctx context.Context,
previous account.UserAccount,
current account.UserAccount,
) {
if !previous.DeclaredCountry.IsZero() {
pipe.SRem(ctx, store.keyspace.DeclaredCountryIndex(previous.DeclaredCountry), current.UserID.String())
}
if !current.DeclaredCountry.IsZero() {
pipe.SAdd(ctx, store.keyspace.DeclaredCountryIndex(current.DeclaredCountry), current.UserID.String())
}
}
func (store *Store) syncEntitlementIndexes(
pipe redis.Pipeliner,
ctx context.Context,
snapshot entitlement.CurrentSnapshot,
) {
pipe.SRem(ctx, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), snapshot.UserID.String())
pipe.SRem(ctx, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), snapshot.UserID.String())
pipe.SAdd(ctx, store.keyspace.PaidStateIndex(paidStateFromSnapshot(snapshot)), snapshot.UserID.String())
pipe.ZRem(ctx, store.keyspace.FinitePaidExpiryIndex(), snapshot.UserID.String())
if snapshot.HasFiniteExpiry() {
pipe.ZAdd(ctx, store.keyspace.FinitePaidExpiryIndex(), redis.Z{
Score: redisstate.ExpiryScore(*snapshot.EndsAt),
Member: snapshot.UserID.String(),
})
}
}
func (store *Store) syncActiveSanctionCodeIndexes(
pipe redis.Pipeliner,
ctx context.Context,
userID common.UserID,
activeCodes map[policy.SanctionCode]struct{},
) {
for _, code := range knownSanctionCodes {
pipe.SRem(ctx, store.keyspace.ActiveSanctionCodeIndex(code), userID.String())
if _, ok := activeCodes[code]; ok {
pipe.SAdd(ctx, store.keyspace.ActiveSanctionCodeIndex(code), userID.String())
}
}
}
func (store *Store) syncActiveLimitCodeIndexes(
pipe redis.Pipeliner,
ctx context.Context,
userID common.UserID,
activeCodes map[policy.LimitCode]struct{},
) {
for _, code := range knownLimitCodes {
pipe.SRem(ctx, store.keyspace.ActiveLimitCodeIndex(code), userID.String())
if _, ok := activeCodes[code]; ok {
pipe.SAdd(ctx, store.keyspace.ActiveLimitCodeIndex(code), userID.String())
}
}
}
func (store *Store) syncEligibilityMarkerIndexes(
pipe redis.Pipeliner,
ctx context.Context,
userID common.UserID,
isPaid bool,
activeSanctionCodes map[policy.SanctionCode]struct{},
) {
values := deriveEligibilityMarkerValues(isPaid, activeSanctionCodes)
for _, marker := range knownEligibilityMarkers {
pipe.SRem(ctx, store.keyspace.EligibilityMarkerIndex(marker, true), userID.String())
pipe.SRem(ctx, store.keyspace.EligibilityMarkerIndex(marker, false), userID.String())
pipe.SAdd(ctx, store.keyspace.EligibilityMarkerIndex(marker, values[marker]), userID.String())
}
}
func (store *Store) loadActiveSanctionCodeSet(
ctx context.Context,
getter bytesGetter,
userID common.UserID,
) (map[policy.SanctionCode]struct{}, error) {
activeCodes := make(map[policy.SanctionCode]struct{}, len(knownSanctionCodes))
for _, code := range knownSanctionCodes {
_, err := store.loadActiveSanctionRecordID(ctx, getter, store.keyspace.ActiveSanction(userID, code))
switch {
case err == nil:
activeCodes[code] = struct{}{}
case errors.Is(err, ports.ErrNotFound):
continue
default:
return nil, err
}
}
return activeCodes, nil
}
func (store *Store) loadActiveLimitCodeSet(
ctx context.Context,
getter bytesGetter,
userID common.UserID,
) (map[policy.LimitCode]struct{}, error) {
activeCodes := make(map[policy.LimitCode]struct{}, len(knownLimitCodes))
for _, code := range knownLimitCodes {
_, err := store.loadActiveLimitRecordID(ctx, getter, store.keyspace.ActiveLimit(userID, code))
switch {
case err == nil:
activeCodes[code] = struct{}{}
case errors.Is(err, ports.ErrNotFound):
continue
default:
return nil, err
}
}
return activeCodes, nil
}
func (store *Store) activeSanctionWatchKeys(userID common.UserID) []string {
keys := make([]string, 0, len(knownSanctionCodes))
for _, code := range knownSanctionCodes {
keys = append(keys, store.keyspace.ActiveSanction(userID, code))
}
return keys
}
func (store *Store) activeLimitWatchKeys(userID common.UserID) []string {
keys := make([]string, 0, len(knownLimitCodes))
for _, code := range knownLimitCodes {
keys = append(keys, store.keyspace.ActiveLimit(userID, code))
}
return keys
}
func deriveEligibilityMarkerValues(
isPaid bool,
activeSanctionCodes map[policy.SanctionCode]struct{},
) map[policy.EligibilityMarker]bool {
_, loginBlocked := activeSanctionCodes[policy.SanctionCodeLoginBlock]
_, createBlocked := activeSanctionCodes[policy.SanctionCodePrivateGameCreateBlock]
_, manageBlocked := activeSanctionCodes[policy.SanctionCodePrivateGameManageBlock]
_, joinBlocked := activeSanctionCodes[policy.SanctionCodeGameJoinBlock]
_, profileBlocked := activeSanctionCodes[policy.SanctionCodeProfileUpdateBlock]
canLogin := !loginBlocked
return map[policy.EligibilityMarker]bool{
policy.EligibilityMarkerCanLogin: canLogin,
policy.EligibilityMarkerCanCreatePrivateGame: canLogin && isPaid && !createBlocked,
policy.EligibilityMarkerCanManagePrivateGame: canLogin && isPaid && !manageBlocked,
policy.EligibilityMarkerCanJoinGame: canLogin && !joinBlocked,
policy.EligibilityMarkerCanUpdateProfile: canLogin && !profileBlocked,
}
}
func paidStateFromSnapshot(snapshot entitlement.CurrentSnapshot) entitlement.PaidState {
if snapshot.IsPaid {
return entitlement.PaidStatePaid
}
return entitlement.PaidStateFree
}
@@ -0,0 +1,449 @@
package userstore
import (
"context"
"testing"
"time"
"galaxy/user/internal/adapters/redisstate"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"galaxy/user/internal/service/adminusers"
"galaxy/user/internal/service/entitlementsvc"
"github.com/stretchr/testify/require"
)
func TestListUserIDsCreatedAtPagination(t *testing.T) {
t.Parallel()
store := newTestStore(t)
base := time.Unix(1_775_240_000, 0).UTC()
first := validAccountRecord()
first.UserID = common.UserID("user-100")
first.Email = common.Email("u100@example.com")
first.RaceName = common.RaceName("User 100")
first.CreatedAt = base.Add(-time.Hour)
first.UpdatedAt = first.CreatedAt
second := validAccountRecord()
second.UserID = common.UserID("user-200")
second.Email = common.Email("u200@example.com")
second.RaceName = common.RaceName("User 200")
second.CreatedAt = base
second.UpdatedAt = second.CreatedAt
third := validAccountRecord()
third.UserID = common.UserID("user-300")
third.Email = common.Email("u300@example.com")
third.RaceName = common.RaceName("User 300")
third.CreatedAt = base
third.UpdatedAt = third.CreatedAt
require.NoError(t, store.Create(context.Background(), createAccountInput(first)))
require.NoError(t, store.Create(context.Background(), createAccountInput(second)))
require.NoError(t, store.Create(context.Background(), createAccountInput(third)))
firstPage, err := store.ListUserIDs(context.Background(), ports.ListUsersInput{
PageSize: 2,
Filters: ports.UserListFilters{},
})
require.NoError(t, err)
require.Equal(t, []common.UserID{third.UserID, second.UserID}, firstPage.UserIDs)
require.NotEmpty(t, firstPage.NextPageToken)
secondPage, err := store.ListUserIDs(context.Background(), ports.ListUsersInput{
PageSize: 2,
PageToken: firstPage.NextPageToken,
Filters: ports.UserListFilters{},
})
require.NoError(t, err)
require.Equal(t, []common.UserID{first.UserID}, secondPage.UserIDs)
require.Empty(t, secondPage.NextPageToken)
}
func TestEnsureByEmailInitialAdminIndexes(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.DeclaredCountry = common.CountryCode("DE")
record.CreatedAt = now
record.UpdatedAt = now
result, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(record.UserID, now),
EntitlementRecord: validEntitlementRecord(record.UserID, now),
Reservation: raceNameReservation(record.UserID, record.RaceName, now),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeCreated, result.Outcome)
requireSortedSetScore(t, store, store.keyspace.CreatedAtIndex(), record.UserID.String(), redisstate.CreatedAtScore(record.CreatedAt))
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
requireSetContains(t, store, store.keyspace.DeclaredCountryIndex(record.DeclaredCountry), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, false), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanJoinGame, true), record.UserID.String())
}
func TestAccountUpdateSyncsDeclaredCountryIndex(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
record.DeclaredCountry = common.CountryCode("DE")
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
updated := record
updated.DeclaredCountry = common.CountryCode("FR")
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
require.NoError(t, accountStore.Update(context.Background(), updated))
requireSetNotContains(t, store, store.keyspace.DeclaredCountryIndex(common.CountryCode("DE")), record.UserID.String())
requireSetContains(t, store, store.keyspace.DeclaredCountryIndex(common.CountryCode("FR")), record.UserID.String())
}
func TestEntitlementLifecycleSyncsAdminIndexes(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.CreatedAt = now
record.UpdatedAt = now
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(record.UserID, now),
EntitlementRecord: validEntitlementRecord(record.UserID, now),
Reservation: raceNameReservation(record.UserID, record.RaceName, now),
})
require.NoError(t, err)
lifecycleStore := store.EntitlementLifecycle()
freeRecord := validEntitlementRecord(record.UserID, now)
freeSnapshot := validEntitlementSnapshot(record.UserID, now)
grantStartsAt := now.Add(time.Hour)
grantEndsAt := grantStartsAt.Add(30 * 24 * time.Hour)
grantedRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-1"),
record.UserID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
grantedSnapshot := paidEntitlementSnapshot(
record.UserID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
closedFreeRecord := freeRecord
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
require.NoError(t, lifecycleStore.Grant(context.Background(), ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: freeSnapshot,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFreeRecord,
NewRecord: grantedRecord,
NewSnapshot: grantedSnapshot,
}))
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
requireSortedSetScore(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String(), redisstate.ExpiryScore(grantEndsAt))
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, true), record.UserID.String())
extendedEndsAt := grantEndsAt.Add(30 * 24 * time.Hour)
extensionRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-2"),
record.UserID,
entitlement.PlanCodePaidMonthly,
grantEndsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
extendedSnapshot := paidEntitlementSnapshot(
record.UserID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
require.NoError(t, lifecycleStore.Extend(context.Background(), ports.ExtendEntitlementInput{
ExpectedCurrentSnapshot: grantedSnapshot,
NewRecord: extensionRecord,
NewSnapshot: extendedSnapshot,
}))
requireSortedSetScore(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String(), redisstate.ExpiryScore(extendedEndsAt))
revokeAt := grantEndsAt.Add(12 * time.Hour)
revokedCurrentRecord := extensionRecord
revokedCurrentRecord.ClosedAt = timePointer(revokeAt)
revokedCurrentRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
revokedCurrentRecord.ClosedReasonCode = common.ReasonCode("manual_revoke")
freeAfterRevokeRecord := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-free-2"),
UserID: record.UserID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
StartsAt: revokeAt,
CreatedAt: revokeAt,
}
freeAfterRevokeSnapshot := entitlement.CurrentSnapshot{
UserID: record.UserID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: revokeAt,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
UpdatedAt: revokeAt,
}
require.NoError(t, lifecycleStore.Revoke(context.Background(), ports.RevokeEntitlementInput{
ExpectedCurrentSnapshot: extendedSnapshot,
ExpectedCurrentRecord: extensionRecord,
UpdatedCurrentRecord: revokedCurrentRecord,
NewRecord: freeAfterRevokeRecord,
NewSnapshot: freeAfterRevokeSnapshot,
}))
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
requireSortedSetMissing(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, false), record.UserID.String())
}
func TestPolicyLifecycleSyncsAdminIndexes(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.CreatedAt = now
record.UpdatedAt = now
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(record.UserID, now),
EntitlementRecord: validEntitlementRecord(record.UserID, now),
Reservation: raceNameReservation(record.UserID, record.RaceName, now),
})
require.NoError(t, err)
lifecycleStore := store.PolicyLifecycle()
sanctionRecord := policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-1"),
UserID: record.UserID,
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("auth"),
ReasonCode: common.ReasonCode("manual_block"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now,
}
require.NoError(t, lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
NewRecord: sanctionRecord,
}))
requireSetContains(t, store, store.keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, false), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanJoinGame, false), record.UserID.String())
removedSanction := sanctionRecord
removedAt := now.Add(time.Minute)
removedSanction.RemovedAt = &removedAt
removedSanction.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
removedSanction.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveSanction(context.Background(), ports.RemoveSanctionInput{
ExpectedActiveRecord: sanctionRecord,
UpdatedRecord: removedSanction,
}))
requireSetNotContains(t, store, store.keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true), record.UserID.String())
limitRecord := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-1"),
UserID: record.UserID,
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 5,
ReasonCode: common.ReasonCode("manual_override"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now.Add(2 * time.Minute),
}
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
NewRecord: limitRecord,
}))
requireSetContains(t, store, store.keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames), record.UserID.String())
removedLimit := limitRecord
limitRemovedAt := now.Add(3 * time.Minute)
removedLimit.RemovedAt = &limitRemovedAt
removedLimit.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
removedLimit.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveLimit(context.Background(), ports.RemoveLimitInput{
ExpectedActiveRecord: limitRecord,
UpdatedRecord: removedLimit,
}))
requireSetNotContains(t, store, store.keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames), record.UserID.String())
}
func TestAdminListerReevaluatesExpiredPaidSnapshots(t *testing.T) {
t.Parallel()
store := newTestStore(t)
userID := common.UserID("user-123")
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.CreatedAt = now.Add(-2 * time.Hour)
record.UpdatedAt = record.CreatedAt
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(userID, record.CreatedAt),
EntitlementRecord: validEntitlementRecord(userID, record.CreatedAt),
Reservation: raceNameReservation(userID, record.RaceName, record.CreatedAt),
})
require.NoError(t, err)
grantStartsAt := now.Add(-90 * time.Minute)
grantEndsAt := now.Add(-30 * time.Minute)
freeRecord := validEntitlementRecord(userID, record.CreatedAt)
freeSnapshot := validEntitlementSnapshot(userID, record.CreatedAt)
grantedRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-expired"),
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
grantedSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
closedFreeRecord := freeRecord
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
require.NoError(t, store.EntitlementLifecycle().Grant(context.Background(), ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: freeSnapshot,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFreeRecord,
NewRecord: grantedRecord,
NewSnapshot: grantedSnapshot,
}))
reader, err := entitlementsvc.NewReader(
store.EntitlementSnapshots(),
store.EntitlementLifecycle(),
adminStoreClock{now: now},
adminStoreIDGenerator{entitlementRecordID: entitlement.EntitlementRecordID("entitlement-free-after-expiry")},
)
require.NoError(t, err)
lister, err := adminusers.NewLister(store.Accounts(), reader, store.Sanctions(), store.Limits(), adminStoreClock{now: now}, store)
require.NoError(t, err)
result, err := lister.Execute(context.Background(), adminusers.ListUsersInput{PaidState: "free"})
require.NoError(t, err)
require.Len(t, result.Items, 1)
require.Equal(t, "user-123", result.Items[0].UserID)
require.Equal(t, "free", result.Items[0].Entitlement.PlanCode)
require.False(t, result.Items[0].Entitlement.IsPaid)
storedSnapshot, err := store.EntitlementSnapshots().GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, entitlement.PlanCodeFree, storedSnapshot.PlanCode)
require.False(t, storedSnapshot.IsPaid)
}
type adminStoreClock struct {
now time.Time
}
func (clock adminStoreClock) Now() time.Time {
return clock.now
}
type adminStoreIDGenerator struct {
entitlementRecordID entitlement.EntitlementRecordID
}
func (generator adminStoreIDGenerator) NewUserID() (common.UserID, error) {
return "", nil
}
func (generator adminStoreIDGenerator) NewInitialRaceName() (common.RaceName, error) {
return "", nil
}
func (generator adminStoreIDGenerator) NewEntitlementRecordID() (entitlement.EntitlementRecordID, error) {
return generator.entitlementRecordID, nil
}
func (generator adminStoreIDGenerator) NewSanctionRecordID() (policy.SanctionRecordID, error) {
return "", nil
}
func (generator adminStoreIDGenerator) NewLimitRecordID() (policy.LimitRecordID, error) {
return "", nil
}
func requireSetContains(t *testing.T, store *Store, key string, member string) {
t.Helper()
exists, err := store.client.SIsMember(context.Background(), key, member).Result()
require.NoError(t, err)
require.True(t, exists, "expected %q to contain %q", key, member)
}
func requireSetNotContains(t *testing.T, store *Store, key string, member string) {
t.Helper()
exists, err := store.client.SIsMember(context.Background(), key, member).Result()
require.NoError(t, err)
require.False(t, exists, "expected %q not to contain %q", key, member)
}
func requireSortedSetScore(t *testing.T, store *Store, key string, member string, want float64) {
t.Helper()
got, err := store.client.ZScore(context.Background(), key, member).Result()
require.NoError(t, err)
require.Equal(t, want, got)
}
func requireSortedSetMissing(t *testing.T, store *Store, key string, member string) {
t.Helper()
_, err := store.client.ZScore(context.Background(), key, member).Result()
require.Error(t, err)
}
@@ -0,0 +1,752 @@
package userstore
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
type entitlementPeriodRecord struct {
RecordID string `json:"record_id"`
UserID string `json:"user_id"`
PlanCode string `json:"plan_code"`
Source string `json:"source"`
ActorType string `json:"actor_type"`
ActorID *string `json:"actor_id,omitempty"`
ReasonCode string `json:"reason_code"`
StartsAt string `json:"starts_at"`
EndsAt *string `json:"ends_at,omitempty"`
CreatedAt string `json:"created_at"`
ClosedAt *string `json:"closed_at,omitempty"`
ClosedByType *string `json:"closed_by_type,omitempty"`
ClosedByID *string `json:"closed_by_id,omitempty"`
ClosedReasonCode *string `json:"closed_reason_code,omitempty"`
}
// CreateEntitlementRecord stores one new entitlement history record.
func (store *Store) CreateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create entitlement record in redis: %w", err)
}
payload, err := marshalEntitlementPeriodRecord(record)
if err != nil {
return fmt.Errorf("create entitlement record in redis: %w", err)
}
recordKey := store.keyspace.EntitlementRecord(record.RecordID)
historyKey := store.keyspace.EntitlementHistory(record.UserID)
operationCtx, cancel, err := store.operationContext(ctx, "create entitlement record in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil {
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, err)
}
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, payload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(record.StartsAt.UTC().UnixMicro()),
Member: record.RecordID.String(),
})
return nil
})
if err != nil {
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, err)
}
return nil
}, recordKey, historyKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// GetEntitlementRecordByRecordID returns the entitlement history record
// identified by recordID.
func (store *Store) GetEntitlementRecordByRecordID(
ctx context.Context,
recordID entitlement.EntitlementRecordID,
) (entitlement.PeriodRecord, error) {
if err := recordID.Validate(); err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id from redis: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement record by record id from redis")
if err != nil {
return entitlement.PeriodRecord{}, err
}
defer cancel()
record, err := store.loadEntitlementRecord(operationCtx, store.client, recordID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id %q from redis: %w", recordID, ports.ErrNotFound)
default:
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id %q from redis: %w", recordID, err)
}
}
return record, nil
}
// ListEntitlementRecordsByUserID returns every entitlement history record
// owned by userID.
func (store *Store) ListEntitlementRecordsByUserID(
ctx context.Context,
userID common.UserID,
) ([]entitlement.PeriodRecord, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list entitlement records by user id from redis: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list entitlement records by user id from redis")
if err != nil {
return nil, err
}
defer cancel()
recordIDs, err := store.client.ZRange(operationCtx, store.keyspace.EntitlementHistory(userID), 0, -1).Result()
if err != nil {
return nil, fmt.Errorf("list entitlement records by user id %q from redis: %w", userID, err)
}
records := make([]entitlement.PeriodRecord, 0, len(recordIDs))
for _, rawRecordID := range recordIDs {
record, err := store.loadEntitlementRecord(operationCtx, store.client, entitlement.EntitlementRecordID(rawRecordID))
if err != nil {
return nil, fmt.Errorf("list entitlement records by user id %q from redis: %w", userID, err)
}
records = append(records, record)
}
return records, nil
}
// UpdateEntitlementRecord replaces one stored entitlement history record.
func (store *Store) UpdateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update entitlement record in redis: %w", err)
}
payload, err := marshalEntitlementPeriodRecord(record)
if err != nil {
return fmt.Errorf("update entitlement record in redis: %w", err)
}
recordKey := store.keyspace.EntitlementRecord(record.RecordID)
operationCtx, cancel, err := store.operationContext(ctx, "update entitlement record in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if _, err := store.loadEntitlementRecord(operationCtx, tx, record.RecordID); err != nil {
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, err)
}
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, payload, 0)
return nil
})
if err != nil {
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, err)
}
return nil
}, recordKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// GrantEntitlement atomically closes the current free history record, creates
// one paid history record, and replaces the current snapshot.
func (store *Store) GrantEntitlement(ctx context.Context, input ports.GrantEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
updatedCurrentRecordPayload, err := marshalEntitlementPeriodRecord(input.UpdatedCurrentRecord)
if err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
currentRecordKey := store.keyspace.EntitlementRecord(input.ExpectedCurrentRecord.RecordID)
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{currentRecordKey, newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "grant entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
storedCurrentRecord, err := store.loadEntitlementRecord(operationCtx, tx, input.ExpectedCurrentRecord.RecordID)
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementPeriodRecords(storedCurrentRecord, input.ExpectedCurrentRecord) {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, currentRecordKey, updatedCurrentRecordPayload, 0)
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// ExtendEntitlement atomically appends one paid history segment and replaces
// the current paid snapshot.
func (store *Store) ExtendEntitlement(ctx context.Context, input ports.ExtendEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("extend entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("extend entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("extend entitlement in redis: %w", err)
}
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "extend entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
if err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RevokeEntitlement atomically closes the current paid history record,
// creates one free history record, and replaces the current snapshot.
func (store *Store) RevokeEntitlement(ctx context.Context, input ports.RevokeEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
updatedCurrentRecordPayload, err := marshalEntitlementPeriodRecord(input.UpdatedCurrentRecord)
if err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
currentRecordKey := store.keyspace.EntitlementRecord(input.ExpectedCurrentRecord.RecordID)
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{currentRecordKey, newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "revoke entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
storedCurrentRecord, err := store.loadEntitlementRecord(operationCtx, tx, input.ExpectedCurrentRecord.RecordID)
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementPeriodRecords(storedCurrentRecord, input.ExpectedCurrentRecord) {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, currentRecordKey, updatedCurrentRecordPayload, 0)
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RepairExpiredEntitlement atomically replaces one expired finite paid
// snapshot with a materialized free state.
func (store *Store) RepairExpiredEntitlement(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("repair expired entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("repair expired entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("repair expired entitlement in redis: %w", err)
}
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "repair expired entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedExpiredSnapshot.UserID)
if err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedExpiredSnapshot) {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
func (store *Store) loadEntitlementRecord(
ctx context.Context,
getter bytesGetter,
recordID entitlement.EntitlementRecordID,
) (entitlement.PeriodRecord, error) {
payload, err := getter.Get(ctx, store.keyspace.EntitlementRecord(recordID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return entitlement.PeriodRecord{}, ports.ErrNotFound
case err != nil:
return entitlement.PeriodRecord{}, err
}
return decodeEntitlementPeriodRecord(payload)
}
func marshalEntitlementPeriodRecord(record entitlement.PeriodRecord) ([]byte, error) {
encoded := entitlementPeriodRecord{
RecordID: record.RecordID.String(),
UserID: record.UserID.String(),
PlanCode: string(record.PlanCode),
Source: record.Source.String(),
ActorType: record.Actor.Type.String(),
ReasonCode: record.ReasonCode.String(),
StartsAt: record.StartsAt.UTC().Format(time.RFC3339Nano),
CreatedAt: record.CreatedAt.UTC().Format(time.RFC3339Nano),
}
if !record.Actor.ID.IsZero() {
value := record.Actor.ID.String()
encoded.ActorID = &value
}
if record.EndsAt != nil {
value := record.EndsAt.UTC().Format(time.RFC3339Nano)
encoded.EndsAt = &value
}
if record.ClosedAt != nil {
value := record.ClosedAt.UTC().Format(time.RFC3339Nano)
encoded.ClosedAt = &value
}
if !record.ClosedBy.Type.IsZero() {
value := record.ClosedBy.Type.String()
encoded.ClosedByType = &value
}
if !record.ClosedBy.ID.IsZero() {
value := record.ClosedBy.ID.String()
encoded.ClosedByID = &value
}
if !record.ClosedReasonCode.IsZero() {
value := record.ClosedReasonCode.String()
encoded.ClosedReasonCode = &value
}
return json.Marshal(encoded)
}
func decodeEntitlementPeriodRecord(payload []byte) (entitlement.PeriodRecord, error) {
var encoded entitlementPeriodRecord
if err := decodeJSONPayload(payload, &encoded); err != nil {
return entitlement.PeriodRecord{}, err
}
startsAt, err := time.Parse(time.RFC3339Nano, encoded.StartsAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record starts_at: %w", err)
}
createdAt, err := time.Parse(time.RFC3339Nano, encoded.CreatedAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record created_at: %w", err)
}
record := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID(encoded.RecordID),
UserID: common.UserID(encoded.UserID),
PlanCode: entitlement.PlanCode(encoded.PlanCode),
Source: common.Source(encoded.Source),
Actor: common.ActorRef{Type: common.ActorType(encoded.ActorType)},
ReasonCode: common.ReasonCode(encoded.ReasonCode),
StartsAt: startsAt.UTC(),
CreatedAt: createdAt.UTC(),
}
if encoded.ActorID != nil {
record.Actor.ID = common.ActorID(*encoded.ActorID)
}
if encoded.EndsAt != nil {
value, err := time.Parse(time.RFC3339Nano, *encoded.EndsAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record ends_at: %w", err)
}
value = value.UTC()
record.EndsAt = &value
}
if encoded.ClosedAt != nil {
value, err := time.Parse(time.RFC3339Nano, *encoded.ClosedAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record closed_at: %w", err)
}
value = value.UTC()
record.ClosedAt = &value
}
if encoded.ClosedByType != nil {
record.ClosedBy.Type = common.ActorType(*encoded.ClosedByType)
}
if encoded.ClosedByID != nil {
record.ClosedBy.ID = common.ActorID(*encoded.ClosedByID)
}
if encoded.ClosedReasonCode != nil {
record.ClosedReasonCode = common.ReasonCode(*encoded.ClosedReasonCode)
}
if err := record.Validate(); err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record: %w", err)
}
return record, nil
}
func equalEntitlementSnapshots(left entitlement.CurrentSnapshot, right entitlement.CurrentSnapshot) bool {
return left.UserID == right.UserID &&
left.PlanCode == right.PlanCode &&
left.IsPaid == right.IsPaid &&
left.StartsAt.Equal(right.StartsAt) &&
equalOptionalTime(left.EndsAt, right.EndsAt) &&
left.Source == right.Source &&
left.Actor == right.Actor &&
left.ReasonCode == right.ReasonCode &&
left.UpdatedAt.Equal(right.UpdatedAt)
}
func equalEntitlementPeriodRecords(left entitlement.PeriodRecord, right entitlement.PeriodRecord) bool {
return left.RecordID == right.RecordID &&
left.UserID == right.UserID &&
left.PlanCode == right.PlanCode &&
left.Source == right.Source &&
left.Actor == right.Actor &&
left.ReasonCode == right.ReasonCode &&
left.StartsAt.Equal(right.StartsAt) &&
equalOptionalTime(left.EndsAt, right.EndsAt) &&
left.CreatedAt.Equal(right.CreatedAt) &&
equalOptionalTime(left.ClosedAt, right.ClosedAt) &&
left.ClosedBy == right.ClosedBy &&
left.ClosedReasonCode == right.ClosedReasonCode
}
func equalOptionalTime(left *time.Time, right *time.Time) bool {
switch {
case left == nil && right == nil:
return true
case left == nil || right == nil:
return false
default:
return left.Equal(*right)
}
}
// EntitlementHistoryStore adapts Store to the existing
// EntitlementHistoryStore port.
type EntitlementHistoryStore struct {
store *Store
}
// EntitlementHistory returns one adapter that exposes the entitlement-history
// store port over Store.
func (store *Store) EntitlementHistory() *EntitlementHistoryStore {
if store == nil {
return nil
}
return &EntitlementHistoryStore{store: store}
}
// Create stores one new entitlement history record.
func (adapter *EntitlementHistoryStore) Create(ctx context.Context, record entitlement.PeriodRecord) error {
return adapter.store.CreateEntitlementRecord(ctx, record)
}
// GetByRecordID returns the entitlement history record identified by recordID.
func (adapter *EntitlementHistoryStore) GetByRecordID(
ctx context.Context,
recordID entitlement.EntitlementRecordID,
) (entitlement.PeriodRecord, error) {
return adapter.store.GetEntitlementRecordByRecordID(ctx, recordID)
}
// ListByUserID returns every entitlement history record owned by userID.
func (adapter *EntitlementHistoryStore) ListByUserID(
ctx context.Context,
userID common.UserID,
) ([]entitlement.PeriodRecord, error) {
return adapter.store.ListEntitlementRecordsByUserID(ctx, userID)
}
// Update replaces one stored entitlement history record.
func (adapter *EntitlementHistoryStore) Update(ctx context.Context, record entitlement.PeriodRecord) error {
return adapter.store.UpdateEntitlementRecord(ctx, record)
}
var _ ports.EntitlementHistoryStore = (*EntitlementHistoryStore)(nil)
// EntitlementLifecycleStore adapts Store to the existing
// EntitlementLifecycleStore port.
type EntitlementLifecycleStore struct {
store *Store
}
// EntitlementLifecycle returns one adapter that exposes the atomic
// entitlement-lifecycle store port over Store.
func (store *Store) EntitlementLifecycle() *EntitlementLifecycleStore {
if store == nil {
return nil
}
return &EntitlementLifecycleStore{store: store}
}
// Grant atomically applies one free-to-paid transition.
func (adapter *EntitlementLifecycleStore) Grant(ctx context.Context, input ports.GrantEntitlementInput) error {
return adapter.store.GrantEntitlement(ctx, input)
}
// Extend atomically appends one paid extension segment and updates the current
// snapshot.
func (adapter *EntitlementLifecycleStore) Extend(ctx context.Context, input ports.ExtendEntitlementInput) error {
return adapter.store.ExtendEntitlement(ctx, input)
}
// Revoke atomically applies one paid-to-free transition.
func (adapter *EntitlementLifecycleStore) Revoke(ctx context.Context, input ports.RevokeEntitlementInput) error {
return adapter.store.RevokeEntitlement(ctx, input)
}
// RepairExpired atomically repairs one expired finite paid snapshot.
func (adapter *EntitlementLifecycleStore) RepairExpired(
ctx context.Context,
input ports.RepairExpiredEntitlementInput,
) error {
return adapter.store.RepairExpiredEntitlement(ctx, input)
}
var _ ports.EntitlementLifecycleStore = (*EntitlementLifecycleStore)(nil)
@@ -0,0 +1,137 @@
package userstore
import (
"context"
"errors"
"fmt"
"time"
"galaxy/user/internal/adapters/redisstate"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
// ListUserIDs returns one deterministic page of user identifiers ordered by
// `created_at desc`, then `user_id desc`.
func (store *Store) ListUserIDs(ctx context.Context, input ports.ListUsersInput) (ports.ListUsersResult, error) {
if err := input.Validate(); err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list users in redis")
if err != nil {
return ports.ListUsersResult{}, err
}
defer cancel()
startIndex := int64(0)
filters := userListFiltersFromPorts(input.Filters)
if input.PageToken != "" {
cursor, err := redisstate.DecodePageToken(input.PageToken, filters)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
}
score, err := store.client.ZScore(operationCtx, store.keyspace.CreatedAtIndex(), cursor.UserID.String()).Result()
switch {
case errors.Is(err, redis.Nil):
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
case err != nil:
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
if !time.UnixMicro(int64(score)).UTC().Equal(cursor.CreatedAt.UTC()) {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
}
rank, err := store.client.ZRevRank(operationCtx, store.keyspace.CreatedAtIndex(), cursor.UserID.String()).Result()
switch {
case errors.Is(err, redis.Nil):
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
case err != nil:
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
startIndex = rank + 1
}
rawPage, err := store.client.ZRevRangeWithScores(
operationCtx,
store.keyspace.CreatedAtIndex(),
startIndex,
startIndex+int64(input.PageSize),
).Result()
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
result := ports.ListUsersResult{
UserIDs: make([]common.UserID, 0, min(len(rawPage), input.PageSize)),
}
visibleCount := min(len(rawPage), input.PageSize)
for index := 0; index < visibleCount; index++ {
userID, err := memberUserID(rawPage[index].Member)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
result.UserIDs = append(result.UserIDs, userID)
}
if len(rawPage) > input.PageSize {
lastVisible := rawPage[input.PageSize-1]
lastUserID, err := memberUserID(lastVisible.Member)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
token, err := redisstate.EncodePageToken(redisstate.PageCursor{
CreatedAt: time.UnixMicro(int64(lastVisible.Score)).UTC(),
UserID: lastUserID,
}, filters)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
result.NextPageToken = token
}
return result, nil
}
func userListFiltersFromPorts(filters ports.UserListFilters) redisstate.UserListFilters {
return redisstate.UserListFilters{
PaidState: filters.PaidState,
PaidExpiresBefore: filters.PaidExpiresBefore,
PaidExpiresAfter: filters.PaidExpiresAfter,
DeclaredCountry: filters.DeclaredCountry,
SanctionCode: filters.SanctionCode,
LimitCode: filters.LimitCode,
CanLogin: filters.CanLogin,
CanCreatePrivateGame: filters.CanCreatePrivateGame,
CanJoinGame: filters.CanJoinGame,
}
}
func memberUserID(member any) (common.UserID, error) {
value, ok := member.(string)
if !ok {
return "", fmt.Errorf("unexpected created-at index member type %T", member)
}
userID := common.UserID(value)
if err := userID.Validate(); err != nil {
return "", fmt.Errorf("created-at index member user id: %w", err)
}
return userID, nil
}
func min(left int, right int) int {
if left < right {
return left
}
return right
}
var _ ports.UserListStore = (*Store)(nil)
@@ -0,0 +1,445 @@
package userstore
import (
"context"
"errors"
"fmt"
"time"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
// ApplySanction atomically creates one new active sanction record.
func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("apply sanction in redis: %w", err)
}
recordPayload, err := marshalSanctionRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("apply sanction in redis: %w", err)
}
recordKey := store.keyspace.SanctionRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.SanctionHistory(input.NewRecord.UserID)
activeKey := store.keyspace.ActiveSanction(input.NewRecord.UserID, input.NewRecord.SanctionCode)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewRecord.UserID)
watchedKeys := append(
[]string{recordKey, historyKey, activeKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "apply sanction in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.NewRecord.UserID)
if err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewRecord.UserID)
if err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
activeSanctionCodes[input.NewRecord.SanctionCode] = struct{}{}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, recordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeSanctionCodes)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RemoveSanction atomically removes one active sanction record.
func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove sanction in redis: %w", err)
}
updatedPayload, err := marshalSanctionRecord(input.UpdatedRecord)
if err != nil {
return fmt.Errorf("remove sanction in redis: %w", err)
}
recordKey := store.keyspace.SanctionRecord(input.ExpectedActiveRecord.RecordID)
activeKey := store.keyspace.ActiveSanction(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.SanctionCode)
snapshotKey := store.keyspace.EntitlementSnapshot(input.ExpectedActiveRecord.UserID)
watchedKeys := append(
[]string{recordKey, activeKey, snapshotKey},
store.activeSanctionWatchKeys(input.ExpectedActiveRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "remove sanction in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
activeRecordID, err := store.loadActiveSanctionRecordID(operationCtx, tx, activeKey)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if activeRecordID != input.ExpectedActiveRecord.RecordID {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
storedRecord, err := store.loadSanctionRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if !equalSanctionRecords(storedRecord, input.ExpectedActiveRecord) {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedActiveRecord.UserID)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
delete(activeSanctionCodes, input.ExpectedActiveRecord.SanctionCode)
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
pipe.Del(operationCtx, activeKey)
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeSanctionCodes)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// SetLimit atomically creates or replaces one active limit record.
func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("set limit in redis: %w", err)
}
newRecordPayload, err := marshalLimitRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("set limit in redis: %w", err)
}
newRecordKey := store.keyspace.LimitRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.LimitHistory(input.NewRecord.UserID)
activeKey := store.keyspace.ActiveLimit(input.NewRecord.UserID, input.NewRecord.LimitCode)
watchedKeys := append(
[]string{newRecordKey, historyKey, activeKey},
store.activeLimitWatchKeys(input.NewRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "set limit in redis")
if err != nil {
return err
}
defer cancel()
if input.ExpectedActiveRecord != nil {
watchedKeys = append(watchedKeys, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID))
}
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
var updatedPayload []byte
if input.ExpectedActiveRecord == nil {
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
} else {
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
if activeRecordID != input.ExpectedActiveRecord.RecordID {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
}
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
if !equalLimitRecords(storedRecord, *input.ExpectedActiveRecord) {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
}
updatedPayload, err = marshalLimitRecord(*input.UpdatedActiveRecord)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
}
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.NewRecord.UserID)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
activeLimitCodes[input.NewRecord.LimitCode] = struct{}{}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
if input.ExpectedActiveRecord != nil {
pipe.Set(operationCtx, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID), updatedPayload, 0)
}
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeLimitCodes)
return nil
})
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RemoveLimit atomically removes one active limit record.
func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove limit in redis: %w", err)
}
updatedPayload, err := marshalLimitRecord(input.UpdatedRecord)
if err != nil {
return fmt.Errorf("remove limit in redis: %w", err)
}
recordKey := store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID)
activeKey := store.keyspace.ActiveLimit(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.LimitCode)
watchedKeys := append(
[]string{recordKey, activeKey},
store.activeLimitWatchKeys(input.ExpectedActiveRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "remove limit in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if activeRecordID != input.ExpectedActiveRecord.RecordID {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if !equalLimitRecords(storedRecord, input.ExpectedActiveRecord) {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
delete(activeLimitCodes, input.ExpectedActiveRecord.LimitCode)
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
pipe.Del(operationCtx, activeKey)
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeLimitCodes)
return nil
})
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
func (store *Store) loadActiveSanctionRecordID(
ctx context.Context,
getter bytesGetter,
key string,
) (policy.SanctionRecordID, error) {
value, err := getter.Get(ctx, key).Result()
switch {
case errors.Is(err, redis.Nil):
return "", ports.ErrNotFound
case err != nil:
return "", err
}
recordID := policy.SanctionRecordID(value)
if err := recordID.Validate(); err != nil {
return "", fmt.Errorf("active sanction record id: %w", err)
}
return recordID, nil
}
func (store *Store) loadActiveLimitRecordID(
ctx context.Context,
getter bytesGetter,
key string,
) (policy.LimitRecordID, error) {
value, err := getter.Get(ctx, key).Result()
switch {
case errors.Is(err, redis.Nil):
return "", ports.ErrNotFound
case err != nil:
return "", err
}
recordID := policy.LimitRecordID(value)
if err := recordID.Validate(); err != nil {
return "", fmt.Errorf("active limit record id: %w", err)
}
return recordID, nil
}
func setActiveSlot(
pipe redis.Pipeliner,
ctx context.Context,
key string,
recordID string,
expiresAt *time.Time,
) {
pipe.Set(ctx, key, recordID, 0)
if expiresAt != nil {
pipe.PExpireAt(ctx, key, expiresAt.UTC())
}
}
func equalSanctionRecords(left policy.SanctionRecord, right policy.SanctionRecord) bool {
return left.RecordID == right.RecordID &&
left.UserID == right.UserID &&
left.SanctionCode == right.SanctionCode &&
left.Scope == right.Scope &&
left.ReasonCode == right.ReasonCode &&
left.Actor == right.Actor &&
left.AppliedAt.Equal(right.AppliedAt) &&
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
left.RemovedBy == right.RemovedBy &&
left.RemovedReasonCode == right.RemovedReasonCode
}
func equalLimitRecords(left policy.LimitRecord, right policy.LimitRecord) bool {
return left.RecordID == right.RecordID &&
left.UserID == right.UserID &&
left.LimitCode == right.LimitCode &&
left.Value == right.Value &&
left.ReasonCode == right.ReasonCode &&
left.Actor == right.Actor &&
left.AppliedAt.Equal(right.AppliedAt) &&
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
left.RemovedBy == right.RemovedBy &&
left.RemovedReasonCode == right.RemovedReasonCode
}
// PolicyLifecycleStore adapts Store to the existing PolicyLifecycleStore
// port.
type PolicyLifecycleStore struct {
store *Store
}
// PolicyLifecycle returns one adapter that exposes the atomic policy-lifecycle
// store port over Store.
func (store *Store) PolicyLifecycle() *PolicyLifecycleStore {
if store == nil {
return nil
}
return &PolicyLifecycleStore{store: store}
}
// ApplySanction atomically creates one new active sanction record.
func (adapter *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
return adapter.store.ApplySanction(ctx, input)
}
// RemoveSanction atomically removes one active sanction record.
func (adapter *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
return adapter.store.RemoveSanction(ctx, input)
}
// SetLimit atomically creates or replaces one active limit record.
func (adapter *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
return adapter.store.SetLimit(ctx, input)
}
// RemoveLimit atomically removes one active limit record.
func (adapter *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
return adapter.store.RemoveLimit(ctx, input)
}
var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,930 @@
package userstore
import (
"context"
"strings"
"testing"
"time"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/authblock"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/require"
)
func TestAccountStoreCreateAndLookups(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.Equal(t, record, byUserID)
byEmail, err := accountStore.GetByEmail(context.Background(), record.Email)
require.NoError(t, err)
require.Equal(t, record, byEmail)
byRaceName, err := accountStore.GetByRaceName(context.Background(), record.RaceName)
require.NoError(t, err)
require.Equal(t, record, byRaceName)
exists, err := accountStore.ExistsByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.True(t, exists)
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(record.RaceName))
require.NoError(t, err)
require.Equal(t, record.UserID, reservation.UserID)
require.Equal(t, record.RaceName, reservation.RaceName)
}
func TestBlockedEmailStoreUpsertAndGet(t *testing.T) {
t.Parallel()
store := newTestStore(t)
blockedEmailStore := store.BlockedEmails()
record := authblock.BlockedEmailSubject{
Email: common.Email("blocked@example.com"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: time.Unix(1_775_240_100, 0).UTC(),
ResolvedUserID: common.UserID("user-123"),
}
require.NoError(t, blockedEmailStore.Upsert(context.Background(), record))
got, err := blockedEmailStore.GetByEmail(context.Background(), record.Email)
require.NoError(t, err)
require.Equal(t, record, got)
}
func TestEnsureResolveAndBlockFlows(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
accountRecord := validAccountRecord()
entitlementSnapshot := validEntitlementSnapshot(accountRecord.UserID, now)
created, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: accountRecord.Email,
Account: accountRecord,
Entitlement: entitlementSnapshot,
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
Reservation: raceNameReservation(accountRecord.UserID, accountRecord.RaceName, accountRecord.UpdatedAt),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeCreated, created.Outcome)
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(accountRecord.RaceName))
require.NoError(t, err)
require.Equal(t, accountRecord.UserID, reservation.UserID)
entitlementHistory, err := store.ListEntitlementRecordsByUserID(context.Background(), accountRecord.UserID)
require.NoError(t, err)
require.Len(t, entitlementHistory, 1)
require.Equal(t, validEntitlementRecord(accountRecord.UserID, now), entitlementHistory[0])
resolved, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindExisting, resolved.Kind)
blockedByUserID, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: accountRecord.UserID,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now.Add(time.Minute),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, blockedByUserID.Outcome)
repeatedBlock, err := store.BlockByEmail(context.Background(), ports.BlockByEmailInput{
Email: accountRecord.Email,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now.Add(2 * time.Minute),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeAlreadyBlocked, repeatedBlock.Outcome)
require.Equal(t, accountRecord.UserID, repeatedBlock.UserID)
blockedResolution, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindBlocked, blockedResolution.Kind)
ensureBlocked, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: accountRecord.Email,
Account: accountRecord,
Entitlement: entitlementSnapshot,
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
Reservation: raceNameReservation(accountRecord.UserID, accountRecord.RaceName, accountRecord.UpdatedAt),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, ensureBlocked.Outcome)
}
func TestBlockedEmailWithoutUserPreventsEnsureCreate(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
accountRecord := validAccountRecord()
entitlementSnapshot := validEntitlementSnapshot(accountRecord.UserID, now)
blocked, err := store.BlockByEmail(context.Background(), ports.BlockByEmailInput{
Email: accountRecord.Email,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now,
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, blocked.Outcome)
require.True(t, blocked.UserID.IsZero())
resolved, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindBlocked, resolved.Kind)
ensured, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: accountRecord.Email,
Account: accountRecord,
Entitlement: entitlementSnapshot,
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
Reservation: raceNameReservation(accountRecord.UserID, accountRecord.RaceName, accountRecord.UpdatedAt),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, ensured.Outcome)
exists, err := store.ExistsByUserID(context.Background(), accountRecord.UserID)
require.NoError(t, err)
require.False(t, exists)
}
func TestEnsureByEmailExistingDoesNotOverwriteStoredSettings(t *testing.T) {
t.Parallel()
store := newTestStore(t)
createdAt := time.Unix(1_775_240_000, 0).UTC()
existingAccount := account.UserAccount{
UserID: common.UserID("user-existing"),
Email: common.Email("pilot@example.com"),
RaceName: common.RaceName("Pilot Nova"),
PreferredLanguage: common.LanguageTag("en"),
TimeZone: common.TimeZoneName("Europe/Kaliningrad"),
CreatedAt: createdAt,
UpdatedAt: createdAt,
}
require.NoError(t, store.Create(context.Background(), createAccountInput(existingAccount)))
result, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: existingAccount.Email,
Account: account.UserAccount{
UserID: common.UserID("user-created"),
Email: existingAccount.Email,
RaceName: common.RaceName("player-new123"),
PreferredLanguage: common.LanguageTag("fr-FR"),
TimeZone: common.TimeZoneName("UTC"),
CreatedAt: createdAt.Add(time.Minute),
UpdatedAt: createdAt.Add(time.Minute),
},
Entitlement: validEntitlementSnapshot(common.UserID("user-created"), createdAt.Add(time.Minute)),
EntitlementRecord: validEntitlementRecord(common.UserID("user-created"), createdAt.Add(time.Minute)),
Reservation: raceNameReservation(common.UserID("user-created"), common.RaceName("player-new123"), createdAt.Add(time.Minute)),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeExisting, result.Outcome)
require.Equal(t, existingAccount.UserID, result.UserID)
storedAccount, err := store.GetByEmail(context.Background(), existingAccount.Email)
require.NoError(t, err)
require.Equal(t, existingAccount, storedAccount)
}
func TestAccountStoreRenameRaceNameSwapsLookupAtomically(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
updatedAt := record.UpdatedAt.Add(time.Minute)
require.NoError(t, accountStore.RenameRaceName(context.Background(), renameRaceNameInput(record, common.RaceName("Nova Prime"), updatedAt)))
stored, err := accountStore.GetByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.Equal(t, common.RaceName("Nova Prime"), stored.RaceName)
require.True(t, stored.UpdatedAt.Equal(updatedAt))
_, err = accountStore.GetByRaceName(context.Background(), record.RaceName)
require.ErrorIs(t, err, ports.ErrNotFound)
renamed, err := accountStore.GetByRaceName(context.Background(), common.RaceName("Nova Prime"))
require.NoError(t, err)
require.Equal(t, record.UserID, renamed.UserID)
_, err = store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(record.RaceName))
require.ErrorIs(t, err, ports.ErrNotFound)
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(common.RaceName("Nova Prime")))
require.NoError(t, err)
require.Equal(t, common.RaceName("Nova Prime"), reservation.RaceName)
}
func TestAccountStoreRenameRaceNameAllowsSameOwnerCanonicalSlot(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
record.RaceName = common.RaceName("Pilot Nova")
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
updatedAt := record.UpdatedAt.Add(time.Minute)
require.NoError(t, accountStore.RenameRaceName(context.Background(), renameRaceNameInput(record, common.RaceName("P1lot Nova"), updatedAt)))
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(common.RaceName("P1lot Nova")))
require.NoError(t, err)
require.Equal(t, common.RaceName("P1lot Nova"), reservation.RaceName)
}
func TestAccountStoreRenameRaceNameReturnsConflictWhenTargetExists(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
first := validAccountRecord()
second := validAccountRecord()
second.UserID = common.UserID("user-456")
second.Email = common.Email("other@example.com")
second.RaceName = common.RaceName("Taken Name")
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(first)))
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(second)))
err := accountStore.RenameRaceName(context.Background(), renameRaceNameInput(first, second.RaceName, first.UpdatedAt.Add(time.Minute)))
require.ErrorIs(t, err, ports.ErrConflict)
stored, err := accountStore.GetByUserID(context.Background(), first.UserID)
require.NoError(t, err)
require.Equal(t, first.RaceName, stored.RaceName)
}
func TestAccountStoreUpdateDeclaredCountryPreservesLookups(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
updated := record
updated.DeclaredCountry = common.CountryCode("FR")
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
require.NoError(t, accountStore.Update(context.Background(), updated))
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.Equal(t, updated, byUserID)
byEmail, err := accountStore.GetByEmail(context.Background(), record.Email)
require.NoError(t, err)
require.Equal(t, updated, byEmail)
byRaceName, err := accountStore.GetByRaceName(context.Background(), record.RaceName)
require.NoError(t, err)
require.Equal(t, updated, byRaceName)
}
func TestAccountStoreCreateReturnsConflictWhenCanonicalReservationExists(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
first := validAccountRecord()
second := validAccountRecord()
second.UserID = common.UserID("user-456")
second.Email = common.Email("other@example.com")
second.RaceName = common.RaceName("P1lot Nova")
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(first)))
err := accountStore.Create(context.Background(), createAccountInput(second))
require.ErrorIs(t, err, ports.ErrConflict)
}
func TestBlockByUserIDRepeatedCallsStayIdempotent(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
accountRecord := validAccountRecord()
require.NoError(t, store.Create(context.Background(), createAccountInput(accountRecord)))
first, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: accountRecord.UserID,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now,
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, first.Outcome)
second, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: accountRecord.UserID,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now.Add(time.Minute),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeAlreadyBlocked, second.Outcome)
require.Equal(t, accountRecord.UserID, second.UserID)
}
func TestBlockByUserIDUnknownUserReturnsNotFound(t *testing.T) {
t.Parallel()
store := newTestStore(t)
_, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: common.UserID("user-missing"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: time.Unix(1_775_240_000, 0).UTC(),
})
require.ErrorIs(t, err, ports.ErrNotFound)
}
func TestSanctionAndLimitStoresRoundTrip(t *testing.T) {
t.Parallel()
store := newTestStore(t)
sanctionStore := store.Sanctions()
limitStore := store.Limits()
now := time.Unix(1_775_240_000, 0).UTC()
sanctionRecord := policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-1"),
UserID: common.UserID("user-123"),
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("self_service"),
ReasonCode: common.ReasonCode("policy_enforced"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
AppliedAt: now,
}
require.NoError(t, sanctionStore.Create(context.Background(), sanctionRecord))
gotSanction, err := sanctionStore.GetByRecordID(context.Background(), sanctionRecord.RecordID)
require.NoError(t, err)
require.Equal(t, sanctionRecord, gotSanction)
sanctions, err := sanctionStore.ListByUserID(context.Background(), sanctionRecord.UserID)
require.NoError(t, err)
require.Len(t, sanctions, 1)
expiresAt := now.Add(time.Hour)
sanctionRecord.ExpiresAt = &expiresAt
require.NoError(t, sanctionStore.Update(context.Background(), sanctionRecord))
gotSanction, err = sanctionStore.GetByRecordID(context.Background(), sanctionRecord.RecordID)
require.NoError(t, err)
require.Equal(t, sanctionRecord.RecordID, gotSanction.RecordID)
require.Equal(t, sanctionRecord.UserID, gotSanction.UserID)
require.Equal(t, sanctionRecord.SanctionCode, gotSanction.SanctionCode)
require.Equal(t, sanctionRecord.Scope, gotSanction.Scope)
require.Equal(t, sanctionRecord.ReasonCode, gotSanction.ReasonCode)
require.Equal(t, sanctionRecord.Actor, gotSanction.Actor)
require.True(t, gotSanction.AppliedAt.Equal(sanctionRecord.AppliedAt))
require.NotNil(t, gotSanction.ExpiresAt)
require.True(t, gotSanction.ExpiresAt.Equal(*sanctionRecord.ExpiresAt))
limitRecord := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-1"),
UserID: common.UserID("user-123"),
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 3,
ReasonCode: common.ReasonCode("policy_enforced"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
AppliedAt: now,
}
require.NoError(t, limitStore.Create(context.Background(), limitRecord))
gotLimit, err := limitStore.GetByRecordID(context.Background(), limitRecord.RecordID)
require.NoError(t, err)
require.Equal(t, limitRecord, gotLimit)
limits, err := limitStore.ListByUserID(context.Background(), limitRecord.UserID)
require.NoError(t, err)
require.Len(t, limits, 1)
limitRecord.Value = 5
require.NoError(t, limitStore.Update(context.Background(), limitRecord))
gotLimit, err = limitStore.GetByRecordID(context.Background(), limitRecord.RecordID)
require.NoError(t, err)
require.Equal(t, limitRecord, gotLimit)
}
func TestPolicyLifecycleApplyAndRemoveSanction(t *testing.T) {
t.Parallel()
store := newTestStore(t)
lifecycleStore := store.PolicyLifecycle()
sanctionStore := store.Sanctions()
snapshotStore := store.EntitlementSnapshots()
now := time.Unix(1_775_240_000, 0).UTC()
userID := common.UserID("user-123")
require.NoError(t, snapshotStore.Put(context.Background(), validEntitlementSnapshot(userID, now)))
record := policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-1"),
UserID: userID,
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("auth"),
ReasonCode: common.ReasonCode("manual_block"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now,
}
require.NoError(t, lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
NewRecord: record,
}))
activeRecordID, err := store.loadActiveSanctionRecordID(
context.Background(),
store.client,
store.keyspace.ActiveSanction(userID, policy.SanctionCodeLoginBlock),
)
require.NoError(t, err)
require.Equal(t, record.RecordID, activeRecordID)
err = lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
NewRecord: policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-2"),
UserID: userID,
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("auth"),
ReasonCode: common.ReasonCode("manual_block"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")},
AppliedAt: now.Add(time.Minute),
},
})
require.ErrorIs(t, err, ports.ErrConflict)
removed := record
removedAt := now.Add(30 * time.Minute)
removed.RemovedAt = &removedAt
removed.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
removed.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveSanction(context.Background(), ports.RemoveSanctionInput{
ExpectedActiveRecord: record,
UpdatedRecord: removed,
}))
stored, err := sanctionStore.GetByRecordID(context.Background(), record.RecordID)
require.NoError(t, err)
require.Equal(t, removed, stored)
_, err = store.loadActiveSanctionRecordID(
context.Background(),
store.client,
store.keyspace.ActiveSanction(userID, policy.SanctionCodeLoginBlock),
)
require.ErrorIs(t, err, ports.ErrNotFound)
}
func TestPolicyLifecycleSetAndRemoveLimit(t *testing.T) {
t.Parallel()
store := newTestStore(t)
lifecycleStore := store.PolicyLifecycle()
limitStore := store.Limits()
now := time.Unix(1_775_240_000, 0).UTC()
userID := common.UserID("user-123")
first := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-1"),
UserID: userID,
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 3,
ReasonCode: common.ReasonCode("manual_override"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now,
}
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
NewRecord: first,
}))
activeRecordID, err := store.loadActiveLimitRecordID(
context.Background(),
store.client,
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
)
require.NoError(t, err)
require.Equal(t, first.RecordID, activeRecordID)
second := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-2"),
UserID: userID,
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 5,
ReasonCode: common.ReasonCode("manual_override"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")},
AppliedAt: now.Add(time.Hour),
}
updatedFirst := first
removedAt := second.AppliedAt
updatedFirst.RemovedAt = &removedAt
updatedFirst.RemovedBy = second.Actor
updatedFirst.RemovedReasonCode = second.ReasonCode
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
ExpectedActiveRecord: &first,
UpdatedActiveRecord: &updatedFirst,
NewRecord: second,
}))
storedFirst, err := limitStore.GetByRecordID(context.Background(), first.RecordID)
require.NoError(t, err)
require.Equal(t, updatedFirst, storedFirst)
activeRecordID, err = store.loadActiveLimitRecordID(
context.Background(),
store.client,
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
)
require.NoError(t, err)
require.Equal(t, second.RecordID, activeRecordID)
removedSecond := second
removeAt := now.Add(90 * time.Minute)
removedSecond.RemovedAt = &removeAt
removedSecond.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-3")}
removedSecond.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveLimit(context.Background(), ports.RemoveLimitInput{
ExpectedActiveRecord: second,
UpdatedRecord: removedSecond,
}))
storedSecond, err := limitStore.GetByRecordID(context.Background(), second.RecordID)
require.NoError(t, err)
require.Equal(t, removedSecond, storedSecond)
_, err = store.loadActiveLimitRecordID(
context.Background(),
store.client,
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
)
require.ErrorIs(t, err, ports.ErrNotFound)
}
func TestEntitlementLifecycleTransitions(t *testing.T) {
t.Parallel()
store := newTestStore(t)
historyStore := store.EntitlementHistory()
snapshotStore := store.EntitlementSnapshots()
lifecycleStore := store.EntitlementLifecycle()
userID := common.UserID("user-123")
startedFreeAt := time.Unix(1_775_240_000, 0).UTC()
freeRecord := validEntitlementRecord(userID, startedFreeAt)
freeSnapshot := validEntitlementSnapshot(userID, startedFreeAt)
require.NoError(t, historyStore.Create(context.Background(), freeRecord))
require.NoError(t, snapshotStore.Put(context.Background(), freeSnapshot))
grantStartsAt := startedFreeAt.Add(24 * time.Hour)
grantEndsAt := grantStartsAt.Add(30 * 24 * time.Hour)
grantedRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-1"),
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
grantedSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
closedFreeRecord := freeRecord
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
require.NoError(t, lifecycleStore.Grant(context.Background(), ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: freeSnapshot,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFreeRecord,
NewRecord: grantedRecord,
NewSnapshot: grantedSnapshot,
}))
storedSnapshot, err := snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, grantedSnapshot, storedSnapshot)
storedFreeRecord, err := historyStore.GetByRecordID(context.Background(), freeRecord.RecordID)
require.NoError(t, err)
require.Equal(t, closedFreeRecord, storedFreeRecord)
extendedEndsAt := grantEndsAt.Add(30 * 24 * time.Hour)
extensionRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-2"),
userID,
entitlement.PlanCodePaidMonthly,
grantEndsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
extendedSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
require.NoError(t, lifecycleStore.Extend(context.Background(), ports.ExtendEntitlementInput{
ExpectedCurrentSnapshot: grantedSnapshot,
NewRecord: extensionRecord,
NewSnapshot: extendedSnapshot,
}))
storedSnapshot, err = snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, extendedSnapshot, storedSnapshot)
revokeAt := grantEndsAt.Add(12 * time.Hour)
revokedCurrentRecord := extensionRecord
revokedCurrentRecord.ClosedAt = timePointer(revokeAt)
revokedCurrentRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
revokedCurrentRecord.ClosedReasonCode = common.ReasonCode("manual_revoke")
freeAfterRevokeRecord := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-free-2"),
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
StartsAt: revokeAt,
CreatedAt: revokeAt,
}
freeAfterRevokeSnapshot := entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: revokeAt,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
UpdatedAt: revokeAt,
}
require.NoError(t, lifecycleStore.Revoke(context.Background(), ports.RevokeEntitlementInput{
ExpectedCurrentSnapshot: extendedSnapshot,
ExpectedCurrentRecord: extensionRecord,
UpdatedCurrentRecord: revokedCurrentRecord,
NewRecord: freeAfterRevokeRecord,
NewSnapshot: freeAfterRevokeSnapshot,
}))
storedSnapshot, err = snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, freeAfterRevokeSnapshot, storedSnapshot)
historyRecords, err := historyStore.ListByUserID(context.Background(), userID)
require.NoError(t, err)
require.Len(t, historyRecords, 4)
}
func TestRepairExpiredEntitlementMaterializesFreeSnapshot(t *testing.T) {
t.Parallel()
store := newTestStore(t)
historyStore := store.EntitlementHistory()
snapshotStore := store.EntitlementSnapshots()
lifecycleStore := store.EntitlementLifecycle()
userID := common.UserID("user-123")
startsAt := time.Unix(1_775_240_000, 0).UTC()
endsAt := startsAt.Add(24 * time.Hour)
expiredSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
startsAt,
endsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
expiredSnapshot.UpdatedAt = endsAt.Add(24 * time.Hour)
expiredRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-1"),
userID,
entitlement.PlanCodePaidMonthly,
startsAt,
endsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
require.NoError(t, historyStore.Create(context.Background(), expiredRecord))
require.NoError(t, snapshotStore.Put(context.Background(), expiredSnapshot))
repairedAt := endsAt.Add(2 * time.Hour)
freeRecord := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-free-after-expiry"),
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("entitlement_expiry_repair"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("paid_entitlement_expired"),
StartsAt: endsAt,
CreatedAt: repairedAt,
}
freeSnapshot := entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: endsAt,
Source: common.Source("entitlement_expiry_repair"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("paid_entitlement_expired"),
UpdatedAt: repairedAt,
}
require.NoError(t, lifecycleStore.RepairExpired(context.Background(), ports.RepairExpiredEntitlementInput{
ExpectedExpiredSnapshot: expiredSnapshot,
NewRecord: freeRecord,
NewSnapshot: freeSnapshot,
}))
storedSnapshot, err := snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, freeSnapshot, storedSnapshot)
historyRecords, err := historyStore.ListByUserID(context.Background(), userID)
require.NoError(t, err)
require.Len(t, historyRecords, 2)
require.Equal(t, freeRecord, historyRecords[1])
}
func newTestStore(t *testing.T) *Store {
t.Helper()
server := miniredis.RunT(t)
store, err := New(Config{
Addr: server.Addr(),
DB: 0,
KeyspacePrefix: "user:test:",
OperationTimeout: 250 * time.Millisecond,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = store.Close()
})
return store
}
func validAccountRecord() account.UserAccount {
createdAt := time.Unix(1_775_240_000, 0).UTC()
return account.UserAccount{
UserID: common.UserID("user-123"),
Email: common.Email("pilot@example.com"),
RaceName: common.RaceName("Pilot Nova"),
PreferredLanguage: common.LanguageTag("en"),
TimeZone: common.TimeZoneName("Europe/Kaliningrad"),
CreatedAt: createdAt,
UpdatedAt: createdAt,
}
}
func validEntitlementSnapshot(userID common.UserID, now time.Time) entitlement.CurrentSnapshot {
return entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: now,
Source: common.Source("auth_registration"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("initial_free_entitlement"),
UpdatedAt: now,
}
}
func validEntitlementRecord(userID common.UserID, now time.Time) entitlement.PeriodRecord {
return entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-" + userID.String()),
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("auth_registration"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("initial_free_entitlement"),
StartsAt: now,
CreatedAt: now,
}
}
func paidEntitlementRecord(
recordID entitlement.EntitlementRecordID,
userID common.UserID,
planCode entitlement.PlanCode,
startsAt time.Time,
endsAt time.Time,
source common.Source,
reasonCode common.ReasonCode,
) entitlement.PeriodRecord {
return entitlement.PeriodRecord{
RecordID: recordID,
UserID: userID,
PlanCode: planCode,
Source: source,
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: reasonCode,
StartsAt: startsAt,
EndsAt: timePointer(endsAt),
CreatedAt: startsAt,
}
}
func paidEntitlementSnapshot(
userID common.UserID,
planCode entitlement.PlanCode,
startsAt time.Time,
endsAt time.Time,
source common.Source,
reasonCode common.ReasonCode,
) entitlement.CurrentSnapshot {
return entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: planCode,
IsPaid: true,
StartsAt: startsAt,
EndsAt: timePointer(endsAt),
Source: source,
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: reasonCode,
UpdatedAt: startsAt,
}
}
func timePointer(value time.Time) *time.Time {
utcValue := value.UTC()
return &utcValue
}
func createAccountInput(record account.UserAccount) ports.CreateAccountInput {
return ports.CreateAccountInput{
Account: record,
Reservation: raceNameReservation(record.UserID, record.RaceName, record.UpdatedAt),
}
}
func renameRaceNameInput(
record account.UserAccount,
newRaceName common.RaceName,
updatedAt time.Time,
) ports.RenameRaceNameInput {
return ports.RenameRaceNameInput{
UserID: record.UserID,
CurrentCanonicalKey: canonicalKey(record.RaceName),
NewRaceName: newRaceName,
NewReservation: raceNameReservation(record.UserID, newRaceName, updatedAt),
UpdatedAt: updatedAt,
}
}
func raceNameReservation(
userID common.UserID,
raceName common.RaceName,
reservedAt time.Time,
) account.RaceNameReservation {
return account.RaceNameReservation{
CanonicalKey: canonicalKey(raceName),
UserID: userID,
RaceName: raceName,
ReservedAt: reservedAt.UTC(),
}
}
func canonicalKey(raceName common.RaceName) account.RaceNameCanonicalKey {
return account.RaceNameCanonicalKey(strings.NewReplacer(
"1", "i",
"0", "o",
"8", "b",
).Replace(strings.ToLower(raceName.String())))
}