feat: user service
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
// Package local provides small in-process runtime adapters used by the user
|
||||
// service process.
|
||||
package local
|
||||
|
||||
import "time"
|
||||
|
||||
// Clock returns the current wall-clock time.
|
||||
type Clock struct{}
|
||||
|
||||
// Now returns the current time.
|
||||
func (Clock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/user/internal/ports"
|
||||
)
|
||||
|
||||
// NoopDeclaredCountryChangedPublisher validates and discards auxiliary
|
||||
// declared-country change events.
|
||||
type NoopDeclaredCountryChangedPublisher struct{}
|
||||
|
||||
// PublishDeclaredCountryChanged validates event and discards it.
|
||||
func (NoopDeclaredCountryChangedPublisher) PublishDeclaredCountryChanged(
|
||||
ctx context.Context,
|
||||
event ports.DeclaredCountryChangedEvent,
|
||||
) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("publish declared-country changed event: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return event.Validate()
|
||||
}
|
||||
|
||||
var _ ports.DeclaredCountryChangedPublisher = NoopDeclaredCountryChangedPublisher{}
|
||||
@@ -0,0 +1,62 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/user/internal/ports"
|
||||
)
|
||||
|
||||
// NoopDomainEventPublisher validates and discards auxiliary user-domain
|
||||
// events.
|
||||
type NoopDomainEventPublisher struct{}
|
||||
|
||||
// PublishProfileChanged validates event and discards it.
|
||||
func (NoopDomainEventPublisher) PublishProfileChanged(ctx context.Context, event ports.ProfileChangedEvent) error {
|
||||
return validateNoopPublish(ctx, "publish profile changed event", event.Validate)
|
||||
}
|
||||
|
||||
// PublishSettingsChanged validates event and discards it.
|
||||
func (NoopDomainEventPublisher) PublishSettingsChanged(ctx context.Context, event ports.SettingsChangedEvent) error {
|
||||
return validateNoopPublish(ctx, "publish settings changed event", event.Validate)
|
||||
}
|
||||
|
||||
// PublishEntitlementChanged validates event and discards it.
|
||||
func (NoopDomainEventPublisher) PublishEntitlementChanged(ctx context.Context, event ports.EntitlementChangedEvent) error {
|
||||
return validateNoopPublish(ctx, "publish entitlement changed event", event.Validate)
|
||||
}
|
||||
|
||||
// PublishSanctionChanged validates event and discards it.
|
||||
func (NoopDomainEventPublisher) PublishSanctionChanged(ctx context.Context, event ports.SanctionChangedEvent) error {
|
||||
return validateNoopPublish(ctx, "publish sanction changed event", event.Validate)
|
||||
}
|
||||
|
||||
// PublishLimitChanged validates event and discards it.
|
||||
func (NoopDomainEventPublisher) PublishLimitChanged(ctx context.Context, event ports.LimitChangedEvent) error {
|
||||
return validateNoopPublish(ctx, "publish limit changed event", event.Validate)
|
||||
}
|
||||
|
||||
// PublishDeclaredCountryChanged validates event and discards it.
|
||||
func (NoopDomainEventPublisher) PublishDeclaredCountryChanged(ctx context.Context, event ports.DeclaredCountryChangedEvent) error {
|
||||
return validateNoopPublish(ctx, "publish declared-country changed event", event.Validate)
|
||||
}
|
||||
|
||||
func validateNoopPublish(ctx context.Context, operation string, validate func() error) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validate()
|
||||
}
|
||||
|
||||
var (
|
||||
_ ports.ProfileChangedPublisher = NoopDomainEventPublisher{}
|
||||
_ ports.SettingsChangedPublisher = NoopDomainEventPublisher{}
|
||||
_ ports.EntitlementChangedPublisher = NoopDomainEventPublisher{}
|
||||
_ ports.SanctionChangedPublisher = NoopDomainEventPublisher{}
|
||||
_ ports.LimitChangedPublisher = NoopDomainEventPublisher{}
|
||||
_ ports.DeclaredCountryChangedPublisher = NoopDomainEventPublisher{}
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base32"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
)
|
||||
|
||||
var base32NoPadding = base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
|
||||
// IDGenerator creates opaque stable user identifiers and generated initial
|
||||
// race names.
|
||||
type IDGenerator struct{}
|
||||
|
||||
// NewUserID returns one newly generated opaque user identifier.
|
||||
func (IDGenerator) NewUserID() (common.UserID, error) {
|
||||
token, err := randomToken(10)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate user id: %w", err)
|
||||
}
|
||||
|
||||
userID := common.UserID("user-" + token)
|
||||
if err := userID.Validate(); err != nil {
|
||||
return "", fmt.Errorf("generate user id: %w", err)
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// NewInitialRaceName returns one generated race name in the `player-<shortid>`
|
||||
// form.
|
||||
func (IDGenerator) NewInitialRaceName() (common.RaceName, error) {
|
||||
token, err := randomToken(5)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate initial race name: %w", err)
|
||||
}
|
||||
|
||||
raceName := common.RaceName("player-" + token)
|
||||
if err := raceName.Validate(); err != nil {
|
||||
return "", fmt.Errorf("generate initial race name: %w", err)
|
||||
}
|
||||
|
||||
return raceName, nil
|
||||
}
|
||||
|
||||
// NewEntitlementRecordID returns one generated entitlement history record
|
||||
// identifier.
|
||||
func (IDGenerator) NewEntitlementRecordID() (entitlement.EntitlementRecordID, error) {
|
||||
token, err := randomToken(10)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate entitlement record id: %w", err)
|
||||
}
|
||||
|
||||
recordID := entitlement.EntitlementRecordID("entitlement-" + token)
|
||||
if err := recordID.Validate(); err != nil {
|
||||
return "", fmt.Errorf("generate entitlement record id: %w", err)
|
||||
}
|
||||
|
||||
return recordID, nil
|
||||
}
|
||||
|
||||
// NewSanctionRecordID returns one generated sanction history record
|
||||
// identifier.
|
||||
func (IDGenerator) NewSanctionRecordID() (policy.SanctionRecordID, error) {
|
||||
token, err := randomToken(10)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate sanction record id: %w", err)
|
||||
}
|
||||
|
||||
recordID := policy.SanctionRecordID("sanction-" + token)
|
||||
if err := recordID.Validate(); err != nil {
|
||||
return "", fmt.Errorf("generate sanction record id: %w", err)
|
||||
}
|
||||
|
||||
return recordID, nil
|
||||
}
|
||||
|
||||
// NewLimitRecordID returns one generated limit history record identifier.
|
||||
func (IDGenerator) NewLimitRecordID() (policy.LimitRecordID, error) {
|
||||
token, err := randomToken(10)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate limit record id: %w", err)
|
||||
}
|
||||
|
||||
recordID := policy.LimitRecordID("limit-" + token)
|
||||
if err := recordID.Validate(); err != nil {
|
||||
return "", fmt.Errorf("generate limit record id: %w", err)
|
||||
}
|
||||
|
||||
return recordID, nil
|
||||
}
|
||||
|
||||
func randomToken(size int) (string, error) {
|
||||
buffer := make([]byte, size)
|
||||
if _, err := rand.Read(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.ToLower(base32NoPadding.EncodeToString(buffer)), nil
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"galaxy/user/internal/domain/account"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
confusables "github.com/disciplinedware/go-confusables"
|
||||
"golang.org/x/text/cases"
|
||||
)
|
||||
|
||||
type confusableSkeletoner interface {
|
||||
Skeleton(string) string
|
||||
}
|
||||
|
||||
type raceNamePolicy struct {
|
||||
caseFolder cases.Caser
|
||||
skeletoner confusableSkeletoner
|
||||
}
|
||||
|
||||
var raceNameAntiFraudReplacer = strings.NewReplacer(
|
||||
"1", "i",
|
||||
"0", "o",
|
||||
"8", "b",
|
||||
)
|
||||
|
||||
// NewRaceNamePolicy returns the local Stage 06 race-name canonicalization
|
||||
// policy backed by Unicode case folding, explicit ASCII anti-fraud mappings,
|
||||
// and a TR39 confusable skeleton.
|
||||
func NewRaceNamePolicy() (ports.RaceNamePolicy, error) {
|
||||
policy := &raceNamePolicy{
|
||||
caseFolder: cases.Fold(),
|
||||
skeletoner: confusables.Default(),
|
||||
}
|
||||
if policy.skeletoner == nil {
|
||||
return nil, fmt.Errorf("new race-name policy: nil confusable skeletoner")
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// CanonicalKey returns the stable uniqueness key for raceName.
|
||||
func (policy *raceNamePolicy) CanonicalKey(raceName common.RaceName) (account.RaceNameCanonicalKey, error) {
|
||||
switch {
|
||||
case policy == nil:
|
||||
return "", fmt.Errorf("canonicalize race name: nil policy")
|
||||
case policy.skeletoner == nil:
|
||||
return "", fmt.Errorf("canonicalize race name: nil confusable skeletoner")
|
||||
}
|
||||
if err := raceName.Validate(); err != nil {
|
||||
return "", fmt.Errorf("canonicalize race name: %w", err)
|
||||
}
|
||||
|
||||
folded := policy.caseFolder.String(raceName.String())
|
||||
antiFraudMapped := raceNameAntiFraudReplacer.Replace(folded)
|
||||
key := account.RaceNameCanonicalKey(policy.skeletoner.Skeleton(antiFraudMapped))
|
||||
if err := key.Validate(); err != nil {
|
||||
return "", fmt.Errorf("canonicalize race name: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/account"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRaceNamePolicyCanonicalKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
policy, err := NewRaceNamePolicy()
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
left common.RaceName
|
||||
right common.RaceName
|
||||
}{
|
||||
{
|
||||
name: "case insensitive collision",
|
||||
left: common.RaceName("Pilot Nova"),
|
||||
right: common.RaceName("pilot nova"),
|
||||
},
|
||||
{
|
||||
name: "ascii anti fraud collision",
|
||||
left: common.RaceName("Pilot Nova"),
|
||||
right: common.RaceName("P1lot N0va"),
|
||||
},
|
||||
{
|
||||
name: "unicode confusable collision",
|
||||
left: common.RaceName("paypal"),
|
||||
right: common.RaceName("раураl"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
leftKey, err := policy.CanonicalKey(tt.left)
|
||||
require.NoError(t, err)
|
||||
rightKey, err := policy.CanonicalKey(tt.right)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rightKey, leftKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRaceNameReservationPreservesOriginalDisplayValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
policy, err := NewRaceNamePolicy()
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := shared.BuildRaceNameReservation(
|
||||
policy,
|
||||
common.UserID("user-123"),
|
||||
common.RaceName("P1lot Nova"),
|
||||
time.Unix(1_775_240_000, 0).UTC(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, common.RaceName("P1lot Nova"), record.RaceName)
|
||||
require.NotEqual(t, account.RaceNameCanonicalKey(""), record.CanonicalKey)
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
// Package domainevents implements Redis Stream-backed auxiliary user-domain
|
||||
// event publishers.
|
||||
package domainevents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Config configures one Redis-backed user domain-event publisher.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// Stream identifies the Redis Stream key used for domain events.
|
||||
Stream string
|
||||
|
||||
// StreamMaxLen bounds the stream with approximate trimming via
|
||||
// `XADD MAXLEN ~`.
|
||||
StreamMaxLen int64
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Publisher publishes auxiliary user-domain events into one Redis Stream.
|
||||
type Publisher struct {
|
||||
client *redis.Client
|
||||
stream string
|
||||
streamMaxLen int64
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed domain-event publisher from cfg.
|
||||
func New(cfg Config) (*Publisher, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis domain-event publisher: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis domain-event publisher: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.Stream) == "":
|
||||
return nil, errors.New("new redis domain-event publisher: stream must not be empty")
|
||||
case cfg.StreamMaxLen <= 0:
|
||||
return nil, errors.New("new redis domain-event publisher: stream max len must be positive")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis domain-event publisher: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Publisher{
|
||||
client: redis.NewClient(options),
|
||||
stream: cfg.Stream,
|
||||
streamMaxLen: cfg.StreamMaxLen,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (publisher *Publisher) Close() error {
|
||||
if publisher == nil || publisher.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return publisher.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (publisher *Publisher) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := publisher.operationContext(ctx, "ping redis domain-event publisher")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := publisher.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis domain-event publisher: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishProfileChanged publishes one committed profile-change event.
|
||||
func (publisher *Publisher) PublishProfileChanged(ctx context.Context, event ports.ProfileChangedEvent) error {
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("publish profile changed event: %w", err)
|
||||
}
|
||||
|
||||
values := buildEnvelope(ports.ProfileChangedEventType, event.UserID.String(), event.OccurredAt, event.Source.String(), traceIDFromContext(ctx, event.TraceID))
|
||||
values["operation"] = string(event.Operation)
|
||||
values["race_name"] = event.RaceName.String()
|
||||
|
||||
return publisher.publish(ctx, "publish profile changed event", values)
|
||||
}
|
||||
|
||||
// PublishSettingsChanged publishes one committed settings-change event.
|
||||
func (publisher *Publisher) PublishSettingsChanged(ctx context.Context, event ports.SettingsChangedEvent) error {
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("publish settings changed event: %w", err)
|
||||
}
|
||||
|
||||
values := buildEnvelope(ports.SettingsChangedEventType, event.UserID.String(), event.OccurredAt, event.Source.String(), traceIDFromContext(ctx, event.TraceID))
|
||||
values["operation"] = string(event.Operation)
|
||||
values["preferred_language"] = event.PreferredLanguage.String()
|
||||
values["time_zone"] = event.TimeZone.String()
|
||||
|
||||
return publisher.publish(ctx, "publish settings changed event", values)
|
||||
}
|
||||
|
||||
// PublishEntitlementChanged publishes one committed entitlement-change event.
|
||||
func (publisher *Publisher) PublishEntitlementChanged(ctx context.Context, event ports.EntitlementChangedEvent) error {
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("publish entitlement changed event: %w", err)
|
||||
}
|
||||
|
||||
values := buildEnvelope(ports.EntitlementChangedEventType, event.UserID.String(), event.OccurredAt, event.Source.String(), traceIDFromContext(ctx, event.TraceID))
|
||||
values["operation"] = string(event.Operation)
|
||||
values["plan_code"] = string(event.PlanCode)
|
||||
values["is_paid"] = strconv.FormatBool(event.IsPaid)
|
||||
values["starts_at_ms"] = strconv.FormatInt(event.StartsAt.UTC().UnixMilli(), 10)
|
||||
values["reason_code"] = event.ReasonCode.String()
|
||||
values["actor_type"] = event.Actor.Type.String()
|
||||
values["updated_at_ms"] = strconv.FormatInt(event.UpdatedAt.UTC().UnixMilli(), 10)
|
||||
if !event.Actor.ID.IsZero() {
|
||||
values["actor_id"] = event.Actor.ID.String()
|
||||
}
|
||||
if event.EndsAt != nil {
|
||||
values["ends_at_ms"] = strconv.FormatInt(event.EndsAt.UTC().UnixMilli(), 10)
|
||||
}
|
||||
|
||||
return publisher.publish(ctx, "publish entitlement changed event", values)
|
||||
}
|
||||
|
||||
// PublishSanctionChanged publishes one committed sanction-change event.
|
||||
func (publisher *Publisher) PublishSanctionChanged(ctx context.Context, event ports.SanctionChangedEvent) error {
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("publish sanction changed event: %w", err)
|
||||
}
|
||||
|
||||
values := buildEnvelope(ports.SanctionChangedEventType, event.UserID.String(), event.OccurredAt, event.Source.String(), traceIDFromContext(ctx, event.TraceID))
|
||||
values["operation"] = string(event.Operation)
|
||||
values["sanction_code"] = string(event.SanctionCode)
|
||||
values["scope"] = event.Scope.String()
|
||||
values["reason_code"] = event.ReasonCode.String()
|
||||
values["actor_type"] = event.Actor.Type.String()
|
||||
values["applied_at_ms"] = strconv.FormatInt(event.AppliedAt.UTC().UnixMilli(), 10)
|
||||
if !event.Actor.ID.IsZero() {
|
||||
values["actor_id"] = event.Actor.ID.String()
|
||||
}
|
||||
if event.ExpiresAt != nil {
|
||||
values["expires_at_ms"] = strconv.FormatInt(event.ExpiresAt.UTC().UnixMilli(), 10)
|
||||
}
|
||||
if event.RemovedAt != nil {
|
||||
values["removed_at_ms"] = strconv.FormatInt(event.RemovedAt.UTC().UnixMilli(), 10)
|
||||
}
|
||||
|
||||
return publisher.publish(ctx, "publish sanction changed event", values)
|
||||
}
|
||||
|
||||
// PublishLimitChanged publishes one committed limit-change event.
|
||||
func (publisher *Publisher) PublishLimitChanged(ctx context.Context, event ports.LimitChangedEvent) error {
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("publish limit changed event: %w", err)
|
||||
}
|
||||
|
||||
values := buildEnvelope(ports.LimitChangedEventType, event.UserID.String(), event.OccurredAt, event.Source.String(), traceIDFromContext(ctx, event.TraceID))
|
||||
values["operation"] = string(event.Operation)
|
||||
values["limit_code"] = string(event.LimitCode)
|
||||
values["reason_code"] = event.ReasonCode.String()
|
||||
values["actor_type"] = event.Actor.Type.String()
|
||||
values["applied_at_ms"] = strconv.FormatInt(event.AppliedAt.UTC().UnixMilli(), 10)
|
||||
if event.Value != nil {
|
||||
values["value"] = strconv.Itoa(*event.Value)
|
||||
}
|
||||
if !event.Actor.ID.IsZero() {
|
||||
values["actor_id"] = event.Actor.ID.String()
|
||||
}
|
||||
if event.ExpiresAt != nil {
|
||||
values["expires_at_ms"] = strconv.FormatInt(event.ExpiresAt.UTC().UnixMilli(), 10)
|
||||
}
|
||||
if event.RemovedAt != nil {
|
||||
values["removed_at_ms"] = strconv.FormatInt(event.RemovedAt.UTC().UnixMilli(), 10)
|
||||
}
|
||||
|
||||
return publisher.publish(ctx, "publish limit changed event", values)
|
||||
}
|
||||
|
||||
// PublishDeclaredCountryChanged publishes one committed declared-country change
|
||||
// event.
|
||||
func (publisher *Publisher) PublishDeclaredCountryChanged(ctx context.Context, event ports.DeclaredCountryChangedEvent) error {
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("publish declared-country changed event: %w", err)
|
||||
}
|
||||
|
||||
values := buildEnvelope(
|
||||
ports.DeclaredCountryChangedEventType,
|
||||
event.UserID.String(),
|
||||
event.UpdatedAt,
|
||||
event.Source.String(),
|
||||
traceIDFromContext(ctx, event.TraceID),
|
||||
)
|
||||
values["declared_country"] = event.DeclaredCountry.String()
|
||||
values["updated_at_ms"] = strconv.FormatInt(event.UpdatedAt.UTC().UnixMilli(), 10)
|
||||
|
||||
return publisher.publish(ctx, "publish declared-country changed event", values)
|
||||
}
|
||||
|
||||
func (publisher *Publisher) publish(ctx context.Context, operation string, values map[string]any) error {
|
||||
operationCtx, cancel, err := publisher.operationContext(ctx, operation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := publisher.client.XAdd(operationCtx, &redis.XAddArgs{
|
||||
Stream: publisher.stream,
|
||||
MaxLen: publisher.streamMaxLen,
|
||||
Approx: true,
|
||||
Values: values,
|
||||
}).Err(); err != nil {
|
||||
return fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (publisher *Publisher) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if publisher == nil || publisher.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil publisher", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, publisher.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func buildEnvelope(eventType string, userID string, occurredAt time.Time, source string, traceID string) map[string]any {
|
||||
values := map[string]any{
|
||||
"event_type": eventType,
|
||||
"user_id": userID,
|
||||
"occurred_at_ms": strconv.FormatInt(occurredAt.UTC().UnixMilli(), 10),
|
||||
"source": source,
|
||||
}
|
||||
if traceID != "" {
|
||||
values["trace_id"] = traceID
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func traceIDFromContext(ctx context.Context, fallback string) string {
|
||||
if strings.TrimSpace(fallback) != "" {
|
||||
return fallback
|
||||
}
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
spanContext := trace.SpanContextFromContext(ctx)
|
||||
if !spanContext.IsValid() {
|
||||
return ""
|
||||
}
|
||||
|
||||
return spanContext.TraceID().String()
|
||||
}
|
||||
|
||||
var (
|
||||
_ interface{ Close() error } = (*Publisher)(nil)
|
||||
_ interface{ Ping(context.Context) error } = (*Publisher)(nil)
|
||||
_ ports.ProfileChangedPublisher = (*Publisher)(nil)
|
||||
_ ports.SettingsChangedPublisher = (*Publisher)(nil)
|
||||
_ ports.EntitlementChangedPublisher = (*Publisher)(nil)
|
||||
_ ports.SanctionChangedPublisher = (*Publisher)(nil)
|
||||
_ ports.LimitChangedPublisher = (*Publisher)(nil)
|
||||
_ ports.DeclaredCountryChangedPublisher = (*Publisher)(nil)
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
package domainevents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublisherPublishesFlatRedisStreamEntry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher, err := New(Config{
|
||||
Addr: server.Addr(),
|
||||
Stream: "user:test_events",
|
||||
StreamMaxLen: 5,
|
||||
OperationTimeout: time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
occurredAt := time.Unix(1_775_240_000, 0).UTC()
|
||||
err = publisher.PublishProfileChanged(context.Background(), ports.ProfileChangedEvent{
|
||||
UserID: common.UserID("user-123"),
|
||||
OccurredAt: occurredAt,
|
||||
Source: common.Source("gateway_self_service"),
|
||||
TraceID: "4bf92f3577b34da6a3ce929d0e0e4736",
|
||||
Operation: ports.ProfileChangedOperationUpdated,
|
||||
RaceName: common.RaceName("Nova Prime"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.stream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
require.Equal(t, ports.ProfileChangedEventType, entries[0].Values["event_type"])
|
||||
require.Equal(t, "user-123", entries[0].Values["user_id"])
|
||||
require.Equal(t, strconv.FormatInt(occurredAt.UnixMilli(), 10), entries[0].Values["occurred_at_ms"])
|
||||
require.Equal(t, "gateway_self_service", entries[0].Values["source"])
|
||||
require.Equal(t, "4bf92f3577b34da6a3ce929d0e0e4736", entries[0].Values["trace_id"])
|
||||
require.Equal(t, string(ports.ProfileChangedOperationUpdated), entries[0].Values["operation"])
|
||||
require.Equal(t, "Nova Prime", entries[0].Values["race_name"])
|
||||
|
||||
for index := 0; index < 20; index++ {
|
||||
err = publisher.PublishSettingsChanged(context.Background(), ports.SettingsChangedEvent{
|
||||
UserID: common.UserID("user-123"),
|
||||
OccurredAt: occurredAt.Add(time.Duration(index+1) * time.Second),
|
||||
Source: common.Source("gateway_self_service"),
|
||||
Operation: ports.SettingsChangedOperationUpdated,
|
||||
PreferredLanguage: common.LanguageTag("en-US"),
|
||||
TimeZone: common.TimeZoneName("UTC"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
length, err := publisher.client.XLen(context.Background(), publisher.stream).Result()
|
||||
require.NoError(t, err)
|
||||
require.LessOrEqual(t, length, int64(20))
|
||||
}
|
||||
|
||||
func TestPublisherRejectsInvalidEventBeforeXAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher, err := New(Config{
|
||||
Addr: server.Addr(),
|
||||
Stream: "user:test_events",
|
||||
StreamMaxLen: 5,
|
||||
OperationTimeout: time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = publisher.PublishProfileChanged(context.Background(), ports.ProfileChangedEvent{
|
||||
UserID: common.UserID("user-123"),
|
||||
OccurredAt: time.Unix(1_775_240_000, 0).UTC(),
|
||||
Operation: ports.ProfileChangedOperationUpdated,
|
||||
RaceName: common.RaceName("Nova Prime"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
length, xLenErr := publisher.client.XLen(context.Background(), publisher.stream).Result()
|
||||
require.NoError(t, xLenErr)
|
||||
require.Zero(t, length)
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/user/internal/adapters/redisstate"
|
||||
"galaxy/user/internal/domain/account"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var knownSanctionCodes = []policy.SanctionCode{
|
||||
policy.SanctionCodeLoginBlock,
|
||||
policy.SanctionCodePrivateGameCreateBlock,
|
||||
policy.SanctionCodePrivateGameManageBlock,
|
||||
policy.SanctionCodeGameJoinBlock,
|
||||
policy.SanctionCodeProfileUpdateBlock,
|
||||
}
|
||||
|
||||
var knownLimitCodes = []policy.LimitCode{
|
||||
policy.LimitCodeMaxOwnedPrivateGames,
|
||||
policy.LimitCodeMaxPendingPublicApplications,
|
||||
policy.LimitCodeMaxActiveGameMemberships,
|
||||
}
|
||||
|
||||
var knownEligibilityMarkers = []policy.EligibilityMarker{
|
||||
policy.EligibilityMarkerCanLogin,
|
||||
policy.EligibilityMarkerCanCreatePrivateGame,
|
||||
policy.EligibilityMarkerCanManagePrivateGame,
|
||||
policy.EligibilityMarkerCanJoinGame,
|
||||
policy.EligibilityMarkerCanUpdateProfile,
|
||||
}
|
||||
|
||||
func (store *Store) addCreatedAtIndex(
|
||||
pipe redis.Pipeliner,
|
||||
ctx context.Context,
|
||||
record account.UserAccount,
|
||||
) {
|
||||
pipe.ZAdd(ctx, store.keyspace.CreatedAtIndex(), redis.Z{
|
||||
Score: redisstate.CreatedAtScore(record.CreatedAt),
|
||||
Member: record.UserID.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func (store *Store) syncDeclaredCountryIndex(
|
||||
pipe redis.Pipeliner,
|
||||
ctx context.Context,
|
||||
previous account.UserAccount,
|
||||
current account.UserAccount,
|
||||
) {
|
||||
if !previous.DeclaredCountry.IsZero() {
|
||||
pipe.SRem(ctx, store.keyspace.DeclaredCountryIndex(previous.DeclaredCountry), current.UserID.String())
|
||||
}
|
||||
if !current.DeclaredCountry.IsZero() {
|
||||
pipe.SAdd(ctx, store.keyspace.DeclaredCountryIndex(current.DeclaredCountry), current.UserID.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) syncEntitlementIndexes(
|
||||
pipe redis.Pipeliner,
|
||||
ctx context.Context,
|
||||
snapshot entitlement.CurrentSnapshot,
|
||||
) {
|
||||
pipe.SRem(ctx, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), snapshot.UserID.String())
|
||||
pipe.SRem(ctx, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), snapshot.UserID.String())
|
||||
pipe.SAdd(ctx, store.keyspace.PaidStateIndex(paidStateFromSnapshot(snapshot)), snapshot.UserID.String())
|
||||
|
||||
pipe.ZRem(ctx, store.keyspace.FinitePaidExpiryIndex(), snapshot.UserID.String())
|
||||
if snapshot.HasFiniteExpiry() {
|
||||
pipe.ZAdd(ctx, store.keyspace.FinitePaidExpiryIndex(), redis.Z{
|
||||
Score: redisstate.ExpiryScore(*snapshot.EndsAt),
|
||||
Member: snapshot.UserID.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) syncActiveSanctionCodeIndexes(
|
||||
pipe redis.Pipeliner,
|
||||
ctx context.Context,
|
||||
userID common.UserID,
|
||||
activeCodes map[policy.SanctionCode]struct{},
|
||||
) {
|
||||
for _, code := range knownSanctionCodes {
|
||||
pipe.SRem(ctx, store.keyspace.ActiveSanctionCodeIndex(code), userID.String())
|
||||
if _, ok := activeCodes[code]; ok {
|
||||
pipe.SAdd(ctx, store.keyspace.ActiveSanctionCodeIndex(code), userID.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) syncActiveLimitCodeIndexes(
|
||||
pipe redis.Pipeliner,
|
||||
ctx context.Context,
|
||||
userID common.UserID,
|
||||
activeCodes map[policy.LimitCode]struct{},
|
||||
) {
|
||||
for _, code := range knownLimitCodes {
|
||||
pipe.SRem(ctx, store.keyspace.ActiveLimitCodeIndex(code), userID.String())
|
||||
if _, ok := activeCodes[code]; ok {
|
||||
pipe.SAdd(ctx, store.keyspace.ActiveLimitCodeIndex(code), userID.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) syncEligibilityMarkerIndexes(
|
||||
pipe redis.Pipeliner,
|
||||
ctx context.Context,
|
||||
userID common.UserID,
|
||||
isPaid bool,
|
||||
activeSanctionCodes map[policy.SanctionCode]struct{},
|
||||
) {
|
||||
values := deriveEligibilityMarkerValues(isPaid, activeSanctionCodes)
|
||||
|
||||
for _, marker := range knownEligibilityMarkers {
|
||||
pipe.SRem(ctx, store.keyspace.EligibilityMarkerIndex(marker, true), userID.String())
|
||||
pipe.SRem(ctx, store.keyspace.EligibilityMarkerIndex(marker, false), userID.String())
|
||||
pipe.SAdd(ctx, store.keyspace.EligibilityMarkerIndex(marker, values[marker]), userID.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) loadActiveSanctionCodeSet(
|
||||
ctx context.Context,
|
||||
getter bytesGetter,
|
||||
userID common.UserID,
|
||||
) (map[policy.SanctionCode]struct{}, error) {
|
||||
activeCodes := make(map[policy.SanctionCode]struct{}, len(knownSanctionCodes))
|
||||
|
||||
for _, code := range knownSanctionCodes {
|
||||
_, err := store.loadActiveSanctionRecordID(ctx, getter, store.keyspace.ActiveSanction(userID, code))
|
||||
switch {
|
||||
case err == nil:
|
||||
activeCodes[code] = struct{}{}
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
continue
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return activeCodes, nil
|
||||
}
|
||||
|
||||
func (store *Store) loadActiveLimitCodeSet(
|
||||
ctx context.Context,
|
||||
getter bytesGetter,
|
||||
userID common.UserID,
|
||||
) (map[policy.LimitCode]struct{}, error) {
|
||||
activeCodes := make(map[policy.LimitCode]struct{}, len(knownLimitCodes))
|
||||
|
||||
for _, code := range knownLimitCodes {
|
||||
_, err := store.loadActiveLimitRecordID(ctx, getter, store.keyspace.ActiveLimit(userID, code))
|
||||
switch {
|
||||
case err == nil:
|
||||
activeCodes[code] = struct{}{}
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
continue
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return activeCodes, nil
|
||||
}
|
||||
|
||||
func (store *Store) activeSanctionWatchKeys(userID common.UserID) []string {
|
||||
keys := make([]string, 0, len(knownSanctionCodes))
|
||||
for _, code := range knownSanctionCodes {
|
||||
keys = append(keys, store.keyspace.ActiveSanction(userID, code))
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (store *Store) activeLimitWatchKeys(userID common.UserID) []string {
|
||||
keys := make([]string, 0, len(knownLimitCodes))
|
||||
for _, code := range knownLimitCodes {
|
||||
keys = append(keys, store.keyspace.ActiveLimit(userID, code))
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func deriveEligibilityMarkerValues(
|
||||
isPaid bool,
|
||||
activeSanctionCodes map[policy.SanctionCode]struct{},
|
||||
) map[policy.EligibilityMarker]bool {
|
||||
_, loginBlocked := activeSanctionCodes[policy.SanctionCodeLoginBlock]
|
||||
_, createBlocked := activeSanctionCodes[policy.SanctionCodePrivateGameCreateBlock]
|
||||
_, manageBlocked := activeSanctionCodes[policy.SanctionCodePrivateGameManageBlock]
|
||||
_, joinBlocked := activeSanctionCodes[policy.SanctionCodeGameJoinBlock]
|
||||
_, profileBlocked := activeSanctionCodes[policy.SanctionCodeProfileUpdateBlock]
|
||||
|
||||
canLogin := !loginBlocked
|
||||
|
||||
return map[policy.EligibilityMarker]bool{
|
||||
policy.EligibilityMarkerCanLogin: canLogin,
|
||||
policy.EligibilityMarkerCanCreatePrivateGame: canLogin && isPaid && !createBlocked,
|
||||
policy.EligibilityMarkerCanManagePrivateGame: canLogin && isPaid && !manageBlocked,
|
||||
policy.EligibilityMarkerCanJoinGame: canLogin && !joinBlocked,
|
||||
policy.EligibilityMarkerCanUpdateProfile: canLogin && !profileBlocked,
|
||||
}
|
||||
}
|
||||
|
||||
func paidStateFromSnapshot(snapshot entitlement.CurrentSnapshot) entitlement.PaidState {
|
||||
if snapshot.IsPaid {
|
||||
return entitlement.PaidStatePaid
|
||||
}
|
||||
|
||||
return entitlement.PaidStateFree
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/adapters/redisstate"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
"galaxy/user/internal/ports"
|
||||
"galaxy/user/internal/service/adminusers"
|
||||
"galaxy/user/internal/service/entitlementsvc"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestListUserIDsCreatedAtPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
base := time.Unix(1_775_240_000, 0).UTC()
|
||||
|
||||
first := validAccountRecord()
|
||||
first.UserID = common.UserID("user-100")
|
||||
first.Email = common.Email("u100@example.com")
|
||||
first.RaceName = common.RaceName("User 100")
|
||||
first.CreatedAt = base.Add(-time.Hour)
|
||||
first.UpdatedAt = first.CreatedAt
|
||||
|
||||
second := validAccountRecord()
|
||||
second.UserID = common.UserID("user-200")
|
||||
second.Email = common.Email("u200@example.com")
|
||||
second.RaceName = common.RaceName("User 200")
|
||||
second.CreatedAt = base
|
||||
second.UpdatedAt = second.CreatedAt
|
||||
|
||||
third := validAccountRecord()
|
||||
third.UserID = common.UserID("user-300")
|
||||
third.Email = common.Email("u300@example.com")
|
||||
third.RaceName = common.RaceName("User 300")
|
||||
third.CreatedAt = base
|
||||
third.UpdatedAt = third.CreatedAt
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), createAccountInput(first)))
|
||||
require.NoError(t, store.Create(context.Background(), createAccountInput(second)))
|
||||
require.NoError(t, store.Create(context.Background(), createAccountInput(third)))
|
||||
|
||||
firstPage, err := store.ListUserIDs(context.Background(), ports.ListUsersInput{
|
||||
PageSize: 2,
|
||||
Filters: ports.UserListFilters{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []common.UserID{third.UserID, second.UserID}, firstPage.UserIDs)
|
||||
require.NotEmpty(t, firstPage.NextPageToken)
|
||||
|
||||
secondPage, err := store.ListUserIDs(context.Background(), ports.ListUsersInput{
|
||||
PageSize: 2,
|
||||
PageToken: firstPage.NextPageToken,
|
||||
Filters: ports.UserListFilters{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []common.UserID{first.UserID}, secondPage.UserIDs)
|
||||
require.Empty(t, secondPage.NextPageToken)
|
||||
}
|
||||
|
||||
func TestEnsureByEmailInitialAdminIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
record := validAccountRecord()
|
||||
record.DeclaredCountry = common.CountryCode("DE")
|
||||
record.CreatedAt = now
|
||||
record.UpdatedAt = now
|
||||
|
||||
result, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: record.Email,
|
||||
Account: record,
|
||||
Entitlement: validEntitlementSnapshot(record.UserID, now),
|
||||
EntitlementRecord: validEntitlementRecord(record.UserID, now),
|
||||
Reservation: raceNameReservation(record.UserID, record.RaceName, now),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.EnsureByEmailOutcomeCreated, result.Outcome)
|
||||
|
||||
requireSortedSetScore(t, store, store.keyspace.CreatedAtIndex(), record.UserID.String(), redisstate.CreatedAtScore(record.CreatedAt))
|
||||
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
|
||||
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.DeclaredCountryIndex(record.DeclaredCountry), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, false), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanJoinGame, true), record.UserID.String())
|
||||
}
|
||||
|
||||
func TestAccountUpdateSyncsDeclaredCountryIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
accountStore := store.Accounts()
|
||||
record := validAccountRecord()
|
||||
record.DeclaredCountry = common.CountryCode("DE")
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
|
||||
|
||||
updated := record
|
||||
updated.DeclaredCountry = common.CountryCode("FR")
|
||||
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
|
||||
require.NoError(t, accountStore.Update(context.Background(), updated))
|
||||
|
||||
requireSetNotContains(t, store, store.keyspace.DeclaredCountryIndex(common.CountryCode("DE")), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.DeclaredCountryIndex(common.CountryCode("FR")), record.UserID.String())
|
||||
}
|
||||
|
||||
func TestEntitlementLifecycleSyncsAdminIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
record := validAccountRecord()
|
||||
record.CreatedAt = now
|
||||
record.UpdatedAt = now
|
||||
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: record.Email,
|
||||
Account: record,
|
||||
Entitlement: validEntitlementSnapshot(record.UserID, now),
|
||||
EntitlementRecord: validEntitlementRecord(record.UserID, now),
|
||||
Reservation: raceNameReservation(record.UserID, record.RaceName, now),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
lifecycleStore := store.EntitlementLifecycle()
|
||||
freeRecord := validEntitlementRecord(record.UserID, now)
|
||||
freeSnapshot := validEntitlementSnapshot(record.UserID, now)
|
||||
|
||||
grantStartsAt := now.Add(time.Hour)
|
||||
grantEndsAt := grantStartsAt.Add(30 * 24 * time.Hour)
|
||||
grantedRecord := paidEntitlementRecord(
|
||||
entitlement.EntitlementRecordID("entitlement-paid-1"),
|
||||
record.UserID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
grantEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
grantedSnapshot := paidEntitlementSnapshot(
|
||||
record.UserID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
grantEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
closedFreeRecord := freeRecord
|
||||
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
|
||||
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
|
||||
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
|
||||
|
||||
require.NoError(t, lifecycleStore.Grant(context.Background(), ports.GrantEntitlementInput{
|
||||
ExpectedCurrentSnapshot: freeSnapshot,
|
||||
ExpectedCurrentRecord: freeRecord,
|
||||
UpdatedCurrentRecord: closedFreeRecord,
|
||||
NewRecord: grantedRecord,
|
||||
NewSnapshot: grantedSnapshot,
|
||||
}))
|
||||
|
||||
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
|
||||
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
|
||||
requireSortedSetScore(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String(), redisstate.ExpiryScore(grantEndsAt))
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, true), record.UserID.String())
|
||||
|
||||
extendedEndsAt := grantEndsAt.Add(30 * 24 * time.Hour)
|
||||
extensionRecord := paidEntitlementRecord(
|
||||
entitlement.EntitlementRecordID("entitlement-paid-2"),
|
||||
record.UserID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantEndsAt,
|
||||
extendedEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_extend"),
|
||||
)
|
||||
extendedSnapshot := paidEntitlementSnapshot(
|
||||
record.UserID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
extendedEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_extend"),
|
||||
)
|
||||
require.NoError(t, lifecycleStore.Extend(context.Background(), ports.ExtendEntitlementInput{
|
||||
ExpectedCurrentSnapshot: grantedSnapshot,
|
||||
NewRecord: extensionRecord,
|
||||
NewSnapshot: extendedSnapshot,
|
||||
}))
|
||||
|
||||
requireSortedSetScore(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String(), redisstate.ExpiryScore(extendedEndsAt))
|
||||
|
||||
revokeAt := grantEndsAt.Add(12 * time.Hour)
|
||||
revokedCurrentRecord := extensionRecord
|
||||
revokedCurrentRecord.ClosedAt = timePointer(revokeAt)
|
||||
revokedCurrentRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
|
||||
revokedCurrentRecord.ClosedReasonCode = common.ReasonCode("manual_revoke")
|
||||
freeAfterRevokeRecord := entitlement.PeriodRecord{
|
||||
RecordID: entitlement.EntitlementRecordID("entitlement-free-2"),
|
||||
UserID: record.UserID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
Source: common.Source("admin"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
ReasonCode: common.ReasonCode("manual_revoke"),
|
||||
StartsAt: revokeAt,
|
||||
CreatedAt: revokeAt,
|
||||
}
|
||||
freeAfterRevokeSnapshot := entitlement.CurrentSnapshot{
|
||||
UserID: record.UserID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
IsPaid: false,
|
||||
StartsAt: revokeAt,
|
||||
Source: common.Source("admin"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
ReasonCode: common.ReasonCode("manual_revoke"),
|
||||
UpdatedAt: revokeAt,
|
||||
}
|
||||
require.NoError(t, lifecycleStore.Revoke(context.Background(), ports.RevokeEntitlementInput{
|
||||
ExpectedCurrentSnapshot: extendedSnapshot,
|
||||
ExpectedCurrentRecord: extensionRecord,
|
||||
UpdatedCurrentRecord: revokedCurrentRecord,
|
||||
NewRecord: freeAfterRevokeRecord,
|
||||
NewSnapshot: freeAfterRevokeSnapshot,
|
||||
}))
|
||||
|
||||
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
|
||||
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
|
||||
requireSortedSetMissing(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, false), record.UserID.String())
|
||||
}
|
||||
|
||||
func TestPolicyLifecycleSyncsAdminIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
record := validAccountRecord()
|
||||
record.CreatedAt = now
|
||||
record.UpdatedAt = now
|
||||
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: record.Email,
|
||||
Account: record,
|
||||
Entitlement: validEntitlementSnapshot(record.UserID, now),
|
||||
EntitlementRecord: validEntitlementRecord(record.UserID, now),
|
||||
Reservation: raceNameReservation(record.UserID, record.RaceName, now),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
lifecycleStore := store.PolicyLifecycle()
|
||||
sanctionRecord := policy.SanctionRecord{
|
||||
RecordID: policy.SanctionRecordID("sanction-1"),
|
||||
UserID: record.UserID,
|
||||
SanctionCode: policy.SanctionCodeLoginBlock,
|
||||
Scope: common.Scope("auth"),
|
||||
ReasonCode: common.ReasonCode("manual_block"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
AppliedAt: now,
|
||||
}
|
||||
require.NoError(t, lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
|
||||
NewRecord: sanctionRecord,
|
||||
}))
|
||||
|
||||
requireSetContains(t, store, store.keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, false), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanJoinGame, false), record.UserID.String())
|
||||
|
||||
removedSanction := sanctionRecord
|
||||
removedAt := now.Add(time.Minute)
|
||||
removedSanction.RemovedAt = &removedAt
|
||||
removedSanction.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
|
||||
removedSanction.RemovedReasonCode = common.ReasonCode("manual_remove")
|
||||
require.NoError(t, lifecycleStore.RemoveSanction(context.Background(), ports.RemoveSanctionInput{
|
||||
ExpectedActiveRecord: sanctionRecord,
|
||||
UpdatedRecord: removedSanction,
|
||||
}))
|
||||
|
||||
requireSetNotContains(t, store, store.keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock), record.UserID.String())
|
||||
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true), record.UserID.String())
|
||||
|
||||
limitRecord := policy.LimitRecord{
|
||||
RecordID: policy.LimitRecordID("limit-1"),
|
||||
UserID: record.UserID,
|
||||
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
|
||||
Value: 5,
|
||||
ReasonCode: common.ReasonCode("manual_override"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
AppliedAt: now.Add(2 * time.Minute),
|
||||
}
|
||||
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
|
||||
NewRecord: limitRecord,
|
||||
}))
|
||||
|
||||
requireSetContains(t, store, store.keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames), record.UserID.String())
|
||||
|
||||
removedLimit := limitRecord
|
||||
limitRemovedAt := now.Add(3 * time.Minute)
|
||||
removedLimit.RemovedAt = &limitRemovedAt
|
||||
removedLimit.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
|
||||
removedLimit.RemovedReasonCode = common.ReasonCode("manual_remove")
|
||||
require.NoError(t, lifecycleStore.RemoveLimit(context.Background(), ports.RemoveLimitInput{
|
||||
ExpectedActiveRecord: limitRecord,
|
||||
UpdatedRecord: removedLimit,
|
||||
}))
|
||||
|
||||
requireSetNotContains(t, store, store.keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames), record.UserID.String())
|
||||
}
|
||||
|
||||
func TestAdminListerReevaluatesExpiredPaidSnapshots(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
userID := common.UserID("user-123")
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
record := validAccountRecord()
|
||||
record.CreatedAt = now.Add(-2 * time.Hour)
|
||||
record.UpdatedAt = record.CreatedAt
|
||||
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: record.Email,
|
||||
Account: record,
|
||||
Entitlement: validEntitlementSnapshot(userID, record.CreatedAt),
|
||||
EntitlementRecord: validEntitlementRecord(userID, record.CreatedAt),
|
||||
Reservation: raceNameReservation(userID, record.RaceName, record.CreatedAt),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
grantStartsAt := now.Add(-90 * time.Minute)
|
||||
grantEndsAt := now.Add(-30 * time.Minute)
|
||||
freeRecord := validEntitlementRecord(userID, record.CreatedAt)
|
||||
freeSnapshot := validEntitlementSnapshot(userID, record.CreatedAt)
|
||||
grantedRecord := paidEntitlementRecord(
|
||||
entitlement.EntitlementRecordID("entitlement-paid-expired"),
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
grantEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
grantedSnapshot := paidEntitlementSnapshot(
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
grantEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
closedFreeRecord := freeRecord
|
||||
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
|
||||
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
|
||||
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
|
||||
require.NoError(t, store.EntitlementLifecycle().Grant(context.Background(), ports.GrantEntitlementInput{
|
||||
ExpectedCurrentSnapshot: freeSnapshot,
|
||||
ExpectedCurrentRecord: freeRecord,
|
||||
UpdatedCurrentRecord: closedFreeRecord,
|
||||
NewRecord: grantedRecord,
|
||||
NewSnapshot: grantedSnapshot,
|
||||
}))
|
||||
|
||||
reader, err := entitlementsvc.NewReader(
|
||||
store.EntitlementSnapshots(),
|
||||
store.EntitlementLifecycle(),
|
||||
adminStoreClock{now: now},
|
||||
adminStoreIDGenerator{entitlementRecordID: entitlement.EntitlementRecordID("entitlement-free-after-expiry")},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
lister, err := adminusers.NewLister(store.Accounts(), reader, store.Sanctions(), store.Limits(), adminStoreClock{now: now}, store)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := lister.Execute(context.Background(), adminusers.ListUsersInput{PaidState: "free"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Items, 1)
|
||||
require.Equal(t, "user-123", result.Items[0].UserID)
|
||||
require.Equal(t, "free", result.Items[0].Entitlement.PlanCode)
|
||||
require.False(t, result.Items[0].Entitlement.IsPaid)
|
||||
|
||||
storedSnapshot, err := store.EntitlementSnapshots().GetByUserID(context.Background(), userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, entitlement.PlanCodeFree, storedSnapshot.PlanCode)
|
||||
require.False(t, storedSnapshot.IsPaid)
|
||||
}
|
||||
|
||||
type adminStoreClock struct {
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (clock adminStoreClock) Now() time.Time {
|
||||
return clock.now
|
||||
}
|
||||
|
||||
type adminStoreIDGenerator struct {
|
||||
entitlementRecordID entitlement.EntitlementRecordID
|
||||
}
|
||||
|
||||
func (generator adminStoreIDGenerator) NewUserID() (common.UserID, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (generator adminStoreIDGenerator) NewInitialRaceName() (common.RaceName, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (generator adminStoreIDGenerator) NewEntitlementRecordID() (entitlement.EntitlementRecordID, error) {
|
||||
return generator.entitlementRecordID, nil
|
||||
}
|
||||
|
||||
func (generator adminStoreIDGenerator) NewSanctionRecordID() (policy.SanctionRecordID, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (generator adminStoreIDGenerator) NewLimitRecordID() (policy.LimitRecordID, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func requireSetContains(t *testing.T, store *Store, key string, member string) {
|
||||
t.Helper()
|
||||
|
||||
exists, err := store.client.SIsMember(context.Background(), key, member).Result()
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "expected %q to contain %q", key, member)
|
||||
}
|
||||
|
||||
func requireSetNotContains(t *testing.T, store *Store, key string, member string) {
|
||||
t.Helper()
|
||||
|
||||
exists, err := store.client.SIsMember(context.Background(), key, member).Result()
|
||||
require.NoError(t, err)
|
||||
require.False(t, exists, "expected %q not to contain %q", key, member)
|
||||
}
|
||||
|
||||
func requireSortedSetScore(t *testing.T, store *Store, key string, member string, want float64) {
|
||||
t.Helper()
|
||||
|
||||
got, err := store.client.ZScore(context.Background(), key, member).Result()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func requireSortedSetMissing(t *testing.T, store *Store, key string, member string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := store.client.ZScore(context.Background(), key, member).Result()
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -0,0 +1,752 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type entitlementPeriodRecord struct {
|
||||
RecordID string `json:"record_id"`
|
||||
UserID string `json:"user_id"`
|
||||
PlanCode string `json:"plan_code"`
|
||||
Source string `json:"source"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID *string `json:"actor_id,omitempty"`
|
||||
ReasonCode string `json:"reason_code"`
|
||||
StartsAt string `json:"starts_at"`
|
||||
EndsAt *string `json:"ends_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ClosedAt *string `json:"closed_at,omitempty"`
|
||||
ClosedByType *string `json:"closed_by_type,omitempty"`
|
||||
ClosedByID *string `json:"closed_by_id,omitempty"`
|
||||
ClosedReasonCode *string `json:"closed_reason_code,omitempty"`
|
||||
}
|
||||
|
||||
// CreateEntitlementRecord stores one new entitlement history record.
|
||||
func (store *Store) CreateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create entitlement record in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalEntitlementPeriodRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create entitlement record in redis: %w", err)
|
||||
}
|
||||
|
||||
recordKey := store.keyspace.EntitlementRecord(record.RecordID)
|
||||
historyKey := store.keyspace.EntitlementHistory(record.UserID)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "create entitlement record in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil {
|
||||
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, err)
|
||||
}
|
||||
|
||||
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, recordKey, payload, 0)
|
||||
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
||||
Score: float64(record.StartsAt.UTC().UnixMicro()),
|
||||
Member: record.RecordID.String(),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, recordKey, historyKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetEntitlementRecordByRecordID returns the entitlement history record
|
||||
// identified by recordID.
|
||||
func (store *Store) GetEntitlementRecordByRecordID(
|
||||
ctx context.Context,
|
||||
recordID entitlement.EntitlementRecordID,
|
||||
) (entitlement.PeriodRecord, error) {
|
||||
if err := recordID.Validate(); err != nil {
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement record by record id from redis")
|
||||
if err != nil {
|
||||
return entitlement.PeriodRecord{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
record, err := store.loadEntitlementRecord(operationCtx, store.client, recordID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id %q from redis: %w", recordID, ports.ErrNotFound)
|
||||
default:
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id %q from redis: %w", recordID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// ListEntitlementRecordsByUserID returns every entitlement history record
|
||||
// owned by userID.
|
||||
func (store *Store) ListEntitlementRecordsByUserID(
|
||||
ctx context.Context,
|
||||
userID common.UserID,
|
||||
) ([]entitlement.PeriodRecord, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("list entitlement records by user id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "list entitlement records by user id from redis")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
recordIDs, err := store.client.ZRange(operationCtx, store.keyspace.EntitlementHistory(userID), 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list entitlement records by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
|
||||
records := make([]entitlement.PeriodRecord, 0, len(recordIDs))
|
||||
for _, rawRecordID := range recordIDs {
|
||||
record, err := store.loadEntitlementRecord(operationCtx, store.client, entitlement.EntitlementRecordID(rawRecordID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list entitlement records by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// UpdateEntitlementRecord replaces one stored entitlement history record.
|
||||
func (store *Store) UpdateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("update entitlement record in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalEntitlementPeriodRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update entitlement record in redis: %w", err)
|
||||
}
|
||||
|
||||
recordKey := store.keyspace.EntitlementRecord(record.RecordID)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "update entitlement record in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
if _, err := store.loadEntitlementRecord(operationCtx, tx, record.RecordID); err != nil {
|
||||
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, err)
|
||||
}
|
||||
|
||||
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, recordKey, payload, 0)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, recordKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GrantEntitlement atomically closes the current free history record, creates
|
||||
// one paid history record, and replaces the current snapshot.
|
||||
func (store *Store) GrantEntitlement(ctx context.Context, input ports.GrantEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("grant entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
updatedCurrentRecordPayload, err := marshalEntitlementPeriodRecord(input.UpdatedCurrentRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant entitlement in redis: %w", err)
|
||||
}
|
||||
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant entitlement in redis: %w", err)
|
||||
}
|
||||
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
currentRecordKey := store.keyspace.EntitlementRecord(input.ExpectedCurrentRecord.RecordID)
|
||||
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
|
||||
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
|
||||
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
|
||||
watchedKeys := append(
|
||||
[]string{currentRecordKey, newRecordKey, historyKey, snapshotKey},
|
||||
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "grant entitlement in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
storedCurrentRecord, err := store.loadEntitlementRecord(operationCtx, tx, input.ExpectedCurrentRecord.RecordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if !equalEntitlementPeriodRecords(storedCurrentRecord, input.ExpectedCurrentRecord) {
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
}
|
||||
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, currentRecordKey, updatedCurrentRecordPayload, 0)
|
||||
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
|
||||
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
||||
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
|
||||
Member: input.NewRecord.RecordID.String(),
|
||||
})
|
||||
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
|
||||
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
|
||||
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExtendEntitlement atomically appends one paid history segment and replaces
|
||||
// the current paid snapshot.
|
||||
func (store *Store) ExtendEntitlement(ctx context.Context, input ports.ExtendEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("extend entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extend entitlement in redis: %w", err)
|
||||
}
|
||||
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extend entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
|
||||
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
|
||||
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
|
||||
watchedKeys := append(
|
||||
[]string{newRecordKey, historyKey, snapshotKey},
|
||||
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "extend entitlement in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
|
||||
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
}
|
||||
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
|
||||
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
|
||||
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
||||
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
|
||||
Member: input.NewRecord.RecordID.String(),
|
||||
})
|
||||
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
|
||||
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
|
||||
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeEntitlement atomically closes the current paid history record,
|
||||
// creates one free history record, and replaces the current snapshot.
|
||||
func (store *Store) RevokeEntitlement(ctx context.Context, input ports.RevokeEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
updatedCurrentRecordPayload, err := marshalEntitlementPeriodRecord(input.UpdatedCurrentRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke entitlement in redis: %w", err)
|
||||
}
|
||||
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke entitlement in redis: %w", err)
|
||||
}
|
||||
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
currentRecordKey := store.keyspace.EntitlementRecord(input.ExpectedCurrentRecord.RecordID)
|
||||
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
|
||||
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
|
||||
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
|
||||
watchedKeys := append(
|
||||
[]string{currentRecordKey, newRecordKey, historyKey, snapshotKey},
|
||||
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "revoke entitlement in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
storedCurrentRecord, err := store.loadEntitlementRecord(operationCtx, tx, input.ExpectedCurrentRecord.RecordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
if !equalEntitlementPeriodRecords(storedCurrentRecord, input.ExpectedCurrentRecord) {
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
}
|
||||
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, currentRecordKey, updatedCurrentRecordPayload, 0)
|
||||
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
|
||||
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
||||
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
|
||||
Member: input.NewRecord.RecordID.String(),
|
||||
})
|
||||
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
|
||||
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
|
||||
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RepairExpiredEntitlement atomically replaces one expired finite paid
|
||||
// snapshot with a materialized free state.
|
||||
func (store *Store) RepairExpiredEntitlement(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("repair expired entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repair expired entitlement in redis: %w", err)
|
||||
}
|
||||
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repair expired entitlement in redis: %w", err)
|
||||
}
|
||||
|
||||
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
|
||||
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
|
||||
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
|
||||
watchedKeys := append(
|
||||
[]string{newRecordKey, historyKey, snapshotKey},
|
||||
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "repair expired entitlement in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedExpiredSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
|
||||
}
|
||||
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedExpiredSnapshot) {
|
||||
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, ports.ErrConflict)
|
||||
}
|
||||
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
|
||||
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
|
||||
}
|
||||
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
|
||||
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
||||
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
|
||||
Member: input.NewRecord.RecordID.String(),
|
||||
})
|
||||
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
|
||||
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
|
||||
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) loadEntitlementRecord(
|
||||
ctx context.Context,
|
||||
getter bytesGetter,
|
||||
recordID entitlement.EntitlementRecordID,
|
||||
) (entitlement.PeriodRecord, error) {
|
||||
payload, err := getter.Get(ctx, store.keyspace.EntitlementRecord(recordID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return entitlement.PeriodRecord{}, ports.ErrNotFound
|
||||
case err != nil:
|
||||
return entitlement.PeriodRecord{}, err
|
||||
}
|
||||
|
||||
return decodeEntitlementPeriodRecord(payload)
|
||||
}
|
||||
|
||||
func marshalEntitlementPeriodRecord(record entitlement.PeriodRecord) ([]byte, error) {
|
||||
encoded := entitlementPeriodRecord{
|
||||
RecordID: record.RecordID.String(),
|
||||
UserID: record.UserID.String(),
|
||||
PlanCode: string(record.PlanCode),
|
||||
Source: record.Source.String(),
|
||||
ActorType: record.Actor.Type.String(),
|
||||
ReasonCode: record.ReasonCode.String(),
|
||||
StartsAt: record.StartsAt.UTC().Format(time.RFC3339Nano),
|
||||
CreatedAt: record.CreatedAt.UTC().Format(time.RFC3339Nano),
|
||||
}
|
||||
if !record.Actor.ID.IsZero() {
|
||||
value := record.Actor.ID.String()
|
||||
encoded.ActorID = &value
|
||||
}
|
||||
if record.EndsAt != nil {
|
||||
value := record.EndsAt.UTC().Format(time.RFC3339Nano)
|
||||
encoded.EndsAt = &value
|
||||
}
|
||||
if record.ClosedAt != nil {
|
||||
value := record.ClosedAt.UTC().Format(time.RFC3339Nano)
|
||||
encoded.ClosedAt = &value
|
||||
}
|
||||
if !record.ClosedBy.Type.IsZero() {
|
||||
value := record.ClosedBy.Type.String()
|
||||
encoded.ClosedByType = &value
|
||||
}
|
||||
if !record.ClosedBy.ID.IsZero() {
|
||||
value := record.ClosedBy.ID.String()
|
||||
encoded.ClosedByID = &value
|
||||
}
|
||||
if !record.ClosedReasonCode.IsZero() {
|
||||
value := record.ClosedReasonCode.String()
|
||||
encoded.ClosedReasonCode = &value
|
||||
}
|
||||
|
||||
return json.Marshal(encoded)
|
||||
}
|
||||
|
||||
func decodeEntitlementPeriodRecord(payload []byte) (entitlement.PeriodRecord, error) {
|
||||
var encoded entitlementPeriodRecord
|
||||
if err := decodeJSONPayload(payload, &encoded); err != nil {
|
||||
return entitlement.PeriodRecord{}, err
|
||||
}
|
||||
|
||||
startsAt, err := time.Parse(time.RFC3339Nano, encoded.StartsAt)
|
||||
if err != nil {
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record starts_at: %w", err)
|
||||
}
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, encoded.CreatedAt)
|
||||
if err != nil {
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record created_at: %w", err)
|
||||
}
|
||||
|
||||
record := entitlement.PeriodRecord{
|
||||
RecordID: entitlement.EntitlementRecordID(encoded.RecordID),
|
||||
UserID: common.UserID(encoded.UserID),
|
||||
PlanCode: entitlement.PlanCode(encoded.PlanCode),
|
||||
Source: common.Source(encoded.Source),
|
||||
Actor: common.ActorRef{Type: common.ActorType(encoded.ActorType)},
|
||||
ReasonCode: common.ReasonCode(encoded.ReasonCode),
|
||||
StartsAt: startsAt.UTC(),
|
||||
CreatedAt: createdAt.UTC(),
|
||||
}
|
||||
if encoded.ActorID != nil {
|
||||
record.Actor.ID = common.ActorID(*encoded.ActorID)
|
||||
}
|
||||
if encoded.EndsAt != nil {
|
||||
value, err := time.Parse(time.RFC3339Nano, *encoded.EndsAt)
|
||||
if err != nil {
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record ends_at: %w", err)
|
||||
}
|
||||
value = value.UTC()
|
||||
record.EndsAt = &value
|
||||
}
|
||||
if encoded.ClosedAt != nil {
|
||||
value, err := time.Parse(time.RFC3339Nano, *encoded.ClosedAt)
|
||||
if err != nil {
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record closed_at: %w", err)
|
||||
}
|
||||
value = value.UTC()
|
||||
record.ClosedAt = &value
|
||||
}
|
||||
if encoded.ClosedByType != nil {
|
||||
record.ClosedBy.Type = common.ActorType(*encoded.ClosedByType)
|
||||
}
|
||||
if encoded.ClosedByID != nil {
|
||||
record.ClosedBy.ID = common.ActorID(*encoded.ClosedByID)
|
||||
}
|
||||
if encoded.ClosedReasonCode != nil {
|
||||
record.ClosedReasonCode = common.ReasonCode(*encoded.ClosedReasonCode)
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func equalEntitlementSnapshots(left entitlement.CurrentSnapshot, right entitlement.CurrentSnapshot) bool {
|
||||
return left.UserID == right.UserID &&
|
||||
left.PlanCode == right.PlanCode &&
|
||||
left.IsPaid == right.IsPaid &&
|
||||
left.StartsAt.Equal(right.StartsAt) &&
|
||||
equalOptionalTime(left.EndsAt, right.EndsAt) &&
|
||||
left.Source == right.Source &&
|
||||
left.Actor == right.Actor &&
|
||||
left.ReasonCode == right.ReasonCode &&
|
||||
left.UpdatedAt.Equal(right.UpdatedAt)
|
||||
}
|
||||
|
||||
func equalEntitlementPeriodRecords(left entitlement.PeriodRecord, right entitlement.PeriodRecord) bool {
|
||||
return left.RecordID == right.RecordID &&
|
||||
left.UserID == right.UserID &&
|
||||
left.PlanCode == right.PlanCode &&
|
||||
left.Source == right.Source &&
|
||||
left.Actor == right.Actor &&
|
||||
left.ReasonCode == right.ReasonCode &&
|
||||
left.StartsAt.Equal(right.StartsAt) &&
|
||||
equalOptionalTime(left.EndsAt, right.EndsAt) &&
|
||||
left.CreatedAt.Equal(right.CreatedAt) &&
|
||||
equalOptionalTime(left.ClosedAt, right.ClosedAt) &&
|
||||
left.ClosedBy == right.ClosedBy &&
|
||||
left.ClosedReasonCode == right.ClosedReasonCode
|
||||
}
|
||||
|
||||
func equalOptionalTime(left *time.Time, right *time.Time) bool {
|
||||
switch {
|
||||
case left == nil && right == nil:
|
||||
return true
|
||||
case left == nil || right == nil:
|
||||
return false
|
||||
default:
|
||||
return left.Equal(*right)
|
||||
}
|
||||
}
|
||||
|
||||
// EntitlementHistoryStore adapts Store to the existing
|
||||
// EntitlementHistoryStore port.
|
||||
type EntitlementHistoryStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// EntitlementHistory returns one adapter that exposes the entitlement-history
|
||||
// store port over Store.
|
||||
func (store *Store) EntitlementHistory() *EntitlementHistoryStore {
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &EntitlementHistoryStore{store: store}
|
||||
}
|
||||
|
||||
// Create stores one new entitlement history record.
|
||||
func (adapter *EntitlementHistoryStore) Create(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
return adapter.store.CreateEntitlementRecord(ctx, record)
|
||||
}
|
||||
|
||||
// GetByRecordID returns the entitlement history record identified by recordID.
|
||||
func (adapter *EntitlementHistoryStore) GetByRecordID(
|
||||
ctx context.Context,
|
||||
recordID entitlement.EntitlementRecordID,
|
||||
) (entitlement.PeriodRecord, error) {
|
||||
return adapter.store.GetEntitlementRecordByRecordID(ctx, recordID)
|
||||
}
|
||||
|
||||
// ListByUserID returns every entitlement history record owned by userID.
|
||||
func (adapter *EntitlementHistoryStore) ListByUserID(
|
||||
ctx context.Context,
|
||||
userID common.UserID,
|
||||
) ([]entitlement.PeriodRecord, error) {
|
||||
return adapter.store.ListEntitlementRecordsByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// Update replaces one stored entitlement history record.
|
||||
func (adapter *EntitlementHistoryStore) Update(ctx context.Context, record entitlement.PeriodRecord) error {
|
||||
return adapter.store.UpdateEntitlementRecord(ctx, record)
|
||||
}
|
||||
|
||||
var _ ports.EntitlementHistoryStore = (*EntitlementHistoryStore)(nil)
|
||||
|
||||
// EntitlementLifecycleStore adapts Store to the existing
|
||||
// EntitlementLifecycleStore port.
|
||||
type EntitlementLifecycleStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// EntitlementLifecycle returns one adapter that exposes the atomic
|
||||
// entitlement-lifecycle store port over Store.
|
||||
func (store *Store) EntitlementLifecycle() *EntitlementLifecycleStore {
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &EntitlementLifecycleStore{store: store}
|
||||
}
|
||||
|
||||
// Grant atomically applies one free-to-paid transition.
|
||||
func (adapter *EntitlementLifecycleStore) Grant(ctx context.Context, input ports.GrantEntitlementInput) error {
|
||||
return adapter.store.GrantEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
// Extend atomically appends one paid extension segment and updates the current
|
||||
// snapshot.
|
||||
func (adapter *EntitlementLifecycleStore) Extend(ctx context.Context, input ports.ExtendEntitlementInput) error {
|
||||
return adapter.store.ExtendEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
// Revoke atomically applies one paid-to-free transition.
|
||||
func (adapter *EntitlementLifecycleStore) Revoke(ctx context.Context, input ports.RevokeEntitlementInput) error {
|
||||
return adapter.store.RevokeEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
// RepairExpired atomically repairs one expired finite paid snapshot.
|
||||
func (adapter *EntitlementLifecycleStore) RepairExpired(
|
||||
ctx context.Context,
|
||||
input ports.RepairExpiredEntitlementInput,
|
||||
) error {
|
||||
return adapter.store.RepairExpiredEntitlement(ctx, input)
|
||||
}
|
||||
|
||||
var _ ports.EntitlementLifecycleStore = (*EntitlementLifecycleStore)(nil)
|
||||
@@ -0,0 +1,137 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/adapters/redisstate"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ListUserIDs returns one deterministic page of user identifiers ordered by
|
||||
// `created_at desc`, then `user_id desc`.
|
||||
func (store *Store) ListUserIDs(ctx context.Context, input ports.ListUsersInput) (ports.ListUsersResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "list users in redis")
|
||||
if err != nil {
|
||||
return ports.ListUsersResult{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
startIndex := int64(0)
|
||||
filters := userListFiltersFromPorts(input.Filters)
|
||||
if input.PageToken != "" {
|
||||
cursor, err := redisstate.DecodePageToken(input.PageToken, filters)
|
||||
if err != nil {
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
|
||||
}
|
||||
|
||||
score, err := store.client.ZScore(operationCtx, store.keyspace.CreatedAtIndex(), cursor.UserID.String()).Result()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
|
||||
case err != nil:
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
|
||||
}
|
||||
if !time.UnixMicro(int64(score)).UTC().Equal(cursor.CreatedAt.UTC()) {
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
|
||||
}
|
||||
|
||||
rank, err := store.client.ZRevRank(operationCtx, store.keyspace.CreatedAtIndex(), cursor.UserID.String()).Result()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
|
||||
case err != nil:
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
|
||||
}
|
||||
|
||||
startIndex = rank + 1
|
||||
}
|
||||
|
||||
rawPage, err := store.client.ZRevRangeWithScores(
|
||||
operationCtx,
|
||||
store.keyspace.CreatedAtIndex(),
|
||||
startIndex,
|
||||
startIndex+int64(input.PageSize),
|
||||
).Result()
|
||||
if err != nil {
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
|
||||
}
|
||||
|
||||
result := ports.ListUsersResult{
|
||||
UserIDs: make([]common.UserID, 0, min(len(rawPage), input.PageSize)),
|
||||
}
|
||||
|
||||
visibleCount := min(len(rawPage), input.PageSize)
|
||||
for index := 0; index < visibleCount; index++ {
|
||||
userID, err := memberUserID(rawPage[index].Member)
|
||||
if err != nil {
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
|
||||
}
|
||||
result.UserIDs = append(result.UserIDs, userID)
|
||||
}
|
||||
|
||||
if len(rawPage) > input.PageSize {
|
||||
lastVisible := rawPage[input.PageSize-1]
|
||||
lastUserID, err := memberUserID(lastVisible.Member)
|
||||
if err != nil {
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
|
||||
}
|
||||
token, err := redisstate.EncodePageToken(redisstate.PageCursor{
|
||||
CreatedAt: time.UnixMicro(int64(lastVisible.Score)).UTC(),
|
||||
UserID: lastUserID,
|
||||
}, filters)
|
||||
if err != nil {
|
||||
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
|
||||
}
|
||||
result.NextPageToken = token
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func userListFiltersFromPorts(filters ports.UserListFilters) redisstate.UserListFilters {
|
||||
return redisstate.UserListFilters{
|
||||
PaidState: filters.PaidState,
|
||||
PaidExpiresBefore: filters.PaidExpiresBefore,
|
||||
PaidExpiresAfter: filters.PaidExpiresAfter,
|
||||
DeclaredCountry: filters.DeclaredCountry,
|
||||
SanctionCode: filters.SanctionCode,
|
||||
LimitCode: filters.LimitCode,
|
||||
CanLogin: filters.CanLogin,
|
||||
CanCreatePrivateGame: filters.CanCreatePrivateGame,
|
||||
CanJoinGame: filters.CanJoinGame,
|
||||
}
|
||||
}
|
||||
|
||||
func memberUserID(member any) (common.UserID, error) {
|
||||
value, ok := member.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unexpected created-at index member type %T", member)
|
||||
}
|
||||
|
||||
userID := common.UserID(value)
|
||||
if err := userID.Validate(); err != nil {
|
||||
return "", fmt.Errorf("created-at index member user id: %w", err)
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func min(left int, right int) int {
|
||||
if left < right {
|
||||
return left
|
||||
}
|
||||
|
||||
return right
|
||||
}
|
||||
|
||||
var _ ports.UserListStore = (*Store)(nil)
|
||||
@@ -0,0 +1,445 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/policy"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ApplySanction atomically creates one new active sanction record.
|
||||
func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("apply sanction in redis: %w", err)
|
||||
}
|
||||
|
||||
recordPayload, err := marshalSanctionRecord(input.NewRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply sanction in redis: %w", err)
|
||||
}
|
||||
|
||||
recordKey := store.keyspace.SanctionRecord(input.NewRecord.RecordID)
|
||||
historyKey := store.keyspace.SanctionHistory(input.NewRecord.UserID)
|
||||
activeKey := store.keyspace.ActiveSanction(input.NewRecord.UserID, input.NewRecord.SanctionCode)
|
||||
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewRecord.UserID)
|
||||
watchedKeys := append(
|
||||
[]string{recordKey, historyKey, activeKey, snapshotKey},
|
||||
store.activeSanctionWatchKeys(input.NewRecord.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "apply sanction in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil {
|
||||
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
|
||||
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.NewRecord.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewRecord.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
activeSanctionCodes[input.NewRecord.SanctionCode] = struct{}{}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, recordKey, recordPayload, 0)
|
||||
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
||||
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
|
||||
Member: input.NewRecord.RecordID.String(),
|
||||
})
|
||||
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
|
||||
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeSanctionCodes)
|
||||
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveSanction atomically removes one active sanction record.
|
||||
func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("remove sanction in redis: %w", err)
|
||||
}
|
||||
|
||||
updatedPayload, err := marshalSanctionRecord(input.UpdatedRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove sanction in redis: %w", err)
|
||||
}
|
||||
|
||||
recordKey := store.keyspace.SanctionRecord(input.ExpectedActiveRecord.RecordID)
|
||||
activeKey := store.keyspace.ActiveSanction(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.SanctionCode)
|
||||
snapshotKey := store.keyspace.EntitlementSnapshot(input.ExpectedActiveRecord.UserID)
|
||||
watchedKeys := append(
|
||||
[]string{recordKey, activeKey, snapshotKey},
|
||||
store.activeSanctionWatchKeys(input.ExpectedActiveRecord.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "remove sanction in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
activeRecordID, err := store.loadActiveSanctionRecordID(operationCtx, tx, activeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
if activeRecordID != input.ExpectedActiveRecord.RecordID {
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
storedRecord, err := store.loadSanctionRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
if !equalSanctionRecords(storedRecord, input.ExpectedActiveRecord) {
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
||||
}
|
||||
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedActiveRecord.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
delete(activeSanctionCodes, input.ExpectedActiveRecord.SanctionCode)
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
|
||||
pipe.Del(operationCtx, activeKey)
|
||||
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeSanctionCodes)
|
||||
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetLimit atomically creates or replaces one active limit record.
|
||||
func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("set limit in redis: %w", err)
|
||||
}
|
||||
|
||||
newRecordPayload, err := marshalLimitRecord(input.NewRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set limit in redis: %w", err)
|
||||
}
|
||||
|
||||
newRecordKey := store.keyspace.LimitRecord(input.NewRecord.RecordID)
|
||||
historyKey := store.keyspace.LimitHistory(input.NewRecord.UserID)
|
||||
activeKey := store.keyspace.ActiveLimit(input.NewRecord.UserID, input.NewRecord.LimitCode)
|
||||
watchedKeys := append(
|
||||
[]string{newRecordKey, historyKey, activeKey},
|
||||
store.activeLimitWatchKeys(input.NewRecord.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "set limit in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
if input.ExpectedActiveRecord != nil {
|
||||
watchedKeys = append(watchedKeys, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID))
|
||||
}
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
|
||||
var updatedPayload []byte
|
||||
if input.ExpectedActiveRecord == nil {
|
||||
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
} else {
|
||||
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
if activeRecordID != input.ExpectedActiveRecord.RecordID {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
if !equalLimitRecords(storedRecord, *input.ExpectedActiveRecord) {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
updatedPayload, err = marshalLimitRecord(*input.UpdatedActiveRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
}
|
||||
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.NewRecord.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
activeLimitCodes[input.NewRecord.LimitCode] = struct{}{}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
if input.ExpectedActiveRecord != nil {
|
||||
pipe.Set(operationCtx, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID), updatedPayload, 0)
|
||||
}
|
||||
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
|
||||
pipe.ZAdd(operationCtx, historyKey, redis.Z{
|
||||
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
|
||||
Member: input.NewRecord.RecordID.String(),
|
||||
})
|
||||
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
|
||||
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeLimitCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveLimit atomically removes one active limit record.
|
||||
func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
|
||||
if err := input.Validate(); err != nil {
|
||||
return fmt.Errorf("remove limit in redis: %w", err)
|
||||
}
|
||||
|
||||
updatedPayload, err := marshalLimitRecord(input.UpdatedRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove limit in redis: %w", err)
|
||||
}
|
||||
|
||||
recordKey := store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID)
|
||||
activeKey := store.keyspace.ActiveLimit(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.LimitCode)
|
||||
watchedKeys := append(
|
||||
[]string{recordKey, activeKey},
|
||||
store.activeLimitWatchKeys(input.ExpectedActiveRecord.UserID)...,
|
||||
)
|
||||
|
||||
operationCtx, cancel, err := store.operationContext(ctx, "remove limit in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
if activeRecordID != input.ExpectedActiveRecord.RecordID {
|
||||
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
if !equalLimitRecords(storedRecord, input.ExpectedActiveRecord) {
|
||||
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
||||
}
|
||||
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
delete(activeLimitCodes, input.ExpectedActiveRecord.LimitCode)
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
|
||||
pipe.Del(operationCtx, activeKey)
|
||||
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeLimitCodes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, watchedKeys...)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) loadActiveSanctionRecordID(
|
||||
ctx context.Context,
|
||||
getter bytesGetter,
|
||||
key string,
|
||||
) (policy.SanctionRecordID, error) {
|
||||
value, err := getter.Get(ctx, key).Result()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "", ports.ErrNotFound
|
||||
case err != nil:
|
||||
return "", err
|
||||
}
|
||||
|
||||
recordID := policy.SanctionRecordID(value)
|
||||
if err := recordID.Validate(); err != nil {
|
||||
return "", fmt.Errorf("active sanction record id: %w", err)
|
||||
}
|
||||
|
||||
return recordID, nil
|
||||
}
|
||||
|
||||
func (store *Store) loadActiveLimitRecordID(
|
||||
ctx context.Context,
|
||||
getter bytesGetter,
|
||||
key string,
|
||||
) (policy.LimitRecordID, error) {
|
||||
value, err := getter.Get(ctx, key).Result()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "", ports.ErrNotFound
|
||||
case err != nil:
|
||||
return "", err
|
||||
}
|
||||
|
||||
recordID := policy.LimitRecordID(value)
|
||||
if err := recordID.Validate(); err != nil {
|
||||
return "", fmt.Errorf("active limit record id: %w", err)
|
||||
}
|
||||
|
||||
return recordID, nil
|
||||
}
|
||||
|
||||
func setActiveSlot(
|
||||
pipe redis.Pipeliner,
|
||||
ctx context.Context,
|
||||
key string,
|
||||
recordID string,
|
||||
expiresAt *time.Time,
|
||||
) {
|
||||
pipe.Set(ctx, key, recordID, 0)
|
||||
if expiresAt != nil {
|
||||
pipe.PExpireAt(ctx, key, expiresAt.UTC())
|
||||
}
|
||||
}
|
||||
|
||||
func equalSanctionRecords(left policy.SanctionRecord, right policy.SanctionRecord) bool {
|
||||
return left.RecordID == right.RecordID &&
|
||||
left.UserID == right.UserID &&
|
||||
left.SanctionCode == right.SanctionCode &&
|
||||
left.Scope == right.Scope &&
|
||||
left.ReasonCode == right.ReasonCode &&
|
||||
left.Actor == right.Actor &&
|
||||
left.AppliedAt.Equal(right.AppliedAt) &&
|
||||
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
|
||||
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
|
||||
left.RemovedBy == right.RemovedBy &&
|
||||
left.RemovedReasonCode == right.RemovedReasonCode
|
||||
}
|
||||
|
||||
func equalLimitRecords(left policy.LimitRecord, right policy.LimitRecord) bool {
|
||||
return left.RecordID == right.RecordID &&
|
||||
left.UserID == right.UserID &&
|
||||
left.LimitCode == right.LimitCode &&
|
||||
left.Value == right.Value &&
|
||||
left.ReasonCode == right.ReasonCode &&
|
||||
left.Actor == right.Actor &&
|
||||
left.AppliedAt.Equal(right.AppliedAt) &&
|
||||
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
|
||||
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
|
||||
left.RemovedBy == right.RemovedBy &&
|
||||
left.RemovedReasonCode == right.RemovedReasonCode
|
||||
}
|
||||
|
||||
// PolicyLifecycleStore adapts Store to the existing PolicyLifecycleStore
|
||||
// port.
|
||||
type PolicyLifecycleStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// PolicyLifecycle returns one adapter that exposes the atomic policy-lifecycle
|
||||
// store port over Store.
|
||||
func (store *Store) PolicyLifecycle() *PolicyLifecycleStore {
|
||||
if store == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &PolicyLifecycleStore{store: store}
|
||||
}
|
||||
|
||||
// ApplySanction atomically creates one new active sanction record.
|
||||
func (adapter *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
|
||||
return adapter.store.ApplySanction(ctx, input)
|
||||
}
|
||||
|
||||
// RemoveSanction atomically removes one active sanction record.
|
||||
func (adapter *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
|
||||
return adapter.store.RemoveSanction(ctx, input)
|
||||
}
|
||||
|
||||
// SetLimit atomically creates or replaces one active limit record.
|
||||
func (adapter *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
|
||||
return adapter.store.SetLimit(ctx, input)
|
||||
}
|
||||
|
||||
// RemoveLimit atomically removes one active limit record.
|
||||
func (adapter *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
|
||||
return adapter.store.RemoveLimit(ctx, input)
|
||||
}
|
||||
|
||||
var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,930 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/account"
|
||||
"galaxy/user/internal/domain/authblock"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
"galaxy/user/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountStoreCreateAndLookups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
accountStore := store.Accounts()
|
||||
|
||||
record := validAccountRecord()
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
|
||||
|
||||
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, record, byUserID)
|
||||
|
||||
byEmail, err := accountStore.GetByEmail(context.Background(), record.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, record, byEmail)
|
||||
|
||||
byRaceName, err := accountStore.GetByRaceName(context.Background(), record.RaceName)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, record, byRaceName)
|
||||
|
||||
exists, err := accountStore.ExistsByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
|
||||
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(record.RaceName))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, record.UserID, reservation.UserID)
|
||||
require.Equal(t, record.RaceName, reservation.RaceName)
|
||||
}
|
||||
|
||||
func TestBlockedEmailStoreUpsertAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
blockedEmailStore := store.BlockedEmails()
|
||||
|
||||
record := authblock.BlockedEmailSubject{
|
||||
Email: common.Email("blocked@example.com"),
|
||||
ReasonCode: common.ReasonCode("policy_blocked"),
|
||||
BlockedAt: time.Unix(1_775_240_100, 0).UTC(),
|
||||
ResolvedUserID: common.UserID("user-123"),
|
||||
}
|
||||
require.NoError(t, blockedEmailStore.Upsert(context.Background(), record))
|
||||
|
||||
got, err := blockedEmailStore.GetByEmail(context.Background(), record.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, record, got)
|
||||
}
|
||||
|
||||
func TestEnsureResolveAndBlockFlows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
accountRecord := validAccountRecord()
|
||||
entitlementSnapshot := validEntitlementSnapshot(accountRecord.UserID, now)
|
||||
|
||||
created, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: accountRecord.Email,
|
||||
Account: accountRecord,
|
||||
Entitlement: entitlementSnapshot,
|
||||
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
|
||||
Reservation: raceNameReservation(accountRecord.UserID, accountRecord.RaceName, accountRecord.UpdatedAt),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.EnsureByEmailOutcomeCreated, created.Outcome)
|
||||
|
||||
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(accountRecord.RaceName))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, accountRecord.UserID, reservation.UserID)
|
||||
|
||||
entitlementHistory, err := store.ListEntitlementRecordsByUserID(context.Background(), accountRecord.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entitlementHistory, 1)
|
||||
require.Equal(t, validEntitlementRecord(accountRecord.UserID, now), entitlementHistory[0])
|
||||
|
||||
resolved, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthResolutionKindExisting, resolved.Kind)
|
||||
|
||||
blockedByUserID, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
|
||||
UserID: accountRecord.UserID,
|
||||
ReasonCode: common.ReasonCode("policy_blocked"),
|
||||
BlockedAt: now.Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthBlockOutcomeBlocked, blockedByUserID.Outcome)
|
||||
|
||||
repeatedBlock, err := store.BlockByEmail(context.Background(), ports.BlockByEmailInput{
|
||||
Email: accountRecord.Email,
|
||||
ReasonCode: common.ReasonCode("policy_blocked"),
|
||||
BlockedAt: now.Add(2 * time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthBlockOutcomeAlreadyBlocked, repeatedBlock.Outcome)
|
||||
require.Equal(t, accountRecord.UserID, repeatedBlock.UserID)
|
||||
|
||||
blockedResolution, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthResolutionKindBlocked, blockedResolution.Kind)
|
||||
|
||||
ensureBlocked, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: accountRecord.Email,
|
||||
Account: accountRecord,
|
||||
Entitlement: entitlementSnapshot,
|
||||
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
|
||||
Reservation: raceNameReservation(accountRecord.UserID, accountRecord.RaceName, accountRecord.UpdatedAt),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, ensureBlocked.Outcome)
|
||||
}
|
||||
|
||||
func TestBlockedEmailWithoutUserPreventsEnsureCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
accountRecord := validAccountRecord()
|
||||
entitlementSnapshot := validEntitlementSnapshot(accountRecord.UserID, now)
|
||||
|
||||
blocked, err := store.BlockByEmail(context.Background(), ports.BlockByEmailInput{
|
||||
Email: accountRecord.Email,
|
||||
ReasonCode: common.ReasonCode("policy_blocked"),
|
||||
BlockedAt: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthBlockOutcomeBlocked, blocked.Outcome)
|
||||
require.True(t, blocked.UserID.IsZero())
|
||||
|
||||
resolved, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthResolutionKindBlocked, resolved.Kind)
|
||||
|
||||
ensured, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: accountRecord.Email,
|
||||
Account: accountRecord,
|
||||
Entitlement: entitlementSnapshot,
|
||||
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
|
||||
Reservation: raceNameReservation(accountRecord.UserID, accountRecord.RaceName, accountRecord.UpdatedAt),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, ensured.Outcome)
|
||||
|
||||
exists, err := store.ExistsByUserID(context.Background(), accountRecord.UserID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestEnsureByEmailExistingDoesNotOverwriteStoredSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
createdAt := time.Unix(1_775_240_000, 0).UTC()
|
||||
existingAccount := account.UserAccount{
|
||||
UserID: common.UserID("user-existing"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
RaceName: common.RaceName("Pilot Nova"),
|
||||
PreferredLanguage: common.LanguageTag("en"),
|
||||
TimeZone: common.TimeZoneName("Europe/Kaliningrad"),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: createdAt,
|
||||
}
|
||||
require.NoError(t, store.Create(context.Background(), createAccountInput(existingAccount)))
|
||||
|
||||
result, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
|
||||
Email: existingAccount.Email,
|
||||
Account: account.UserAccount{
|
||||
UserID: common.UserID("user-created"),
|
||||
Email: existingAccount.Email,
|
||||
RaceName: common.RaceName("player-new123"),
|
||||
PreferredLanguage: common.LanguageTag("fr-FR"),
|
||||
TimeZone: common.TimeZoneName("UTC"),
|
||||
CreatedAt: createdAt.Add(time.Minute),
|
||||
UpdatedAt: createdAt.Add(time.Minute),
|
||||
},
|
||||
Entitlement: validEntitlementSnapshot(common.UserID("user-created"), createdAt.Add(time.Minute)),
|
||||
EntitlementRecord: validEntitlementRecord(common.UserID("user-created"), createdAt.Add(time.Minute)),
|
||||
Reservation: raceNameReservation(common.UserID("user-created"), common.RaceName("player-new123"), createdAt.Add(time.Minute)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.EnsureByEmailOutcomeExisting, result.Outcome)
|
||||
require.Equal(t, existingAccount.UserID, result.UserID)
|
||||
|
||||
storedAccount, err := store.GetByEmail(context.Background(), existingAccount.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, existingAccount, storedAccount)
|
||||
}
|
||||
|
||||
func TestAccountStoreRenameRaceNameSwapsLookupAtomically(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
accountStore := store.Accounts()
|
||||
record := validAccountRecord()
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
|
||||
|
||||
updatedAt := record.UpdatedAt.Add(time.Minute)
|
||||
require.NoError(t, accountStore.RenameRaceName(context.Background(), renameRaceNameInput(record, common.RaceName("Nova Prime"), updatedAt)))
|
||||
|
||||
stored, err := accountStore.GetByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, common.RaceName("Nova Prime"), stored.RaceName)
|
||||
require.True(t, stored.UpdatedAt.Equal(updatedAt))
|
||||
|
||||
_, err = accountStore.GetByRaceName(context.Background(), record.RaceName)
|
||||
require.ErrorIs(t, err, ports.ErrNotFound)
|
||||
|
||||
renamed, err := accountStore.GetByRaceName(context.Background(), common.RaceName("Nova Prime"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, record.UserID, renamed.UserID)
|
||||
|
||||
_, err = store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(record.RaceName))
|
||||
require.ErrorIs(t, err, ports.ErrNotFound)
|
||||
|
||||
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(common.RaceName("Nova Prime")))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, common.RaceName("Nova Prime"), reservation.RaceName)
|
||||
}
|
||||
|
||||
func TestAccountStoreRenameRaceNameAllowsSameOwnerCanonicalSlot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
accountStore := store.Accounts()
|
||||
|
||||
record := validAccountRecord()
|
||||
record.RaceName = common.RaceName("Pilot Nova")
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
|
||||
|
||||
updatedAt := record.UpdatedAt.Add(time.Minute)
|
||||
require.NoError(t, accountStore.RenameRaceName(context.Background(), renameRaceNameInput(record, common.RaceName("P1lot Nova"), updatedAt)))
|
||||
|
||||
reservation, err := store.loadRaceNameReservation(context.Background(), store.client, canonicalKey(common.RaceName("P1lot Nova")))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, common.RaceName("P1lot Nova"), reservation.RaceName)
|
||||
}
|
||||
|
||||
func TestAccountStoreRenameRaceNameReturnsConflictWhenTargetExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
accountStore := store.Accounts()
|
||||
|
||||
first := validAccountRecord()
|
||||
second := validAccountRecord()
|
||||
second.UserID = common.UserID("user-456")
|
||||
second.Email = common.Email("other@example.com")
|
||||
second.RaceName = common.RaceName("Taken Name")
|
||||
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(first)))
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(second)))
|
||||
|
||||
err := accountStore.RenameRaceName(context.Background(), renameRaceNameInput(first, second.RaceName, first.UpdatedAt.Add(time.Minute)))
|
||||
require.ErrorIs(t, err, ports.ErrConflict)
|
||||
|
||||
stored, err := accountStore.GetByUserID(context.Background(), first.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, first.RaceName, stored.RaceName)
|
||||
}
|
||||
|
||||
func TestAccountStoreUpdateDeclaredCountryPreservesLookups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
accountStore := store.Accounts()
|
||||
|
||||
record := validAccountRecord()
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
|
||||
|
||||
updated := record
|
||||
updated.DeclaredCountry = common.CountryCode("FR")
|
||||
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
|
||||
|
||||
require.NoError(t, accountStore.Update(context.Background(), updated))
|
||||
|
||||
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated, byUserID)
|
||||
|
||||
byEmail, err := accountStore.GetByEmail(context.Background(), record.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated, byEmail)
|
||||
|
||||
byRaceName, err := accountStore.GetByRaceName(context.Background(), record.RaceName)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated, byRaceName)
|
||||
}
|
||||
|
||||
func TestAccountStoreCreateReturnsConflictWhenCanonicalReservationExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
accountStore := store.Accounts()
|
||||
|
||||
first := validAccountRecord()
|
||||
second := validAccountRecord()
|
||||
second.UserID = common.UserID("user-456")
|
||||
second.Email = common.Email("other@example.com")
|
||||
second.RaceName = common.RaceName("P1lot Nova")
|
||||
|
||||
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(first)))
|
||||
|
||||
err := accountStore.Create(context.Background(), createAccountInput(second))
|
||||
require.ErrorIs(t, err, ports.ErrConflict)
|
||||
}
|
||||
|
||||
func TestBlockByUserIDRepeatedCallsStayIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
accountRecord := validAccountRecord()
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), createAccountInput(accountRecord)))
|
||||
|
||||
first, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
|
||||
UserID: accountRecord.UserID,
|
||||
ReasonCode: common.ReasonCode("policy_blocked"),
|
||||
BlockedAt: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthBlockOutcomeBlocked, first.Outcome)
|
||||
|
||||
second, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
|
||||
UserID: accountRecord.UserID,
|
||||
ReasonCode: common.ReasonCode("policy_blocked"),
|
||||
BlockedAt: now.Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.AuthBlockOutcomeAlreadyBlocked, second.Outcome)
|
||||
require.Equal(t, accountRecord.UserID, second.UserID)
|
||||
}
|
||||
|
||||
func TestBlockByUserIDUnknownUserReturnsNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
|
||||
_, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
|
||||
UserID: common.UserID("user-missing"),
|
||||
ReasonCode: common.ReasonCode("policy_blocked"),
|
||||
BlockedAt: time.Unix(1_775_240_000, 0).UTC(),
|
||||
})
|
||||
require.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestSanctionAndLimitStoresRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
sanctionStore := store.Sanctions()
|
||||
limitStore := store.Limits()
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
|
||||
sanctionRecord := policy.SanctionRecord{
|
||||
RecordID: policy.SanctionRecordID("sanction-1"),
|
||||
UserID: common.UserID("user-123"),
|
||||
SanctionCode: policy.SanctionCodeLoginBlock,
|
||||
Scope: common.Scope("self_service"),
|
||||
ReasonCode: common.ReasonCode("policy_enforced"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
|
||||
AppliedAt: now,
|
||||
}
|
||||
require.NoError(t, sanctionStore.Create(context.Background(), sanctionRecord))
|
||||
|
||||
gotSanction, err := sanctionStore.GetByRecordID(context.Background(), sanctionRecord.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sanctionRecord, gotSanction)
|
||||
|
||||
sanctions, err := sanctionStore.ListByUserID(context.Background(), sanctionRecord.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, sanctions, 1)
|
||||
|
||||
expiresAt := now.Add(time.Hour)
|
||||
sanctionRecord.ExpiresAt = &expiresAt
|
||||
require.NoError(t, sanctionStore.Update(context.Background(), sanctionRecord))
|
||||
|
||||
gotSanction, err = sanctionStore.GetByRecordID(context.Background(), sanctionRecord.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sanctionRecord.RecordID, gotSanction.RecordID)
|
||||
require.Equal(t, sanctionRecord.UserID, gotSanction.UserID)
|
||||
require.Equal(t, sanctionRecord.SanctionCode, gotSanction.SanctionCode)
|
||||
require.Equal(t, sanctionRecord.Scope, gotSanction.Scope)
|
||||
require.Equal(t, sanctionRecord.ReasonCode, gotSanction.ReasonCode)
|
||||
require.Equal(t, sanctionRecord.Actor, gotSanction.Actor)
|
||||
require.True(t, gotSanction.AppliedAt.Equal(sanctionRecord.AppliedAt))
|
||||
require.NotNil(t, gotSanction.ExpiresAt)
|
||||
require.True(t, gotSanction.ExpiresAt.Equal(*sanctionRecord.ExpiresAt))
|
||||
|
||||
limitRecord := policy.LimitRecord{
|
||||
RecordID: policy.LimitRecordID("limit-1"),
|
||||
UserID: common.UserID("user-123"),
|
||||
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
|
||||
Value: 3,
|
||||
ReasonCode: common.ReasonCode("policy_enforced"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
|
||||
AppliedAt: now,
|
||||
}
|
||||
require.NoError(t, limitStore.Create(context.Background(), limitRecord))
|
||||
|
||||
gotLimit, err := limitStore.GetByRecordID(context.Background(), limitRecord.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, limitRecord, gotLimit)
|
||||
|
||||
limits, err := limitStore.ListByUserID(context.Background(), limitRecord.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, limits, 1)
|
||||
|
||||
limitRecord.Value = 5
|
||||
require.NoError(t, limitStore.Update(context.Background(), limitRecord))
|
||||
|
||||
gotLimit, err = limitStore.GetByRecordID(context.Background(), limitRecord.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, limitRecord, gotLimit)
|
||||
}
|
||||
|
||||
func TestPolicyLifecycleApplyAndRemoveSanction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
lifecycleStore := store.PolicyLifecycle()
|
||||
sanctionStore := store.Sanctions()
|
||||
snapshotStore := store.EntitlementSnapshots()
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
userID := common.UserID("user-123")
|
||||
require.NoError(t, snapshotStore.Put(context.Background(), validEntitlementSnapshot(userID, now)))
|
||||
|
||||
record := policy.SanctionRecord{
|
||||
RecordID: policy.SanctionRecordID("sanction-1"),
|
||||
UserID: userID,
|
||||
SanctionCode: policy.SanctionCodeLoginBlock,
|
||||
Scope: common.Scope("auth"),
|
||||
ReasonCode: common.ReasonCode("manual_block"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
AppliedAt: now,
|
||||
}
|
||||
require.NoError(t, lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
|
||||
NewRecord: record,
|
||||
}))
|
||||
|
||||
activeRecordID, err := store.loadActiveSanctionRecordID(
|
||||
context.Background(),
|
||||
store.client,
|
||||
store.keyspace.ActiveSanction(userID, policy.SanctionCodeLoginBlock),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, record.RecordID, activeRecordID)
|
||||
|
||||
err = lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
|
||||
NewRecord: policy.SanctionRecord{
|
||||
RecordID: policy.SanctionRecordID("sanction-2"),
|
||||
UserID: userID,
|
||||
SanctionCode: policy.SanctionCodeLoginBlock,
|
||||
Scope: common.Scope("auth"),
|
||||
ReasonCode: common.ReasonCode("manual_block"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")},
|
||||
AppliedAt: now.Add(time.Minute),
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, ports.ErrConflict)
|
||||
|
||||
removed := record
|
||||
removedAt := now.Add(30 * time.Minute)
|
||||
removed.RemovedAt = &removedAt
|
||||
removed.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
|
||||
removed.RemovedReasonCode = common.ReasonCode("manual_remove")
|
||||
require.NoError(t, lifecycleStore.RemoveSanction(context.Background(), ports.RemoveSanctionInput{
|
||||
ExpectedActiveRecord: record,
|
||||
UpdatedRecord: removed,
|
||||
}))
|
||||
|
||||
stored, err := sanctionStore.GetByRecordID(context.Background(), record.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, removed, stored)
|
||||
|
||||
_, err = store.loadActiveSanctionRecordID(
|
||||
context.Background(),
|
||||
store.client,
|
||||
store.keyspace.ActiveSanction(userID, policy.SanctionCodeLoginBlock),
|
||||
)
|
||||
require.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestPolicyLifecycleSetAndRemoveLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
lifecycleStore := store.PolicyLifecycle()
|
||||
limitStore := store.Limits()
|
||||
now := time.Unix(1_775_240_000, 0).UTC()
|
||||
userID := common.UserID("user-123")
|
||||
|
||||
first := policy.LimitRecord{
|
||||
RecordID: policy.LimitRecordID("limit-1"),
|
||||
UserID: userID,
|
||||
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
|
||||
Value: 3,
|
||||
ReasonCode: common.ReasonCode("manual_override"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
AppliedAt: now,
|
||||
}
|
||||
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
|
||||
NewRecord: first,
|
||||
}))
|
||||
|
||||
activeRecordID, err := store.loadActiveLimitRecordID(
|
||||
context.Background(),
|
||||
store.client,
|
||||
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, first.RecordID, activeRecordID)
|
||||
|
||||
second := policy.LimitRecord{
|
||||
RecordID: policy.LimitRecordID("limit-2"),
|
||||
UserID: userID,
|
||||
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
|
||||
Value: 5,
|
||||
ReasonCode: common.ReasonCode("manual_override"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")},
|
||||
AppliedAt: now.Add(time.Hour),
|
||||
}
|
||||
updatedFirst := first
|
||||
removedAt := second.AppliedAt
|
||||
updatedFirst.RemovedAt = &removedAt
|
||||
updatedFirst.RemovedBy = second.Actor
|
||||
updatedFirst.RemovedReasonCode = second.ReasonCode
|
||||
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
|
||||
ExpectedActiveRecord: &first,
|
||||
UpdatedActiveRecord: &updatedFirst,
|
||||
NewRecord: second,
|
||||
}))
|
||||
|
||||
storedFirst, err := limitStore.GetByRecordID(context.Background(), first.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updatedFirst, storedFirst)
|
||||
|
||||
activeRecordID, err = store.loadActiveLimitRecordID(
|
||||
context.Background(),
|
||||
store.client,
|
||||
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, second.RecordID, activeRecordID)
|
||||
|
||||
removedSecond := second
|
||||
removeAt := now.Add(90 * time.Minute)
|
||||
removedSecond.RemovedAt = &removeAt
|
||||
removedSecond.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-3")}
|
||||
removedSecond.RemovedReasonCode = common.ReasonCode("manual_remove")
|
||||
require.NoError(t, lifecycleStore.RemoveLimit(context.Background(), ports.RemoveLimitInput{
|
||||
ExpectedActiveRecord: second,
|
||||
UpdatedRecord: removedSecond,
|
||||
}))
|
||||
|
||||
storedSecond, err := limitStore.GetByRecordID(context.Background(), second.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, removedSecond, storedSecond)
|
||||
|
||||
_, err = store.loadActiveLimitRecordID(
|
||||
context.Background(),
|
||||
store.client,
|
||||
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
|
||||
)
|
||||
require.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestEntitlementLifecycleTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
historyStore := store.EntitlementHistory()
|
||||
snapshotStore := store.EntitlementSnapshots()
|
||||
lifecycleStore := store.EntitlementLifecycle()
|
||||
userID := common.UserID("user-123")
|
||||
startedFreeAt := time.Unix(1_775_240_000, 0).UTC()
|
||||
|
||||
freeRecord := validEntitlementRecord(userID, startedFreeAt)
|
||||
freeSnapshot := validEntitlementSnapshot(userID, startedFreeAt)
|
||||
require.NoError(t, historyStore.Create(context.Background(), freeRecord))
|
||||
require.NoError(t, snapshotStore.Put(context.Background(), freeSnapshot))
|
||||
|
||||
grantStartsAt := startedFreeAt.Add(24 * time.Hour)
|
||||
grantEndsAt := grantStartsAt.Add(30 * 24 * time.Hour)
|
||||
grantedRecord := paidEntitlementRecord(
|
||||
entitlement.EntitlementRecordID("entitlement-paid-1"),
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
grantEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
grantedSnapshot := paidEntitlementSnapshot(
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
grantEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
closedFreeRecord := freeRecord
|
||||
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
|
||||
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
|
||||
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
|
||||
|
||||
require.NoError(t, lifecycleStore.Grant(context.Background(), ports.GrantEntitlementInput{
|
||||
ExpectedCurrentSnapshot: freeSnapshot,
|
||||
ExpectedCurrentRecord: freeRecord,
|
||||
UpdatedCurrentRecord: closedFreeRecord,
|
||||
NewRecord: grantedRecord,
|
||||
NewSnapshot: grantedSnapshot,
|
||||
}))
|
||||
|
||||
storedSnapshot, err := snapshotStore.GetByUserID(context.Background(), userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, grantedSnapshot, storedSnapshot)
|
||||
|
||||
storedFreeRecord, err := historyStore.GetByRecordID(context.Background(), freeRecord.RecordID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, closedFreeRecord, storedFreeRecord)
|
||||
|
||||
extendedEndsAt := grantEndsAt.Add(30 * 24 * time.Hour)
|
||||
extensionRecord := paidEntitlementRecord(
|
||||
entitlement.EntitlementRecordID("entitlement-paid-2"),
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantEndsAt,
|
||||
extendedEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_extend"),
|
||||
)
|
||||
extendedSnapshot := paidEntitlementSnapshot(
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
grantStartsAt,
|
||||
extendedEndsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_extend"),
|
||||
)
|
||||
|
||||
require.NoError(t, lifecycleStore.Extend(context.Background(), ports.ExtendEntitlementInput{
|
||||
ExpectedCurrentSnapshot: grantedSnapshot,
|
||||
NewRecord: extensionRecord,
|
||||
NewSnapshot: extendedSnapshot,
|
||||
}))
|
||||
|
||||
storedSnapshot, err = snapshotStore.GetByUserID(context.Background(), userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, extendedSnapshot, storedSnapshot)
|
||||
|
||||
revokeAt := grantEndsAt.Add(12 * time.Hour)
|
||||
revokedCurrentRecord := extensionRecord
|
||||
revokedCurrentRecord.ClosedAt = timePointer(revokeAt)
|
||||
revokedCurrentRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
|
||||
revokedCurrentRecord.ClosedReasonCode = common.ReasonCode("manual_revoke")
|
||||
|
||||
freeAfterRevokeRecord := entitlement.PeriodRecord{
|
||||
RecordID: entitlement.EntitlementRecordID("entitlement-free-2"),
|
||||
UserID: userID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
Source: common.Source("admin"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
ReasonCode: common.ReasonCode("manual_revoke"),
|
||||
StartsAt: revokeAt,
|
||||
CreatedAt: revokeAt,
|
||||
}
|
||||
freeAfterRevokeSnapshot := entitlement.CurrentSnapshot{
|
||||
UserID: userID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
IsPaid: false,
|
||||
StartsAt: revokeAt,
|
||||
Source: common.Source("admin"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
ReasonCode: common.ReasonCode("manual_revoke"),
|
||||
UpdatedAt: revokeAt,
|
||||
}
|
||||
|
||||
require.NoError(t, lifecycleStore.Revoke(context.Background(), ports.RevokeEntitlementInput{
|
||||
ExpectedCurrentSnapshot: extendedSnapshot,
|
||||
ExpectedCurrentRecord: extensionRecord,
|
||||
UpdatedCurrentRecord: revokedCurrentRecord,
|
||||
NewRecord: freeAfterRevokeRecord,
|
||||
NewSnapshot: freeAfterRevokeSnapshot,
|
||||
}))
|
||||
|
||||
storedSnapshot, err = snapshotStore.GetByUserID(context.Background(), userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, freeAfterRevokeSnapshot, storedSnapshot)
|
||||
|
||||
historyRecords, err := historyStore.ListByUserID(context.Background(), userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, historyRecords, 4)
|
||||
}
|
||||
|
||||
func TestRepairExpiredEntitlementMaterializesFreeSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newTestStore(t)
|
||||
historyStore := store.EntitlementHistory()
|
||||
snapshotStore := store.EntitlementSnapshots()
|
||||
lifecycleStore := store.EntitlementLifecycle()
|
||||
userID := common.UserID("user-123")
|
||||
startsAt := time.Unix(1_775_240_000, 0).UTC()
|
||||
endsAt := startsAt.Add(24 * time.Hour)
|
||||
expiredSnapshot := paidEntitlementSnapshot(
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
startsAt,
|
||||
endsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
expiredSnapshot.UpdatedAt = endsAt.Add(24 * time.Hour)
|
||||
expiredRecord := paidEntitlementRecord(
|
||||
entitlement.EntitlementRecordID("entitlement-paid-1"),
|
||||
userID,
|
||||
entitlement.PlanCodePaidMonthly,
|
||||
startsAt,
|
||||
endsAt,
|
||||
common.Source("admin"),
|
||||
common.ReasonCode("manual_grant"),
|
||||
)
|
||||
require.NoError(t, historyStore.Create(context.Background(), expiredRecord))
|
||||
require.NoError(t, snapshotStore.Put(context.Background(), expiredSnapshot))
|
||||
|
||||
repairedAt := endsAt.Add(2 * time.Hour)
|
||||
freeRecord := entitlement.PeriodRecord{
|
||||
RecordID: entitlement.EntitlementRecordID("entitlement-free-after-expiry"),
|
||||
UserID: userID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
Source: common.Source("entitlement_expiry_repair"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
|
||||
ReasonCode: common.ReasonCode("paid_entitlement_expired"),
|
||||
StartsAt: endsAt,
|
||||
CreatedAt: repairedAt,
|
||||
}
|
||||
freeSnapshot := entitlement.CurrentSnapshot{
|
||||
UserID: userID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
IsPaid: false,
|
||||
StartsAt: endsAt,
|
||||
Source: common.Source("entitlement_expiry_repair"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
|
||||
ReasonCode: common.ReasonCode("paid_entitlement_expired"),
|
||||
UpdatedAt: repairedAt,
|
||||
}
|
||||
|
||||
require.NoError(t, lifecycleStore.RepairExpired(context.Background(), ports.RepairExpiredEntitlementInput{
|
||||
ExpectedExpiredSnapshot: expiredSnapshot,
|
||||
NewRecord: freeRecord,
|
||||
NewSnapshot: freeSnapshot,
|
||||
}))
|
||||
|
||||
storedSnapshot, err := snapshotStore.GetByUserID(context.Background(), userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, freeSnapshot, storedSnapshot)
|
||||
|
||||
historyRecords, err := historyStore.ListByUserID(context.Background(), userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, historyRecords, 2)
|
||||
require.Equal(t, freeRecord, historyRecords[1])
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store, err := New(Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 0,
|
||||
KeyspacePrefix: "user:test:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = store.Close()
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func validAccountRecord() account.UserAccount {
|
||||
createdAt := time.Unix(1_775_240_000, 0).UTC()
|
||||
return account.UserAccount{
|
||||
UserID: common.UserID("user-123"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
RaceName: common.RaceName("Pilot Nova"),
|
||||
PreferredLanguage: common.LanguageTag("en"),
|
||||
TimeZone: common.TimeZoneName("Europe/Kaliningrad"),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func validEntitlementSnapshot(userID common.UserID, now time.Time) entitlement.CurrentSnapshot {
|
||||
return entitlement.CurrentSnapshot{
|
||||
UserID: userID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
IsPaid: false,
|
||||
StartsAt: now,
|
||||
Source: common.Source("auth_registration"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
|
||||
ReasonCode: common.ReasonCode("initial_free_entitlement"),
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func validEntitlementRecord(userID common.UserID, now time.Time) entitlement.PeriodRecord {
|
||||
return entitlement.PeriodRecord{
|
||||
RecordID: entitlement.EntitlementRecordID("entitlement-" + userID.String()),
|
||||
UserID: userID,
|
||||
PlanCode: entitlement.PlanCodeFree,
|
||||
Source: common.Source("auth_registration"),
|
||||
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
|
||||
ReasonCode: common.ReasonCode("initial_free_entitlement"),
|
||||
StartsAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func paidEntitlementRecord(
|
||||
recordID entitlement.EntitlementRecordID,
|
||||
userID common.UserID,
|
||||
planCode entitlement.PlanCode,
|
||||
startsAt time.Time,
|
||||
endsAt time.Time,
|
||||
source common.Source,
|
||||
reasonCode common.ReasonCode,
|
||||
) entitlement.PeriodRecord {
|
||||
return entitlement.PeriodRecord{
|
||||
RecordID: recordID,
|
||||
UserID: userID,
|
||||
PlanCode: planCode,
|
||||
Source: source,
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
ReasonCode: reasonCode,
|
||||
StartsAt: startsAt,
|
||||
EndsAt: timePointer(endsAt),
|
||||
CreatedAt: startsAt,
|
||||
}
|
||||
}
|
||||
|
||||
func paidEntitlementSnapshot(
|
||||
userID common.UserID,
|
||||
planCode entitlement.PlanCode,
|
||||
startsAt time.Time,
|
||||
endsAt time.Time,
|
||||
source common.Source,
|
||||
reasonCode common.ReasonCode,
|
||||
) entitlement.CurrentSnapshot {
|
||||
return entitlement.CurrentSnapshot{
|
||||
UserID: userID,
|
||||
PlanCode: planCode,
|
||||
IsPaid: true,
|
||||
StartsAt: startsAt,
|
||||
EndsAt: timePointer(endsAt),
|
||||
Source: source,
|
||||
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
|
||||
ReasonCode: reasonCode,
|
||||
UpdatedAt: startsAt,
|
||||
}
|
||||
}
|
||||
|
||||
func timePointer(value time.Time) *time.Time {
|
||||
utcValue := value.UTC()
|
||||
return &utcValue
|
||||
}
|
||||
|
||||
func createAccountInput(record account.UserAccount) ports.CreateAccountInput {
|
||||
return ports.CreateAccountInput{
|
||||
Account: record,
|
||||
Reservation: raceNameReservation(record.UserID, record.RaceName, record.UpdatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
func renameRaceNameInput(
|
||||
record account.UserAccount,
|
||||
newRaceName common.RaceName,
|
||||
updatedAt time.Time,
|
||||
) ports.RenameRaceNameInput {
|
||||
return ports.RenameRaceNameInput{
|
||||
UserID: record.UserID,
|
||||
CurrentCanonicalKey: canonicalKey(record.RaceName),
|
||||
NewRaceName: newRaceName,
|
||||
NewReservation: raceNameReservation(record.UserID, newRaceName, updatedAt),
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func raceNameReservation(
|
||||
userID common.UserID,
|
||||
raceName common.RaceName,
|
||||
reservedAt time.Time,
|
||||
) account.RaceNameReservation {
|
||||
return account.RaceNameReservation{
|
||||
CanonicalKey: canonicalKey(raceName),
|
||||
UserID: userID,
|
||||
RaceName: raceName,
|
||||
ReservedAt: reservedAt.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func canonicalKey(raceName common.RaceName) account.RaceNameCanonicalKey {
|
||||
return account.RaceNameCanonicalKey(strings.NewReplacer(
|
||||
"1", "i",
|
||||
"0", "o",
|
||||
"8", "b",
|
||||
).Replace(strings.ToLower(raceName.String())))
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
// Package redisstate defines the frozen Redis logical keyspace and pagination
|
||||
// helpers used by future User Service storage adapters.
|
||||
package redisstate
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/account"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
)
|
||||
|
||||
const defaultPrefix = "user:"
|
||||
|
||||
// Keyspace builds the frozen Redis logical keys used by future storage
|
||||
// adapters. The package intentionally exposes key construction only and does
|
||||
// not depend on any Redis client.
|
||||
type Keyspace struct {
|
||||
// Prefix stores the namespace prefix applied to every key. The zero value
|
||||
// uses `user:`.
|
||||
Prefix string
|
||||
}
|
||||
|
||||
// Account returns the primary user-account key for userID.
|
||||
func (k Keyspace) Account(userID common.UserID) string {
|
||||
return k.prefix() + "account:" + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
// EmailLookup returns the exact normalized e-mail lookup key.
|
||||
func (k Keyspace) EmailLookup(email common.Email) string {
|
||||
return k.prefix() + "lookup:email:" + encodeKeyComponent(email.String())
|
||||
}
|
||||
|
||||
// RaceNameLookup returns the exact stored race-name lookup key.
|
||||
func (k Keyspace) RaceNameLookup(raceName common.RaceName) string {
|
||||
return k.prefix() + "lookup:race-name:" + encodeKeyComponent(raceName.String())
|
||||
}
|
||||
|
||||
// RaceNameReservation returns the replaceable canonical race-name reservation
|
||||
// key.
|
||||
func (k Keyspace) RaceNameReservation(key account.RaceNameCanonicalKey) string {
|
||||
return k.prefix() + "reservation:race-name:" + encodeKeyComponent(key.String())
|
||||
}
|
||||
|
||||
// BlockedEmailSubject returns the dedicated blocked-email-subject key.
|
||||
func (k Keyspace) BlockedEmailSubject(email common.Email) string {
|
||||
return k.prefix() + "blocked-email:" + encodeKeyComponent(email.String())
|
||||
}
|
||||
|
||||
// EntitlementRecord returns the primary entitlement history-record key.
|
||||
func (k Keyspace) EntitlementRecord(recordID entitlement.EntitlementRecordID) string {
|
||||
return k.prefix() + "entitlement:record:" + encodeKeyComponent(recordID.String())
|
||||
}
|
||||
|
||||
// EntitlementHistory returns the per-user entitlement-history index key.
|
||||
func (k Keyspace) EntitlementHistory(userID common.UserID) string {
|
||||
return k.prefix() + "entitlement:history:" + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
// EntitlementSnapshot returns the current entitlement-snapshot key.
|
||||
func (k Keyspace) EntitlementSnapshot(userID common.UserID) string {
|
||||
return k.prefix() + "entitlement:snapshot:" + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
// SanctionRecord returns the primary sanction history-record key.
|
||||
func (k Keyspace) SanctionRecord(recordID policy.SanctionRecordID) string {
|
||||
return k.prefix() + "sanction:record:" + encodeKeyComponent(recordID.String())
|
||||
}
|
||||
|
||||
// SanctionHistory returns the per-user sanction-history index key.
|
||||
func (k Keyspace) SanctionHistory(userID common.UserID) string {
|
||||
return k.prefix() + "sanction:history:" + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
// ActiveSanction returns the per-user active-sanction slot for one sanction
|
||||
// code. The slot guarantees at most one active sanction per `user_id +
|
||||
// sanction_code`.
|
||||
func (k Keyspace) ActiveSanction(userID common.UserID, code policy.SanctionCode) string {
|
||||
return k.prefix() + "sanction:active:" + encodeKeyComponent(userID.String()) + ":" + encodeKeyComponent(string(code))
|
||||
}
|
||||
|
||||
// LimitRecord returns the primary limit history-record key.
|
||||
func (k Keyspace) LimitRecord(recordID policy.LimitRecordID) string {
|
||||
return k.prefix() + "limit:record:" + encodeKeyComponent(recordID.String())
|
||||
}
|
||||
|
||||
// LimitHistory returns the per-user limit-history index key.
|
||||
func (k Keyspace) LimitHistory(userID common.UserID) string {
|
||||
return k.prefix() + "limit:history:" + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
// ActiveLimit returns the per-user active-limit slot for one limit code. The
|
||||
// slot guarantees at most one active limit per `user_id + limit_code`.
|
||||
func (k Keyspace) ActiveLimit(userID common.UserID, code policy.LimitCode) string {
|
||||
return k.prefix() + "limit:active:" + encodeKeyComponent(userID.String()) + ":" + encodeKeyComponent(string(code))
|
||||
}
|
||||
|
||||
// CreatedAtIndex returns the deterministic newest-first user-ordering index.
|
||||
func (k Keyspace) CreatedAtIndex() string {
|
||||
return k.prefix() + "index:created-at"
|
||||
}
|
||||
|
||||
// PaidStateIndex returns the coarse free-versus-paid index key.
|
||||
func (k Keyspace) PaidStateIndex(state entitlement.PaidState) string {
|
||||
return k.prefix() + "index:paid-state:" + encodeKeyComponent(string(state))
|
||||
}
|
||||
|
||||
// FinitePaidExpiryIndex returns the finite paid-expiry index key. Lifetime
|
||||
// plans intentionally do not participate in this index.
|
||||
func (k Keyspace) FinitePaidExpiryIndex() string {
|
||||
return k.prefix() + "index:paid-expiry:finite"
|
||||
}
|
||||
|
||||
// DeclaredCountryIndex returns the current declared-country reverse-lookup
|
||||
// index key.
|
||||
func (k Keyspace) DeclaredCountryIndex(code common.CountryCode) string {
|
||||
return k.prefix() + "index:declared-country:" + encodeKeyComponent(code.String())
|
||||
}
|
||||
|
||||
// ActiveSanctionCodeIndex returns the reverse-lookup index key for users with
|
||||
// an active sanction code.
|
||||
func (k Keyspace) ActiveSanctionCodeIndex(code policy.SanctionCode) string {
|
||||
return k.prefix() + "index:active-sanction:" + encodeKeyComponent(string(code))
|
||||
}
|
||||
|
||||
// ActiveLimitCodeIndex returns the reverse-lookup index key for users with an
|
||||
// active limit code.
|
||||
func (k Keyspace) ActiveLimitCodeIndex(code policy.LimitCode) string {
|
||||
return k.prefix() + "index:active-limit:" + encodeKeyComponent(string(code))
|
||||
}
|
||||
|
||||
// EligibilityMarkerIndex returns the reverse-lookup index key for one derived
|
||||
// eligibility marker boolean.
|
||||
func (k Keyspace) EligibilityMarkerIndex(marker policy.EligibilityMarker, value bool) string {
|
||||
return fmt.Sprintf("%sindex:eligibility:%s:%t", k.prefix(), encodeKeyComponent(string(marker)), value)
|
||||
}
|
||||
|
||||
// CreatedAtScore returns the frozen ZSET score representation for created-at
|
||||
// ordering and deterministic pagination.
|
||||
func CreatedAtScore(createdAt time.Time) float64 {
|
||||
return float64(createdAt.UTC().UnixMicro())
|
||||
}
|
||||
|
||||
// ExpiryScore returns the frozen ZSET score representation for finite paid
|
||||
// expiry ordering.
|
||||
func ExpiryScore(expiresAt time.Time) float64 {
|
||||
return float64(expiresAt.UTC().UnixMicro())
|
||||
}
|
||||
|
||||
// PageCursor identifies the last seen `(created_at, user_id)` tuple used by
|
||||
// deterministic newest-first pagination.
|
||||
type PageCursor struct {
|
||||
// CreatedAt stores the created-at component of the last seen row.
|
||||
CreatedAt time.Time
|
||||
|
||||
// UserID stores the user-id tiebreaker component of the last seen row.
|
||||
UserID common.UserID
|
||||
}
|
||||
|
||||
// Validate reports whether PageCursor contains a complete cursor tuple.
|
||||
func (cursor PageCursor) Validate() error {
|
||||
if err := common.ValidateTimestamp("page cursor created at", cursor.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cursor.UserID.Validate(); err != nil {
|
||||
return fmt.Errorf("page cursor user id: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ComparePageOrder compares two listing positions using the frozen ordering:
|
||||
// `created_at desc`, then `user_id desc`.
|
||||
func ComparePageOrder(left PageCursor, right PageCursor) int {
|
||||
switch {
|
||||
case left.CreatedAt.After(right.CreatedAt):
|
||||
return -1
|
||||
case left.CreatedAt.Before(right.CreatedAt):
|
||||
return 1
|
||||
default:
|
||||
return -strings.Compare(left.UserID.String(), right.UserID.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (k Keyspace) prefix() string {
|
||||
prefix := strings.TrimSpace(k.Prefix)
|
||||
if prefix == "" {
|
||||
return defaultPrefix
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package redisstate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/account"
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestKeyspaceBuildsStableKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyspace := Keyspace{Prefix: "custom:"}
|
||||
|
||||
require.Equal(t, "custom:account:dXNlci0xMjM", keyspace.Account(common.UserID("user-123")))
|
||||
require.Equal(t, "custom:lookup:email:cGlsb3RAZXhhbXBsZS5jb20", keyspace.EmailLookup(common.Email("pilot@example.com")))
|
||||
require.Equal(t, "custom:lookup:race-name:UGlsb3QgTm92YQ", keyspace.RaceNameLookup(common.RaceName("Pilot Nova")))
|
||||
require.Equal(t, "custom:reservation:race-name:cGlsb3Qtbm92YQ", keyspace.RaceNameReservation(account.RaceNameCanonicalKey("pilot-nova")))
|
||||
require.Equal(t, "custom:blocked-email:cGlsb3RAZXhhbXBsZS5jb20", keyspace.BlockedEmailSubject(common.Email("pilot@example.com")))
|
||||
require.Equal(t, "custom:entitlement:record:ZW50aXRsZW1lbnQtMTIz", keyspace.EntitlementRecord(entitlement.EntitlementRecordID("entitlement-123")))
|
||||
require.Equal(t, "custom:sanction:record:c2FuY3Rpb24tMQ", keyspace.SanctionRecord(policy.SanctionRecordID("sanction-1")))
|
||||
require.Equal(t, "custom:limit:record:bGltaXQtMQ", keyspace.LimitRecord(policy.LimitRecordID("limit-1")))
|
||||
require.Equal(t, "custom:sanction:active:dXNlci0xMjM:bG9naW5fYmxvY2s", keyspace.ActiveSanction(common.UserID("user-123"), policy.SanctionCodeLoginBlock))
|
||||
require.Equal(t, "custom:limit:active:dXNlci0xMjM:bWF4X293bmVkX3ByaXZhdGVfZ2FtZXM", keyspace.ActiveLimit(common.UserID("user-123"), policy.LimitCodeMaxOwnedPrivateGames))
|
||||
require.Equal(t, "custom:index:created-at", keyspace.CreatedAtIndex())
|
||||
require.Equal(t, "custom:index:paid-state:cGFpZA", keyspace.PaidStateIndex(entitlement.PaidStatePaid))
|
||||
require.Equal(t, "custom:index:paid-expiry:finite", keyspace.FinitePaidExpiryIndex())
|
||||
require.Equal(t, "custom:index:declared-country:REU", keyspace.DeclaredCountryIndex(common.CountryCode("DE")))
|
||||
require.Equal(t, "custom:index:active-sanction:bG9naW5fYmxvY2s", keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock))
|
||||
require.Equal(t, "custom:index:active-limit:bWF4X293bmVkX3ByaXZhdGVfZ2FtZXM", keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames))
|
||||
require.Equal(t, "custom:index:eligibility:Y2FuX2xvZ2lu:true", keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true))
|
||||
}
|
||||
|
||||
func TestComparePageOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newer := PageCursor{CreatedAt: time.Unix(20, 0).UTC(), UserID: common.UserID("user-200")}
|
||||
older := PageCursor{CreatedAt: time.Unix(10, 0).UTC(), UserID: common.UserID("user-100")}
|
||||
sameTimeHigherUserID := PageCursor{CreatedAt: time.Unix(20, 0).UTC(), UserID: common.UserID("user-300")}
|
||||
|
||||
require.Negative(t, ComparePageOrder(newer, older))
|
||||
require.Positive(t, ComparePageOrder(older, newer))
|
||||
require.Negative(t, ComparePageOrder(sameTimeHigherUserID, newer))
|
||||
}
|
||||
|
||||
func TestScoresUseUnixMicro(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
value := time.Unix(1_775_240_000, 123_000).UTC()
|
||||
want := float64(value.UnixMicro())
|
||||
|
||||
require.Equal(t, want, CreatedAtScore(value))
|
||||
require.Equal(t, want, ExpiryScore(value))
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
package redisstate
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPageTokenFiltersMismatch reports that a supplied page token was created
|
||||
// for a different normalized filter set.
|
||||
ErrPageTokenFiltersMismatch = errors.New("page token filters do not match current filters")
|
||||
)
|
||||
|
||||
// UserListFilters stores the frozen admin-listing filter set that becomes part
|
||||
// of the opaque page token fingerprint.
|
||||
type UserListFilters struct {
|
||||
// PaidState stores the coarse free-versus-paid filter.
|
||||
PaidState entitlement.PaidState
|
||||
|
||||
// PaidExpiresBefore stores the optional finite-paid expiry upper bound.
|
||||
PaidExpiresBefore *time.Time
|
||||
|
||||
// PaidExpiresAfter stores the optional finite-paid expiry lower bound.
|
||||
PaidExpiresAfter *time.Time
|
||||
|
||||
// DeclaredCountry stores the optional declared-country filter.
|
||||
DeclaredCountry common.CountryCode
|
||||
|
||||
// SanctionCode stores the optional active-sanction filter.
|
||||
SanctionCode policy.SanctionCode
|
||||
|
||||
// LimitCode stores the optional active-limit filter.
|
||||
LimitCode policy.LimitCode
|
||||
|
||||
// CanLogin stores the optional login-eligibility filter.
|
||||
CanLogin *bool
|
||||
|
||||
// CanCreatePrivateGame stores the optional private-game-create eligibility
|
||||
// filter.
|
||||
CanCreatePrivateGame *bool
|
||||
|
||||
// CanJoinGame stores the optional join-game eligibility filter.
|
||||
CanJoinGame *bool
|
||||
}
|
||||
|
||||
// Validate reports whether UserListFilters is structurally valid.
|
||||
func (filters UserListFilters) Validate() error {
|
||||
if !filters.PaidState.IsKnown() {
|
||||
return fmt.Errorf("paid state %q is unsupported", filters.PaidState)
|
||||
}
|
||||
if filters.PaidExpiresBefore != nil && filters.PaidExpiresBefore.IsZero() {
|
||||
return fmt.Errorf("paid expires before must not be zero")
|
||||
}
|
||||
if filters.PaidExpiresAfter != nil && filters.PaidExpiresAfter.IsZero() {
|
||||
return fmt.Errorf("paid expires after must not be zero")
|
||||
}
|
||||
if !filters.DeclaredCountry.IsZero() {
|
||||
if err := filters.DeclaredCountry.Validate(); err != nil {
|
||||
return fmt.Errorf("declared country: %w", err)
|
||||
}
|
||||
}
|
||||
if filters.SanctionCode != "" && !filters.SanctionCode.IsKnown() {
|
||||
return fmt.Errorf("sanction code %q is unsupported", filters.SanctionCode)
|
||||
}
|
||||
if filters.LimitCode != "" && !filters.LimitCode.IsKnown() {
|
||||
return fmt.Errorf("limit code %q is unsupported", filters.LimitCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodePageToken encodes cursor and filters into the frozen opaque page token
|
||||
// format.
|
||||
func EncodePageToken(cursor PageCursor, filters UserListFilters) (string, error) {
|
||||
if err := cursor.Validate(); err != nil {
|
||||
return "", fmt.Errorf("encode page token: %w", err)
|
||||
}
|
||||
fingerprint, err := normalizeFilters(filters)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encode page token: %w", err)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(pageTokenPayload{
|
||||
CreatedAt: cursor.CreatedAt.UTC().Format(time.RFC3339Nano),
|
||||
UserID: cursor.UserID.String(),
|
||||
Filters: fingerprint,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encode page token: %w", err)
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(payload), nil
|
||||
}
|
||||
|
||||
// DecodePageToken decodes raw into the frozen page cursor and verifies that
|
||||
// the embedded normalized filter set matches expectedFilters.
|
||||
func DecodePageToken(raw string, expectedFilters UserListFilters) (PageCursor, error) {
|
||||
fingerprint, err := normalizeFilters(expectedFilters)
|
||||
if err != nil {
|
||||
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
|
||||
}
|
||||
|
||||
var token pageTokenPayload
|
||||
if err := json.Unmarshal(payload, &token); err != nil {
|
||||
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
|
||||
}
|
||||
if token.Filters != fingerprint {
|
||||
return PageCursor{}, ErrPageTokenFiltersMismatch
|
||||
}
|
||||
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, token.CreatedAt)
|
||||
if err != nil {
|
||||
return PageCursor{}, fmt.Errorf("decode page token: parse created_at: %w", err)
|
||||
}
|
||||
|
||||
cursor := PageCursor{
|
||||
CreatedAt: createdAt.UTC(),
|
||||
UserID: common.UserID(token.UserID),
|
||||
}
|
||||
if err := cursor.Validate(); err != nil {
|
||||
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
|
||||
}
|
||||
|
||||
return cursor, nil
|
||||
}
|
||||
|
||||
type pageTokenPayload struct {
|
||||
CreatedAt string `json:"created_at"`
|
||||
UserID string `json:"user_id"`
|
||||
Filters normalizedFilterPayload `json:"filters"`
|
||||
}
|
||||
|
||||
type normalizedFilterPayload struct {
|
||||
PaidState string `json:"paid_state,omitempty"`
|
||||
PaidExpiresBeforeUTC string `json:"paid_expires_before_utc,omitempty"`
|
||||
PaidExpiresAfterUTC string `json:"paid_expires_after_utc,omitempty"`
|
||||
DeclaredCountry string `json:"declared_country,omitempty"`
|
||||
SanctionCode string `json:"sanction_code,omitempty"`
|
||||
LimitCode string `json:"limit_code,omitempty"`
|
||||
CanLogin string `json:"can_login,omitempty"`
|
||||
CanCreatePrivateGame string `json:"can_create_private_game,omitempty"`
|
||||
CanJoinGame string `json:"can_join_game,omitempty"`
|
||||
}
|
||||
|
||||
func normalizeFilters(filters UserListFilters) (normalizedFilterPayload, error) {
|
||||
if err := filters.Validate(); err != nil {
|
||||
return normalizedFilterPayload{}, err
|
||||
}
|
||||
|
||||
return normalizedFilterPayload{
|
||||
PaidState: string(filters.PaidState),
|
||||
PaidExpiresBeforeUTC: formatOptionalTime(filters.PaidExpiresBefore),
|
||||
PaidExpiresAfterUTC: formatOptionalTime(filters.PaidExpiresAfter),
|
||||
DeclaredCountry: filters.DeclaredCountry.String(),
|
||||
SanctionCode: string(filters.SanctionCode),
|
||||
LimitCode: string(filters.LimitCode),
|
||||
CanLogin: formatOptionalBool(filters.CanLogin),
|
||||
CanCreatePrivateGame: formatOptionalBool(filters.CanCreatePrivateGame),
|
||||
CanJoinGame: formatOptionalBool(filters.CanJoinGame),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func formatOptionalTime(value *time.Time) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func formatOptionalBool(value *bool) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
if *value {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package redisstate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/user/internal/domain/common"
|
||||
"galaxy/user/internal/domain/entitlement"
|
||||
"galaxy/user/internal/domain/policy"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeDecodePageToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
before := time.Unix(1_775_250_000, 0).UTC()
|
||||
after := time.Unix(1_775_240_000, 0).UTC()
|
||||
canLogin := true
|
||||
canCreate := false
|
||||
canJoin := true
|
||||
|
||||
filters := UserListFilters{
|
||||
PaidState: entitlement.PaidStatePaid,
|
||||
PaidExpiresBefore: &before,
|
||||
PaidExpiresAfter: &after,
|
||||
DeclaredCountry: common.CountryCode("DE"),
|
||||
SanctionCode: policy.SanctionCodeLoginBlock,
|
||||
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
|
||||
CanLogin: &canLogin,
|
||||
CanCreatePrivateGame: &canCreate,
|
||||
CanJoinGame: &canJoin,
|
||||
}
|
||||
cursor := PageCursor{
|
||||
CreatedAt: time.Unix(1_775_240_100, 987_000_000).UTC(),
|
||||
UserID: common.UserID("user-123"),
|
||||
}
|
||||
|
||||
token, err := EncodePageToken(cursor, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
decoded, err := DecodePageToken(token, filters)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cursor, decoded)
|
||||
}
|
||||
|
||||
func TestDecodePageTokenFilterMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cursor := PageCursor{
|
||||
CreatedAt: time.Unix(1_775_240_100, 0).UTC(),
|
||||
UserID: common.UserID("user-123"),
|
||||
}
|
||||
filters := UserListFilters{
|
||||
PaidState: entitlement.PaidStatePaid,
|
||||
}
|
||||
|
||||
token, err := EncodePageToken(cursor, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = DecodePageToken(token, UserListFilters{PaidState: entitlement.PaidStateFree})
|
||||
require.ErrorIs(t, err, ErrPageTokenFiltersMismatch)
|
||||
}
|
||||
|
||||
func TestDecodePageTokenRejectsInvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := DecodePageToken("%%%not-base64%%%", UserListFilters{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
Reference in New Issue
Block a user