feat: use postgres

This commit is contained in:
Ilia Denisov
2026-04-26 20:34:39 +02:00
committed by GitHub
parent 48b0056b49
commit fe829285a6
365 changed files with 29223 additions and 24049 deletions
@@ -0,0 +1,25 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type Accounts struct {
UserID string `sql:"primary_key"`
Email string
UserName string
DisplayName string
PreferredLanguage string
TimeZone string
DeclaredCountry *string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
@@ -0,0 +1,21 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type BlockedEmails struct {
Email string `sql:"primary_key"`
ReasonCode string
BlockedAt time.Time
ActorType *string
ActorID *string
ResolvedUserID *string
}
@@ -0,0 +1,29 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type EntitlementRecords struct {
RecordID string `sql:"primary_key"`
UserID string
PlanCode string
Source string
ActorType string
ActorID *string
ReasonCode string
StartsAt time.Time
EndsAt *time.Time
CreatedAt time.Time
ClosedAt *time.Time
ClosedByType *string
ClosedByID *string
ClosedReasonCode *string
}
@@ -0,0 +1,25 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type EntitlementSnapshots struct {
UserID string `sql:"primary_key"`
PlanCode string
IsPaid bool
StartsAt time.Time
EndsAt *time.Time
Source string
ActorType string
ActorID *string
ReasonCode string
UpdatedAt time.Time
}
@@ -0,0 +1,19 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type GooseDbVersion struct {
ID int32 `sql:"primary_key"`
VersionID int64
IsApplied bool
Tstamp time.Time
}
@@ -0,0 +1,15 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
type LimitActive struct {
UserID string `sql:"primary_key"`
LimitCode string `sql:"primary_key"`
RecordID string
Value int32
}
@@ -0,0 +1,28 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type LimitRecords struct {
RecordID string `sql:"primary_key"`
UserID string
LimitCode string
Value int32
ReasonCode string
ActorType string
ActorID *string
AppliedAt time.Time
ExpiresAt *time.Time
RemovedAt *time.Time
RemovedByType *string
RemovedByID *string
RemovedReasonCode *string
}
@@ -0,0 +1,14 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
type SanctionActive struct {
UserID string `sql:"primary_key"`
SanctionCode string `sql:"primary_key"`
RecordID string
}
@@ -0,0 +1,28 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type SanctionRecords struct {
RecordID string `sql:"primary_key"`
UserID string
SanctionCode string
Scope string
ReasonCode string
ActorType string
ActorID *string
AppliedAt time.Time
ExpiresAt *time.Time
RemovedAt *time.Time
RemovedByType *string
RemovedByID *string
RemovedReasonCode *string
}
@@ -0,0 +1,105 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var Accounts = newAccountsTable("user", "accounts", "")
type accountsTable struct {
postgres.Table
// Columns
UserID postgres.ColumnString
Email postgres.ColumnString
UserName postgres.ColumnString
DisplayName postgres.ColumnString
PreferredLanguage postgres.ColumnString
TimeZone postgres.ColumnString
DeclaredCountry postgres.ColumnString
CreatedAt postgres.ColumnTimestampz
UpdatedAt postgres.ColumnTimestampz
DeletedAt postgres.ColumnTimestampz
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type AccountsTable struct {
accountsTable
EXCLUDED accountsTable
}
// AS creates new AccountsTable with assigned alias
func (a AccountsTable) AS(alias string) *AccountsTable {
return newAccountsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new AccountsTable with assigned schema name
func (a AccountsTable) FromSchema(schemaName string) *AccountsTable {
return newAccountsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new AccountsTable with assigned table prefix
func (a AccountsTable) WithPrefix(prefix string) *AccountsTable {
return newAccountsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new AccountsTable with assigned table suffix
func (a AccountsTable) WithSuffix(suffix string) *AccountsTable {
return newAccountsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newAccountsTable(schemaName, tableName, alias string) *AccountsTable {
return &AccountsTable{
accountsTable: newAccountsTableImpl(schemaName, tableName, alias),
EXCLUDED: newAccountsTableImpl("", "excluded", ""),
}
}
func newAccountsTableImpl(schemaName, tableName, alias string) accountsTable {
var (
UserIDColumn = postgres.StringColumn("user_id")
EmailColumn = postgres.StringColumn("email")
UserNameColumn = postgres.StringColumn("user_name")
DisplayNameColumn = postgres.StringColumn("display_name")
PreferredLanguageColumn = postgres.StringColumn("preferred_language")
TimeZoneColumn = postgres.StringColumn("time_zone")
DeclaredCountryColumn = postgres.StringColumn("declared_country")
CreatedAtColumn = postgres.TimestampzColumn("created_at")
UpdatedAtColumn = postgres.TimestampzColumn("updated_at")
DeletedAtColumn = postgres.TimestampzColumn("deleted_at")
allColumns = postgres.ColumnList{UserIDColumn, EmailColumn, UserNameColumn, DisplayNameColumn, PreferredLanguageColumn, TimeZoneColumn, DeclaredCountryColumn, CreatedAtColumn, UpdatedAtColumn, DeletedAtColumn}
mutableColumns = postgres.ColumnList{EmailColumn, UserNameColumn, DisplayNameColumn, PreferredLanguageColumn, TimeZoneColumn, DeclaredCountryColumn, CreatedAtColumn, UpdatedAtColumn, DeletedAtColumn}
defaultColumns = postgres.ColumnList{DisplayNameColumn}
)
return accountsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
UserID: UserIDColumn,
Email: EmailColumn,
UserName: UserNameColumn,
DisplayName: DisplayNameColumn,
PreferredLanguage: PreferredLanguageColumn,
TimeZone: TimeZoneColumn,
DeclaredCountry: DeclaredCountryColumn,
CreatedAt: CreatedAtColumn,
UpdatedAt: UpdatedAtColumn,
DeletedAt: DeletedAtColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,93 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var BlockedEmails = newBlockedEmailsTable("user", "blocked_emails", "")
type blockedEmailsTable struct {
postgres.Table
// Columns
Email postgres.ColumnString
ReasonCode postgres.ColumnString
BlockedAt postgres.ColumnTimestampz
ActorType postgres.ColumnString
ActorID postgres.ColumnString
ResolvedUserID postgres.ColumnString
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type BlockedEmailsTable struct {
blockedEmailsTable
EXCLUDED blockedEmailsTable
}
// AS creates new BlockedEmailsTable with assigned alias
func (a BlockedEmailsTable) AS(alias string) *BlockedEmailsTable {
return newBlockedEmailsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new BlockedEmailsTable with assigned schema name
func (a BlockedEmailsTable) FromSchema(schemaName string) *BlockedEmailsTable {
return newBlockedEmailsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new BlockedEmailsTable with assigned table prefix
func (a BlockedEmailsTable) WithPrefix(prefix string) *BlockedEmailsTable {
return newBlockedEmailsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new BlockedEmailsTable with assigned table suffix
func (a BlockedEmailsTable) WithSuffix(suffix string) *BlockedEmailsTable {
return newBlockedEmailsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newBlockedEmailsTable(schemaName, tableName, alias string) *BlockedEmailsTable {
return &BlockedEmailsTable{
blockedEmailsTable: newBlockedEmailsTableImpl(schemaName, tableName, alias),
EXCLUDED: newBlockedEmailsTableImpl("", "excluded", ""),
}
}
func newBlockedEmailsTableImpl(schemaName, tableName, alias string) blockedEmailsTable {
var (
EmailColumn = postgres.StringColumn("email")
ReasonCodeColumn = postgres.StringColumn("reason_code")
BlockedAtColumn = postgres.TimestampzColumn("blocked_at")
ActorTypeColumn = postgres.StringColumn("actor_type")
ActorIDColumn = postgres.StringColumn("actor_id")
ResolvedUserIDColumn = postgres.StringColumn("resolved_user_id")
allColumns = postgres.ColumnList{EmailColumn, ReasonCodeColumn, BlockedAtColumn, ActorTypeColumn, ActorIDColumn, ResolvedUserIDColumn}
mutableColumns = postgres.ColumnList{ReasonCodeColumn, BlockedAtColumn, ActorTypeColumn, ActorIDColumn, ResolvedUserIDColumn}
defaultColumns = postgres.ColumnList{}
)
return blockedEmailsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
Email: EmailColumn,
ReasonCode: ReasonCodeColumn,
BlockedAt: BlockedAtColumn,
ActorType: ActorTypeColumn,
ActorID: ActorIDColumn,
ResolvedUserID: ResolvedUserIDColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,117 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var EntitlementRecords = newEntitlementRecordsTable("user", "entitlement_records", "")
type entitlementRecordsTable struct {
postgres.Table
// Columns
RecordID postgres.ColumnString
UserID postgres.ColumnString
PlanCode postgres.ColumnString
Source postgres.ColumnString
ActorType postgres.ColumnString
ActorID postgres.ColumnString
ReasonCode postgres.ColumnString
StartsAt postgres.ColumnTimestampz
EndsAt postgres.ColumnTimestampz
CreatedAt postgres.ColumnTimestampz
ClosedAt postgres.ColumnTimestampz
ClosedByType postgres.ColumnString
ClosedByID postgres.ColumnString
ClosedReasonCode postgres.ColumnString
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type EntitlementRecordsTable struct {
entitlementRecordsTable
EXCLUDED entitlementRecordsTable
}
// AS creates new EntitlementRecordsTable with assigned alias
func (a EntitlementRecordsTable) AS(alias string) *EntitlementRecordsTable {
return newEntitlementRecordsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new EntitlementRecordsTable with assigned schema name
func (a EntitlementRecordsTable) FromSchema(schemaName string) *EntitlementRecordsTable {
return newEntitlementRecordsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new EntitlementRecordsTable with assigned table prefix
func (a EntitlementRecordsTable) WithPrefix(prefix string) *EntitlementRecordsTable {
return newEntitlementRecordsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new EntitlementRecordsTable with assigned table suffix
func (a EntitlementRecordsTable) WithSuffix(suffix string) *EntitlementRecordsTable {
return newEntitlementRecordsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newEntitlementRecordsTable(schemaName, tableName, alias string) *EntitlementRecordsTable {
return &EntitlementRecordsTable{
entitlementRecordsTable: newEntitlementRecordsTableImpl(schemaName, tableName, alias),
EXCLUDED: newEntitlementRecordsTableImpl("", "excluded", ""),
}
}
func newEntitlementRecordsTableImpl(schemaName, tableName, alias string) entitlementRecordsTable {
var (
RecordIDColumn = postgres.StringColumn("record_id")
UserIDColumn = postgres.StringColumn("user_id")
PlanCodeColumn = postgres.StringColumn("plan_code")
SourceColumn = postgres.StringColumn("source")
ActorTypeColumn = postgres.StringColumn("actor_type")
ActorIDColumn = postgres.StringColumn("actor_id")
ReasonCodeColumn = postgres.StringColumn("reason_code")
StartsAtColumn = postgres.TimestampzColumn("starts_at")
EndsAtColumn = postgres.TimestampzColumn("ends_at")
CreatedAtColumn = postgres.TimestampzColumn("created_at")
ClosedAtColumn = postgres.TimestampzColumn("closed_at")
ClosedByTypeColumn = postgres.StringColumn("closed_by_type")
ClosedByIDColumn = postgres.StringColumn("closed_by_id")
ClosedReasonCodeColumn = postgres.StringColumn("closed_reason_code")
allColumns = postgres.ColumnList{RecordIDColumn, UserIDColumn, PlanCodeColumn, SourceColumn, ActorTypeColumn, ActorIDColumn, ReasonCodeColumn, StartsAtColumn, EndsAtColumn, CreatedAtColumn, ClosedAtColumn, ClosedByTypeColumn, ClosedByIDColumn, ClosedReasonCodeColumn}
mutableColumns = postgres.ColumnList{UserIDColumn, PlanCodeColumn, SourceColumn, ActorTypeColumn, ActorIDColumn, ReasonCodeColumn, StartsAtColumn, EndsAtColumn, CreatedAtColumn, ClosedAtColumn, ClosedByTypeColumn, ClosedByIDColumn, ClosedReasonCodeColumn}
defaultColumns = postgres.ColumnList{}
)
return entitlementRecordsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
RecordID: RecordIDColumn,
UserID: UserIDColumn,
PlanCode: PlanCodeColumn,
Source: SourceColumn,
ActorType: ActorTypeColumn,
ActorID: ActorIDColumn,
ReasonCode: ReasonCodeColumn,
StartsAt: StartsAtColumn,
EndsAt: EndsAtColumn,
CreatedAt: CreatedAtColumn,
ClosedAt: ClosedAtColumn,
ClosedByType: ClosedByTypeColumn,
ClosedByID: ClosedByIDColumn,
ClosedReasonCode: ClosedReasonCodeColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,105 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var EntitlementSnapshots = newEntitlementSnapshotsTable("user", "entitlement_snapshots", "")
type entitlementSnapshotsTable struct {
postgres.Table
// Columns
UserID postgres.ColumnString
PlanCode postgres.ColumnString
IsPaid postgres.ColumnBool
StartsAt postgres.ColumnTimestampz
EndsAt postgres.ColumnTimestampz
Source postgres.ColumnString
ActorType postgres.ColumnString
ActorID postgres.ColumnString
ReasonCode postgres.ColumnString
UpdatedAt postgres.ColumnTimestampz
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type EntitlementSnapshotsTable struct {
entitlementSnapshotsTable
EXCLUDED entitlementSnapshotsTable
}
// AS creates new EntitlementSnapshotsTable with assigned alias
func (a EntitlementSnapshotsTable) AS(alias string) *EntitlementSnapshotsTable {
return newEntitlementSnapshotsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new EntitlementSnapshotsTable with assigned schema name
func (a EntitlementSnapshotsTable) FromSchema(schemaName string) *EntitlementSnapshotsTable {
return newEntitlementSnapshotsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new EntitlementSnapshotsTable with assigned table prefix
func (a EntitlementSnapshotsTable) WithPrefix(prefix string) *EntitlementSnapshotsTable {
return newEntitlementSnapshotsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new EntitlementSnapshotsTable with assigned table suffix
func (a EntitlementSnapshotsTable) WithSuffix(suffix string) *EntitlementSnapshotsTable {
return newEntitlementSnapshotsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newEntitlementSnapshotsTable(schemaName, tableName, alias string) *EntitlementSnapshotsTable {
return &EntitlementSnapshotsTable{
entitlementSnapshotsTable: newEntitlementSnapshotsTableImpl(schemaName, tableName, alias),
EXCLUDED: newEntitlementSnapshotsTableImpl("", "excluded", ""),
}
}
func newEntitlementSnapshotsTableImpl(schemaName, tableName, alias string) entitlementSnapshotsTable {
var (
UserIDColumn = postgres.StringColumn("user_id")
PlanCodeColumn = postgres.StringColumn("plan_code")
IsPaidColumn = postgres.BoolColumn("is_paid")
StartsAtColumn = postgres.TimestampzColumn("starts_at")
EndsAtColumn = postgres.TimestampzColumn("ends_at")
SourceColumn = postgres.StringColumn("source")
ActorTypeColumn = postgres.StringColumn("actor_type")
ActorIDColumn = postgres.StringColumn("actor_id")
ReasonCodeColumn = postgres.StringColumn("reason_code")
UpdatedAtColumn = postgres.TimestampzColumn("updated_at")
allColumns = postgres.ColumnList{UserIDColumn, PlanCodeColumn, IsPaidColumn, StartsAtColumn, EndsAtColumn, SourceColumn, ActorTypeColumn, ActorIDColumn, ReasonCodeColumn, UpdatedAtColumn}
mutableColumns = postgres.ColumnList{PlanCodeColumn, IsPaidColumn, StartsAtColumn, EndsAtColumn, SourceColumn, ActorTypeColumn, ActorIDColumn, ReasonCodeColumn, UpdatedAtColumn}
defaultColumns = postgres.ColumnList{}
)
return entitlementSnapshotsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
UserID: UserIDColumn,
PlanCode: PlanCodeColumn,
IsPaid: IsPaidColumn,
StartsAt: StartsAtColumn,
EndsAt: EndsAtColumn,
Source: SourceColumn,
ActorType: ActorTypeColumn,
ActorID: ActorIDColumn,
ReasonCode: ReasonCodeColumn,
UpdatedAt: UpdatedAtColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,87 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var GooseDbVersion = newGooseDbVersionTable("user", "goose_db_version", "")
type gooseDbVersionTable struct {
postgres.Table
// Columns
ID postgres.ColumnInteger
VersionID postgres.ColumnInteger
IsApplied postgres.ColumnBool
Tstamp postgres.ColumnTimestamp
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type GooseDbVersionTable struct {
gooseDbVersionTable
EXCLUDED gooseDbVersionTable
}
// AS creates new GooseDbVersionTable with assigned alias
func (a GooseDbVersionTable) AS(alias string) *GooseDbVersionTable {
return newGooseDbVersionTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new GooseDbVersionTable with assigned schema name
func (a GooseDbVersionTable) FromSchema(schemaName string) *GooseDbVersionTable {
return newGooseDbVersionTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new GooseDbVersionTable with assigned table prefix
func (a GooseDbVersionTable) WithPrefix(prefix string) *GooseDbVersionTable {
return newGooseDbVersionTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new GooseDbVersionTable with assigned table suffix
func (a GooseDbVersionTable) WithSuffix(suffix string) *GooseDbVersionTable {
return newGooseDbVersionTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newGooseDbVersionTable(schemaName, tableName, alias string) *GooseDbVersionTable {
return &GooseDbVersionTable{
gooseDbVersionTable: newGooseDbVersionTableImpl(schemaName, tableName, alias),
EXCLUDED: newGooseDbVersionTableImpl("", "excluded", ""),
}
}
func newGooseDbVersionTableImpl(schemaName, tableName, alias string) gooseDbVersionTable {
var (
IDColumn = postgres.IntegerColumn("id")
VersionIDColumn = postgres.IntegerColumn("version_id")
IsAppliedColumn = postgres.BoolColumn("is_applied")
TstampColumn = postgres.TimestampColumn("tstamp")
allColumns = postgres.ColumnList{IDColumn, VersionIDColumn, IsAppliedColumn, TstampColumn}
mutableColumns = postgres.ColumnList{VersionIDColumn, IsAppliedColumn, TstampColumn}
defaultColumns = postgres.ColumnList{TstampColumn}
)
return gooseDbVersionTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
ID: IDColumn,
VersionID: VersionIDColumn,
IsApplied: IsAppliedColumn,
Tstamp: TstampColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,87 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var LimitActive = newLimitActiveTable("user", "limit_active", "")
type limitActiveTable struct {
postgres.Table
// Columns
UserID postgres.ColumnString
LimitCode postgres.ColumnString
RecordID postgres.ColumnString
Value postgres.ColumnInteger
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type LimitActiveTable struct {
limitActiveTable
EXCLUDED limitActiveTable
}
// AS creates new LimitActiveTable with assigned alias
func (a LimitActiveTable) AS(alias string) *LimitActiveTable {
return newLimitActiveTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new LimitActiveTable with assigned schema name
func (a LimitActiveTable) FromSchema(schemaName string) *LimitActiveTable {
return newLimitActiveTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new LimitActiveTable with assigned table prefix
func (a LimitActiveTable) WithPrefix(prefix string) *LimitActiveTable {
return newLimitActiveTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new LimitActiveTable with assigned table suffix
func (a LimitActiveTable) WithSuffix(suffix string) *LimitActiveTable {
return newLimitActiveTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newLimitActiveTable(schemaName, tableName, alias string) *LimitActiveTable {
return &LimitActiveTable{
limitActiveTable: newLimitActiveTableImpl(schemaName, tableName, alias),
EXCLUDED: newLimitActiveTableImpl("", "excluded", ""),
}
}
func newLimitActiveTableImpl(schemaName, tableName, alias string) limitActiveTable {
var (
UserIDColumn = postgres.StringColumn("user_id")
LimitCodeColumn = postgres.StringColumn("limit_code")
RecordIDColumn = postgres.StringColumn("record_id")
ValueColumn = postgres.IntegerColumn("value")
allColumns = postgres.ColumnList{UserIDColumn, LimitCodeColumn, RecordIDColumn, ValueColumn}
mutableColumns = postgres.ColumnList{RecordIDColumn, ValueColumn}
defaultColumns = postgres.ColumnList{}
)
return limitActiveTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
UserID: UserIDColumn,
LimitCode: LimitCodeColumn,
RecordID: RecordIDColumn,
Value: ValueColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,114 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var LimitRecords = newLimitRecordsTable("user", "limit_records", "")
type limitRecordsTable struct {
postgres.Table
// Columns
RecordID postgres.ColumnString
UserID postgres.ColumnString
LimitCode postgres.ColumnString
Value postgres.ColumnInteger
ReasonCode postgres.ColumnString
ActorType postgres.ColumnString
ActorID postgres.ColumnString
AppliedAt postgres.ColumnTimestampz
ExpiresAt postgres.ColumnTimestampz
RemovedAt postgres.ColumnTimestampz
RemovedByType postgres.ColumnString
RemovedByID postgres.ColumnString
RemovedReasonCode postgres.ColumnString
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type LimitRecordsTable struct {
limitRecordsTable
EXCLUDED limitRecordsTable
}
// AS creates new LimitRecordsTable with assigned alias
func (a LimitRecordsTable) AS(alias string) *LimitRecordsTable {
return newLimitRecordsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new LimitRecordsTable with assigned schema name
func (a LimitRecordsTable) FromSchema(schemaName string) *LimitRecordsTable {
return newLimitRecordsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new LimitRecordsTable with assigned table prefix
func (a LimitRecordsTable) WithPrefix(prefix string) *LimitRecordsTable {
return newLimitRecordsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new LimitRecordsTable with assigned table suffix
func (a LimitRecordsTable) WithSuffix(suffix string) *LimitRecordsTable {
return newLimitRecordsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newLimitRecordsTable(schemaName, tableName, alias string) *LimitRecordsTable {
return &LimitRecordsTable{
limitRecordsTable: newLimitRecordsTableImpl(schemaName, tableName, alias),
EXCLUDED: newLimitRecordsTableImpl("", "excluded", ""),
}
}
func newLimitRecordsTableImpl(schemaName, tableName, alias string) limitRecordsTable {
var (
RecordIDColumn = postgres.StringColumn("record_id")
UserIDColumn = postgres.StringColumn("user_id")
LimitCodeColumn = postgres.StringColumn("limit_code")
ValueColumn = postgres.IntegerColumn("value")
ReasonCodeColumn = postgres.StringColumn("reason_code")
ActorTypeColumn = postgres.StringColumn("actor_type")
ActorIDColumn = postgres.StringColumn("actor_id")
AppliedAtColumn = postgres.TimestampzColumn("applied_at")
ExpiresAtColumn = postgres.TimestampzColumn("expires_at")
RemovedAtColumn = postgres.TimestampzColumn("removed_at")
RemovedByTypeColumn = postgres.StringColumn("removed_by_type")
RemovedByIDColumn = postgres.StringColumn("removed_by_id")
RemovedReasonCodeColumn = postgres.StringColumn("removed_reason_code")
allColumns = postgres.ColumnList{RecordIDColumn, UserIDColumn, LimitCodeColumn, ValueColumn, ReasonCodeColumn, ActorTypeColumn, ActorIDColumn, AppliedAtColumn, ExpiresAtColumn, RemovedAtColumn, RemovedByTypeColumn, RemovedByIDColumn, RemovedReasonCodeColumn}
mutableColumns = postgres.ColumnList{UserIDColumn, LimitCodeColumn, ValueColumn, ReasonCodeColumn, ActorTypeColumn, ActorIDColumn, AppliedAtColumn, ExpiresAtColumn, RemovedAtColumn, RemovedByTypeColumn, RemovedByIDColumn, RemovedReasonCodeColumn}
defaultColumns = postgres.ColumnList{}
)
return limitRecordsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
RecordID: RecordIDColumn,
UserID: UserIDColumn,
LimitCode: LimitCodeColumn,
Value: ValueColumn,
ReasonCode: ReasonCodeColumn,
ActorType: ActorTypeColumn,
ActorID: ActorIDColumn,
AppliedAt: AppliedAtColumn,
ExpiresAt: ExpiresAtColumn,
RemovedAt: RemovedAtColumn,
RemovedByType: RemovedByTypeColumn,
RemovedByID: RemovedByIDColumn,
RemovedReasonCode: RemovedReasonCodeColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,84 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var SanctionActive = newSanctionActiveTable("user", "sanction_active", "")
type sanctionActiveTable struct {
postgres.Table
// Columns
UserID postgres.ColumnString
SanctionCode postgres.ColumnString
RecordID postgres.ColumnString
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type SanctionActiveTable struct {
sanctionActiveTable
EXCLUDED sanctionActiveTable
}
// AS creates new SanctionActiveTable with assigned alias
func (a SanctionActiveTable) AS(alias string) *SanctionActiveTable {
return newSanctionActiveTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new SanctionActiveTable with assigned schema name
func (a SanctionActiveTable) FromSchema(schemaName string) *SanctionActiveTable {
return newSanctionActiveTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new SanctionActiveTable with assigned table prefix
func (a SanctionActiveTable) WithPrefix(prefix string) *SanctionActiveTable {
return newSanctionActiveTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new SanctionActiveTable with assigned table suffix
func (a SanctionActiveTable) WithSuffix(suffix string) *SanctionActiveTable {
return newSanctionActiveTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newSanctionActiveTable(schemaName, tableName, alias string) *SanctionActiveTable {
return &SanctionActiveTable{
sanctionActiveTable: newSanctionActiveTableImpl(schemaName, tableName, alias),
EXCLUDED: newSanctionActiveTableImpl("", "excluded", ""),
}
}
func newSanctionActiveTableImpl(schemaName, tableName, alias string) sanctionActiveTable {
var (
UserIDColumn = postgres.StringColumn("user_id")
SanctionCodeColumn = postgres.StringColumn("sanction_code")
RecordIDColumn = postgres.StringColumn("record_id")
allColumns = postgres.ColumnList{UserIDColumn, SanctionCodeColumn, RecordIDColumn}
mutableColumns = postgres.ColumnList{RecordIDColumn}
defaultColumns = postgres.ColumnList{}
)
return sanctionActiveTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
UserID: UserIDColumn,
SanctionCode: SanctionCodeColumn,
RecordID: RecordIDColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,114 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var SanctionRecords = newSanctionRecordsTable("user", "sanction_records", "")
type sanctionRecordsTable struct {
postgres.Table
// Columns
RecordID postgres.ColumnString
UserID postgres.ColumnString
SanctionCode postgres.ColumnString
Scope postgres.ColumnString
ReasonCode postgres.ColumnString
ActorType postgres.ColumnString
ActorID postgres.ColumnString
AppliedAt postgres.ColumnTimestampz
ExpiresAt postgres.ColumnTimestampz
RemovedAt postgres.ColumnTimestampz
RemovedByType postgres.ColumnString
RemovedByID postgres.ColumnString
RemovedReasonCode postgres.ColumnString
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type SanctionRecordsTable struct {
sanctionRecordsTable
EXCLUDED sanctionRecordsTable
}
// AS creates new SanctionRecordsTable with assigned alias
func (a SanctionRecordsTable) AS(alias string) *SanctionRecordsTable {
return newSanctionRecordsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new SanctionRecordsTable with assigned schema name
func (a SanctionRecordsTable) FromSchema(schemaName string) *SanctionRecordsTable {
return newSanctionRecordsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new SanctionRecordsTable with assigned table prefix
func (a SanctionRecordsTable) WithPrefix(prefix string) *SanctionRecordsTable {
return newSanctionRecordsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new SanctionRecordsTable with assigned table suffix
func (a SanctionRecordsTable) WithSuffix(suffix string) *SanctionRecordsTable {
return newSanctionRecordsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newSanctionRecordsTable(schemaName, tableName, alias string) *SanctionRecordsTable {
return &SanctionRecordsTable{
sanctionRecordsTable: newSanctionRecordsTableImpl(schemaName, tableName, alias),
EXCLUDED: newSanctionRecordsTableImpl("", "excluded", ""),
}
}
func newSanctionRecordsTableImpl(schemaName, tableName, alias string) sanctionRecordsTable {
var (
RecordIDColumn = postgres.StringColumn("record_id")
UserIDColumn = postgres.StringColumn("user_id")
SanctionCodeColumn = postgres.StringColumn("sanction_code")
ScopeColumn = postgres.StringColumn("scope")
ReasonCodeColumn = postgres.StringColumn("reason_code")
ActorTypeColumn = postgres.StringColumn("actor_type")
ActorIDColumn = postgres.StringColumn("actor_id")
AppliedAtColumn = postgres.TimestampzColumn("applied_at")
ExpiresAtColumn = postgres.TimestampzColumn("expires_at")
RemovedAtColumn = postgres.TimestampzColumn("removed_at")
RemovedByTypeColumn = postgres.StringColumn("removed_by_type")
RemovedByIDColumn = postgres.StringColumn("removed_by_id")
RemovedReasonCodeColumn = postgres.StringColumn("removed_reason_code")
allColumns = postgres.ColumnList{RecordIDColumn, UserIDColumn, SanctionCodeColumn, ScopeColumn, ReasonCodeColumn, ActorTypeColumn, ActorIDColumn, AppliedAtColumn, ExpiresAtColumn, RemovedAtColumn, RemovedByTypeColumn, RemovedByIDColumn, RemovedReasonCodeColumn}
mutableColumns = postgres.ColumnList{UserIDColumn, SanctionCodeColumn, ScopeColumn, ReasonCodeColumn, ActorTypeColumn, ActorIDColumn, AppliedAtColumn, ExpiresAtColumn, RemovedAtColumn, RemovedByTypeColumn, RemovedByIDColumn, RemovedReasonCodeColumn}
defaultColumns = postgres.ColumnList{}
)
return sanctionRecordsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
RecordID: RecordIDColumn,
UserID: UserIDColumn,
SanctionCode: SanctionCodeColumn,
Scope: ScopeColumn,
ReasonCode: ReasonCodeColumn,
ActorType: ActorTypeColumn,
ActorID: ActorIDColumn,
AppliedAt: AppliedAtColumn,
ExpiresAt: ExpiresAtColumn,
RemovedAt: RemovedAtColumn,
RemovedByType: RemovedByTypeColumn,
RemovedByID: RemovedByIDColumn,
RemovedReasonCode: RemovedReasonCodeColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
@@ -0,0 +1,22 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
// UseSchema sets a new schema name for all generated table SQL builder types. It is recommended to invoke
// this method only once at the beginning of the program.
func UseSchema(schema string) {
Accounts = Accounts.FromSchema(schema)
BlockedEmails = BlockedEmails.FromSchema(schema)
EntitlementRecords = EntitlementRecords.FromSchema(schema)
EntitlementSnapshots = EntitlementSnapshots.FromSchema(schema)
GooseDbVersion = GooseDbVersion.FromSchema(schema)
LimitActive = LimitActive.FromSchema(schema)
LimitRecords = LimitRecords.FromSchema(schema)
SanctionActive = SanctionActive.FromSchema(schema)
SanctionRecords = SanctionRecords.FromSchema(schema)
}
@@ -0,0 +1,169 @@
-- +goose Up
-- accounts holds the editable source-of-truth user-account state.
-- email and user_name remain UNIQUE for both live and soft-deleted records:
-- emails are never reassigned to a fresh user_id after DeleteUser, and
-- user_name is immutable for the lifetime of the account.
CREATE TABLE accounts (
user_id text PRIMARY KEY,
email text NOT NULL,
user_name text NOT NULL,
display_name text NOT NULL DEFAULT '',
preferred_language text NOT NULL,
time_zone text NOT NULL,
declared_country text,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
deleted_at timestamptz,
CONSTRAINT accounts_email_unique UNIQUE (email),
CONSTRAINT accounts_user_name_unique UNIQUE (user_name)
);
-- Newest-first listing index used by the trusted admin user-list surface.
CREATE INDEX accounts_listing_idx
ON accounts (created_at DESC, user_id DESC);
-- Reverse-lookup index for the optional declared-country filter; the partial
-- predicate keeps the index small while declared_country is mostly NULL.
CREATE INDEX accounts_declared_country_idx
ON accounts (declared_country)
WHERE declared_country IS NOT NULL;
-- blocked_emails persists pre-user blocked-email subjects that may exist
-- before any user account exists, plus the blocked subjects produced by
-- BlockByUserID/BlockByEmail. resolved_user_id is populated when the block
-- corresponds to an existing or formerly existing account.
CREATE TABLE blocked_emails (
email text PRIMARY KEY,
reason_code text NOT NULL,
blocked_at timestamptz NOT NULL,
actor_type text,
actor_id text,
resolved_user_id text
);
-- entitlement_records stores the immutable history of entitlement periods.
-- Each row represents one segment that was current at some point; closed
-- segments carry closed_* metadata.
CREATE TABLE entitlement_records (
record_id text PRIMARY KEY,
user_id text NOT NULL REFERENCES accounts(user_id),
plan_code text NOT NULL,
source text NOT NULL,
actor_type text NOT NULL,
actor_id text,
reason_code text NOT NULL,
starts_at timestamptz NOT NULL,
ends_at timestamptz,
created_at timestamptz NOT NULL,
closed_at timestamptz,
closed_by_type text,
closed_by_id text,
closed_reason_code text
);
CREATE INDEX entitlement_records_user_idx
ON entitlement_records (user_id, created_at DESC);
-- entitlement_snapshots stores the read-optimized current entitlement state.
-- Exactly one row per user_id; updated atomically together with history rows
-- by EntitlementLifecycleStore operations.
CREATE TABLE entitlement_snapshots (
user_id text PRIMARY KEY REFERENCES accounts(user_id),
plan_code text NOT NULL,
is_paid boolean NOT NULL,
starts_at timestamptz NOT NULL,
ends_at timestamptz,
source text NOT NULL,
actor_type text NOT NULL,
actor_id text,
reason_code text NOT NULL,
updated_at timestamptz NOT NULL
);
-- Coarse free-versus-paid filter used by the admin listing surface.
CREATE INDEX entitlement_snapshots_paid_state_idx
ON entitlement_snapshots (is_paid, plan_code);
-- Finite paid-expiry filter; partial predicate keeps the index limited to
-- finite paid plans (paid_monthly, paid_yearly).
CREATE INDEX entitlement_snapshots_paid_expiry_idx
ON entitlement_snapshots (ends_at)
WHERE is_paid AND ends_at IS NOT NULL;
-- sanction_records stores the immutable history of sanction mutations.
-- A row may carry removed_at + removed_* fields once the sanction is lifted.
CREATE TABLE sanction_records (
record_id text PRIMARY KEY,
user_id text NOT NULL REFERENCES accounts(user_id),
sanction_code text NOT NULL,
scope text NOT NULL,
reason_code text NOT NULL,
actor_type text NOT NULL,
actor_id text,
applied_at timestamptz NOT NULL,
expires_at timestamptz,
removed_at timestamptz,
removed_by_type text,
removed_by_id text,
removed_reason_code text
);
CREATE INDEX sanction_records_user_idx
ON sanction_records (user_id, applied_at DESC);
-- sanction_active stores the at-most-one active record per (user_id,
-- sanction_code). It is maintained by PolicyLifecycleStore in the same
-- transaction as the corresponding sanction_records mutation.
CREATE TABLE sanction_active (
user_id text NOT NULL REFERENCES accounts(user_id),
sanction_code text NOT NULL,
record_id text NOT NULL REFERENCES sanction_records(record_id),
PRIMARY KEY (user_id, sanction_code)
);
CREATE INDEX sanction_active_code_idx
ON sanction_active (sanction_code);
-- limit_records mirrors sanction_records for user-specific limit overrides.
CREATE TABLE limit_records (
record_id text PRIMARY KEY,
user_id text NOT NULL REFERENCES accounts(user_id),
limit_code text NOT NULL,
value integer NOT NULL,
reason_code text NOT NULL,
actor_type text NOT NULL,
actor_id text,
applied_at timestamptz NOT NULL,
expires_at timestamptz,
removed_at timestamptz,
removed_by_type text,
removed_by_id text,
removed_reason_code text
);
CREATE INDEX limit_records_user_idx
ON limit_records (user_id, applied_at DESC);
-- limit_active mirrors sanction_active for user-specific limits. value is
-- denormalised so the admin listing predicate can read it without joining
-- the full history.
CREATE TABLE limit_active (
user_id text NOT NULL REFERENCES accounts(user_id),
limit_code text NOT NULL,
record_id text NOT NULL REFERENCES limit_records(record_id),
value integer NOT NULL,
PRIMARY KEY (user_id, limit_code)
);
CREATE INDEX limit_active_code_idx
ON limit_active (limit_code);
-- +goose Down
DROP TABLE IF EXISTS limit_active;
DROP TABLE IF EXISTS limit_records;
DROP TABLE IF EXISTS sanction_active;
DROP TABLE IF EXISTS sanction_records;
DROP TABLE IF EXISTS entitlement_snapshots;
DROP TABLE IF EXISTS entitlement_records;
DROP TABLE IF EXISTS blocked_emails;
DROP TABLE IF EXISTS accounts;
@@ -0,0 +1,19 @@
// Package migrations exposes the embedded goose migration files used by
// User Service to provision its `user` schema in PostgreSQL.
//
// The embedded filesystem is consumed by `pkg/postgres.RunMigrations`
// during user-service startup and by `cmd/jetgen` when regenerating the
// `internal/adapters/postgres/jet/` code against a transient PostgreSQL
// instance.
package migrations
import "embed"
//go:embed *.sql
var fs embed.FS
// FS returns the embedded filesystem containing every numbered goose
// migration shipped with User Service.
func FS() embed.FS {
return fs
}
@@ -0,0 +1,375 @@
package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/ports"
pg "github.com/go-jet/jet/v2/postgres"
)
// SQL constraint names declared in 00001_init.sql; referenced from error
// translation so we can disambiguate UNIQUE violations on (email) versus
// (user_name).
const (
accountsEmailUniqueConstraint = "accounts_email_unique"
accountsUserNameUniqueConstraint = "accounts_user_name_unique"
)
// accountSelectColumns is the canonical SELECT list for accounts, matching
// scanAccountRow's column order.
var accountSelectColumns = pg.ColumnList{
pgtable.Accounts.UserID,
pgtable.Accounts.Email,
pgtable.Accounts.UserName,
pgtable.Accounts.DisplayName,
pgtable.Accounts.PreferredLanguage,
pgtable.Accounts.TimeZone,
pgtable.Accounts.DeclaredCountry,
pgtable.Accounts.CreatedAt,
pgtable.Accounts.UpdatedAt,
pgtable.Accounts.DeletedAt,
}
// Create stores one new account record. Email and user-name uniqueness are
// enforced by the schema; conflicts on those columns surface as
// ports.ErrConflict (with ports.ErrUserNameConflict for the dedicated
// user-name index).
func (store *Store) Create(ctx context.Context, input ports.CreateAccountInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("create account in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "create account in postgres")
if err != nil {
return err
}
defer cancel()
if err := insertAccount(operationCtx, store.db, input.Account); err != nil {
return err
}
return nil
}
// insertAccount runs one INSERT against accounts using the supplied Queryer
// (a *sql.DB or a *sql.Tx). It centralises the column list and error
// translation used by Create and EnsureByEmail.
func insertAccount(ctx context.Context, q queryer, record account.UserAccount) error {
stmt := pgtable.Accounts.INSERT(
pgtable.Accounts.UserID,
pgtable.Accounts.Email,
pgtable.Accounts.UserName,
pgtable.Accounts.DisplayName,
pgtable.Accounts.PreferredLanguage,
pgtable.Accounts.TimeZone,
pgtable.Accounts.DeclaredCountry,
pgtable.Accounts.CreatedAt,
pgtable.Accounts.UpdatedAt,
pgtable.Accounts.DeletedAt,
).VALUES(
record.UserID.String(),
record.Email.String(),
record.UserName.String(),
record.DisplayName.String(),
record.PreferredLanguage.String(),
record.TimeZone.String(),
nullableCountry(record.DeclaredCountry),
record.CreatedAt.UTC(),
record.UpdatedAt.UTC(),
nullableTime(record.DeletedAt),
)
query, args := stmt.Sql()
_, err := q.ExecContext(ctx, query, args...)
if err == nil {
return nil
}
if mapped := classifyUniqueViolation(err, accountsUserNameUniqueConstraint, ports.ErrUserNameConflict); mapped != nil {
return fmt.Errorf("create account %q in postgres: %w", record.UserID, mapped)
}
if isUniqueViolation(err) {
return fmt.Errorf("create account %q in postgres: %w", record.UserID, ports.ErrConflict)
}
return fmt.Errorf("create account %q in postgres: %w", record.UserID, err)
}
// queryer is the subset of *sql.DB / *sql.Tx used by helpers that need to
// run inside an existing transaction or against the bare pool.
type queryer interface {
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
// GetByUserID returns the stored account identified by userID.
func (store *Store) GetByUserID(ctx context.Context, userID common.UserID) (account.UserAccount, error) {
if err := userID.Validate(); err != nil {
return account.UserAccount{}, fmt.Errorf("get account by user id from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get account by user id from postgres")
if err != nil {
return account.UserAccount{}, err
}
defer cancel()
record, err := scanAccountByUserID(operationCtx, store.db, userID)
switch {
case errors.Is(err, ports.ErrNotFound):
return account.UserAccount{}, fmt.Errorf("get account by user id %q from postgres: %w", userID, ports.ErrNotFound)
case err != nil:
return account.UserAccount{}, fmt.Errorf("get account by user id %q from postgres: %w", userID, err)
}
return record, nil
}
// GetByEmail returns the stored account identified by the normalized e-mail
// address.
func (store *Store) GetByEmail(ctx context.Context, email common.Email) (account.UserAccount, error) {
if err := email.Validate(); err != nil {
return account.UserAccount{}, fmt.Errorf("get account by email from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get account by email from postgres")
if err != nil {
return account.UserAccount{}, err
}
defer cancel()
record, err := scanAccountByEmail(operationCtx, store.db, email)
switch {
case errors.Is(err, ports.ErrNotFound):
return account.UserAccount{}, fmt.Errorf("get account by email %q from postgres: %w", email, ports.ErrNotFound)
case err != nil:
return account.UserAccount{}, fmt.Errorf("get account by email %q from postgres: %w", email, err)
}
return record, nil
}
// GetByUserName returns the stored account identified by the exact stored
// user name.
func (store *Store) GetByUserName(ctx context.Context, userName common.UserName) (account.UserAccount, error) {
if err := userName.Validate(); err != nil {
return account.UserAccount{}, fmt.Errorf("get account by user name from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get account by user name from postgres")
if err != nil {
return account.UserAccount{}, err
}
defer cancel()
record, err := scanAccountByUserName(operationCtx, store.db, userName)
switch {
case errors.Is(err, ports.ErrNotFound):
return account.UserAccount{}, fmt.Errorf("get account by user name %q from postgres: %w", userName, ports.ErrNotFound)
case err != nil:
return account.UserAccount{}, fmt.Errorf("get account by user name %q from postgres: %w", userName, err)
}
return record, nil
}
// ExistsByUserID reports whether userID currently identifies a stored account
// that is not soft-deleted. Soft-deleted accounts are treated as non-existing
// for external callers per Stage 22.
func (store *Store) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
if err := userID.Validate(); err != nil {
return false, fmt.Errorf("exists by user id from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "exists by user id from postgres")
if err != nil {
return false, err
}
defer cancel()
stmt := pg.SELECT(pgtable.Accounts.DeletedAt).
FROM(pgtable.Accounts).
WHERE(pgtable.Accounts.UserID.EQ(pg.String(userID.String())))
query, args := stmt.Sql()
var deletedAt *time.Time
err = store.db.QueryRowContext(operationCtx, query, args...).Scan(&deletedAt)
switch {
case errors.Is(err, sql.ErrNoRows):
return false, nil
case err != nil:
return false, fmt.Errorf("exists by user id %q from postgres: %w", userID, err)
}
return deletedAt == nil, nil
}
// Update replaces the stored account state for record.UserID. Email and
// user_name are immutable; mutation attempts return ports.ErrConflict.
// declared_country, display_name, preferred_language, time_zone, updated_at,
// and deleted_at are the columns affected.
func (store *Store) Update(ctx context.Context, record account.UserAccount) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update account in postgres: %w", err)
}
return store.withTx(ctx, "update account in postgres", func(ctx context.Context, tx *sql.Tx) error {
current, err := scanAccountForUpdate(ctx, tx, record.UserID)
if err != nil {
if errors.Is(err, ports.ErrNotFound) {
return fmt.Errorf("update account %q in postgres: %w", record.UserID, ports.ErrNotFound)
}
return fmt.Errorf("update account %q in postgres: %w", record.UserID, err)
}
if current.Email != record.Email || current.UserName != record.UserName {
return fmt.Errorf("update account %q in postgres: %w", record.UserID, ports.ErrConflict)
}
stmt := pgtable.Accounts.UPDATE(
pgtable.Accounts.DisplayName,
pgtable.Accounts.PreferredLanguage,
pgtable.Accounts.TimeZone,
pgtable.Accounts.DeclaredCountry,
pgtable.Accounts.UpdatedAt,
pgtable.Accounts.DeletedAt,
).SET(
record.DisplayName.String(),
record.PreferredLanguage.String(),
record.TimeZone.String(),
nullableCountry(record.DeclaredCountry),
record.UpdatedAt.UTC(),
nullableTime(record.DeletedAt),
).WHERE(pgtable.Accounts.UserID.EQ(pg.String(record.UserID.String())))
query, args := stmt.Sql()
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("update account %q in postgres: %w", record.UserID, err)
}
return nil
})
}
// scanAccountByUserID is a thin wrapper around scanAccountWhere for the
// (user_id) column so atomic flows can reuse the same scanner with FOR
// UPDATE locking semantics.
func scanAccountByUserID(ctx context.Context, q queryer, userID common.UserID) (account.UserAccount, error) {
return scanAccountWhere(ctx, q, pgtable.Accounts.UserID.EQ(pg.String(userID.String())), false)
}
func scanAccountByEmail(ctx context.Context, q queryer, email common.Email) (account.UserAccount, error) {
return scanAccountWhere(ctx, q, pgtable.Accounts.Email.EQ(pg.String(email.String())), false)
}
func scanAccountByUserName(ctx context.Context, q queryer, userName common.UserName) (account.UserAccount, error) {
return scanAccountWhere(ctx, q, pgtable.Accounts.UserName.EQ(pg.String(userName.String())), false)
}
func scanAccountForUpdate(ctx context.Context, q queryer, userID common.UserID) (account.UserAccount, error) {
return scanAccountWhere(ctx, q, pgtable.Accounts.UserID.EQ(pg.String(userID.String())), true)
}
func scanAccountForUpdateByEmail(ctx context.Context, q queryer, email common.Email) (account.UserAccount, error) {
return scanAccountWhere(ctx, q, pgtable.Accounts.Email.EQ(pg.String(email.String())), true)
}
func scanAccountWhere(ctx context.Context, q queryer, condition pg.BoolExpression, forUpdate bool) (account.UserAccount, error) {
stmt := pg.SELECT(accountSelectColumns).
FROM(pgtable.Accounts).
WHERE(condition)
if forUpdate {
stmt = stmt.FOR(pg.UPDATE())
}
query, args := stmt.Sql()
row := q.QueryRowContext(ctx, query, args...)
return scanAccountRow(row)
}
func scanAccountRow(row *sql.Row) (account.UserAccount, error) {
var (
record account.UserAccount
userID string
email string
userName string
displayName string
preferredLang string
timeZone string
declaredCountry *string
createdAt time.Time
updatedAt time.Time
deletedAt *time.Time
)
if err := row.Scan(
&userID, &email, &userName, &displayName,
&preferredLang, &timeZone, &declaredCountry,
&createdAt, &updatedAt, &deletedAt,
); err != nil {
return account.UserAccount{}, mapNotFound(err)
}
record.UserID = common.UserID(userID)
record.Email = common.Email(email)
record.UserName = common.UserName(userName)
record.DisplayName = common.DisplayName(displayName)
record.PreferredLanguage = common.LanguageTag(preferredLang)
record.TimeZone = common.TimeZoneName(timeZone)
if declaredCountry != nil {
record.DeclaredCountry = common.CountryCode(*declaredCountry)
}
record.CreatedAt = createdAt.UTC()
record.UpdatedAt = updatedAt.UTC()
record.DeletedAt = timeFromNullable(deletedAt)
return record, nil
}
// AccountStore adapts Store to the UserAccountStore port. The wrapper is
// returned by Store.Accounts() so callers that need only the narrow port
// interface remain unaware of the broader Store surface.
type AccountStore struct {
store *Store
}
// Accounts returns one adapter that exposes the user-account store port over
// Store.
func (store *Store) Accounts() *AccountStore {
if store == nil {
return nil
}
return &AccountStore{store: store}
}
// Create stores one new account record.
func (adapter *AccountStore) Create(ctx context.Context, input ports.CreateAccountInput) error {
return adapter.store.Create(ctx, input)
}
// GetByUserID returns the stored account identified by userID.
func (adapter *AccountStore) GetByUserID(ctx context.Context, userID common.UserID) (account.UserAccount, error) {
return adapter.store.GetByUserID(ctx, userID)
}
// GetByEmail returns the stored account identified by email.
func (adapter *AccountStore) GetByEmail(ctx context.Context, email common.Email) (account.UserAccount, error) {
return adapter.store.GetByEmail(ctx, email)
}
// GetByUserName returns the stored account identified by userName.
func (adapter *AccountStore) GetByUserName(ctx context.Context, userName common.UserName) (account.UserAccount, error) {
return adapter.store.GetByUserName(ctx, userName)
}
// ExistsByUserID reports whether userID currently identifies a stored
// account.
func (adapter *AccountStore) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
return adapter.store.ExistsByUserID(ctx, userID)
}
// Update replaces the stored account state for record.UserID.
func (adapter *AccountStore) Update(ctx context.Context, record account.UserAccount) error {
return adapter.store.Update(ctx, record)
}
var _ ports.UserAccountStore = (*AccountStore)(nil)
@@ -0,0 +1,280 @@
package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/authblock"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/ports"
)
// deletedAccountBlockReasonCode is returned to auth callers when the lookup
// resolves to a soft-deleted account. Auth/Session treats this exactly like
// a regular block: it refuses to mint a session for the subject. The code is
// not a real sanction record; it lives only on the wire.
const deletedAccountBlockReasonCode common.ReasonCode = "account_deleted"
// ResolveByEmail returns the current coarse auth-facing resolution state for
// email. The decision tree, in order:
//
// 1. blocked_emails has a row for this address → blocked.
// 2. accounts has a non-soft-deleted row for this address → existing.
// 3. accounts has a soft-deleted row for this address → blocked
// (account_deleted).
// 4. otherwise → creatable.
//
// The whole sequence is a read-only path; no transaction is required.
func (store *Store) ResolveByEmail(ctx context.Context, email common.Email) (ports.ResolveByEmailResult, error) {
if err := email.Validate(); err != nil {
return ports.ResolveByEmailResult{}, fmt.Errorf("resolve by email in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "resolve by email in postgres")
if err != nil {
return ports.ResolveByEmailResult{}, err
}
defer cancel()
blocked, err := scanBlockedEmail(operationCtx, store.db, email, false)
switch {
case err == nil:
return ports.ResolveByEmailResult{
Kind: ports.AuthResolutionKindBlocked,
BlockReasonCode: blocked.ReasonCode,
}, nil
case !errors.Is(err, ports.ErrNotFound):
return ports.ResolveByEmailResult{}, fmt.Errorf("resolve by email %q in postgres: %w", email, err)
}
record, err := scanAccountByEmail(operationCtx, store.db, email)
switch {
case errors.Is(err, ports.ErrNotFound):
return ports.ResolveByEmailResult{Kind: ports.AuthResolutionKindCreatable}, nil
case err != nil:
return ports.ResolveByEmailResult{}, fmt.Errorf("resolve by email %q in postgres: %w", email, err)
}
if record.IsDeleted() {
return ports.ResolveByEmailResult{
Kind: ports.AuthResolutionKindBlocked,
BlockReasonCode: deletedAccountBlockReasonCode,
}, nil
}
return ports.ResolveByEmailResult{
Kind: ports.AuthResolutionKindExisting,
UserID: record.UserID,
}, nil
}
// EnsureByEmail atomically returns an existing user, creates a new one, or
// reports a blocked outcome. The whole flow runs in one transaction with
// row-level locks on `blocked_emails(email)` and `accounts(email)` so we
// observe a consistent snapshot of the auth-facing state.
//
// On the create branch the transaction also INSERTs the initial
// entitlement_records row and the entitlement_snapshots row. UNIQUE
// violations on user_id or user_name surface as ports.ErrConflict (with
// ports.ErrUserNameConflict for the user-name index).
func (store *Store) EnsureByEmail(ctx context.Context, input ports.EnsureByEmailInput) (ports.EnsureByEmailResult, error) {
if err := input.Validate(); err != nil {
return ports.EnsureByEmailResult{}, fmt.Errorf("ensure by email in postgres: %w", err)
}
var (
result ports.EnsureByEmailResult
handled bool
)
if err := store.withTx(ctx, "ensure by email in postgres", func(ctx context.Context, tx *sql.Tx) error {
blocked, err := scanBlockedEmail(ctx, tx, input.Email, true)
switch {
case err == nil:
result = ports.EnsureByEmailResult{
Outcome: ports.EnsureByEmailOutcomeBlocked,
BlockReasonCode: blocked.ReasonCode,
}
handled = true
return nil
case !errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("ensure by email %q in postgres: %w", input.Email, err)
}
existing, err := scanAccountForUpdateByEmail(ctx, tx, input.Email)
switch {
case err == nil:
if existing.IsDeleted() {
result = ports.EnsureByEmailResult{
Outcome: ports.EnsureByEmailOutcomeBlocked,
BlockReasonCode: deletedAccountBlockReasonCode,
}
handled = true
return nil
}
result = ports.EnsureByEmailResult{
Outcome: ports.EnsureByEmailOutcomeExisting,
UserID: existing.UserID,
}
handled = true
return nil
case !errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("ensure by email %q in postgres: %w", input.Email, err)
}
if err := insertAccount(ctx, tx, input.Account); err != nil {
return err
}
if err := insertEntitlementPeriod(ctx, tx, input.EntitlementRecord); err != nil {
return err
}
if err := upsertEntitlementSnapshot(ctx, tx, input.Entitlement); err != nil {
return err
}
result = ports.EnsureByEmailResult{
Outcome: ports.EnsureByEmailOutcomeCreated,
UserID: input.Account.UserID,
}
handled = true
return nil
}); err != nil {
return ports.EnsureByEmailResult{}, err
}
if !handled {
return ports.EnsureByEmailResult{}, fmt.Errorf("ensure by email %q in postgres: unhandled transaction outcome", input.Email)
}
return result, nil
}
// BlockByUserID applies a block to the account identified by userID. The
// block is stored as a row in blocked_emails keyed on the user's e-mail with
// resolved_user_id pointing back to the account.
func (store *Store) BlockByUserID(ctx context.Context, input ports.BlockByUserIDInput) (ports.BlockResult, error) {
if err := input.Validate(); err != nil {
return ports.BlockResult{}, fmt.Errorf("block by user id in postgres: %w", err)
}
var (
result ports.BlockResult
handled bool
)
if err := store.withTx(ctx, "block by user id in postgres", func(ctx context.Context, tx *sql.Tx) error {
acc, err := scanAccountForUpdate(ctx, tx, input.UserID)
switch {
case errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("block by user id %q in postgres: %w", input.UserID, ports.ErrNotFound)
case err != nil:
return fmt.Errorf("block by user id %q in postgres: %w", input.UserID, err)
}
if acc.IsDeleted() {
return fmt.Errorf("block by user id %q in postgres: %w", input.UserID, ports.ErrNotFound)
}
blocked, err := scanBlockedEmail(ctx, tx, acc.Email, true)
switch {
case err == nil:
result = ports.BlockResult{
Outcome: ports.AuthBlockOutcomeAlreadyBlocked,
UserID: input.UserID,
}
if !blocked.ResolvedUserID.IsZero() {
result.UserID = blocked.ResolvedUserID
}
handled = true
return nil
case !errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("block by user id %q in postgres: %w", input.UserID, err)
}
record := authblock.BlockedEmailSubject{
Email: acc.Email,
ReasonCode: input.ReasonCode,
BlockedAt: input.BlockedAt.UTC(),
ResolvedUserID: input.UserID,
}
if err := upsertBlockedEmail(ctx, tx, record); err != nil {
return fmt.Errorf("block by user id %q in postgres: %w", input.UserID, err)
}
result = ports.BlockResult{
Outcome: ports.AuthBlockOutcomeBlocked,
UserID: input.UserID,
}
handled = true
return nil
}); err != nil {
return ports.BlockResult{}, err
}
if !handled {
return ports.BlockResult{}, fmt.Errorf("block by user id %q in postgres: unhandled transaction outcome", input.UserID)
}
return result, nil
}
// BlockByEmail applies a block to email even when no account exists yet. If
// an account does exist for the e-mail, its user_id is recorded as
// resolved_user_id; soft-deleted accounts also count for this resolution.
func (store *Store) BlockByEmail(ctx context.Context, input ports.BlockByEmailInput) (ports.BlockResult, error) {
if err := input.Validate(); err != nil {
return ports.BlockResult{}, fmt.Errorf("block by email in postgres: %w", err)
}
var (
result ports.BlockResult
handled bool
)
if err := store.withTx(ctx, "block by email in postgres", func(ctx context.Context, tx *sql.Tx) error {
blocked, err := scanBlockedEmail(ctx, tx, input.Email, true)
switch {
case err == nil:
result = ports.BlockResult{
Outcome: ports.AuthBlockOutcomeAlreadyBlocked,
UserID: blocked.ResolvedUserID,
}
handled = true
return nil
case !errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("block by email %q in postgres: %w", input.Email, err)
}
var resolvedUserID common.UserID
acc, err := scanAccountForUpdateByEmail(ctx, tx, input.Email)
switch {
case err == nil:
resolvedUserID = acc.UserID
case !errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("block by email %q in postgres: %w", input.Email, err)
}
record := authblock.BlockedEmailSubject{
Email: input.Email,
ReasonCode: input.ReasonCode,
BlockedAt: input.BlockedAt.UTC(),
ResolvedUserID: resolvedUserID,
}
if err := upsertBlockedEmail(ctx, tx, record); err != nil {
return fmt.Errorf("block by email %q in postgres: %w", input.Email, err)
}
result = ports.BlockResult{
Outcome: ports.AuthBlockOutcomeBlocked,
UserID: resolvedUserID,
}
handled = true
return nil
}); err != nil {
return ports.BlockResult{}, err
}
if !handled {
return ports.BlockResult{}, fmt.Errorf("block by email %q in postgres: unhandled transaction outcome", input.Email)
}
return result, nil
}
// guard so external callers cannot mistake this file's helpers for a public
// surface.
var _ account.UserAccount = account.UserAccount{}
@@ -0,0 +1,175 @@
package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
"galaxy/user/internal/domain/authblock"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/ports"
pg "github.com/go-jet/jet/v2/postgres"
)
// blockedEmailSelectColumns is the canonical SELECT list for blocked_emails.
var blockedEmailSelectColumns = pg.ColumnList{
pgtable.BlockedEmails.Email,
pgtable.BlockedEmails.ReasonCode,
pgtable.BlockedEmails.BlockedAt,
pgtable.BlockedEmails.ActorType,
pgtable.BlockedEmails.ActorID,
pgtable.BlockedEmails.ResolvedUserID,
}
// GetBlockedEmail returns the blocked-email subject for email.
func (store *Store) GetBlockedEmail(ctx context.Context, email common.Email) (authblock.BlockedEmailSubject, error) {
if err := email.Validate(); err != nil {
return authblock.BlockedEmailSubject{}, fmt.Errorf("get blocked email subject from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get blocked email subject from postgres")
if err != nil {
return authblock.BlockedEmailSubject{}, err
}
defer cancel()
record, err := scanBlockedEmail(operationCtx, store.db, email, false)
switch {
case errors.Is(err, ports.ErrNotFound):
return authblock.BlockedEmailSubject{}, fmt.Errorf("get blocked email subject %q from postgres: %w", email, ports.ErrNotFound)
case err != nil:
return authblock.BlockedEmailSubject{}, fmt.Errorf("get blocked email subject %q from postgres: %w", email, err)
}
return record, nil
}
// PutBlockedEmail stores or replaces the blocked-email subject for
// record.Email. The schema's PRIMARY KEY on (email) makes this an UPSERT via
// `INSERT … ON CONFLICT (email) DO UPDATE`.
func (store *Store) PutBlockedEmail(ctx context.Context, record authblock.BlockedEmailSubject) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("upsert blocked email subject in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "upsert blocked email subject in postgres")
if err != nil {
return err
}
defer cancel()
if err := upsertBlockedEmail(operationCtx, store.db, record); err != nil {
return err
}
return nil
}
// upsertBlockedEmail centralises the UPSERT used by PutBlockedEmail and the
// composite block flows. q is a *sql.DB or *sql.Tx so it can run inside an
// auth-directory transaction.
func upsertBlockedEmail(ctx context.Context, q queryer, record authblock.BlockedEmailSubject) error {
stmt := pgtable.BlockedEmails.INSERT(
pgtable.BlockedEmails.Email,
pgtable.BlockedEmails.ReasonCode,
pgtable.BlockedEmails.BlockedAt,
pgtable.BlockedEmails.ActorType,
pgtable.BlockedEmails.ActorID,
pgtable.BlockedEmails.ResolvedUserID,
).VALUES(
record.Email.String(),
record.ReasonCode.String(),
record.BlockedAt.UTC(),
nullableActorType(record.Actor.Type),
nullableActorID(record.Actor.ID),
nullableUserID(record.ResolvedUserID),
).ON_CONFLICT(pgtable.BlockedEmails.Email).DO_UPDATE(
pg.SET(
pgtable.BlockedEmails.ReasonCode.SET(pgtable.BlockedEmails.EXCLUDED.ReasonCode),
pgtable.BlockedEmails.BlockedAt.SET(pgtable.BlockedEmails.EXCLUDED.BlockedAt),
pgtable.BlockedEmails.ActorType.SET(pgtable.BlockedEmails.EXCLUDED.ActorType),
pgtable.BlockedEmails.ActorID.SET(pgtable.BlockedEmails.EXCLUDED.ActorID),
pgtable.BlockedEmails.ResolvedUserID.SET(pgtable.BlockedEmails.EXCLUDED.ResolvedUserID),
),
)
query, args := stmt.Sql()
if _, err := q.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("upsert blocked email subject %q in postgres: %w", record.Email, err)
}
return nil
}
// scanBlockedEmail loads one blocked-email row. forUpdate selects the
// `FOR UPDATE` lock variant used inside the auth-directory transaction.
func scanBlockedEmail(ctx context.Context, q queryer, email common.Email, forUpdate bool) (authblock.BlockedEmailSubject, error) {
stmt := pg.SELECT(blockedEmailSelectColumns).
FROM(pgtable.BlockedEmails).
WHERE(pgtable.BlockedEmails.Email.EQ(pg.String(email.String())))
if forUpdate {
stmt = stmt.FOR(pg.UPDATE())
}
query, args := stmt.Sql()
row := q.QueryRowContext(ctx, query, args...)
return scanBlockedEmailRow(row)
}
func scanBlockedEmailRow(row *sql.Row) (authblock.BlockedEmailSubject, error) {
var (
record authblock.BlockedEmailSubject
emailValue string
reasonCode string
blockedAt time.Time
actorType *string
actorID *string
resolvedUserID *string
)
if err := row.Scan(
&emailValue, &reasonCode, &blockedAt,
&actorType, &actorID, &resolvedUserID,
); err != nil {
return authblock.BlockedEmailSubject{}, mapNotFound(err)
}
record.Email = common.Email(emailValue)
record.ReasonCode = common.ReasonCode(reasonCode)
record.BlockedAt = blockedAt.UTC()
if actorType != nil {
record.Actor.Type = common.ActorType(*actorType)
}
if actorID != nil {
record.Actor.ID = common.ActorID(*actorID)
}
if resolvedUserID != nil {
record.ResolvedUserID = common.UserID(*resolvedUserID)
}
return record, nil
}
// BlockedEmailStore adapts Store to the BlockedEmailStore port.
type BlockedEmailStore struct {
store *Store
}
// BlockedEmails returns one adapter that exposes the blocked-email store
// port over Store.
func (store *Store) BlockedEmails() *BlockedEmailStore {
if store == nil {
return nil
}
return &BlockedEmailStore{store: store}
}
// GetByEmail returns the blocked-email subject for email.
func (adapter *BlockedEmailStore) GetByEmail(ctx context.Context, email common.Email) (authblock.BlockedEmailSubject, error) {
return adapter.store.GetBlockedEmail(ctx, email)
}
// Upsert stores or replaces the blocked-email subject for record.Email.
func (adapter *BlockedEmailStore) Upsert(ctx context.Context, record authblock.BlockedEmailSubject) error {
return adapter.store.PutBlockedEmail(ctx, record)
}
var _ ports.BlockedEmailStore = (*BlockedEmailStore)(nil)
@@ -0,0 +1,729 @@
package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/ports"
pg "github.com/go-jet/jet/v2/postgres"
)
// entitlementPeriodSelectColumns is the canonical SELECT list for
// entitlement_records, matching scanEntitlementPeriod's column order.
var entitlementPeriodSelectColumns = pg.ColumnList{
pgtable.EntitlementRecords.RecordID,
pgtable.EntitlementRecords.UserID,
pgtable.EntitlementRecords.PlanCode,
pgtable.EntitlementRecords.Source,
pgtable.EntitlementRecords.ActorType,
pgtable.EntitlementRecords.ActorID,
pgtable.EntitlementRecords.ReasonCode,
pgtable.EntitlementRecords.StartsAt,
pgtable.EntitlementRecords.EndsAt,
pgtable.EntitlementRecords.CreatedAt,
pgtable.EntitlementRecords.ClosedAt,
pgtable.EntitlementRecords.ClosedByType,
pgtable.EntitlementRecords.ClosedByID,
pgtable.EntitlementRecords.ClosedReasonCode,
}
// entitlementSnapshotSelectColumns is the canonical SELECT list for
// entitlement_snapshots, matching scanEntitlementSnapshotRow's column order.
var entitlementSnapshotSelectColumns = pg.ColumnList{
pgtable.EntitlementSnapshots.UserID,
pgtable.EntitlementSnapshots.PlanCode,
pgtable.EntitlementSnapshots.IsPaid,
pgtable.EntitlementSnapshots.StartsAt,
pgtable.EntitlementSnapshots.EndsAt,
pgtable.EntitlementSnapshots.Source,
pgtable.EntitlementSnapshots.ActorType,
pgtable.EntitlementSnapshots.ActorID,
pgtable.EntitlementSnapshots.ReasonCode,
pgtable.EntitlementSnapshots.UpdatedAt,
}
// CreateEntitlementRecord stores one new entitlement period history record.
// The unique key is record_id; a duplicate record_id returns
// ports.ErrConflict.
func (store *Store) CreateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create entitlement record in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "create entitlement record in postgres")
if err != nil {
return err
}
defer cancel()
return insertEntitlementPeriod(operationCtx, store.db, record)
}
// GetEntitlementRecordByID returns the entitlement period record identified
// by recordID.
func (store *Store) GetEntitlementRecordByID(ctx context.Context, recordID entitlement.EntitlementRecordID) (entitlement.PeriodRecord, error) {
if err := recordID.Validate(); err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement record from postgres")
if err != nil {
return entitlement.PeriodRecord{}, err
}
defer cancel()
stmt := pg.SELECT(entitlementPeriodSelectColumns).
FROM(pgtable.EntitlementRecords).
WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(recordID.String())))
query, args := stmt.Sql()
row := store.db.QueryRowContext(operationCtx, query, args...)
record, err := scanEntitlementPeriodRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record %q from postgres: %w", recordID, ports.ErrNotFound)
case err != nil:
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record %q from postgres: %w", recordID, err)
}
return record, nil
}
// ListEntitlementRecordsByUserID returns every entitlement period record
// owned by userID, ordered by created_at ascending so historical replay is
// deterministic.
func (store *Store) ListEntitlementRecordsByUserID(ctx context.Context, userID common.UserID) ([]entitlement.PeriodRecord, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list entitlement records from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list entitlement records from postgres")
if err != nil {
return nil, err
}
defer cancel()
stmt := pg.SELECT(entitlementPeriodSelectColumns).
FROM(pgtable.EntitlementRecords).
WHERE(pgtable.EntitlementRecords.UserID.EQ(pg.String(userID.String()))).
ORDER_BY(pgtable.EntitlementRecords.CreatedAt.ASC(), pgtable.EntitlementRecords.RecordID.ASC())
query, args := stmt.Sql()
rows, err := store.db.QueryContext(operationCtx, query, args...)
if err != nil {
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
}
defer func() { _ = rows.Close() }()
out := make([]entitlement.PeriodRecord, 0)
for rows.Next() {
record, err := scanEntitlementPeriodRows(rows)
if err != nil {
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
}
out = append(out, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("list entitlement records for %q from postgres: %w", userID, err)
}
return out, nil
}
// UpdateEntitlementRecord replaces one stored entitlement period record. The
// statement matches by record_id; ports.ErrNotFound is returned when the
// record does not exist.
func (store *Store) UpdateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update entitlement record in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "update entitlement record in postgres")
if err != nil {
return err
}
defer cancel()
rows, err := updateEntitlementPeriod(operationCtx, store.db, record)
if err != nil {
return fmt.Errorf("update entitlement record %q in postgres: %w", record.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("update entitlement record %q in postgres: %w", record.RecordID, ports.ErrNotFound)
}
return nil
}
func updateEntitlementPeriod(ctx context.Context, q queryer, record entitlement.PeriodRecord) (int64, error) {
stmt := pgtable.EntitlementRecords.UPDATE(
pgtable.EntitlementRecords.PlanCode,
pgtable.EntitlementRecords.Source,
pgtable.EntitlementRecords.ActorType,
pgtable.EntitlementRecords.ActorID,
pgtable.EntitlementRecords.ReasonCode,
pgtable.EntitlementRecords.StartsAt,
pgtable.EntitlementRecords.EndsAt,
pgtable.EntitlementRecords.CreatedAt,
pgtable.EntitlementRecords.ClosedAt,
pgtable.EntitlementRecords.ClosedByType,
pgtable.EntitlementRecords.ClosedByID,
pgtable.EntitlementRecords.ClosedReasonCode,
).SET(
string(record.PlanCode),
record.Source.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.ReasonCode.String(),
record.StartsAt.UTC(),
nullableTime(record.EndsAt),
record.CreatedAt.UTC(),
nullableTime(record.ClosedAt),
nullableActorType(record.ClosedBy.Type),
nullableActorID(record.ClosedBy.ID),
nullableReasonCode(record.ClosedReasonCode),
).WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(record.RecordID.String())))
query, args := stmt.Sql()
res, err := q.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func insertEntitlementPeriod(ctx context.Context, q queryer, record entitlement.PeriodRecord) error {
stmt := pgtable.EntitlementRecords.INSERT(
pgtable.EntitlementRecords.RecordID,
pgtable.EntitlementRecords.UserID,
pgtable.EntitlementRecords.PlanCode,
pgtable.EntitlementRecords.Source,
pgtable.EntitlementRecords.ActorType,
pgtable.EntitlementRecords.ActorID,
pgtable.EntitlementRecords.ReasonCode,
pgtable.EntitlementRecords.StartsAt,
pgtable.EntitlementRecords.EndsAt,
pgtable.EntitlementRecords.CreatedAt,
pgtable.EntitlementRecords.ClosedAt,
pgtable.EntitlementRecords.ClosedByType,
pgtable.EntitlementRecords.ClosedByID,
pgtable.EntitlementRecords.ClosedReasonCode,
).VALUES(
record.RecordID.String(),
record.UserID.String(),
string(record.PlanCode),
record.Source.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.ReasonCode.String(),
record.StartsAt.UTC(),
nullableTime(record.EndsAt),
record.CreatedAt.UTC(),
nullableTime(record.ClosedAt),
nullableActorType(record.ClosedBy.Type),
nullableActorID(record.ClosedBy.ID),
nullableReasonCode(record.ClosedReasonCode),
)
query, args := stmt.Sql()
_, err := q.ExecContext(ctx, query, args...)
if err == nil {
return nil
}
if isUniqueViolation(err) {
return fmt.Errorf("create entitlement record %q in postgres: %w", record.RecordID, ports.ErrConflict)
}
return fmt.Errorf("create entitlement record %q in postgres: %w", record.RecordID, err)
}
// scannableRow abstracts *sql.Row and *sql.Rows so the row-scanner can be
// shared by single-row and iterating callers.
type scannableRow interface {
Scan(dest ...any) error
}
func scanEntitlementPeriodRow(row *sql.Row) (entitlement.PeriodRecord, error) {
record, err := scanEntitlementPeriod(row)
if errors.Is(err, sql.ErrNoRows) {
return entitlement.PeriodRecord{}, ports.ErrNotFound
}
return record, err
}
func scanEntitlementPeriodRows(rows *sql.Rows) (entitlement.PeriodRecord, error) {
return scanEntitlementPeriod(rows)
}
func scanEntitlementPeriod(row scannableRow) (entitlement.PeriodRecord, error) {
var (
recordID string
userID string
planCode string
source string
actorType string
actorID *string
reasonCode string
startsAt time.Time
endsAt *time.Time
createdAt time.Time
closedAt *time.Time
closedByType *string
closedByID *string
closedReason *string
)
if err := row.Scan(
&recordID, &userID, &planCode, &source,
&actorType, &actorID, &reasonCode,
&startsAt, &endsAt, &createdAt,
&closedAt, &closedByType, &closedByID, &closedReason,
); err != nil {
return entitlement.PeriodRecord{}, err
}
record := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID(recordID),
UserID: common.UserID(userID),
PlanCode: entitlement.PlanCode(planCode),
Source: common.Source(source),
Actor: common.ActorRef{Type: common.ActorType(actorType)},
ReasonCode: common.ReasonCode(reasonCode),
StartsAt: startsAt.UTC(),
EndsAt: timeFromNullable(endsAt),
CreatedAt: createdAt.UTC(),
ClosedAt: timeFromNullable(closedAt),
}
if actorID != nil {
record.Actor.ID = common.ActorID(*actorID)
}
if closedByType != nil {
record.ClosedBy.Type = common.ActorType(*closedByType)
}
if closedByID != nil {
record.ClosedBy.ID = common.ActorID(*closedByID)
}
if closedReason != nil {
record.ClosedReasonCode = common.ReasonCode(*closedReason)
}
return record, nil
}
// GetEntitlementByUserID returns the current entitlement snapshot for userID.
func (store *Store) GetEntitlementByUserID(ctx context.Context, userID common.UserID) (entitlement.CurrentSnapshot, error) {
if err := userID.Validate(); err != nil {
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement snapshot from postgres")
if err != nil {
return entitlement.CurrentSnapshot{}, err
}
defer cancel()
stmt := pg.SELECT(entitlementSnapshotSelectColumns).
FROM(pgtable.EntitlementSnapshots).
WHERE(pgtable.EntitlementSnapshots.UserID.EQ(pg.String(userID.String())))
query, args := stmt.Sql()
row := store.db.QueryRowContext(operationCtx, query, args...)
record, err := scanEntitlementSnapshotRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot for %q from postgres: %w", userID, ports.ErrNotFound)
case err != nil:
return entitlement.CurrentSnapshot{}, fmt.Errorf("get entitlement snapshot for %q from postgres: %w", userID, err)
}
return record, nil
}
// PutEntitlement stores the current entitlement snapshot for record.UserID.
// It is an UPSERT so the runtime path can call it on creation and on
// replacement uniformly.
func (store *Store) PutEntitlement(ctx context.Context, record entitlement.CurrentSnapshot) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("put entitlement snapshot in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "put entitlement snapshot in postgres")
if err != nil {
return err
}
defer cancel()
return upsertEntitlementSnapshot(operationCtx, store.db, record)
}
func upsertEntitlementSnapshot(ctx context.Context, q queryer, record entitlement.CurrentSnapshot) error {
stmt := pgtable.EntitlementSnapshots.INSERT(
pgtable.EntitlementSnapshots.UserID,
pgtable.EntitlementSnapshots.PlanCode,
pgtable.EntitlementSnapshots.IsPaid,
pgtable.EntitlementSnapshots.StartsAt,
pgtable.EntitlementSnapshots.EndsAt,
pgtable.EntitlementSnapshots.Source,
pgtable.EntitlementSnapshots.ActorType,
pgtable.EntitlementSnapshots.ActorID,
pgtable.EntitlementSnapshots.ReasonCode,
pgtable.EntitlementSnapshots.UpdatedAt,
).VALUES(
record.UserID.String(),
string(record.PlanCode),
record.IsPaid,
record.StartsAt.UTC(),
nullableTime(record.EndsAt),
record.Source.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.ReasonCode.String(),
record.UpdatedAt.UTC(),
).ON_CONFLICT(pgtable.EntitlementSnapshots.UserID).DO_UPDATE(
pg.SET(
pgtable.EntitlementSnapshots.PlanCode.SET(pgtable.EntitlementSnapshots.EXCLUDED.PlanCode),
pgtable.EntitlementSnapshots.IsPaid.SET(pgtable.EntitlementSnapshots.EXCLUDED.IsPaid),
pgtable.EntitlementSnapshots.StartsAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.StartsAt),
pgtable.EntitlementSnapshots.EndsAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.EndsAt),
pgtable.EntitlementSnapshots.Source.SET(pgtable.EntitlementSnapshots.EXCLUDED.Source),
pgtable.EntitlementSnapshots.ActorType.SET(pgtable.EntitlementSnapshots.EXCLUDED.ActorType),
pgtable.EntitlementSnapshots.ActorID.SET(pgtable.EntitlementSnapshots.EXCLUDED.ActorID),
pgtable.EntitlementSnapshots.ReasonCode.SET(pgtable.EntitlementSnapshots.EXCLUDED.ReasonCode),
pgtable.EntitlementSnapshots.UpdatedAt.SET(pgtable.EntitlementSnapshots.EXCLUDED.UpdatedAt),
),
)
query, args := stmt.Sql()
if _, err := q.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("upsert entitlement snapshot for %q in postgres: %w", record.UserID, err)
}
return nil
}
func scanEntitlementSnapshotRow(row *sql.Row) (entitlement.CurrentSnapshot, error) {
var (
userID string
planCode string
isPaid bool
startsAt time.Time
endsAt *time.Time
source string
actorType string
actorID *string
reasonCode string
updatedAt time.Time
)
err := row.Scan(
&userID, &planCode, &isPaid,
&startsAt, &endsAt,
&source, &actorType, &actorID, &reasonCode,
&updatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return entitlement.CurrentSnapshot{}, ports.ErrNotFound
}
if err != nil {
return entitlement.CurrentSnapshot{}, err
}
record := entitlement.CurrentSnapshot{
UserID: common.UserID(userID),
PlanCode: entitlement.PlanCode(planCode),
IsPaid: isPaid,
StartsAt: startsAt.UTC(),
EndsAt: timeFromNullable(endsAt),
Source: common.Source(source),
Actor: common.ActorRef{Type: common.ActorType(actorType)},
ReasonCode: common.ReasonCode(reasonCode),
UpdatedAt: updatedAt.UTC(),
}
if actorID != nil {
record.Actor.ID = common.ActorID(*actorID)
}
return record, nil
}
// GrantEntitlement atomically closes the current free period, inserts the
// new paid period, and replaces the snapshot.
func (store *Store) GrantEntitlement(ctx context.Context, input ports.GrantEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("grant entitlement in postgres: %w", err)
}
return store.withTx(ctx, "grant entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if err := lockPeriodMatching(ctx, tx, input.ExpectedCurrentRecord); err != nil {
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.ExpectedCurrentRecord.RecordID, err)
}
if err := updateEntitlementPeriodTx(ctx, tx, input.UpdatedCurrentRecord); err != nil {
return fmt.Errorf("grant entitlement for %q in postgres: %w", input.UpdatedCurrentRecord.RecordID, err)
}
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
return err
}
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
return err
}
return nil
})
}
// ExtendEntitlement atomically appends a new paid history segment and
// replaces the snapshot.
func (store *Store) ExtendEntitlement(ctx context.Context, input ports.ExtendEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("extend entitlement in postgres: %w", err)
}
return store.withTx(ctx, "extend entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
return fmt.Errorf("extend entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
return err
}
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
return err
}
return nil
})
}
// RevokeEntitlement atomically closes the current paid period, inserts a new
// free period, and replaces the snapshot.
func (store *Store) RevokeEntitlement(ctx context.Context, input ports.RevokeEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("revoke entitlement in postgres: %w", err)
}
return store.withTx(ctx, "revoke entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockSnapshotMatching(ctx, tx, input.ExpectedCurrentSnapshot); err != nil {
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if err := lockPeriodMatching(ctx, tx, input.ExpectedCurrentRecord); err != nil {
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.ExpectedCurrentRecord.RecordID, err)
}
if err := updateEntitlementPeriodTx(ctx, tx, input.UpdatedCurrentRecord); err != nil {
return fmt.Errorf("revoke entitlement for %q in postgres: %w", input.UpdatedCurrentRecord.RecordID, err)
}
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
return err
}
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
return err
}
return nil
})
}
// RepairExpiredEntitlement atomically replaces an expired finite paid
// snapshot with a materialised free state.
func (store *Store) RepairExpiredEntitlement(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("repair expired entitlement in postgres: %w", err)
}
return store.withTx(ctx, "repair expired entitlement in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockSnapshotMatching(ctx, tx, input.ExpectedExpiredSnapshot); err != nil {
return fmt.Errorf("repair expired entitlement for %q in postgres: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
if err := insertEntitlementPeriod(ctx, tx, input.NewRecord); err != nil {
return err
}
if err := upsertEntitlementSnapshot(ctx, tx, input.NewSnapshot); err != nil {
return err
}
return nil
})
}
// lockSnapshotMatching loads the current snapshot under FOR UPDATE and
// verifies it matches expected. Mismatches surface as ports.ErrConflict so
// optimistic-replacement callers can retry.
func lockSnapshotMatching(ctx context.Context, tx *sql.Tx, expected entitlement.CurrentSnapshot) error {
stmt := pg.SELECT(entitlementSnapshotSelectColumns).
FROM(pgtable.EntitlementSnapshots).
WHERE(pgtable.EntitlementSnapshots.UserID.EQ(pg.String(expected.UserID.String()))).
FOR(pg.UPDATE())
query, args := stmt.Sql()
row := tx.QueryRowContext(ctx, query, args...)
current, err := scanEntitlementSnapshotRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return ports.ErrNotFound
case err != nil:
return err
}
if !snapshotsEqual(current, expected) {
return ports.ErrConflict
}
return nil
}
func lockPeriodMatching(ctx context.Context, tx *sql.Tx, expected entitlement.PeriodRecord) error {
stmt := pg.SELECT(entitlementPeriodSelectColumns).
FROM(pgtable.EntitlementRecords).
WHERE(pgtable.EntitlementRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
FOR(pg.UPDATE())
query, args := stmt.Sql()
row := tx.QueryRowContext(ctx, query, args...)
current, err := scanEntitlementPeriodRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return ports.ErrNotFound
case err != nil:
return err
}
if !periodsEqual(current, expected) {
return ports.ErrConflict
}
return nil
}
func updateEntitlementPeriodTx(ctx context.Context, tx *sql.Tx, record entitlement.PeriodRecord) error {
rows, err := updateEntitlementPeriod(ctx, tx, record)
if err != nil {
return err
}
if rows == 0 {
return ports.ErrNotFound
}
return nil
}
func snapshotsEqual(left entitlement.CurrentSnapshot, right entitlement.CurrentSnapshot) bool {
if left.UserID != right.UserID ||
left.PlanCode != right.PlanCode ||
left.IsPaid != right.IsPaid ||
left.Source != right.Source ||
left.Actor != right.Actor ||
left.ReasonCode != right.ReasonCode {
return false
}
if !left.StartsAt.Equal(right.StartsAt) || !left.UpdatedAt.Equal(right.UpdatedAt) {
return false
}
return optionalTimeEqual(left.EndsAt, right.EndsAt)
}
func periodsEqual(left entitlement.PeriodRecord, right entitlement.PeriodRecord) bool {
if left.RecordID != right.RecordID ||
left.UserID != right.UserID ||
left.PlanCode != right.PlanCode ||
left.Source != right.Source ||
left.Actor != right.Actor ||
left.ReasonCode != right.ReasonCode ||
left.ClosedBy != right.ClosedBy ||
left.ClosedReasonCode != right.ClosedReasonCode {
return false
}
if !left.StartsAt.Equal(right.StartsAt) || !left.CreatedAt.Equal(right.CreatedAt) {
return false
}
if !optionalTimeEqual(left.EndsAt, right.EndsAt) {
return false
}
return optionalTimeEqual(left.ClosedAt, right.ClosedAt)
}
func optionalTimeEqual(left *time.Time, right *time.Time) bool {
switch {
case left == nil && right == nil:
return true
case left == nil || right == nil:
return false
default:
return left.Equal(*right)
}
}
// EntitlementSnapshotStore adapts Store to the EntitlementSnapshotStore port.
type EntitlementSnapshotStore struct {
store *Store
}
// EntitlementSnapshots returns one adapter that exposes the entitlement-
// snapshot store port over Store.
func (store *Store) EntitlementSnapshots() *EntitlementSnapshotStore {
if store == nil {
return nil
}
return &EntitlementSnapshotStore{store: store}
}
// GetByUserID returns the current entitlement snapshot for userID.
func (adapter *EntitlementSnapshotStore) GetByUserID(ctx context.Context, userID common.UserID) (entitlement.CurrentSnapshot, error) {
return adapter.store.GetEntitlementByUserID(ctx, userID)
}
// Put stores the current entitlement snapshot for record.UserID.
func (adapter *EntitlementSnapshotStore) Put(ctx context.Context, record entitlement.CurrentSnapshot) error {
return adapter.store.PutEntitlement(ctx, record)
}
var _ ports.EntitlementSnapshotStore = (*EntitlementSnapshotStore)(nil)
// EntitlementHistoryStore adapts Store to the EntitlementHistoryStore port.
type EntitlementHistoryStore struct {
store *Store
}
// EntitlementHistory returns one adapter that exposes the entitlement
// history store port over Store.
func (store *Store) EntitlementHistory() *EntitlementHistoryStore {
if store == nil {
return nil
}
return &EntitlementHistoryStore{store: store}
}
// Create stores one new entitlement history record.
func (adapter *EntitlementHistoryStore) Create(ctx context.Context, record entitlement.PeriodRecord) error {
return adapter.store.CreateEntitlementRecord(ctx, record)
}
// GetByRecordID returns the entitlement history record identified by
// recordID.
func (adapter *EntitlementHistoryStore) GetByRecordID(ctx context.Context, recordID entitlement.EntitlementRecordID) (entitlement.PeriodRecord, error) {
return adapter.store.GetEntitlementRecordByID(ctx, recordID)
}
// ListByUserID returns every entitlement history record owned by userID.
func (adapter *EntitlementHistoryStore) ListByUserID(ctx context.Context, userID common.UserID) ([]entitlement.PeriodRecord, error) {
return adapter.store.ListEntitlementRecordsByUserID(ctx, userID)
}
// Update replaces one stored entitlement history record.
func (adapter *EntitlementHistoryStore) Update(ctx context.Context, record entitlement.PeriodRecord) error {
return adapter.store.UpdateEntitlementRecord(ctx, record)
}
var _ ports.EntitlementHistoryStore = (*EntitlementHistoryStore)(nil)
// EntitlementLifecycleStore adapts Store to the EntitlementLifecycleStore
// port.
type EntitlementLifecycleStore struct {
store *Store
}
// EntitlementLifecycle returns one adapter that exposes the entitlement
// lifecycle store port over Store.
func (store *Store) EntitlementLifecycle() *EntitlementLifecycleStore {
if store == nil {
return nil
}
return &EntitlementLifecycleStore{store: store}
}
// Grant atomically closes the current free period and starts a new paid
// period.
func (adapter *EntitlementLifecycleStore) Grant(ctx context.Context, input ports.GrantEntitlementInput) error {
return adapter.store.GrantEntitlement(ctx, input)
}
// Extend appends a paid history segment.
func (adapter *EntitlementLifecycleStore) Extend(ctx context.Context, input ports.ExtendEntitlementInput) error {
return adapter.store.ExtendEntitlement(ctx, input)
}
// Revoke closes the current paid period and starts a fresh free period.
func (adapter *EntitlementLifecycleStore) Revoke(ctx context.Context, input ports.RevokeEntitlementInput) error {
return adapter.store.RevokeEntitlement(ctx, input)
}
// RepairExpired replaces an expired finite paid snapshot with a free state.
func (adapter *EntitlementLifecycleStore) RepairExpired(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
return adapter.store.RepairExpiredEntitlement(ctx, input)
}
var _ ports.EntitlementLifecycleStore = (*EntitlementLifecycleStore)(nil)
@@ -0,0 +1,203 @@
package userstore
import (
"context"
"database/sql"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"
"galaxy/postgres"
"galaxy/user/internal/adapters/postgres/migrations"
testcontainers "github.com/testcontainers/testcontainers-go"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
const (
pkgPostgresImage = "postgres:16-alpine"
pkgSuperUser = "galaxy"
pkgSuperPassword = "galaxy"
pkgSuperDatabase = "galaxy_user"
pkgServiceRole = "userservice"
pkgServicePassword = "userservice"
pkgServiceSchema = "user"
pkgContainerStartup = 90 * time.Second
pkgOperationTimeout = 10 * time.Second
)
var (
pkgContainerOnce sync.Once
pkgContainerErr error
pkgContainerEnv *postgresEnv
)
type postgresEnv struct {
container *tcpostgres.PostgresContainer
dsn string
pool *sql.DB
}
func ensurePostgresEnv(t testing.TB) *postgresEnv {
t.Helper()
pkgContainerOnce.Do(func() {
pkgContainerEnv, pkgContainerErr = startPostgresEnv()
})
if pkgContainerErr != nil {
t.Skipf("postgres container start failed (Docker unavailable?): %v", pkgContainerErr)
}
return pkgContainerEnv
}
func startPostgresEnv() (*postgresEnv, error) {
ctx := context.Background()
container, err := tcpostgres.Run(ctx, pkgPostgresImage,
tcpostgres.WithDatabase(pkgSuperDatabase),
tcpostgres.WithUsername(pkgSuperUser),
tcpostgres.WithPassword(pkgSuperPassword),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(pkgContainerStartup),
),
)
if err != nil {
return nil, err
}
baseDSN, err := container.ConnectionString(ctx, "sslmode=disable")
if err != nil {
_ = testcontainers.TerminateContainer(container)
return nil, err
}
if err := provisionRoleAndSchema(ctx, baseDSN); err != nil {
_ = testcontainers.TerminateContainer(container)
return nil, err
}
scopedDSN, err := dsnForServiceRole(baseDSN)
if err != nil {
_ = testcontainers.TerminateContainer(container)
return nil, err
}
cfg := postgres.DefaultConfig()
cfg.PrimaryDSN = scopedDSN
cfg.OperationTimeout = pkgOperationTimeout
pool, err := postgres.OpenPrimary(ctx, cfg)
if err != nil {
_ = testcontainers.TerminateContainer(container)
return nil, err
}
if err := postgres.Ping(ctx, pool, pkgOperationTimeout); err != nil {
_ = pool.Close()
_ = testcontainers.TerminateContainer(container)
return nil, err
}
if err := postgres.RunMigrations(ctx, pool, migrations.FS(), "."); err != nil {
_ = pool.Close()
_ = testcontainers.TerminateContainer(container)
return nil, err
}
return &postgresEnv{
container: container,
dsn: scopedDSN,
pool: pool,
}, nil
}
func provisionRoleAndSchema(ctx context.Context, baseDSN string) error {
cfg := postgres.DefaultConfig()
cfg.PrimaryDSN = baseDSN
cfg.OperationTimeout = pkgOperationTimeout
db, err := postgres.OpenPrimary(ctx, cfg)
if err != nil {
return err
}
defer func() { _ = db.Close() }()
statements := []string{
`DO $$ BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'userservice') THEN
CREATE ROLE userservice LOGIN PASSWORD 'userservice';
END IF;
END $$;`,
`CREATE SCHEMA IF NOT EXISTS "user" AUTHORIZATION userservice;`,
`GRANT USAGE ON SCHEMA "user" TO userservice;`,
}
for _, statement := range statements {
if _, err := db.ExecContext(ctx, statement); err != nil {
return err
}
}
return nil
}
func dsnForServiceRole(baseDSN string) (string, error) {
parsed, err := url.Parse(baseDSN)
if err != nil {
return "", err
}
values := url.Values{}
values.Set("search_path", pkgServiceSchema)
values.Set("sslmode", "disable")
scoped := url.URL{
Scheme: parsed.Scheme,
User: url.UserPassword(pkgServiceRole, pkgServicePassword),
Host: parsed.Host,
Path: parsed.Path,
RawQuery: values.Encode(),
}
return scoped.String(), nil
}
// newTestStore returns a Store backed by the package-scoped pool. Every
// invocation truncates the user-owned tables so individual tests start from
// a clean slate while sharing one container start.
func newTestStore(t *testing.T) *Store {
t.Helper()
env := ensurePostgresEnv(t)
truncateAll(t, env.pool)
store, err := New(Config{DB: env.pool, OperationTimeout: pkgOperationTimeout})
if err != nil {
t.Fatalf("new store: %v", err)
}
return store
}
func truncateAll(t *testing.T, db *sql.DB) {
t.Helper()
statement := strings.Join([]string{
"TRUNCATE TABLE",
"sanction_active, limit_active,",
"sanction_records, limit_records,",
"entitlement_snapshots, entitlement_records,",
"blocked_emails, accounts",
"RESTART IDENTITY CASCADE",
}, " ")
if _, err := db.ExecContext(context.Background(), statement); err != nil {
t.Fatalf("truncate tables: %v", err)
}
}
// TestMain runs first when `go test` enters the package. We drive it through
// a TestMain so the container started by the first test is shut down on the
// way out, even when individual tests panic.
func TestMain(m *testing.M) {
code := m.Run()
if pkgContainerEnv != nil {
if pkgContainerEnv.pool != nil {
_ = pkgContainerEnv.pool.Close()
}
if pkgContainerEnv.container != nil {
_ = testcontainers.TerminateContainer(pkgContainerEnv.container)
}
}
os.Exit(code)
}
@@ -0,0 +1,149 @@
package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/ports"
"github.com/jackc/pgx/v5/pgconn"
)
// pgUniqueViolationCode identifies the SQLSTATE returned by PostgreSQL when
// a UNIQUE constraint is violated by INSERT or UPDATE.
const pgUniqueViolationCode = "23505"
// classifyUniqueViolation maps a PostgreSQL unique-violation error to the
// matching ports sentinel. constraint identifies which UNIQUE constraint name
// the caller cares about so we can surface ports.ErrUserNameConflict for the
// dedicated user-name index. Returns nil when err is not a unique violation
// or does not match constraint.
func classifyUniqueViolation(err error, constraint string, mapped error) error {
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) || pgErr.Code != pgUniqueViolationCode {
return nil
}
if constraint != "" && pgErr.ConstraintName != constraint {
return nil
}
return mapped
}
// isUniqueViolation reports whether err is a PostgreSQL unique-violation,
// regardless of constraint name. Useful for "any conflict ⇒ ErrConflict"
// translations on simple INSERT calls.
func isUniqueViolation(err error) bool {
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) {
return false
}
return pgErr.Code == pgUniqueViolationCode
}
// nullableString returns the trimmed string when s is non-empty, otherwise
// reports a NULL stand-in usable in $-parameter lists. Empty strings are
// stored as NULL so optional columns round-trip through nil.
func nullableString(s string) any {
if s == "" {
return nil
}
return s
}
// nullableActorID converts an optional ActorID (the zero value indicates
// "no caller supplied this field") to a NULL stand-in for SQL parameters.
func nullableActorID(id common.ActorID) any {
if id.IsZero() {
return nil
}
return id.String()
}
// nullableActorType mirrors nullableActorID for ActorType.
func nullableActorType(t common.ActorType) any {
if t.IsZero() {
return nil
}
return t.String()
}
// nullableReasonCode mirrors nullableActorID for ReasonCode.
func nullableReasonCode(code common.ReasonCode) any {
if code.IsZero() {
return nil
}
return code.String()
}
// nullableUserID mirrors nullableActorID for UserID.
func nullableUserID(id common.UserID) any {
if id.IsZero() {
return nil
}
return id.String()
}
// nullableTime returns t.UTC() when non-nil, otherwise nil for NULL columns.
func nullableTime(t *time.Time) any {
if t == nil {
return nil
}
return t.UTC()
}
// nullableCountry returns the upper-cased ISO 3166-1 alpha-2 string when set,
// otherwise nil.
func nullableCountry(code common.CountryCode) any {
if code.IsZero() {
return nil
}
return code.String()
}
// stringFromNullable trims an optional sql.NullString-like *string (read from
// Postgres COLUMNAR_NULL) into an ActorID/ReasonCode/UserID-friendly string.
func stringFromNullable(value *string) string {
if value == nil {
return ""
}
return *value
}
// timeFromNullable copies an optional *time.Time read from Postgres into a
// new pointer normalised to UTC.
func timeFromNullable(value *time.Time) *time.Time {
if value == nil {
return nil
}
utc := value.UTC()
return &utc
}
// mapNotFound translates sql.ErrNoRows into ports.ErrNotFound, leaving every
// other error untouched.
func mapNotFound(err error) error {
if errors.Is(err, sql.ErrNoRows) {
return ports.ErrNotFound
}
return err
}
// withTimeout derives a child context bounded by timeout and prefixes context
// errors with operation. Callers must always invoke the returned cancel.
func withTimeout(ctx context.Context, operation string, timeout time.Duration) (context.Context, context.CancelFunc, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
if err := ctx.Err(); err != nil {
return nil, nil, fmt.Errorf("%s: %w", operation, err)
}
if timeout <= 0 {
return nil, nil, fmt.Errorf("%s: operation timeout must be positive", operation)
}
bounded, cancel := context.WithTimeout(ctx, timeout)
return bounded, cancel, nil
}
@@ -0,0 +1,160 @@
package userstore
import (
"context"
"fmt"
"time"
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/ports"
pg "github.com/go-jet/jet/v2/postgres"
)
// ListUserIDs returns one deterministic page of user identifiers ordered by
// `created_at desc`, then `user_id desc`, mirroring the ordering used by the
// previous Redis adapter.
//
// The Postgres implementation keeps the listing surface storage-thin: it
// only paginates on `created_at` + `user_id` and does not attempt to push
// the full filter matrix into SQL. The service layer (`adminusers.Lister`)
// continues to load each candidate via the per-user loader and apply the
// filter set in memory, exactly as it did with the Redis adapter. Pushing
// the filter matrix down to SQL is a follow-up optimisation noted in
// `galaxy/user/docs/postgres-migration.md`.
func (store *Store) ListUserIDs(ctx context.Context, input ports.ListUsersInput) (ports.ListUsersResult, error) {
if err := input.Validate(); err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list users in postgres")
if err != nil {
return ports.ListUsersResult{}, err
}
defer cancel()
filters := userListFiltersFromPorts(input.Filters)
var (
cursorCreatedAt time.Time
cursorUserID common.UserID
cursored bool
)
if input.PageToken != "" {
cursor, err := decodePageToken(input.PageToken, filters)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in postgres: %w", ports.ErrInvalidPageToken)
}
cursorCreatedAt = cursor.CreatedAt
cursorUserID = cursor.UserID
cursored = true
}
limit := input.PageSize + 1
rows, err := queryListPage(operationCtx, store, cursored, cursorCreatedAt, cursorUserID, limit)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in postgres: %w", err)
}
result := ports.ListUsersResult{
UserIDs: make([]common.UserID, 0, min(len(rows), input.PageSize)),
}
visible := min(len(rows), input.PageSize)
for index := range visible {
result.UserIDs = append(result.UserIDs, rows[index].UserID)
}
if len(rows) > input.PageSize {
last := rows[input.PageSize-1]
token, err := encodePageToken(pageCursor{
CreatedAt: last.CreatedAt,
UserID: last.UserID,
}, filters)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in postgres: %w", err)
}
result.NextPageToken = token
}
return result, nil
}
// listRow is the lightweight projection returned by queryListPage; only
// (created_at, user_id) is needed for the listing index plus cursor token
// generation.
type listRow struct {
CreatedAt time.Time
UserID common.UserID
}
// queryListPage returns up to limit rows ordered by created_at DESC, user_id
// DESC. When cursored is true, the query starts strictly after the
// (cursorCreatedAt, cursorUserID) tuple per the keyset pagination rule.
func queryListPage(ctx context.Context, store *Store, cursored bool, cursorCreatedAt time.Time, cursorUserID common.UserID, limit int) ([]listRow, error) {
stmt := pg.SELECT(pgtable.Accounts.CreatedAt, pgtable.Accounts.UserID).
FROM(pgtable.Accounts)
if cursored {
// (created_at, user_id) < (cursorCreatedAt, cursorUserID) expressed as
// the equivalent OR/AND expansion since jet has no row-comparison
// builder.
ts := pg.TimestampzT(cursorCreatedAt.UTC())
uid := pg.String(cursorUserID.String())
stmt = stmt.WHERE(pg.OR(
pgtable.Accounts.CreatedAt.LT(ts),
pg.AND(
pgtable.Accounts.CreatedAt.EQ(ts),
pgtable.Accounts.UserID.LT(uid),
),
))
}
stmt = stmt.
ORDER_BY(pgtable.Accounts.CreatedAt.DESC(), pgtable.Accounts.UserID.DESC()).
LIMIT(int64(limit))
query, args := stmt.Sql()
rows, err := store.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]listRow, 0, limit)
for rows.Next() {
var (
createdAt time.Time
userID string
)
if err := rows.Scan(&createdAt, &userID); err != nil {
return nil, err
}
uid := common.UserID(userID)
if err := uid.Validate(); err != nil {
return nil, fmt.Errorf("created_at index member user id: %w", err)
}
out = append(out, listRow{CreatedAt: createdAt.UTC(), UserID: uid})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// UserList adapts Store to the UserListStore port.
type UserList struct{ store *Store }
// UserListAdapter returns one adapter that exposes the user-list store port.
func (store *Store) UserListAdapter() *UserList {
if store == nil {
return nil
}
return &UserList{store: store}
}
// ListUserIDs returns one deterministic page of user identifiers.
func (a *UserList) ListUserIDs(ctx context.Context, input ports.ListUsersInput) (ports.ListUsersResult, error) {
return a.store.ListUserIDs(ctx, input)
}
var _ ports.UserListStore = (*UserList)(nil)
var _ ports.UserListStore = (*Store)(nil)
@@ -0,0 +1,198 @@
package userstore
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
)
// errPageTokenFiltersMismatch reports that a supplied page token was created
// for a different normalised filter set. Callers translate it to
// ports.ErrInvalidPageToken on the boundary.
var errPageTokenFiltersMismatch = errors.New("page token filters do not match current filters")
// pageCursor identifies the last (created_at, user_id) tuple visible on the
// previous listing page. The cursor is paired with a normalised filter
// fingerprint so the token cannot be replayed across a different filter set.
type pageCursor struct {
CreatedAt time.Time
UserID common.UserID
}
func (cursor pageCursor) Validate() error {
if err := common.ValidateTimestamp("page cursor created at", cursor.CreatedAt); err != nil {
return err
}
if err := cursor.UserID.Validate(); err != nil {
return fmt.Errorf("page cursor user id: %w", err)
}
return nil
}
// userListFilters mirrors ports.UserListFilters but excludes the fields that
// only the service layer enforces (display_name match, user_name) so token
// replay across a UI re-render that toggles a UI-only filter does not
// invalidate the cursor.
type userListFilters struct {
PaidState entitlement.PaidState
PaidExpiresBefore *time.Time
PaidExpiresAfter *time.Time
DeclaredCountry common.CountryCode
SanctionCode policy.SanctionCode
LimitCode policy.LimitCode
CanLogin *bool
CanCreatePrivateGame *bool
CanJoinGame *bool
}
// userListFiltersFromPorts copies the listing-stable subset of port-level
// filters into the form embedded into the page token fingerprint.
func userListFiltersFromPorts(filters ports.UserListFilters) userListFilters {
return userListFilters{
PaidState: filters.PaidState,
PaidExpiresBefore: filters.PaidExpiresBefore,
PaidExpiresAfter: filters.PaidExpiresAfter,
DeclaredCountry: filters.DeclaredCountry,
SanctionCode: filters.SanctionCode,
LimitCode: filters.LimitCode,
CanLogin: filters.CanLogin,
CanCreatePrivateGame: filters.CanCreatePrivateGame,
CanJoinGame: filters.CanJoinGame,
}
}
func (filters userListFilters) Validate() error {
if !filters.PaidState.IsKnown() {
return fmt.Errorf("paid state %q is unsupported", filters.PaidState)
}
if filters.PaidExpiresBefore != nil && filters.PaidExpiresBefore.IsZero() {
return fmt.Errorf("paid expires before must not be zero")
}
if filters.PaidExpiresAfter != nil && filters.PaidExpiresAfter.IsZero() {
return fmt.Errorf("paid expires after must not be zero")
}
if !filters.DeclaredCountry.IsZero() {
if err := filters.DeclaredCountry.Validate(); err != nil {
return fmt.Errorf("declared country: %w", err)
}
}
if filters.SanctionCode != "" && !filters.SanctionCode.IsKnown() {
return fmt.Errorf("sanction code %q is unsupported", filters.SanctionCode)
}
if filters.LimitCode != "" && !filters.LimitCode.IsKnown() {
return fmt.Errorf("limit code %q is unsupported", filters.LimitCode)
}
return nil
}
// encodePageToken encodes cursor + filters into the frozen opaque page token
// shape used by the trusted admin listing surface. The encoding is identical
// to the previous Redis implementation so existing public clients can keep
// using their stored tokens through the migration cut-over.
func encodePageToken(cursor pageCursor, filters userListFilters) (string, error) {
if err := cursor.Validate(); err != nil {
return "", fmt.Errorf("encode page token: %w", err)
}
fingerprint, err := normaliseFilters(filters)
if err != nil {
return "", fmt.Errorf("encode page token: %w", err)
}
payload, err := json.Marshal(pageTokenPayload{
CreatedAt: cursor.CreatedAt.UTC().Format(time.RFC3339Nano),
UserID: cursor.UserID.String(),
Filters: fingerprint,
})
if err != nil {
return "", fmt.Errorf("encode page token: %w", err)
}
return base64.RawURLEncoding.EncodeToString(payload), nil
}
// decodePageToken parses raw and verifies the embedded fingerprint matches
// expected. The token's wire format is preserved across the Redis-to-
// PostgreSQL adapter swap.
func decodePageToken(raw string, expected userListFilters) (pageCursor, error) {
fingerprint, err := normaliseFilters(expected)
if err != nil {
return pageCursor{}, fmt.Errorf("decode page token: %w", err)
}
payload, err := base64.RawURLEncoding.DecodeString(raw)
if err != nil {
return pageCursor{}, fmt.Errorf("decode page token: %w", err)
}
var token pageTokenPayload
if err := json.Unmarshal(payload, &token); err != nil {
return pageCursor{}, fmt.Errorf("decode page token: %w", err)
}
if token.Filters != fingerprint {
return pageCursor{}, errPageTokenFiltersMismatch
}
createdAt, err := time.Parse(time.RFC3339Nano, token.CreatedAt)
if err != nil {
return pageCursor{}, fmt.Errorf("decode page token: parse created_at: %w", err)
}
cursor := pageCursor{CreatedAt: createdAt.UTC(), UserID: common.UserID(token.UserID)}
if err := cursor.Validate(); err != nil {
return pageCursor{}, fmt.Errorf("decode page token: %w", err)
}
return cursor, nil
}
type pageTokenPayload struct {
CreatedAt string `json:"created_at"`
UserID string `json:"user_id"`
Filters normalisedFilterFields `json:"filters"`
}
type normalisedFilterFields struct {
PaidState string `json:"paid_state,omitempty"`
PaidExpiresBeforeUTC string `json:"paid_expires_before_utc,omitempty"`
PaidExpiresAfterUTC string `json:"paid_expires_after_utc,omitempty"`
DeclaredCountry string `json:"declared_country,omitempty"`
SanctionCode string `json:"sanction_code,omitempty"`
LimitCode string `json:"limit_code,omitempty"`
CanLogin string `json:"can_login,omitempty"`
CanCreatePrivateGame string `json:"can_create_private_game,omitempty"`
CanJoinGame string `json:"can_join_game,omitempty"`
}
func normaliseFilters(filters userListFilters) (normalisedFilterFields, error) {
if err := filters.Validate(); err != nil {
return normalisedFilterFields{}, err
}
return normalisedFilterFields{
PaidState: string(filters.PaidState),
PaidExpiresBeforeUTC: formatOptionalUTC(filters.PaidExpiresBefore),
PaidExpiresAfterUTC: formatOptionalUTC(filters.PaidExpiresAfter),
DeclaredCountry: filters.DeclaredCountry.String(),
SanctionCode: string(filters.SanctionCode),
LimitCode: string(filters.LimitCode),
CanLogin: formatOptionalBool(filters.CanLogin),
CanCreatePrivateGame: formatOptionalBool(filters.CanCreatePrivateGame),
CanJoinGame: formatOptionalBool(filters.CanJoinGame),
}, nil
}
func formatOptionalUTC(value *time.Time) string {
if value == nil {
return ""
}
return value.UTC().Format(time.RFC3339Nano)
}
func formatOptionalBool(value *bool) string {
if value == nil {
return ""
}
if *value {
return "true"
}
return "false"
}
@@ -0,0 +1,870 @@
package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
pgtable "galaxy/user/internal/adapters/postgres/jet/user/table"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
pg "github.com/go-jet/jet/v2/postgres"
)
// sanctionSelectColumns is the canonical SELECT list for sanction_records,
// matching scanSanction's column order.
var sanctionSelectColumns = pg.ColumnList{
pgtable.SanctionRecords.RecordID,
pgtable.SanctionRecords.UserID,
pgtable.SanctionRecords.SanctionCode,
pgtable.SanctionRecords.Scope,
pgtable.SanctionRecords.ReasonCode,
pgtable.SanctionRecords.ActorType,
pgtable.SanctionRecords.ActorID,
pgtable.SanctionRecords.AppliedAt,
pgtable.SanctionRecords.ExpiresAt,
pgtable.SanctionRecords.RemovedAt,
pgtable.SanctionRecords.RemovedByType,
pgtable.SanctionRecords.RemovedByID,
pgtable.SanctionRecords.RemovedReasonCode,
}
// limitSelectColumns is the canonical SELECT list for limit_records, matching
// scanLimit's column order.
var limitSelectColumns = pg.ColumnList{
pgtable.LimitRecords.RecordID,
pgtable.LimitRecords.UserID,
pgtable.LimitRecords.LimitCode,
pgtable.LimitRecords.Value,
pgtable.LimitRecords.ReasonCode,
pgtable.LimitRecords.ActorType,
pgtable.LimitRecords.ActorID,
pgtable.LimitRecords.AppliedAt,
pgtable.LimitRecords.ExpiresAt,
pgtable.LimitRecords.RemovedAt,
pgtable.LimitRecords.RemovedByType,
pgtable.LimitRecords.RemovedByID,
pgtable.LimitRecords.RemovedReasonCode,
}
// CreateSanction stores one new sanction history record.
func (store *Store) CreateSanction(ctx context.Context, record policy.SanctionRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create sanction in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "create sanction in postgres")
if err != nil {
return err
}
defer cancel()
return insertSanctionRecord(operationCtx, store.db, record)
}
func insertSanctionRecord(ctx context.Context, q queryer, record policy.SanctionRecord) error {
stmt := pgtable.SanctionRecords.INSERT(
pgtable.SanctionRecords.RecordID,
pgtable.SanctionRecords.UserID,
pgtable.SanctionRecords.SanctionCode,
pgtable.SanctionRecords.Scope,
pgtable.SanctionRecords.ReasonCode,
pgtable.SanctionRecords.ActorType,
pgtable.SanctionRecords.ActorID,
pgtable.SanctionRecords.AppliedAt,
pgtable.SanctionRecords.ExpiresAt,
pgtable.SanctionRecords.RemovedAt,
pgtable.SanctionRecords.RemovedByType,
pgtable.SanctionRecords.RemovedByID,
pgtable.SanctionRecords.RemovedReasonCode,
).VALUES(
record.RecordID.String(),
record.UserID.String(),
string(record.SanctionCode),
record.Scope.String(),
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
)
query, args := stmt.Sql()
_, err := q.ExecContext(ctx, query, args...)
if err == nil {
return nil
}
if isUniqueViolation(err) {
return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, ports.ErrConflict)
}
return fmt.Errorf("create sanction %q in postgres: %w", record.RecordID, err)
}
// GetSanctionByRecordID returns the sanction history record identified by
// recordID.
func (store *Store) GetSanctionByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) {
if err := recordID.Validate(); err != nil {
return policy.SanctionRecord{}, fmt.Errorf("get sanction from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get sanction from postgres")
if err != nil {
return policy.SanctionRecord{}, err
}
defer cancel()
stmt := pg.SELECT(sanctionSelectColumns).
FROM(pgtable.SanctionRecords).
WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(recordID.String())))
query, args := stmt.Sql()
row := store.db.QueryRowContext(operationCtx, query, args...)
record, err := scanSanctionRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, ports.ErrNotFound)
case err != nil:
return policy.SanctionRecord{}, fmt.Errorf("get sanction %q from postgres: %w", recordID, err)
}
return record, nil
}
// ListSanctionsByUserID returns every sanction history record owned by
// userID, ordered by applied_at ascending.
func (store *Store) ListSanctionsByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list sanctions from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list sanctions from postgres")
if err != nil {
return nil, err
}
defer cancel()
stmt := pg.SELECT(sanctionSelectColumns).
FROM(pgtable.SanctionRecords).
WHERE(pgtable.SanctionRecords.UserID.EQ(pg.String(userID.String()))).
ORDER_BY(pgtable.SanctionRecords.AppliedAt.ASC(), pgtable.SanctionRecords.RecordID.ASC())
query, args := stmt.Sql()
rows, err := store.db.QueryContext(operationCtx, query, args...)
if err != nil {
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
}
defer func() { _ = rows.Close() }()
out := make([]policy.SanctionRecord, 0)
for rows.Next() {
record, err := scanSanction(rows)
if err != nil {
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
}
out = append(out, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("list sanctions for %q from postgres: %w", userID, err)
}
return out, nil
}
// UpdateSanction replaces one stored sanction history record. The matched
// row is identified by record_id; ports.ErrNotFound is returned when no row
// matches.
func (store *Store) UpdateSanction(ctx context.Context, record policy.SanctionRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update sanction in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "update sanction in postgres")
if err != nil {
return err
}
defer cancel()
return updateSanctionRecordTx(operationCtx, store.db, record)
}
func updateSanctionRecordTx(ctx context.Context, q queryer, record policy.SanctionRecord) error {
stmt := pgtable.SanctionRecords.UPDATE(
pgtable.SanctionRecords.UserID,
pgtable.SanctionRecords.SanctionCode,
pgtable.SanctionRecords.Scope,
pgtable.SanctionRecords.ReasonCode,
pgtable.SanctionRecords.ActorType,
pgtable.SanctionRecords.ActorID,
pgtable.SanctionRecords.AppliedAt,
pgtable.SanctionRecords.ExpiresAt,
pgtable.SanctionRecords.RemovedAt,
pgtable.SanctionRecords.RemovedByType,
pgtable.SanctionRecords.RemovedByID,
pgtable.SanctionRecords.RemovedReasonCode,
).SET(
record.UserID.String(),
string(record.SanctionCode),
record.Scope.String(),
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
).WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(record.RecordID.String())))
query, args := stmt.Sql()
res, err := q.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("update sanction %q in postgres: %w", record.RecordID, ports.ErrNotFound)
}
return nil
}
func scanSanctionRow(row *sql.Row) (policy.SanctionRecord, error) {
record, err := scanSanction(row)
if errors.Is(err, sql.ErrNoRows) {
return policy.SanctionRecord{}, ports.ErrNotFound
}
return record, err
}
func scanSanction(row scannableRow) (policy.SanctionRecord, error) {
var (
recordID string
userID string
code string
scope string
reason string
actorType string
actorID *string
appliedAt time.Time
expiresAt *time.Time
removedAt *time.Time
rmByType *string
rmByID *string
rmReason *string
)
if err := row.Scan(
&recordID, &userID, &code, &scope, &reason,
&actorType, &actorID, &appliedAt,
&expiresAt, &removedAt,
&rmByType, &rmByID, &rmReason,
); err != nil {
return policy.SanctionRecord{}, err
}
record := policy.SanctionRecord{
RecordID: policy.SanctionRecordID(recordID),
UserID: common.UserID(userID),
SanctionCode: policy.SanctionCode(code),
Scope: common.Scope(scope),
ReasonCode: common.ReasonCode(reason),
Actor: common.ActorRef{Type: common.ActorType(actorType)},
AppliedAt: appliedAt.UTC(),
ExpiresAt: timeFromNullable(expiresAt),
RemovedAt: timeFromNullable(removedAt),
}
if actorID != nil {
record.Actor.ID = common.ActorID(*actorID)
}
if rmByType != nil {
record.RemovedBy.Type = common.ActorType(*rmByType)
}
if rmByID != nil {
record.RemovedBy.ID = common.ActorID(*rmByID)
}
if rmReason != nil {
record.RemovedReasonCode = common.ReasonCode(*rmReason)
}
return record, nil
}
// CreateLimit stores one new limit history record.
func (store *Store) CreateLimit(ctx context.Context, record policy.LimitRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create limit in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "create limit in postgres")
if err != nil {
return err
}
defer cancel()
return insertLimitRecord(operationCtx, store.db, record)
}
func insertLimitRecord(ctx context.Context, q queryer, record policy.LimitRecord) error {
stmt := pgtable.LimitRecords.INSERT(
pgtable.LimitRecords.RecordID,
pgtable.LimitRecords.UserID,
pgtable.LimitRecords.LimitCode,
pgtable.LimitRecords.Value,
pgtable.LimitRecords.ReasonCode,
pgtable.LimitRecords.ActorType,
pgtable.LimitRecords.ActorID,
pgtable.LimitRecords.AppliedAt,
pgtable.LimitRecords.ExpiresAt,
pgtable.LimitRecords.RemovedAt,
pgtable.LimitRecords.RemovedByType,
pgtable.LimitRecords.RemovedByID,
pgtable.LimitRecords.RemovedReasonCode,
).VALUES(
record.RecordID.String(),
record.UserID.String(),
string(record.LimitCode),
record.Value,
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
)
query, args := stmt.Sql()
_, err := q.ExecContext(ctx, query, args...)
if err == nil {
return nil
}
if isUniqueViolation(err) {
return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, ports.ErrConflict)
}
return fmt.Errorf("create limit %q in postgres: %w", record.RecordID, err)
}
// GetLimitByRecordID returns the limit history record identified by recordID.
func (store *Store) GetLimitByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) {
if err := recordID.Validate(); err != nil {
return policy.LimitRecord{}, fmt.Errorf("get limit from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get limit from postgres")
if err != nil {
return policy.LimitRecord{}, err
}
defer cancel()
stmt := pg.SELECT(limitSelectColumns).
FROM(pgtable.LimitRecords).
WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(recordID.String())))
query, args := stmt.Sql()
row := store.db.QueryRowContext(operationCtx, query, args...)
record, err := scanLimitRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, ports.ErrNotFound)
case err != nil:
return policy.LimitRecord{}, fmt.Errorf("get limit %q from postgres: %w", recordID, err)
}
return record, nil
}
// ListLimitsByUserID returns every limit history record owned by userID,
// ordered by applied_at ascending.
func (store *Store) ListLimitsByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list limits from postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list limits from postgres")
if err != nil {
return nil, err
}
defer cancel()
stmt := pg.SELECT(limitSelectColumns).
FROM(pgtable.LimitRecords).
WHERE(pgtable.LimitRecords.UserID.EQ(pg.String(userID.String()))).
ORDER_BY(pgtable.LimitRecords.AppliedAt.ASC(), pgtable.LimitRecords.RecordID.ASC())
query, args := stmt.Sql()
rows, err := store.db.QueryContext(operationCtx, query, args...)
if err != nil {
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
}
defer func() { _ = rows.Close() }()
out := make([]policy.LimitRecord, 0)
for rows.Next() {
record, err := scanLimit(rows)
if err != nil {
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
}
out = append(out, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("list limits for %q from postgres: %w", userID, err)
}
return out, nil
}
// UpdateLimit replaces one stored limit history record.
func (store *Store) UpdateLimit(ctx context.Context, record policy.LimitRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update limit in postgres: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "update limit in postgres")
if err != nil {
return err
}
defer cancel()
return updateLimitRecordTx(operationCtx, store.db, record)
}
func updateLimitRecordTx(ctx context.Context, q queryer, record policy.LimitRecord) error {
stmt := pgtable.LimitRecords.UPDATE(
pgtable.LimitRecords.UserID,
pgtable.LimitRecords.LimitCode,
pgtable.LimitRecords.Value,
pgtable.LimitRecords.ReasonCode,
pgtable.LimitRecords.ActorType,
pgtable.LimitRecords.ActorID,
pgtable.LimitRecords.AppliedAt,
pgtable.LimitRecords.ExpiresAt,
pgtable.LimitRecords.RemovedAt,
pgtable.LimitRecords.RemovedByType,
pgtable.LimitRecords.RemovedByID,
pgtable.LimitRecords.RemovedReasonCode,
).SET(
record.UserID.String(),
string(record.LimitCode),
record.Value,
record.ReasonCode.String(),
record.Actor.Type.String(),
nullableActorID(record.Actor.ID),
record.AppliedAt.UTC(),
nullableTime(record.ExpiresAt),
nullableTime(record.RemovedAt),
nullableActorType(record.RemovedBy.Type),
nullableActorID(record.RemovedBy.ID),
nullableReasonCode(record.RemovedReasonCode),
).WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(record.RecordID.String())))
query, args := stmt.Sql()
res, err := q.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("update limit %q in postgres: %w", record.RecordID, ports.ErrNotFound)
}
return nil
}
func scanLimitRow(row *sql.Row) (policy.LimitRecord, error) {
record, err := scanLimit(row)
if errors.Is(err, sql.ErrNoRows) {
return policy.LimitRecord{}, ports.ErrNotFound
}
return record, err
}
func scanLimit(row scannableRow) (policy.LimitRecord, error) {
var (
recordID string
userID string
code string
value int
reason string
actorType string
actorID *string
appliedAt time.Time
expiresAt *time.Time
removedAt *time.Time
rmByType *string
rmByID *string
rmReason *string
)
if err := row.Scan(
&recordID, &userID, &code, &value, &reason,
&actorType, &actorID, &appliedAt,
&expiresAt, &removedAt,
&rmByType, &rmByID, &rmReason,
); err != nil {
return policy.LimitRecord{}, err
}
record := policy.LimitRecord{
RecordID: policy.LimitRecordID(recordID),
UserID: common.UserID(userID),
LimitCode: policy.LimitCode(code),
Value: value,
ReasonCode: common.ReasonCode(reason),
Actor: common.ActorRef{Type: common.ActorType(actorType)},
AppliedAt: appliedAt.UTC(),
ExpiresAt: timeFromNullable(expiresAt),
RemovedAt: timeFromNullable(removedAt),
}
if actorID != nil {
record.Actor.ID = common.ActorID(*actorID)
}
if rmByType != nil {
record.RemovedBy.Type = common.ActorType(*rmByType)
}
if rmByID != nil {
record.RemovedBy.ID = common.ActorID(*rmByID)
}
if rmReason != nil {
record.RemovedReasonCode = common.ReasonCode(*rmReason)
}
return record, nil
}
// ApplySanction inserts the new sanction history row and points
// sanction_active at it. Re-applying the same code while another active
// record exists returns ports.ErrConflict.
func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("apply sanction in postgres: %w", err)
}
return store.withTx(ctx, "apply sanction in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := insertSanctionRecord(ctx, tx, input.NewRecord); err != nil {
return err
}
stmt := pgtable.SanctionActive.INSERT(
pgtable.SanctionActive.UserID,
pgtable.SanctionActive.SanctionCode,
pgtable.SanctionActive.RecordID,
).VALUES(
input.NewRecord.UserID.String(),
string(input.NewRecord.SanctionCode),
input.NewRecord.RecordID.String(),
)
query, args := stmt.Sql()
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
if isUniqueViolation(err) {
return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict)
}
return fmt.Errorf("apply sanction %q in postgres: %w", input.NewRecord.RecordID, err)
}
return nil
})
}
// RemoveSanction updates the existing sanction record with remove metadata
// and clears the sanction_active row that pointed at it.
func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove sanction in postgres: %w", err)
}
return store.withTx(ctx, "remove sanction in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockSanctionMatching(ctx, tx, input.ExpectedActiveRecord); err != nil {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if err := updateSanctionRecordTx(ctx, tx, input.UpdatedRecord); err != nil {
return err
}
stmt := pgtable.SanctionActive.DELETE().
WHERE(pg.AND(
pgtable.SanctionActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())),
pgtable.SanctionActive.SanctionCode.EQ(pg.String(string(input.ExpectedActiveRecord.SanctionCode))),
pgtable.SanctionActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())),
))
query, args := stmt.Sql()
res, err := tx.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("remove sanction %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict)
}
return nil
})
}
// SetLimit creates a new active limit (or replaces one) for the user. When
// ExpectedActiveRecord is nil the call must succeed only if no active row
// exists for (user_id, limit_code); otherwise the existing record is
// updated with remove metadata and superseded by NewRecord.
func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("set limit in postgres: %w", err)
}
return store.withTx(ctx, "set limit in postgres", func(ctx context.Context, tx *sql.Tx) error {
if input.ExpectedActiveRecord != nil {
if err := lockLimitMatching(ctx, tx, *input.ExpectedActiveRecord); err != nil {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
}
if err := updateLimitRecordTx(ctx, tx, *input.UpdatedActiveRecord); err != nil {
return err
}
} else {
probe := pg.SELECT(pgtable.LimitActive.RecordID).
FROM(pgtable.LimitActive).
WHERE(pg.AND(
pgtable.LimitActive.UserID.EQ(pg.String(input.NewRecord.UserID.String())),
pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.NewRecord.LimitCode))),
)).
FOR(pg.UPDATE())
probeQuery, probeArgs := probe.Sql()
row := tx.QueryRowContext(ctx, probeQuery, probeArgs...)
var marker string
if err := row.Scan(&marker); err == nil {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, ports.ErrConflict)
} else if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
}
}
if err := insertLimitRecord(ctx, tx, input.NewRecord); err != nil {
return err
}
upsert := pgtable.LimitActive.INSERT(
pgtable.LimitActive.UserID,
pgtable.LimitActive.LimitCode,
pgtable.LimitActive.RecordID,
pgtable.LimitActive.Value,
).VALUES(
input.NewRecord.UserID.String(),
string(input.NewRecord.LimitCode),
input.NewRecord.RecordID.String(),
input.NewRecord.Value,
).ON_CONFLICT(pgtable.LimitActive.UserID, pgtable.LimitActive.LimitCode).DO_UPDATE(
pg.SET(
pgtable.LimitActive.RecordID.SET(pgtable.LimitActive.EXCLUDED.RecordID),
pgtable.LimitActive.Value.SET(pgtable.LimitActive.EXCLUDED.Value),
),
)
upsertQuery, upsertArgs := upsert.Sql()
if _, err := tx.ExecContext(ctx, upsertQuery, upsertArgs...); err != nil {
return fmt.Errorf("set limit %q in postgres: %w", input.NewRecord.RecordID, err)
}
return nil
})
}
// RemoveLimit updates the limit record with remove metadata and removes the
// active row that referenced it.
func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove limit in postgres: %w", err)
}
return store.withTx(ctx, "remove limit in postgres", func(ctx context.Context, tx *sql.Tx) error {
if err := lockLimitMatching(ctx, tx, input.ExpectedActiveRecord); err != nil {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if err := updateLimitRecordTx(ctx, tx, input.UpdatedRecord); err != nil {
return err
}
stmt := pgtable.LimitActive.DELETE().
WHERE(pg.AND(
pgtable.LimitActive.UserID.EQ(pg.String(input.ExpectedActiveRecord.UserID.String())),
pgtable.LimitActive.LimitCode.EQ(pg.String(string(input.ExpectedActiveRecord.LimitCode))),
pgtable.LimitActive.RecordID.EQ(pg.String(input.ExpectedActiveRecord.RecordID.String())),
))
query, args := stmt.Sql()
res, err := tx.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, err)
}
if rows == 0 {
return fmt.Errorf("remove limit %q in postgres: %w", input.ExpectedActiveRecord.RecordID, ports.ErrConflict)
}
return nil
})
}
func lockSanctionMatching(ctx context.Context, tx *sql.Tx, expected policy.SanctionRecord) error {
stmt := pg.SELECT(sanctionSelectColumns).
FROM(pgtable.SanctionRecords).
WHERE(pgtable.SanctionRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
FOR(pg.UPDATE())
query, args := stmt.Sql()
row := tx.QueryRowContext(ctx, query, args...)
current, err := scanSanctionRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return ports.ErrNotFound
case err != nil:
return err
}
if !sanctionsEqual(current, expected) {
return ports.ErrConflict
}
return nil
}
func lockLimitMatching(ctx context.Context, tx *sql.Tx, expected policy.LimitRecord) error {
stmt := pg.SELECT(limitSelectColumns).
FROM(pgtable.LimitRecords).
WHERE(pgtable.LimitRecords.RecordID.EQ(pg.String(expected.RecordID.String()))).
FOR(pg.UPDATE())
query, args := stmt.Sql()
row := tx.QueryRowContext(ctx, query, args...)
current, err := scanLimitRow(row)
switch {
case errors.Is(err, ports.ErrNotFound):
return ports.ErrNotFound
case err != nil:
return err
}
if !limitsEqual(current, expected) {
return ports.ErrConflict
}
return nil
}
func sanctionsEqual(left policy.SanctionRecord, right policy.SanctionRecord) bool {
if left.RecordID != right.RecordID ||
left.UserID != right.UserID ||
left.SanctionCode != right.SanctionCode ||
left.Scope != right.Scope ||
left.ReasonCode != right.ReasonCode ||
left.Actor != right.Actor ||
left.RemovedBy != right.RemovedBy ||
left.RemovedReasonCode != right.RemovedReasonCode {
return false
}
if !left.AppliedAt.Equal(right.AppliedAt) {
return false
}
if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) {
return false
}
return optionalTimeEqual(left.RemovedAt, right.RemovedAt)
}
func limitsEqual(left policy.LimitRecord, right policy.LimitRecord) bool {
if left.RecordID != right.RecordID ||
left.UserID != right.UserID ||
left.LimitCode != right.LimitCode ||
left.Value != right.Value ||
left.ReasonCode != right.ReasonCode ||
left.Actor != right.Actor ||
left.RemovedBy != right.RemovedBy ||
left.RemovedReasonCode != right.RemovedReasonCode {
return false
}
if !left.AppliedAt.Equal(right.AppliedAt) {
return false
}
if !optionalTimeEqual(left.ExpiresAt, right.ExpiresAt) {
return false
}
return optionalTimeEqual(left.RemovedAt, right.RemovedAt)
}
// SanctionStore adapts Store to the SanctionStore port.
type SanctionStore struct{ store *Store }
// Sanctions returns one adapter that exposes the sanction store port.
func (store *Store) Sanctions() *SanctionStore {
if store == nil {
return nil
}
return &SanctionStore{store: store}
}
// Create stores one new sanction history record.
func (a *SanctionStore) Create(ctx context.Context, record policy.SanctionRecord) error {
return a.store.CreateSanction(ctx, record)
}
// GetByRecordID returns the sanction record identified by recordID.
func (a *SanctionStore) GetByRecordID(ctx context.Context, recordID policy.SanctionRecordID) (policy.SanctionRecord, error) {
return a.store.GetSanctionByRecordID(ctx, recordID)
}
// ListByUserID returns every sanction record owned by userID.
func (a *SanctionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.SanctionRecord, error) {
return a.store.ListSanctionsByUserID(ctx, userID)
}
// Update replaces one stored sanction record.
func (a *SanctionStore) Update(ctx context.Context, record policy.SanctionRecord) error {
return a.store.UpdateSanction(ctx, record)
}
var _ ports.SanctionStore = (*SanctionStore)(nil)
// LimitStore adapts Store to the LimitStore port.
type LimitStore struct{ store *Store }
// Limits returns one adapter that exposes the limit store port.
func (store *Store) Limits() *LimitStore {
if store == nil {
return nil
}
return &LimitStore{store: store}
}
// Create stores one new limit history record.
func (a *LimitStore) Create(ctx context.Context, record policy.LimitRecord) error {
return a.store.CreateLimit(ctx, record)
}
// GetByRecordID returns the limit record identified by recordID.
func (a *LimitStore) GetByRecordID(ctx context.Context, recordID policy.LimitRecordID) (policy.LimitRecord, error) {
return a.store.GetLimitByRecordID(ctx, recordID)
}
// ListByUserID returns every limit record owned by userID.
func (a *LimitStore) ListByUserID(ctx context.Context, userID common.UserID) ([]policy.LimitRecord, error) {
return a.store.ListLimitsByUserID(ctx, userID)
}
// Update replaces one stored limit record.
func (a *LimitStore) Update(ctx context.Context, record policy.LimitRecord) error {
return a.store.UpdateLimit(ctx, record)
}
var _ ports.LimitStore = (*LimitStore)(nil)
// PolicyLifecycleStore adapts Store to the PolicyLifecycleStore port.
type PolicyLifecycleStore struct{ store *Store }
// PolicyLifecycle returns one adapter that exposes the policy-lifecycle
// store port.
func (store *Store) PolicyLifecycle() *PolicyLifecycleStore {
if store == nil {
return nil
}
return &PolicyLifecycleStore{store: store}
}
// ApplySanction atomically creates one new active sanction record.
func (a *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
return a.store.ApplySanction(ctx, input)
}
// RemoveSanction atomically removes one active sanction record.
func (a *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
return a.store.RemoveSanction(ctx, input)
}
// SetLimit atomically creates or replaces one active limit record.
func (a *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
return a.store.SetLimit(ctx, input)
}
// RemoveLimit atomically removes one active limit record.
func (a *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
return a.store.RemoveLimit(ctx, input)
}
var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)
@@ -0,0 +1,138 @@
// Package userstore implements the PostgreSQL-backed source-of-truth
// persistence used by User Service.
//
// The package owns the on-disk shape of the `user` schema (defined in
// `galaxy/user/internal/adapters/postgres/migrations`) and translates the
// schema-agnostic ports defined under `galaxy/user/internal/ports` into
// concrete `database/sql` operations driven by the pgx driver. Atomic
// composite operations (auth-directory, entitlement-lifecycle, policy-
// lifecycle) execute inside explicit `BEGIN … COMMIT` transactions with
// `SELECT … FOR UPDATE` locks on the rows they mutate.
//
// Stage 3 of `PG_PLAN.md` migrates User Service away from Redis-backed
// durable state. Two Redis Streams (`user:domain_events`,
// `user:lifecycle_events`) remain on Redis for event publication; the
// store is no longer aware of them.
package userstore
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"galaxy/user/internal/ports"
)
// Config configures one PostgreSQL-backed user store instance. The store does
// not own the underlying *sql.DB lifecycle: the caller (typically the
// service runtime) opens, instruments, migrates, and closes the pool. The
// store only borrows the pool and bounds individual round trips with
// OperationTimeout.
type Config struct {
// DB stores the connection pool the store uses for every query.
DB *sql.DB
// OperationTimeout bounds one round trip. The store creates a derived
// context for each operation so callers cannot starve the pool with an
// unbounded ctx. Multi-statement transactions inherit this bound for the
// whole BEGIN … COMMIT span.
OperationTimeout time.Duration
}
// Store persists auth-facing user state in PostgreSQL and exposes the narrow
// atomic auth-facing mutation boundary plus selected entity-store interfaces
// through the same accessor methods (`Accounts`, `BlockedEmails`,
// `EntitlementSnapshots`, `EntitlementHistory`, `EntitlementLifecycle`,
// `Sanctions`, `Limits`, `PolicyLifecycle`) that the previous Redis-backed
// store provided. This keeps the runtime wiring identical between the two
// implementations.
type Store struct {
db *sql.DB
operationTimeout time.Duration
}
// New constructs one PostgreSQL-backed user store from cfg.
func New(cfg Config) (*Store, error) {
if cfg.DB == nil {
return nil, errors.New("new postgres user store: db must not be nil")
}
if cfg.OperationTimeout <= 0 {
return nil, errors.New("new postgres user store: operation timeout must be positive")
}
return &Store{
db: cfg.DB,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close is a no-op for the PostgreSQL-backed store: the connection pool is
// owned by the caller (the runtime) and closed once the runtime shuts down.
// The accessor remains so the Redis-store contract can be preserved
// transparently in the runtime wiring.
func (store *Store) Close() error {
return nil
}
// Ping verifies that the configured PostgreSQL backend is reachable. It runs
// `db.PingContext` under the configured operation timeout.
func (store *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := withTimeout(ctx, "ping postgres user store", store.operationTimeout)
if err != nil {
return err
}
defer cancel()
if err := store.db.PingContext(operationCtx); err != nil {
return fmt.Errorf("ping postgres user store: %w", err)
}
return nil
}
// withTx runs fn inside a BEGIN … COMMIT transaction bounded by the store's
// operation timeout. It rolls back on any error or panic and returns whatever
// fn returned. The transaction uses the default isolation level
// (`READ COMMITTED`); per-row locking is achieved through `SELECT … FOR
// UPDATE` issued inside fn.
func (store *Store) withTx(ctx context.Context, operation string, fn func(ctx context.Context, tx *sql.Tx) error) error {
operationCtx, cancel, err := withTimeout(ctx, operation, store.operationTimeout)
if err != nil {
return err
}
defer cancel()
tx, err := store.db.BeginTx(operationCtx, nil)
if err != nil {
return fmt.Errorf("%s: begin: %w", operation, err)
}
if err := fn(operationCtx, tx); err != nil {
_ = tx.Rollback()
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("%s: commit: %w", operation, err)
}
return nil
}
// operationContext bounds one read or write that does not need a transaction
// envelope (single statement). It mirrors store.withTx for non-transactional
// callers.
func (store *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
return withTimeout(ctx, operation, store.operationTimeout)
}
// Store directly satisfies the user-account port (its primary entity) and the
// composite auth-directory port. The remaining ports
// (BlockedEmailStore, entitlement-*, sanction-*, limit-*, user-list) are
// implemented by adapter types declared in their respective files; those
// adapters are obtained through Accounts(), BlockedEmails(),
// EntitlementSnapshots(), EntitlementHistory(), EntitlementLifecycle(),
// Sanctions(), Limits(), PolicyLifecycle(), and UserList() accessors.
var (
_ ports.AuthDirectoryStore = (*Store)(nil)
_ ports.UserAccountStore = (*Store)(nil)
)
@@ -0,0 +1,656 @@
package userstore
import (
"context"
"errors"
"testing"
"time"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/authblock"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"github.com/stretchr/testify/require"
)
// All time values are aligned to microseconds because PostgreSQL's
// timestamptz only stores microsecond precision; using nanoseconds here
// would cause round-trip mismatches.
var fixtureCreatedAt = time.Unix(1_775_240_000, 0).UTC()
func validAccount() account.UserAccount {
return account.UserAccount{
UserID: common.UserID("user-pilot-001"),
Email: common.Email("pilot@example.com"),
UserName: common.UserName("player-aaaaaaaa"),
DisplayName: common.DisplayName("NovaPrime"),
PreferredLanguage: common.LanguageTag("en"),
TimeZone: common.TimeZoneName("Europe/Kaliningrad"),
CreatedAt: fixtureCreatedAt,
UpdatedAt: fixtureCreatedAt,
}
}
func validFreeSnapshot(userID common.UserID, at time.Time) entitlement.CurrentSnapshot {
return entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: at.UTC(),
Source: common.Source("auth_signup"),
Actor: common.ActorRef{Type: common.ActorType("auth")},
ReasonCode: common.ReasonCode("initial_free_entitlement"),
UpdatedAt: at.UTC(),
}
}
func validFreePeriod(userID common.UserID, recordID entitlement.EntitlementRecordID, at time.Time) entitlement.PeriodRecord {
return entitlement.PeriodRecord{
RecordID: recordID,
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("auth_signup"),
Actor: common.ActorRef{Type: common.ActorType("auth")},
ReasonCode: common.ReasonCode("initial_free_entitlement"),
StartsAt: at.UTC(),
CreatedAt: at.UTC(),
}
}
func paidPeriod(userID common.UserID, recordID entitlement.EntitlementRecordID, startsAt, endsAt time.Time) entitlement.PeriodRecord {
end := endsAt.UTC()
return entitlement.PeriodRecord{
RecordID: recordID,
UserID: userID,
PlanCode: entitlement.PlanCodePaidMonthly,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_grant"),
StartsAt: startsAt.UTC(),
EndsAt: &end,
CreatedAt: startsAt.UTC(),
}
}
func paidSnapshot(userID common.UserID, startsAt, endsAt, updatedAt time.Time) entitlement.CurrentSnapshot {
end := endsAt.UTC()
return entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodePaidMonthly,
IsPaid: true,
StartsAt: startsAt.UTC(),
EndsAt: &end,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_grant"),
UpdatedAt: updatedAt.UTC(),
}
}
func validSanction(userID common.UserID, code policy.SanctionCode, appliedAt time.Time) policy.SanctionRecord {
return policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-" + string(code) + "-1"),
UserID: userID,
SanctionCode: code,
Scope: common.Scope("platform"),
ReasonCode: common.ReasonCode("manual_block"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: appliedAt.UTC(),
}
}
func validLimit(userID common.UserID, code policy.LimitCode, value int, appliedAt time.Time) policy.LimitRecord {
return policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-" + string(code) + "-1"),
UserID: userID,
LimitCode: code,
Value: value,
ReasonCode: common.ReasonCode("manual_override"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: appliedAt.UTC(),
}
}
func TestAccountCreateAndLookups(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
got, err := store.GetByUserID(ctx, record.UserID)
require.NoError(t, err)
require.Equal(t, record, got)
got, err = store.GetByEmail(ctx, record.Email)
require.NoError(t, err)
require.Equal(t, record, got)
got, err = store.GetByUserName(ctx, record.UserName)
require.NoError(t, err)
require.Equal(t, record, got)
exists, err := store.ExistsByUserID(ctx, record.UserID)
require.NoError(t, err)
require.True(t, exists)
}
func TestAccountCreateConflictsAreClassified(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
// Same UserID -> generic conflict.
require.True(t, errors.Is(store.Create(ctx, ports.CreateAccountInput{Account: record}), ports.ErrConflict))
// Same UserName, different UserID/email -> ErrUserNameConflict (which
// also satisfies errors.Is(ErrConflict)).
clone := validAccount()
clone.UserID = common.UserID("user-pilot-002")
clone.Email = common.Email("pilot2@example.com")
err := store.Create(ctx, ports.CreateAccountInput{Account: clone})
require.True(t, errors.Is(err, ports.ErrUserNameConflict))
require.True(t, errors.Is(err, ports.ErrConflict))
// Same email, different UserID/user_name -> generic conflict.
clone = validAccount()
clone.UserID = common.UserID("user-pilot-003")
clone.UserName = common.UserName("player-bbbbbbbb")
err = store.Create(ctx, ports.CreateAccountInput{Account: clone})
require.True(t, errors.Is(err, ports.ErrConflict))
require.False(t, errors.Is(err, ports.ErrUserNameConflict))
}
func TestAccountUpdateRespectsImmutableFieldsAndSoftDelete(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
updated := record
updated.DisplayName = common.DisplayName("HelloWorld")
updated.DeclaredCountry = common.CountryCode("DE")
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
require.NoError(t, store.Update(ctx, updated))
got, err := store.GetByUserID(ctx, record.UserID)
require.NoError(t, err)
require.Equal(t, updated, got)
// Mutating user_name must surface as ErrConflict.
mutating := updated
mutating.UserName = common.UserName("player-xxxxxxxx")
require.True(t, errors.Is(store.Update(ctx, mutating), ports.ErrConflict))
// Soft-delete via Update sets DeletedAt; ExistsByUserID flips to false.
deletedAt := updated.UpdatedAt.Add(time.Minute)
soft := updated
soft.DeletedAt = &deletedAt
soft.UpdatedAt = deletedAt
require.NoError(t, store.Update(ctx, soft))
exists, err := store.ExistsByUserID(ctx, record.UserID)
require.NoError(t, err)
require.False(t, exists)
}
func TestBlockedEmailUpsertAndGet(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := authblock.BlockedEmailSubject{
Email: common.Email("blocked@example.com"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: fixtureCreatedAt,
}
require.NoError(t, store.PutBlockedEmail(ctx, record))
got, err := store.GetBlockedEmail(ctx, record.Email)
require.NoError(t, err)
require.Equal(t, record, got)
// Upsert replaces existing.
updated := record
updated.ReasonCode = common.ReasonCode("admin_blocked")
updated.BlockedAt = record.BlockedAt.Add(time.Hour)
updated.Actor = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
require.NoError(t, store.PutBlockedEmail(ctx, updated))
got, err = store.GetBlockedEmail(ctx, record.Email)
require.NoError(t, err)
require.Equal(t, updated, got)
}
func TestResolveByEmailReturnsCreatableExistingBlockedAndDeleted(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
creatable, err := store.ResolveByEmail(ctx, common.Email("nobody@example.com"))
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindCreatable, creatable.Kind)
require.NoError(t, store.PutBlockedEmail(ctx, authblock.BlockedEmailSubject{
Email: common.Email("blocked@example.com"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: fixtureCreatedAt,
}))
blocked, err := store.ResolveByEmail(ctx, common.Email("blocked@example.com"))
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindBlocked, blocked.Kind)
require.Equal(t, common.ReasonCode("policy_blocked"), blocked.BlockReasonCode)
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
existing, err := store.ResolveByEmail(ctx, record.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindExisting, existing.Kind)
require.Equal(t, record.UserID, existing.UserID)
// Soft-delete the account; the email lookup must now resolve to blocked.
deletedAt := record.UpdatedAt.Add(time.Minute)
soft := record
soft.DeletedAt = &deletedAt
soft.UpdatedAt = deletedAt
require.NoError(t, store.Update(ctx, soft))
deletedResult, err := store.ResolveByEmail(ctx, record.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindBlocked, deletedResult.Kind)
require.Equal(t, deletedAccountBlockReasonCode, deletedResult.BlockReasonCode)
}
func TestEnsureByEmailCoversAllOutcomes(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
snapshot := validFreeSnapshot(record.UserID, record.CreatedAt)
period := validFreePeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-initial"), record.CreatedAt)
created, err := store.EnsureByEmail(ctx, ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: snapshot,
EntitlementRecord: period,
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeCreated, created.Outcome)
require.Equal(t, record.UserID, created.UserID)
// Second call with the same email returns existing. The Account input
// describes the would-be-created record if no account existed yet; its
// email must match the request email per ports.EnsureByEmailInput.Validate.
existingCandidate := validSecondAccount()
existingCandidate.Email = record.Email
existing, err := store.EnsureByEmail(ctx, ports.EnsureByEmailInput{
Email: record.Email,
Account: existingCandidate,
Entitlement: validFreeSnapshot(existingCandidate.UserID, record.CreatedAt),
EntitlementRecord: validFreePeriod(existingCandidate.UserID, entitlement.EntitlementRecordID("entitlement-second"), record.CreatedAt),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeExisting, existing.Outcome)
require.Equal(t, record.UserID, existing.UserID)
// Blocked email path.
require.NoError(t, store.PutBlockedEmail(ctx, authblock.BlockedEmailSubject{
Email: common.Email("blocked@example.com"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: fixtureCreatedAt,
}))
blockedAccount := validSecondAccount()
blockedAccount.Email = common.Email("blocked@example.com")
blockedSnapshot := validFreeSnapshot(blockedAccount.UserID, record.CreatedAt)
blockedPeriod := validFreePeriod(blockedAccount.UserID, entitlement.EntitlementRecordID("entitlement-blocked"), record.CreatedAt)
blocked, err := store.EnsureByEmail(ctx, ports.EnsureByEmailInput{
Email: blockedAccount.Email,
Account: blockedAccount,
Entitlement: blockedSnapshot,
EntitlementRecord: blockedPeriod,
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, blocked.Outcome)
require.Equal(t, common.ReasonCode("policy_blocked"), blocked.BlockReasonCode)
// Soft-deleted account → blocked(account_deleted).
deletedAt := record.UpdatedAt.Add(time.Hour)
soft := record
soft.DeletedAt = &deletedAt
soft.UpdatedAt = deletedAt
require.NoError(t, store.Update(ctx, soft))
deletedCandidate := validSecondAccount()
deletedCandidate.Email = record.Email
deletedCandidate.UserID = common.UserID("user-third")
deletedCandidate.UserName = common.UserName("player-cccccccc")
deletedResult, err := store.EnsureByEmail(ctx, ports.EnsureByEmailInput{
Email: record.Email,
Account: deletedCandidate,
Entitlement: validFreeSnapshot(deletedCandidate.UserID, record.CreatedAt),
EntitlementRecord: validFreePeriod(deletedCandidate.UserID, entitlement.EntitlementRecordID("entitlement-second-2"), record.CreatedAt),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, deletedResult.Outcome)
require.Equal(t, deletedAccountBlockReasonCode, deletedResult.BlockReasonCode)
}
func validSecondAccount() account.UserAccount {
return account.UserAccount{
UserID: common.UserID("user-second"),
Email: common.Email("second@example.com"),
UserName: common.UserName("player-bbbbbbbb"),
PreferredLanguage: common.LanguageTag("en"),
TimeZone: common.TimeZoneName("UTC"),
CreatedAt: fixtureCreatedAt,
UpdatedAt: fixtureCreatedAt,
}
}
func TestBlockByUserIDAndBlockByEmail(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
res, err := store.BlockByUserID(ctx, ports.BlockByUserIDInput{
UserID: record.UserID,
ReasonCode: common.ReasonCode("manual_block"),
BlockedAt: fixtureCreatedAt.Add(time.Hour),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, res.Outcome)
require.Equal(t, record.UserID, res.UserID)
// Replay returns AlreadyBlocked.
res, err = store.BlockByUserID(ctx, ports.BlockByUserIDInput{
UserID: record.UserID,
ReasonCode: common.ReasonCode("manual_block"),
BlockedAt: fixtureCreatedAt.Add(2 * time.Hour),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeAlreadyBlocked, res.Outcome)
require.Equal(t, record.UserID, res.UserID)
// Block by email for a non-existing address records the block with
// nil resolved_user_id.
res, err = store.BlockByEmail(ctx, ports.BlockByEmailInput{
Email: common.Email("ghost@example.com"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: fixtureCreatedAt.Add(time.Hour),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, res.Outcome)
require.True(t, res.UserID.IsZero())
}
func TestEntitlementSnapshotPutAndGet(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
snapshot := validFreeSnapshot(record.UserID, record.CreatedAt)
require.NoError(t, store.PutEntitlement(ctx, snapshot))
got, err := store.GetEntitlementByUserID(ctx, record.UserID)
require.NoError(t, err)
require.Equal(t, snapshot, got)
// Upsert replaces.
paid := paidSnapshot(record.UserID, record.CreatedAt, record.CreatedAt.Add(30*24*time.Hour), record.CreatedAt.Add(time.Minute))
require.NoError(t, store.PutEntitlement(ctx, paid))
got, err = store.GetEntitlementByUserID(ctx, record.UserID)
require.NoError(t, err)
require.Equal(t, paid, got)
}
func TestEntitlementHistoryCRUDAndList(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
first := validFreePeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-1"), record.CreatedAt)
second := paidPeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-2"), record.CreatedAt.Add(time.Hour), record.CreatedAt.Add(48*time.Hour))
require.NoError(t, store.CreateEntitlementRecord(ctx, first))
require.NoError(t, store.CreateEntitlementRecord(ctx, second))
require.True(t, errors.Is(store.CreateEntitlementRecord(ctx, first), ports.ErrConflict))
got, err := store.GetEntitlementRecordByID(ctx, first.RecordID)
require.NoError(t, err)
require.Equal(t, first, got)
list, err := store.ListEntitlementRecordsByUserID(ctx, record.UserID)
require.NoError(t, err)
require.Len(t, list, 2)
require.Equal(t, first.RecordID, list[0].RecordID)
require.Equal(t, second.RecordID, list[1].RecordID)
closedAt := record.CreatedAt.Add(2 * time.Hour)
updated := first
updated.ClosedAt = &closedAt
updated.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
updated.ClosedReasonCode = common.ReasonCode("superseded")
require.NoError(t, store.UpdateEntitlementRecord(ctx, updated))
got, err = store.GetEntitlementRecordByID(ctx, updated.RecordID)
require.NoError(t, err)
require.Equal(t, updated, got)
}
func TestEntitlementLifecycleGrantExtendRevokeRepair(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
freeSnap := validFreeSnapshot(record.UserID, record.CreatedAt)
freeRecord := validFreePeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-free-1"), record.CreatedAt)
require.NoError(t, store.PutEntitlement(ctx, freeSnap))
require.NoError(t, store.CreateEntitlementRecord(ctx, freeRecord))
closedAt := record.CreatedAt.Add(time.Hour)
closedFree := freeRecord
closedFree.ClosedAt = &closedAt
closedFree.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFree.ClosedReasonCode = common.ReasonCode("superseded")
paidStart := closedAt
paidEnd := paidStart.Add(30 * 24 * time.Hour)
paid := paidPeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-paid-1"), paidStart, paidEnd)
paidSnap := paidSnapshot(record.UserID, paidStart, paidEnd, paidStart)
require.NoError(t, store.GrantEntitlement(ctx, ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: freeSnap,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFree,
NewRecord: paid,
NewSnapshot: paidSnap,
}))
got, err := store.GetEntitlementByUserID(ctx, record.UserID)
require.NoError(t, err)
require.Equal(t, paidSnap, got)
// Extend with a new paid segment.
extendStart := paidEnd
extendEnd := extendStart.Add(30 * 24 * time.Hour)
extendRecord := paidPeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-paid-2"), extendStart, extendEnd)
extendSnap := paidSnapshot(record.UserID, paidStart, extendEnd, extendStart)
require.NoError(t, store.ExtendEntitlement(ctx, ports.ExtendEntitlementInput{
ExpectedCurrentSnapshot: paidSnap,
NewRecord: extendRecord,
NewSnapshot: extendSnap,
}))
// Revoke -> back to free.
revokeAt := extendStart.Add(time.Hour)
revokedPaid := extendRecord
revokedPaid.ClosedAt = &revokeAt
revokedPaid.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
revokedPaid.ClosedReasonCode = common.ReasonCode("revoked")
freeAgain := validFreePeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-free-2"), revokeAt)
freeAgainSnap := validFreeSnapshot(record.UserID, revokeAt)
require.NoError(t, store.RevokeEntitlement(ctx, ports.RevokeEntitlementInput{
ExpectedCurrentSnapshot: extendSnap,
ExpectedCurrentRecord: extendRecord,
UpdatedCurrentRecord: revokedPaid,
NewRecord: freeAgain,
NewSnapshot: freeAgainSnap,
}))
got, err = store.GetEntitlementByUserID(ctx, record.UserID)
require.NoError(t, err)
require.Equal(t, freeAgainSnap, got)
}
func TestEntitlementLifecycleConflictsOnSnapshotMismatch(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
freeSnap := validFreeSnapshot(record.UserID, record.CreatedAt)
require.NoError(t, store.PutEntitlement(ctx, freeSnap))
stale := freeSnap
stale.UpdatedAt = freeSnap.UpdatedAt.Add(-time.Hour)
freeRecord := validFreePeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-free-1"), record.CreatedAt)
require.NoError(t, store.CreateEntitlementRecord(ctx, freeRecord))
closedAt := record.CreatedAt.Add(time.Hour)
closedFree := freeRecord
closedFree.ClosedAt = &closedAt
closedFree.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFree.ClosedReasonCode = common.ReasonCode("superseded")
paid := paidPeriod(record.UserID, entitlement.EntitlementRecordID("entitlement-paid-1"), closedAt, closedAt.Add(time.Hour))
paidSnap := paidSnapshot(record.UserID, closedAt, closedAt.Add(time.Hour), closedAt)
err := store.GrantEntitlement(ctx, ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: stale,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFree,
NewRecord: paid,
NewSnapshot: paidSnap,
})
require.True(t, errors.Is(err, ports.ErrConflict))
}
func TestPolicyApplyRemoveSanctionAndLimit(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
record := validAccount()
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: record}))
sanction := validSanction(record.UserID, policy.SanctionCodeLoginBlock, fixtureCreatedAt.Add(time.Minute))
require.NoError(t, store.ApplySanction(ctx, ports.ApplySanctionInput{NewRecord: sanction}))
got, err := store.GetSanctionByRecordID(ctx, sanction.RecordID)
require.NoError(t, err)
require.Equal(t, sanction, got)
// Re-applying the same sanction code without removing first must return
// ErrConflict because (user_id, sanction_code) is unique on
// sanction_active.
dup := sanction
dup.RecordID = policy.SanctionRecordID("sanction-login_block-2")
require.True(t, errors.Is(store.ApplySanction(ctx, ports.ApplySanctionInput{NewRecord: dup}), ports.ErrConflict))
removedAt := sanction.AppliedAt.Add(time.Hour)
updated := sanction
updated.RemovedAt = &removedAt
updated.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
updated.RemovedReasonCode = common.ReasonCode("manual_unblock")
require.NoError(t, store.RemoveSanction(ctx, ports.RemoveSanctionInput{
ExpectedActiveRecord: sanction,
UpdatedRecord: updated,
}))
got, err = store.GetSanctionByRecordID(ctx, sanction.RecordID)
require.NoError(t, err)
require.Equal(t, updated, got)
// Now SetLimit on a fresh code; replay must conflict.
limit := validLimit(record.UserID, policy.LimitCodeMaxOwnedPrivateGames, 5, fixtureCreatedAt.Add(2*time.Minute))
require.NoError(t, store.SetLimit(ctx, ports.SetLimitInput{NewRecord: limit}))
dupLimit := limit
dupLimit.RecordID = policy.LimitRecordID("limit-max_owned_private_games-2")
require.True(t, errors.Is(store.SetLimit(ctx, ports.SetLimitInput{NewRecord: dupLimit}), ports.ErrConflict))
// SetLimit with ExpectedActiveRecord -> replaces in the active slot.
expected := limit
expected.RemovedAt = nil
expected.RemovedBy = common.ActorRef{}
expected.RemovedReasonCode = ""
supersededTime := limit.AppliedAt.Add(time.Hour)
supersededLimit := limit
supersededLimit.RemovedAt = &supersededTime
supersededLimit.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
supersededLimit.RemovedReasonCode = common.ReasonCode("superseded")
newLimit := validLimit(record.UserID, policy.LimitCodeMaxOwnedPrivateGames, 7, supersededTime)
newLimit.RecordID = policy.LimitRecordID("limit-max_owned_private_games-3")
require.NoError(t, store.SetLimit(ctx, ports.SetLimitInput{
ExpectedActiveRecord: &expected,
UpdatedActiveRecord: &supersededLimit,
NewRecord: newLimit,
}))
gotLimit, err := store.GetLimitByRecordID(ctx, newLimit.RecordID)
require.NoError(t, err)
require.Equal(t, newLimit, gotLimit)
}
func TestUserListPaginatesNewestFirstAndDetectsFilterMismatch(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
base := fixtureCreatedAt
for index, suffix := range []string{"a", "b", "c", "d", "e"} {
acc := validAccount()
acc.UserID = common.UserID("user-list-" + suffix)
acc.Email = common.Email("list-" + suffix + "@example.com")
acc.UserName = common.UserName("player-list" + suffix + "xx")
acc.CreatedAt = base.Add(time.Duration(index) * time.Minute)
acc.UpdatedAt = acc.CreatedAt
require.NoError(t, store.Create(ctx, ports.CreateAccountInput{Account: acc}))
}
page1, err := store.ListUserIDs(ctx, ports.ListUsersInput{PageSize: 2})
require.NoError(t, err)
require.Len(t, page1.UserIDs, 2)
require.Equal(t, common.UserID("user-list-e"), page1.UserIDs[0])
require.Equal(t, common.UserID("user-list-d"), page1.UserIDs[1])
require.NotEmpty(t, page1.NextPageToken)
page2, err := store.ListUserIDs(ctx, ports.ListUsersInput{
PageSize: 2,
PageToken: page1.NextPageToken,
})
require.NoError(t, err)
require.Len(t, page2.UserIDs, 2)
require.Equal(t, common.UserID("user-list-c"), page2.UserIDs[0])
require.Equal(t, common.UserID("user-list-b"), page2.UserIDs[1])
// Mismatched filters must reject the previously-issued token.
mismatched, err := store.ListUserIDs(ctx, ports.ListUsersInput{
PageSize: 2,
PageToken: page1.NextPageToken,
Filters: ports.UserListFilters{PaidState: entitlement.PaidStatePaid},
})
require.True(t, errors.Is(err, ports.ErrInvalidPageToken), "got result %#v err %v", mismatched, err)
}
@@ -4,7 +4,6 @@ package domainevents
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
@@ -17,23 +16,11 @@ import (
"go.opentelemetry.io/otel/trace"
)
// Config configures one Redis-backed user domain-event publisher.
// Config configures one Redis-backed user domain-event publisher. The
// connection is supplied externally by the runtime so multiple publishers
// can share one *redis.Client; this struct now carries only stream-shape
// parameters.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// Stream identifies the Redis Stream key used for domain events.
Stream string
@@ -53,13 +40,13 @@ type Publisher struct {
operationTimeout time.Duration
}
// New constructs a Redis-backed domain-event publisher from cfg.
func New(cfg Config) (*Publisher, error) {
// New constructs a Redis-backed domain-event publisher backed by the
// supplied client. The publisher does not own the client; the runtime is
// responsible for closing it.
func New(client *redis.Client, cfg Config) (*Publisher, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis domain-event publisher: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis domain-event publisher: redis db must not be negative")
case client == nil:
return nil, errors.New("new redis domain-event publisher: redis client must not be nil")
case strings.TrimSpace(cfg.Stream) == "":
return nil, errors.New("new redis domain-event publisher: stream must not be empty")
case cfg.StreamMaxLen <= 0:
@@ -68,33 +55,19 @@ func New(cfg Config) (*Publisher, error) {
return nil, errors.New("new redis domain-event publisher: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Publisher{
client: redis.NewClient(options),
client: client,
stream: cfg.Stream,
streamMaxLen: cfg.StreamMaxLen,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
// Close is a no-op: the client is owned by the runtime, not the publisher.
// The accessor remains for API symmetry with the previous Redis adapter so
// runtime cleanup chains do not need to special-case this surface.
func (publisher *Publisher) Close() error {
if publisher == nil || publisher.client == nil {
return nil
}
return publisher.client.Close()
return nil
}
// Ping verifies that the configured Redis backend is reachable within the
@@ -10,6 +10,7 @@ import (
"galaxy/user/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
@@ -17,8 +18,7 @@ func TestPublisherPublishesFlatRedisStreamEntry(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher, err := New(Config{
Addr: server.Addr(),
publisher, err := New(redis.NewClient(&redis.Options{Addr: server.Addr()}), Config{
Stream: "user:test_events",
StreamMaxLen: 5,
OperationTimeout: time.Second,
@@ -70,8 +70,7 @@ func TestPublisherRejectsInvalidEventBeforeXAdd(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher, err := New(Config{
Addr: server.Addr(),
publisher, err := New(redis.NewClient(&redis.Options{Addr: server.Addr()}), Config{
Stream: "user:test_events",
StreamMaxLen: 5,
OperationTimeout: time.Second,
@@ -4,7 +4,6 @@ package lifecycleevents
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
@@ -17,23 +16,10 @@ import (
"go.opentelemetry.io/otel/trace"
)
// Config configures one Redis-backed user-lifecycle publisher.
// Config configures one Redis-backed user-lifecycle publisher. The
// connection is supplied externally by the runtime so multiple publishers
// can share one *redis.Client.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// Stream identifies the Redis Stream key used for lifecycle events. The
// default platform key is `user:lifecycle_events`.
Stream string
@@ -55,13 +41,13 @@ type Publisher struct {
operationTimeout time.Duration
}
// New constructs a Redis-backed lifecycle-event publisher from cfg.
func New(cfg Config) (*Publisher, error) {
// New constructs a Redis-backed lifecycle-event publisher backed by the
// supplied client. The publisher does not own the client; the runtime is
// responsible for closing it.
func New(client *redis.Client, cfg Config) (*Publisher, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis lifecycle-event publisher: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis lifecycle-event publisher: redis db must not be negative")
case client == nil:
return nil, errors.New("new redis lifecycle-event publisher: redis client must not be nil")
case strings.TrimSpace(cfg.Stream) == "":
return nil, errors.New("new redis lifecycle-event publisher: stream must not be empty")
case cfg.StreamMaxLen <= 0:
@@ -70,33 +56,17 @@ func New(cfg Config) (*Publisher, error) {
return nil, errors.New("new redis lifecycle-event publisher: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Publisher{
client: redis.NewClient(options),
client: client,
stream: cfg.Stream,
streamMaxLen: cfg.StreamMaxLen,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
// Close is a no-op: the client is owned by the runtime.
func (publisher *Publisher) Close() error {
if publisher == nil || publisher.client == nil {
return nil
}
return publisher.client.Close()
return nil
}
// Ping verifies that the configured Redis backend is reachable within the
@@ -10,6 +10,7 @@ import (
"galaxy/user/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
@@ -17,8 +18,7 @@ func TestPublisherPublishesPermanentBlockedEnvelope(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher, err := New(Config{
Addr: server.Addr(),
publisher, err := New(redis.NewClient(&redis.Options{Addr: server.Addr()}), Config{
Stream: "user:lifecycle_events",
StreamMaxLen: 10,
OperationTimeout: time.Second,
@@ -54,8 +54,7 @@ func TestPublisherOmitsOptionalActorIDAndTraceID(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher, err := New(Config{
Addr: server.Addr(),
publisher, err := New(redis.NewClient(&redis.Options{Addr: server.Addr()}), Config{
Stream: "user:lifecycle_events",
StreamMaxLen: 10,
OperationTimeout: time.Second,
@@ -86,8 +85,7 @@ func TestPublisherRejectsInvalidEventBeforeXAdd(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher, err := New(Config{
Addr: server.Addr(),
publisher, err := New(redis.NewClient(&redis.Options{Addr: server.Addr()}), Config{
Stream: "user:lifecycle_events",
StreamMaxLen: 10,
OperationTimeout: time.Second,
@@ -113,8 +111,7 @@ func TestPublisherTrimsBeyondMaxLen(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher, err := New(Config{
Addr: server.Addr(),
publisher, err := New(redis.NewClient(&redis.Options{Addr: server.Addr()}), Config{
Stream: "user:lifecycle_events",
StreamMaxLen: 5,
OperationTimeout: time.Second,
@@ -142,8 +139,7 @@ func TestPublisherPingReportsReachability(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher, err := New(Config{
Addr: server.Addr(),
publisher, err := New(redis.NewClient(&redis.Options{Addr: server.Addr()}), Config{
Stream: "user:lifecycle_events",
StreamMaxLen: 10,
OperationTimeout: time.Second,
@@ -1,227 +0,0 @@
package userstore
import (
"context"
"errors"
"galaxy/user/internal/adapters/redisstate"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
var knownSanctionCodes = []policy.SanctionCode{
policy.SanctionCodeLoginBlock,
policy.SanctionCodePrivateGameCreateBlock,
policy.SanctionCodePrivateGameManageBlock,
policy.SanctionCodeGameJoinBlock,
policy.SanctionCodeProfileUpdateBlock,
policy.SanctionCodePermanentBlock,
}
var knownLimitCodes = []policy.LimitCode{
policy.LimitCodeMaxOwnedPrivateGames,
policy.LimitCodeMaxPendingPublicApplications,
policy.LimitCodeMaxActiveGameMemberships,
policy.LimitCodeMaxRegisteredRaceNames,
}
var knownEligibilityMarkers = []policy.EligibilityMarker{
policy.EligibilityMarkerCanLogin,
policy.EligibilityMarkerCanCreatePrivateGame,
policy.EligibilityMarkerCanManagePrivateGame,
policy.EligibilityMarkerCanJoinGame,
policy.EligibilityMarkerCanUpdateProfile,
}
func (store *Store) addCreatedAtIndex(
pipe redis.Pipeliner,
ctx context.Context,
record account.UserAccount,
) {
pipe.ZAdd(ctx, store.keyspace.CreatedAtIndex(), redis.Z{
Score: redisstate.CreatedAtScore(record.CreatedAt),
Member: record.UserID.String(),
})
}
func (store *Store) syncDeclaredCountryIndex(
pipe redis.Pipeliner,
ctx context.Context,
previous account.UserAccount,
current account.UserAccount,
) {
if !previous.DeclaredCountry.IsZero() {
pipe.SRem(ctx, store.keyspace.DeclaredCountryIndex(previous.DeclaredCountry), current.UserID.String())
}
if !current.DeclaredCountry.IsZero() {
pipe.SAdd(ctx, store.keyspace.DeclaredCountryIndex(current.DeclaredCountry), current.UserID.String())
}
}
func (store *Store) syncEntitlementIndexes(
pipe redis.Pipeliner,
ctx context.Context,
snapshot entitlement.CurrentSnapshot,
) {
pipe.SRem(ctx, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), snapshot.UserID.String())
pipe.SRem(ctx, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), snapshot.UserID.String())
pipe.SAdd(ctx, store.keyspace.PaidStateIndex(paidStateFromSnapshot(snapshot)), snapshot.UserID.String())
pipe.ZRem(ctx, store.keyspace.FinitePaidExpiryIndex(), snapshot.UserID.String())
if snapshot.HasFiniteExpiry() {
pipe.ZAdd(ctx, store.keyspace.FinitePaidExpiryIndex(), redis.Z{
Score: redisstate.ExpiryScore(*snapshot.EndsAt),
Member: snapshot.UserID.String(),
})
}
}
func (store *Store) syncActiveSanctionCodeIndexes(
pipe redis.Pipeliner,
ctx context.Context,
userID common.UserID,
activeCodes map[policy.SanctionCode]struct{},
) {
for _, code := range knownSanctionCodes {
pipe.SRem(ctx, store.keyspace.ActiveSanctionCodeIndex(code), userID.String())
if _, ok := activeCodes[code]; ok {
pipe.SAdd(ctx, store.keyspace.ActiveSanctionCodeIndex(code), userID.String())
}
}
}
func (store *Store) syncActiveLimitCodeIndexes(
pipe redis.Pipeliner,
ctx context.Context,
userID common.UserID,
activeCodes map[policy.LimitCode]struct{},
) {
for _, code := range knownLimitCodes {
pipe.SRem(ctx, store.keyspace.ActiveLimitCodeIndex(code), userID.String())
if _, ok := activeCodes[code]; ok {
pipe.SAdd(ctx, store.keyspace.ActiveLimitCodeIndex(code), userID.String())
}
}
}
func (store *Store) syncEligibilityMarkerIndexes(
pipe redis.Pipeliner,
ctx context.Context,
userID common.UserID,
isPaid bool,
activeSanctionCodes map[policy.SanctionCode]struct{},
) {
values := deriveEligibilityMarkerValues(isPaid, activeSanctionCodes)
for _, marker := range knownEligibilityMarkers {
pipe.SRem(ctx, store.keyspace.EligibilityMarkerIndex(marker, true), userID.String())
pipe.SRem(ctx, store.keyspace.EligibilityMarkerIndex(marker, false), userID.String())
pipe.SAdd(ctx, store.keyspace.EligibilityMarkerIndex(marker, values[marker]), userID.String())
}
}
func (store *Store) loadActiveSanctionCodeSet(
ctx context.Context,
getter bytesGetter,
userID common.UserID,
) (map[policy.SanctionCode]struct{}, error) {
activeCodes := make(map[policy.SanctionCode]struct{}, len(knownSanctionCodes))
for _, code := range knownSanctionCodes {
_, err := store.loadActiveSanctionRecordID(ctx, getter, store.keyspace.ActiveSanction(userID, code))
switch {
case err == nil:
activeCodes[code] = struct{}{}
case errors.Is(err, ports.ErrNotFound):
continue
default:
return nil, err
}
}
return activeCodes, nil
}
func (store *Store) loadActiveLimitCodeSet(
ctx context.Context,
getter bytesGetter,
userID common.UserID,
) (map[policy.LimitCode]struct{}, error) {
activeCodes := make(map[policy.LimitCode]struct{}, len(knownLimitCodes))
for _, code := range knownLimitCodes {
_, err := store.loadActiveLimitRecordID(ctx, getter, store.keyspace.ActiveLimit(userID, code))
switch {
case err == nil:
activeCodes[code] = struct{}{}
case errors.Is(err, ports.ErrNotFound):
continue
default:
return nil, err
}
}
return activeCodes, nil
}
func (store *Store) activeSanctionWatchKeys(userID common.UserID) []string {
keys := make([]string, 0, len(knownSanctionCodes))
for _, code := range knownSanctionCodes {
keys = append(keys, store.keyspace.ActiveSanction(userID, code))
}
return keys
}
func (store *Store) activeLimitWatchKeys(userID common.UserID) []string {
keys := make([]string, 0, len(knownLimitCodes))
for _, code := range knownLimitCodes {
keys = append(keys, store.keyspace.ActiveLimit(userID, code))
}
return keys
}
func deriveEligibilityMarkerValues(
isPaid bool,
activeSanctionCodes map[policy.SanctionCode]struct{},
) map[policy.EligibilityMarker]bool {
if _, permanentBlocked := activeSanctionCodes[policy.SanctionCodePermanentBlock]; permanentBlocked {
return map[policy.EligibilityMarker]bool{
policy.EligibilityMarkerCanLogin: false,
policy.EligibilityMarkerCanCreatePrivateGame: false,
policy.EligibilityMarkerCanManagePrivateGame: false,
policy.EligibilityMarkerCanJoinGame: false,
policy.EligibilityMarkerCanUpdateProfile: false,
}
}
_, loginBlocked := activeSanctionCodes[policy.SanctionCodeLoginBlock]
_, createBlocked := activeSanctionCodes[policy.SanctionCodePrivateGameCreateBlock]
_, manageBlocked := activeSanctionCodes[policy.SanctionCodePrivateGameManageBlock]
_, joinBlocked := activeSanctionCodes[policy.SanctionCodeGameJoinBlock]
_, profileBlocked := activeSanctionCodes[policy.SanctionCodeProfileUpdateBlock]
canLogin := !loginBlocked
return map[policy.EligibilityMarker]bool{
policy.EligibilityMarkerCanLogin: canLogin,
policy.EligibilityMarkerCanCreatePrivateGame: canLogin && isPaid && !createBlocked,
policy.EligibilityMarkerCanManagePrivateGame: canLogin && isPaid && !manageBlocked,
policy.EligibilityMarkerCanJoinGame: canLogin && !joinBlocked,
policy.EligibilityMarkerCanUpdateProfile: canLogin && !profileBlocked,
}
}
func paidStateFromSnapshot(snapshot entitlement.CurrentSnapshot) entitlement.PaidState {
if snapshot.IsPaid {
return entitlement.PaidStatePaid
}
return entitlement.PaidStateFree
}
@@ -1,58 +0,0 @@
package userstore
import (
"testing"
"galaxy/user/internal/domain/policy"
"github.com/stretchr/testify/require"
)
func TestDeriveEligibilityMarkerValuesCollapsesUnderPermanentBlock(t *testing.T) {
t.Parallel()
activeCodes := map[policy.SanctionCode]struct{}{
policy.SanctionCodePermanentBlock: {},
}
values := deriveEligibilityMarkerValues(true, activeCodes)
require.False(t, values[policy.EligibilityMarkerCanLogin])
require.False(t, values[policy.EligibilityMarkerCanCreatePrivateGame])
require.False(t, values[policy.EligibilityMarkerCanManagePrivateGame])
require.False(t, values[policy.EligibilityMarkerCanJoinGame])
require.False(t, values[policy.EligibilityMarkerCanUpdateProfile])
}
func TestDeriveEligibilityMarkerValuesPermanentBlockDominatesOtherSanctions(t *testing.T) {
t.Parallel()
activeCodes := map[policy.SanctionCode]struct{}{
policy.SanctionCodePermanentBlock: {},
policy.SanctionCodeLoginBlock: {},
policy.SanctionCodeGameJoinBlock: {},
}
values := deriveEligibilityMarkerValues(false, activeCodes)
for marker, value := range values {
require.Falsef(t, value, "marker %q must be false under permanent_block", marker)
}
}
func TestDeriveEligibilityMarkerValuesFreeUserWithoutPermanentBlock(t *testing.T) {
t.Parallel()
values := deriveEligibilityMarkerValues(false, map[policy.SanctionCode]struct{}{})
require.True(t, values[policy.EligibilityMarkerCanLogin])
require.False(t, values[policy.EligibilityMarkerCanCreatePrivateGame])
require.False(t, values[policy.EligibilityMarkerCanManagePrivateGame])
require.True(t, values[policy.EligibilityMarkerCanJoinGame])
require.True(t, values[policy.EligibilityMarkerCanUpdateProfile])
}
func TestKnownCatalogsIncludeStage22Codes(t *testing.T) {
t.Parallel()
require.Contains(t, knownSanctionCodes, policy.SanctionCodePermanentBlock)
require.Contains(t, knownLimitCodes, policy.LimitCodeMaxRegisteredRaceNames)
}
@@ -1,445 +0,0 @@
package userstore
import (
"context"
"testing"
"time"
"galaxy/user/internal/adapters/redisstate"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"galaxy/user/internal/service/adminusers"
"galaxy/user/internal/service/entitlementsvc"
"github.com/stretchr/testify/require"
)
func TestListUserIDsCreatedAtPagination(t *testing.T) {
t.Parallel()
store := newTestStore(t)
base := time.Unix(1_775_240_000, 0).UTC()
first := validAccountRecord()
first.UserID = common.UserID("user-100")
first.Email = common.Email("u100@example.com")
first.UserName = common.UserName("player-user100aa")
first.CreatedAt = base.Add(-time.Hour)
first.UpdatedAt = first.CreatedAt
second := validAccountRecord()
second.UserID = common.UserID("user-200")
second.Email = common.Email("u200@example.com")
second.UserName = common.UserName("player-user200aa")
second.CreatedAt = base
second.UpdatedAt = second.CreatedAt
third := validAccountRecord()
third.UserID = common.UserID("user-300")
third.Email = common.Email("u300@example.com")
third.UserName = common.UserName("player-user300aa")
third.CreatedAt = base
third.UpdatedAt = third.CreatedAt
require.NoError(t, store.Create(context.Background(), createAccountInput(first)))
require.NoError(t, store.Create(context.Background(), createAccountInput(second)))
require.NoError(t, store.Create(context.Background(), createAccountInput(third)))
firstPage, err := store.ListUserIDs(context.Background(), ports.ListUsersInput{
PageSize: 2,
Filters: ports.UserListFilters{},
})
require.NoError(t, err)
require.Equal(t, []common.UserID{third.UserID, second.UserID}, firstPage.UserIDs)
require.NotEmpty(t, firstPage.NextPageToken)
secondPage, err := store.ListUserIDs(context.Background(), ports.ListUsersInput{
PageSize: 2,
PageToken: firstPage.NextPageToken,
Filters: ports.UserListFilters{},
})
require.NoError(t, err)
require.Equal(t, []common.UserID{first.UserID}, secondPage.UserIDs)
require.Empty(t, secondPage.NextPageToken)
}
func TestEnsureByEmailInitialAdminIndexes(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.DeclaredCountry = common.CountryCode("DE")
record.CreatedAt = now
record.UpdatedAt = now
result, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(record.UserID, now),
EntitlementRecord: validEntitlementRecord(record.UserID, now),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeCreated, result.Outcome)
requireSortedSetScore(t, store, store.keyspace.CreatedAtIndex(), record.UserID.String(), redisstate.CreatedAtScore(record.CreatedAt))
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
requireSetContains(t, store, store.keyspace.DeclaredCountryIndex(record.DeclaredCountry), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, false), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanJoinGame, true), record.UserID.String())
}
func TestAccountUpdateSyncsDeclaredCountryIndex(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
record.DeclaredCountry = common.CountryCode("DE")
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
updated := record
updated.DeclaredCountry = common.CountryCode("FR")
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
require.NoError(t, accountStore.Update(context.Background(), updated))
requireSetNotContains(t, store, store.keyspace.DeclaredCountryIndex(common.CountryCode("DE")), record.UserID.String())
requireSetContains(t, store, store.keyspace.DeclaredCountryIndex(common.CountryCode("FR")), record.UserID.String())
}
func TestEntitlementLifecycleSyncsAdminIndexes(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.CreatedAt = now
record.UpdatedAt = now
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(record.UserID, now),
EntitlementRecord: validEntitlementRecord(record.UserID, now),
})
require.NoError(t, err)
lifecycleStore := store.EntitlementLifecycle()
freeRecord := validEntitlementRecord(record.UserID, now)
freeSnapshot := validEntitlementSnapshot(record.UserID, now)
grantStartsAt := now.Add(time.Hour)
grantEndsAt := grantStartsAt.Add(30 * 24 * time.Hour)
grantedRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-1"),
record.UserID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
grantedSnapshot := paidEntitlementSnapshot(
record.UserID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
closedFreeRecord := freeRecord
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
require.NoError(t, lifecycleStore.Grant(context.Background(), ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: freeSnapshot,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFreeRecord,
NewRecord: grantedRecord,
NewSnapshot: grantedSnapshot,
}))
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
requireSortedSetScore(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String(), redisstate.ExpiryScore(grantEndsAt))
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, true), record.UserID.String())
extendedEndsAt := grantEndsAt.Add(30 * 24 * time.Hour)
extensionRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-2"),
record.UserID,
entitlement.PlanCodePaidMonthly,
grantEndsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
extendedSnapshot := paidEntitlementSnapshot(
record.UserID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
require.NoError(t, lifecycleStore.Extend(context.Background(), ports.ExtendEntitlementInput{
ExpectedCurrentSnapshot: grantedSnapshot,
NewRecord: extensionRecord,
NewSnapshot: extendedSnapshot,
}))
requireSortedSetScore(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String(), redisstate.ExpiryScore(extendedEndsAt))
revokeAt := grantEndsAt.Add(12 * time.Hour)
revokedCurrentRecord := extensionRecord
revokedCurrentRecord.ClosedAt = timePointer(revokeAt)
revokedCurrentRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
revokedCurrentRecord.ClosedReasonCode = common.ReasonCode("manual_revoke")
freeAfterRevokeRecord := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-free-2"),
UserID: record.UserID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
StartsAt: revokeAt,
CreatedAt: revokeAt,
}
freeAfterRevokeSnapshot := entitlement.CurrentSnapshot{
UserID: record.UserID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: revokeAt,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
UpdatedAt: revokeAt,
}
require.NoError(t, lifecycleStore.Revoke(context.Background(), ports.RevokeEntitlementInput{
ExpectedCurrentSnapshot: extendedSnapshot,
ExpectedCurrentRecord: extensionRecord,
UpdatedCurrentRecord: revokedCurrentRecord,
NewRecord: freeAfterRevokeRecord,
NewSnapshot: freeAfterRevokeSnapshot,
}))
requireSetContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStateFree), record.UserID.String())
requireSetNotContains(t, store, store.keyspace.PaidStateIndex(entitlement.PaidStatePaid), record.UserID.String())
requireSortedSetMissing(t, store, store.keyspace.FinitePaidExpiryIndex(), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanCreatePrivateGame, false), record.UserID.String())
}
func TestPolicyLifecycleSyncsAdminIndexes(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.CreatedAt = now
record.UpdatedAt = now
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(record.UserID, now),
EntitlementRecord: validEntitlementRecord(record.UserID, now),
})
require.NoError(t, err)
lifecycleStore := store.PolicyLifecycle()
sanctionRecord := policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-1"),
UserID: record.UserID,
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("auth"),
ReasonCode: common.ReasonCode("manual_block"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now,
}
require.NoError(t, lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
NewRecord: sanctionRecord,
}))
requireSetContains(t, store, store.keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, false), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanJoinGame, false), record.UserID.String())
removedSanction := sanctionRecord
removedAt := now.Add(time.Minute)
removedSanction.RemovedAt = &removedAt
removedSanction.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
removedSanction.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveSanction(context.Background(), ports.RemoveSanctionInput{
ExpectedActiveRecord: sanctionRecord,
UpdatedRecord: removedSanction,
}))
requireSetNotContains(t, store, store.keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock), record.UserID.String())
requireSetContains(t, store, store.keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true), record.UserID.String())
limitRecord := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-1"),
UserID: record.UserID,
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 5,
ReasonCode: common.ReasonCode("manual_override"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now.Add(2 * time.Minute),
}
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
NewRecord: limitRecord,
}))
requireSetContains(t, store, store.keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames), record.UserID.String())
removedLimit := limitRecord
limitRemovedAt := now.Add(3 * time.Minute)
removedLimit.RemovedAt = &limitRemovedAt
removedLimit.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
removedLimit.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveLimit(context.Background(), ports.RemoveLimitInput{
ExpectedActiveRecord: limitRecord,
UpdatedRecord: removedLimit,
}))
requireSetNotContains(t, store, store.keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames), record.UserID.String())
}
func TestAdminListerReevaluatesExpiredPaidSnapshots(t *testing.T) {
t.Parallel()
store := newTestStore(t)
userID := common.UserID("user-123")
now := time.Unix(1_775_240_000, 0).UTC()
record := validAccountRecord()
record.CreatedAt = now.Add(-2 * time.Hour)
record.UpdatedAt = record.CreatedAt
_, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: record.Email,
Account: record,
Entitlement: validEntitlementSnapshot(userID, record.CreatedAt),
EntitlementRecord: validEntitlementRecord(userID, record.CreatedAt),
})
require.NoError(t, err)
grantStartsAt := now.Add(-90 * time.Minute)
grantEndsAt := now.Add(-30 * time.Minute)
freeRecord := validEntitlementRecord(userID, record.CreatedAt)
freeSnapshot := validEntitlementSnapshot(userID, record.CreatedAt)
grantedRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-expired"),
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
grantedSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
closedFreeRecord := freeRecord
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
require.NoError(t, store.EntitlementLifecycle().Grant(context.Background(), ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: freeSnapshot,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFreeRecord,
NewRecord: grantedRecord,
NewSnapshot: grantedSnapshot,
}))
reader, err := entitlementsvc.NewReader(
store.EntitlementSnapshots(),
store.EntitlementLifecycle(),
adminStoreClock{now: now},
adminStoreIDGenerator{entitlementRecordID: entitlement.EntitlementRecordID("entitlement-free-after-expiry")},
)
require.NoError(t, err)
lister, err := adminusers.NewLister(store.Accounts(), reader, store.Sanctions(), store.Limits(), adminStoreClock{now: now}, store)
require.NoError(t, err)
result, err := lister.Execute(context.Background(), adminusers.ListUsersInput{PaidState: "free"})
require.NoError(t, err)
require.Len(t, result.Items, 1)
require.Equal(t, "user-123", result.Items[0].UserID)
require.Equal(t, "free", result.Items[0].Entitlement.PlanCode)
require.False(t, result.Items[0].Entitlement.IsPaid)
storedSnapshot, err := store.EntitlementSnapshots().GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, entitlement.PlanCodeFree, storedSnapshot.PlanCode)
require.False(t, storedSnapshot.IsPaid)
}
type adminStoreClock struct {
now time.Time
}
func (clock adminStoreClock) Now() time.Time {
return clock.now
}
type adminStoreIDGenerator struct {
entitlementRecordID entitlement.EntitlementRecordID
}
func (generator adminStoreIDGenerator) NewUserID() (common.UserID, error) {
return "", nil
}
func (generator adminStoreIDGenerator) NewUserName() (common.UserName, error) {
return "", nil
}
func (generator adminStoreIDGenerator) NewEntitlementRecordID() (entitlement.EntitlementRecordID, error) {
return generator.entitlementRecordID, nil
}
func (generator adminStoreIDGenerator) NewSanctionRecordID() (policy.SanctionRecordID, error) {
return "", nil
}
func (generator adminStoreIDGenerator) NewLimitRecordID() (policy.LimitRecordID, error) {
return "", nil
}
func requireSetContains(t *testing.T, store *Store, key string, member string) {
t.Helper()
exists, err := store.client.SIsMember(context.Background(), key, member).Result()
require.NoError(t, err)
require.True(t, exists, "expected %q to contain %q", key, member)
}
func requireSetNotContains(t *testing.T, store *Store, key string, member string) {
t.Helper()
exists, err := store.client.SIsMember(context.Background(), key, member).Result()
require.NoError(t, err)
require.False(t, exists, "expected %q not to contain %q", key, member)
}
func requireSortedSetScore(t *testing.T, store *Store, key string, member string, want float64) {
t.Helper()
got, err := store.client.ZScore(context.Background(), key, member).Result()
require.NoError(t, err)
require.Equal(t, want, got)
}
func requireSortedSetMissing(t *testing.T, store *Store, key string, member string) {
t.Helper()
_, err := store.client.ZScore(context.Background(), key, member).Result()
require.Error(t, err)
}
@@ -1,752 +0,0 @@
package userstore
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
type entitlementPeriodRecord struct {
RecordID string `json:"record_id"`
UserID string `json:"user_id"`
PlanCode string `json:"plan_code"`
Source string `json:"source"`
ActorType string `json:"actor_type"`
ActorID *string `json:"actor_id,omitempty"`
ReasonCode string `json:"reason_code"`
StartsAt string `json:"starts_at"`
EndsAt *string `json:"ends_at,omitempty"`
CreatedAt string `json:"created_at"`
ClosedAt *string `json:"closed_at,omitempty"`
ClosedByType *string `json:"closed_by_type,omitempty"`
ClosedByID *string `json:"closed_by_id,omitempty"`
ClosedReasonCode *string `json:"closed_reason_code,omitempty"`
}
// CreateEntitlementRecord stores one new entitlement history record.
func (store *Store) CreateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create entitlement record in redis: %w", err)
}
payload, err := marshalEntitlementPeriodRecord(record)
if err != nil {
return fmt.Errorf("create entitlement record in redis: %w", err)
}
recordKey := store.keyspace.EntitlementRecord(record.RecordID)
historyKey := store.keyspace.EntitlementHistory(record.UserID)
operationCtx, cancel, err := store.operationContext(ctx, "create entitlement record in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil {
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, err)
}
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, payload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(record.StartsAt.UTC().UnixMicro()),
Member: record.RecordID.String(),
})
return nil
})
if err != nil {
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, err)
}
return nil
}, recordKey, historyKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("create entitlement record %q in redis: %w", record.RecordID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// GetEntitlementRecordByRecordID returns the entitlement history record
// identified by recordID.
func (store *Store) GetEntitlementRecordByRecordID(
ctx context.Context,
recordID entitlement.EntitlementRecordID,
) (entitlement.PeriodRecord, error) {
if err := recordID.Validate(); err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id from redis: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "get entitlement record by record id from redis")
if err != nil {
return entitlement.PeriodRecord{}, err
}
defer cancel()
record, err := store.loadEntitlementRecord(operationCtx, store.client, recordID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id %q from redis: %w", recordID, ports.ErrNotFound)
default:
return entitlement.PeriodRecord{}, fmt.Errorf("get entitlement record by record id %q from redis: %w", recordID, err)
}
}
return record, nil
}
// ListEntitlementRecordsByUserID returns every entitlement history record
// owned by userID.
func (store *Store) ListEntitlementRecordsByUserID(
ctx context.Context,
userID common.UserID,
) ([]entitlement.PeriodRecord, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list entitlement records by user id from redis: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list entitlement records by user id from redis")
if err != nil {
return nil, err
}
defer cancel()
recordIDs, err := store.client.ZRange(operationCtx, store.keyspace.EntitlementHistory(userID), 0, -1).Result()
if err != nil {
return nil, fmt.Errorf("list entitlement records by user id %q from redis: %w", userID, err)
}
records := make([]entitlement.PeriodRecord, 0, len(recordIDs))
for _, rawRecordID := range recordIDs {
record, err := store.loadEntitlementRecord(operationCtx, store.client, entitlement.EntitlementRecordID(rawRecordID))
if err != nil {
return nil, fmt.Errorf("list entitlement records by user id %q from redis: %w", userID, err)
}
records = append(records, record)
}
return records, nil
}
// UpdateEntitlementRecord replaces one stored entitlement history record.
func (store *Store) UpdateEntitlementRecord(ctx context.Context, record entitlement.PeriodRecord) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("update entitlement record in redis: %w", err)
}
payload, err := marshalEntitlementPeriodRecord(record)
if err != nil {
return fmt.Errorf("update entitlement record in redis: %w", err)
}
recordKey := store.keyspace.EntitlementRecord(record.RecordID)
operationCtx, cancel, err := store.operationContext(ctx, "update entitlement record in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if _, err := store.loadEntitlementRecord(operationCtx, tx, record.RecordID); err != nil {
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, err)
}
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, payload, 0)
return nil
})
if err != nil {
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, err)
}
return nil
}, recordKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("update entitlement record %q in redis: %w", record.RecordID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// GrantEntitlement atomically closes the current free history record, creates
// one paid history record, and replaces the current snapshot.
func (store *Store) GrantEntitlement(ctx context.Context, input ports.GrantEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
updatedCurrentRecordPayload, err := marshalEntitlementPeriodRecord(input.UpdatedCurrentRecord)
if err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("grant entitlement in redis: %w", err)
}
currentRecordKey := store.keyspace.EntitlementRecord(input.ExpectedCurrentRecord.RecordID)
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{currentRecordKey, newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "grant entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
storedCurrentRecord, err := store.loadEntitlementRecord(operationCtx, tx, input.ExpectedCurrentRecord.RecordID)
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementPeriodRecords(storedCurrentRecord, input.ExpectedCurrentRecord) {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, currentRecordKey, updatedCurrentRecordPayload, 0)
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("grant entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// ExtendEntitlement atomically appends one paid history segment and replaces
// the current paid snapshot.
func (store *Store) ExtendEntitlement(ctx context.Context, input ports.ExtendEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("extend entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("extend entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("extend entitlement in redis: %w", err)
}
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "extend entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
if err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("extend entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RevokeEntitlement atomically closes the current paid history record,
// creates one free history record, and replaces the current snapshot.
func (store *Store) RevokeEntitlement(ctx context.Context, input ports.RevokeEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
updatedCurrentRecordPayload, err := marshalEntitlementPeriodRecord(input.UpdatedCurrentRecord)
if err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("revoke entitlement in redis: %w", err)
}
currentRecordKey := store.keyspace.EntitlementRecord(input.ExpectedCurrentRecord.RecordID)
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{currentRecordKey, newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "revoke entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedCurrentSnapshot.UserID)
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedCurrentSnapshot) {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
storedCurrentRecord, err := store.loadEntitlementRecord(operationCtx, tx, input.ExpectedCurrentRecord.RecordID)
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
if !equalEntitlementPeriodRecords(storedCurrentRecord, input.ExpectedCurrentRecord) {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, currentRecordKey, updatedCurrentRecordPayload, 0)
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("revoke entitlement for user %q in redis: %w", input.ExpectedCurrentSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RepairExpiredEntitlement atomically replaces one expired finite paid
// snapshot with a materialized free state.
func (store *Store) RepairExpiredEntitlement(ctx context.Context, input ports.RepairExpiredEntitlementInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("repair expired entitlement in redis: %w", err)
}
newRecordPayload, err := marshalEntitlementPeriodRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("repair expired entitlement in redis: %w", err)
}
newSnapshotPayload, err := marshalEntitlementSnapshotRecord(input.NewSnapshot)
if err != nil {
return fmt.Errorf("repair expired entitlement in redis: %w", err)
}
newRecordKey := store.keyspace.EntitlementRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.EntitlementHistory(input.NewRecord.UserID)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewSnapshot.UserID)
watchedKeys := append(
[]string{newRecordKey, historyKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewSnapshot.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "repair expired entitlement in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
storedSnapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedExpiredSnapshot.UserID)
if err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
if !equalEntitlementSnapshots(storedSnapshot, input.ExpectedExpiredSnapshot) {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, ports.ErrConflict)
}
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewSnapshot.UserID)
if err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.StartsAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
pipe.Set(operationCtx, snapshotKey, newSnapshotPayload, 0)
store.syncEntitlementIndexes(pipe, operationCtx, input.NewSnapshot)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewSnapshot.UserID, input.NewSnapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("repair expired entitlement for user %q in redis: %w", input.ExpectedExpiredSnapshot.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
func (store *Store) loadEntitlementRecord(
ctx context.Context,
getter bytesGetter,
recordID entitlement.EntitlementRecordID,
) (entitlement.PeriodRecord, error) {
payload, err := getter.Get(ctx, store.keyspace.EntitlementRecord(recordID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return entitlement.PeriodRecord{}, ports.ErrNotFound
case err != nil:
return entitlement.PeriodRecord{}, err
}
return decodeEntitlementPeriodRecord(payload)
}
func marshalEntitlementPeriodRecord(record entitlement.PeriodRecord) ([]byte, error) {
encoded := entitlementPeriodRecord{
RecordID: record.RecordID.String(),
UserID: record.UserID.String(),
PlanCode: string(record.PlanCode),
Source: record.Source.String(),
ActorType: record.Actor.Type.String(),
ReasonCode: record.ReasonCode.String(),
StartsAt: record.StartsAt.UTC().Format(time.RFC3339Nano),
CreatedAt: record.CreatedAt.UTC().Format(time.RFC3339Nano),
}
if !record.Actor.ID.IsZero() {
value := record.Actor.ID.String()
encoded.ActorID = &value
}
if record.EndsAt != nil {
value := record.EndsAt.UTC().Format(time.RFC3339Nano)
encoded.EndsAt = &value
}
if record.ClosedAt != nil {
value := record.ClosedAt.UTC().Format(time.RFC3339Nano)
encoded.ClosedAt = &value
}
if !record.ClosedBy.Type.IsZero() {
value := record.ClosedBy.Type.String()
encoded.ClosedByType = &value
}
if !record.ClosedBy.ID.IsZero() {
value := record.ClosedBy.ID.String()
encoded.ClosedByID = &value
}
if !record.ClosedReasonCode.IsZero() {
value := record.ClosedReasonCode.String()
encoded.ClosedReasonCode = &value
}
return json.Marshal(encoded)
}
func decodeEntitlementPeriodRecord(payload []byte) (entitlement.PeriodRecord, error) {
var encoded entitlementPeriodRecord
if err := decodeJSONPayload(payload, &encoded); err != nil {
return entitlement.PeriodRecord{}, err
}
startsAt, err := time.Parse(time.RFC3339Nano, encoded.StartsAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record starts_at: %w", err)
}
createdAt, err := time.Parse(time.RFC3339Nano, encoded.CreatedAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record created_at: %w", err)
}
record := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID(encoded.RecordID),
UserID: common.UserID(encoded.UserID),
PlanCode: entitlement.PlanCode(encoded.PlanCode),
Source: common.Source(encoded.Source),
Actor: common.ActorRef{Type: common.ActorType(encoded.ActorType)},
ReasonCode: common.ReasonCode(encoded.ReasonCode),
StartsAt: startsAt.UTC(),
CreatedAt: createdAt.UTC(),
}
if encoded.ActorID != nil {
record.Actor.ID = common.ActorID(*encoded.ActorID)
}
if encoded.EndsAt != nil {
value, err := time.Parse(time.RFC3339Nano, *encoded.EndsAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record ends_at: %w", err)
}
value = value.UTC()
record.EndsAt = &value
}
if encoded.ClosedAt != nil {
value, err := time.Parse(time.RFC3339Nano, *encoded.ClosedAt)
if err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record closed_at: %w", err)
}
value = value.UTC()
record.ClosedAt = &value
}
if encoded.ClosedByType != nil {
record.ClosedBy.Type = common.ActorType(*encoded.ClosedByType)
}
if encoded.ClosedByID != nil {
record.ClosedBy.ID = common.ActorID(*encoded.ClosedByID)
}
if encoded.ClosedReasonCode != nil {
record.ClosedReasonCode = common.ReasonCode(*encoded.ClosedReasonCode)
}
if err := record.Validate(); err != nil {
return entitlement.PeriodRecord{}, fmt.Errorf("decode entitlement period record: %w", err)
}
return record, nil
}
func equalEntitlementSnapshots(left entitlement.CurrentSnapshot, right entitlement.CurrentSnapshot) bool {
return left.UserID == right.UserID &&
left.PlanCode == right.PlanCode &&
left.IsPaid == right.IsPaid &&
left.StartsAt.Equal(right.StartsAt) &&
equalOptionalTime(left.EndsAt, right.EndsAt) &&
left.Source == right.Source &&
left.Actor == right.Actor &&
left.ReasonCode == right.ReasonCode &&
left.UpdatedAt.Equal(right.UpdatedAt)
}
func equalEntitlementPeriodRecords(left entitlement.PeriodRecord, right entitlement.PeriodRecord) bool {
return left.RecordID == right.RecordID &&
left.UserID == right.UserID &&
left.PlanCode == right.PlanCode &&
left.Source == right.Source &&
left.Actor == right.Actor &&
left.ReasonCode == right.ReasonCode &&
left.StartsAt.Equal(right.StartsAt) &&
equalOptionalTime(left.EndsAt, right.EndsAt) &&
left.CreatedAt.Equal(right.CreatedAt) &&
equalOptionalTime(left.ClosedAt, right.ClosedAt) &&
left.ClosedBy == right.ClosedBy &&
left.ClosedReasonCode == right.ClosedReasonCode
}
func equalOptionalTime(left *time.Time, right *time.Time) bool {
switch {
case left == nil && right == nil:
return true
case left == nil || right == nil:
return false
default:
return left.Equal(*right)
}
}
// EntitlementHistoryStore adapts Store to the existing
// EntitlementHistoryStore port.
type EntitlementHistoryStore struct {
store *Store
}
// EntitlementHistory returns one adapter that exposes the entitlement-history
// store port over Store.
func (store *Store) EntitlementHistory() *EntitlementHistoryStore {
if store == nil {
return nil
}
return &EntitlementHistoryStore{store: store}
}
// Create stores one new entitlement history record.
func (adapter *EntitlementHistoryStore) Create(ctx context.Context, record entitlement.PeriodRecord) error {
return adapter.store.CreateEntitlementRecord(ctx, record)
}
// GetByRecordID returns the entitlement history record identified by recordID.
func (adapter *EntitlementHistoryStore) GetByRecordID(
ctx context.Context,
recordID entitlement.EntitlementRecordID,
) (entitlement.PeriodRecord, error) {
return adapter.store.GetEntitlementRecordByRecordID(ctx, recordID)
}
// ListByUserID returns every entitlement history record owned by userID.
func (adapter *EntitlementHistoryStore) ListByUserID(
ctx context.Context,
userID common.UserID,
) ([]entitlement.PeriodRecord, error) {
return adapter.store.ListEntitlementRecordsByUserID(ctx, userID)
}
// Update replaces one stored entitlement history record.
func (adapter *EntitlementHistoryStore) Update(ctx context.Context, record entitlement.PeriodRecord) error {
return adapter.store.UpdateEntitlementRecord(ctx, record)
}
var _ ports.EntitlementHistoryStore = (*EntitlementHistoryStore)(nil)
// EntitlementLifecycleStore adapts Store to the existing
// EntitlementLifecycleStore port.
type EntitlementLifecycleStore struct {
store *Store
}
// EntitlementLifecycle returns one adapter that exposes the atomic
// entitlement-lifecycle store port over Store.
func (store *Store) EntitlementLifecycle() *EntitlementLifecycleStore {
if store == nil {
return nil
}
return &EntitlementLifecycleStore{store: store}
}
// Grant atomically applies one free-to-paid transition.
func (adapter *EntitlementLifecycleStore) Grant(ctx context.Context, input ports.GrantEntitlementInput) error {
return adapter.store.GrantEntitlement(ctx, input)
}
// Extend atomically appends one paid extension segment and updates the current
// snapshot.
func (adapter *EntitlementLifecycleStore) Extend(ctx context.Context, input ports.ExtendEntitlementInput) error {
return adapter.store.ExtendEntitlement(ctx, input)
}
// Revoke atomically applies one paid-to-free transition.
func (adapter *EntitlementLifecycleStore) Revoke(ctx context.Context, input ports.RevokeEntitlementInput) error {
return adapter.store.RevokeEntitlement(ctx, input)
}
// RepairExpired atomically repairs one expired finite paid snapshot.
func (adapter *EntitlementLifecycleStore) RepairExpired(
ctx context.Context,
input ports.RepairExpiredEntitlementInput,
) error {
return adapter.store.RepairExpiredEntitlement(ctx, input)
}
var _ ports.EntitlementLifecycleStore = (*EntitlementLifecycleStore)(nil)
@@ -1,137 +0,0 @@
package userstore
import (
"context"
"errors"
"fmt"
"time"
"galaxy/user/internal/adapters/redisstate"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
// ListUserIDs returns one deterministic page of user identifiers ordered by
// `created_at desc`, then `user_id desc`.
func (store *Store) ListUserIDs(ctx context.Context, input ports.ListUsersInput) (ports.ListUsersResult, error) {
if err := input.Validate(); err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
operationCtx, cancel, err := store.operationContext(ctx, "list users in redis")
if err != nil {
return ports.ListUsersResult{}, err
}
defer cancel()
startIndex := int64(0)
filters := userListFiltersFromPorts(input.Filters)
if input.PageToken != "" {
cursor, err := redisstate.DecodePageToken(input.PageToken, filters)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
}
score, err := store.client.ZScore(operationCtx, store.keyspace.CreatedAtIndex(), cursor.UserID.String()).Result()
switch {
case errors.Is(err, redis.Nil):
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
case err != nil:
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
if !time.UnixMicro(int64(score)).UTC().Equal(cursor.CreatedAt.UTC()) {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
}
rank, err := store.client.ZRevRank(operationCtx, store.keyspace.CreatedAtIndex(), cursor.UserID.String()).Result()
switch {
case errors.Is(err, redis.Nil):
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", ports.ErrInvalidPageToken)
case err != nil:
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
startIndex = rank + 1
}
rawPage, err := store.client.ZRevRangeWithScores(
operationCtx,
store.keyspace.CreatedAtIndex(),
startIndex,
startIndex+int64(input.PageSize),
).Result()
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
result := ports.ListUsersResult{
UserIDs: make([]common.UserID, 0, min(len(rawPage), input.PageSize)),
}
visibleCount := min(len(rawPage), input.PageSize)
for index := 0; index < visibleCount; index++ {
userID, err := memberUserID(rawPage[index].Member)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
result.UserIDs = append(result.UserIDs, userID)
}
if len(rawPage) > input.PageSize {
lastVisible := rawPage[input.PageSize-1]
lastUserID, err := memberUserID(lastVisible.Member)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
token, err := redisstate.EncodePageToken(redisstate.PageCursor{
CreatedAt: time.UnixMicro(int64(lastVisible.Score)).UTC(),
UserID: lastUserID,
}, filters)
if err != nil {
return ports.ListUsersResult{}, fmt.Errorf("list users in redis: %w", err)
}
result.NextPageToken = token
}
return result, nil
}
func userListFiltersFromPorts(filters ports.UserListFilters) redisstate.UserListFilters {
return redisstate.UserListFilters{
PaidState: filters.PaidState,
PaidExpiresBefore: filters.PaidExpiresBefore,
PaidExpiresAfter: filters.PaidExpiresAfter,
DeclaredCountry: filters.DeclaredCountry,
SanctionCode: filters.SanctionCode,
LimitCode: filters.LimitCode,
CanLogin: filters.CanLogin,
CanCreatePrivateGame: filters.CanCreatePrivateGame,
CanJoinGame: filters.CanJoinGame,
}
}
func memberUserID(member any) (common.UserID, error) {
value, ok := member.(string)
if !ok {
return "", fmt.Errorf("unexpected created-at index member type %T", member)
}
userID := common.UserID(value)
if err := userID.Validate(); err != nil {
return "", fmt.Errorf("created-at index member user id: %w", err)
}
return userID, nil
}
func min(left int, right int) int {
if left < right {
return left
}
return right
}
var _ ports.UserListStore = (*Store)(nil)
@@ -1,445 +0,0 @@
package userstore
import (
"context"
"errors"
"fmt"
"time"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"github.com/redis/go-redis/v9"
)
// ApplySanction atomically creates one new active sanction record.
func (store *Store) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("apply sanction in redis: %w", err)
}
recordPayload, err := marshalSanctionRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("apply sanction in redis: %w", err)
}
recordKey := store.keyspace.SanctionRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.SanctionHistory(input.NewRecord.UserID)
activeKey := store.keyspace.ActiveSanction(input.NewRecord.UserID, input.NewRecord.SanctionCode)
snapshotKey := store.keyspace.EntitlementSnapshot(input.NewRecord.UserID)
watchedKeys := append(
[]string{recordKey, historyKey, activeKey, snapshotKey},
store.activeSanctionWatchKeys(input.NewRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "apply sanction in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if err := ensureKeyAbsent(operationCtx, tx, recordKey); err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.NewRecord.UserID)
if err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.NewRecord.UserID)
if err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
activeSanctionCodes[input.NewRecord.SanctionCode] = struct{}{}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, recordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeSanctionCodes)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.NewRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("apply sanction for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RemoveSanction atomically removes one active sanction record.
func (store *Store) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove sanction in redis: %w", err)
}
updatedPayload, err := marshalSanctionRecord(input.UpdatedRecord)
if err != nil {
return fmt.Errorf("remove sanction in redis: %w", err)
}
recordKey := store.keyspace.SanctionRecord(input.ExpectedActiveRecord.RecordID)
activeKey := store.keyspace.ActiveSanction(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.SanctionCode)
snapshotKey := store.keyspace.EntitlementSnapshot(input.ExpectedActiveRecord.UserID)
watchedKeys := append(
[]string{recordKey, activeKey, snapshotKey},
store.activeSanctionWatchKeys(input.ExpectedActiveRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "remove sanction in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
activeRecordID, err := store.loadActiveSanctionRecordID(operationCtx, tx, activeKey)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if activeRecordID != input.ExpectedActiveRecord.RecordID {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
storedRecord, err := store.loadSanctionRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if !equalSanctionRecords(storedRecord, input.ExpectedActiveRecord) {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
snapshot, err := store.loadEntitlementSnapshot(operationCtx, tx, input.ExpectedActiveRecord.UserID)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
activeSanctionCodes, err := store.loadActiveSanctionCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
delete(activeSanctionCodes, input.ExpectedActiveRecord.SanctionCode)
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
pipe.Del(operationCtx, activeKey)
store.syncActiveSanctionCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeSanctionCodes)
store.syncEligibilityMarkerIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, snapshot.IsPaid, activeSanctionCodes)
return nil
})
if err != nil {
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("remove sanction for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// SetLimit atomically creates or replaces one active limit record.
func (store *Store) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("set limit in redis: %w", err)
}
newRecordPayload, err := marshalLimitRecord(input.NewRecord)
if err != nil {
return fmt.Errorf("set limit in redis: %w", err)
}
newRecordKey := store.keyspace.LimitRecord(input.NewRecord.RecordID)
historyKey := store.keyspace.LimitHistory(input.NewRecord.UserID)
activeKey := store.keyspace.ActiveLimit(input.NewRecord.UserID, input.NewRecord.LimitCode)
watchedKeys := append(
[]string{newRecordKey, historyKey, activeKey},
store.activeLimitWatchKeys(input.NewRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "set limit in redis")
if err != nil {
return err
}
defer cancel()
if input.ExpectedActiveRecord != nil {
watchedKeys = append(watchedKeys, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID))
}
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
if err := ensureKeyAbsent(operationCtx, tx, newRecordKey); err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
var updatedPayload []byte
if input.ExpectedActiveRecord == nil {
if err := ensureKeyAbsent(operationCtx, tx, activeKey); err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
} else {
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
if activeRecordID != input.ExpectedActiveRecord.RecordID {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
}
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
if !equalLimitRecords(storedRecord, *input.ExpectedActiveRecord) {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
}
updatedPayload, err = marshalLimitRecord(*input.UpdatedActiveRecord)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
}
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.NewRecord.UserID)
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
activeLimitCodes[input.NewRecord.LimitCode] = struct{}{}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
if input.ExpectedActiveRecord != nil {
pipe.Set(operationCtx, store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID), updatedPayload, 0)
}
pipe.Set(operationCtx, newRecordKey, newRecordPayload, 0)
pipe.ZAdd(operationCtx, historyKey, redis.Z{
Score: float64(input.NewRecord.AppliedAt.UTC().UnixMicro()),
Member: input.NewRecord.RecordID.String(),
})
setActiveSlot(pipe, operationCtx, activeKey, input.NewRecord.RecordID.String(), input.NewRecord.ExpiresAt)
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.NewRecord.UserID, activeLimitCodes)
return nil
})
if err != nil {
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("set limit for user %q in redis: %w", input.NewRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// RemoveLimit atomically removes one active limit record.
func (store *Store) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
if err := input.Validate(); err != nil {
return fmt.Errorf("remove limit in redis: %w", err)
}
updatedPayload, err := marshalLimitRecord(input.UpdatedRecord)
if err != nil {
return fmt.Errorf("remove limit in redis: %w", err)
}
recordKey := store.keyspace.LimitRecord(input.ExpectedActiveRecord.RecordID)
activeKey := store.keyspace.ActiveLimit(input.ExpectedActiveRecord.UserID, input.ExpectedActiveRecord.LimitCode)
watchedKeys := append(
[]string{recordKey, activeKey},
store.activeLimitWatchKeys(input.ExpectedActiveRecord.UserID)...,
)
operationCtx, cancel, err := store.operationContext(ctx, "remove limit in redis")
if err != nil {
return err
}
defer cancel()
watchErr := store.client.Watch(operationCtx, func(tx *redis.Tx) error {
activeRecordID, err := store.loadActiveLimitRecordID(operationCtx, tx, activeKey)
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if activeRecordID != input.ExpectedActiveRecord.RecordID {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
storedRecord, err := store.loadLimitRecord(operationCtx, tx, input.ExpectedActiveRecord.RecordID)
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
if !equalLimitRecords(storedRecord, input.ExpectedActiveRecord) {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
}
activeLimitCodes, err := store.loadActiveLimitCodeSet(operationCtx, tx, input.ExpectedActiveRecord.UserID)
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
delete(activeLimitCodes, input.ExpectedActiveRecord.LimitCode)
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, recordKey, updatedPayload, 0)
pipe.Del(operationCtx, activeKey)
store.syncActiveLimitCodeIndexes(pipe, operationCtx, input.ExpectedActiveRecord.UserID, activeLimitCodes)
return nil
})
if err != nil {
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, err)
}
return nil
}, watchedKeys...)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("remove limit for user %q in redis: %w", input.ExpectedActiveRecord.UserID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
func (store *Store) loadActiveSanctionRecordID(
ctx context.Context,
getter bytesGetter,
key string,
) (policy.SanctionRecordID, error) {
value, err := getter.Get(ctx, key).Result()
switch {
case errors.Is(err, redis.Nil):
return "", ports.ErrNotFound
case err != nil:
return "", err
}
recordID := policy.SanctionRecordID(value)
if err := recordID.Validate(); err != nil {
return "", fmt.Errorf("active sanction record id: %w", err)
}
return recordID, nil
}
func (store *Store) loadActiveLimitRecordID(
ctx context.Context,
getter bytesGetter,
key string,
) (policy.LimitRecordID, error) {
value, err := getter.Get(ctx, key).Result()
switch {
case errors.Is(err, redis.Nil):
return "", ports.ErrNotFound
case err != nil:
return "", err
}
recordID := policy.LimitRecordID(value)
if err := recordID.Validate(); err != nil {
return "", fmt.Errorf("active limit record id: %w", err)
}
return recordID, nil
}
func setActiveSlot(
pipe redis.Pipeliner,
ctx context.Context,
key string,
recordID string,
expiresAt *time.Time,
) {
pipe.Set(ctx, key, recordID, 0)
if expiresAt != nil {
pipe.PExpireAt(ctx, key, expiresAt.UTC())
}
}
func equalSanctionRecords(left policy.SanctionRecord, right policy.SanctionRecord) bool {
return left.RecordID == right.RecordID &&
left.UserID == right.UserID &&
left.SanctionCode == right.SanctionCode &&
left.Scope == right.Scope &&
left.ReasonCode == right.ReasonCode &&
left.Actor == right.Actor &&
left.AppliedAt.Equal(right.AppliedAt) &&
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
left.RemovedBy == right.RemovedBy &&
left.RemovedReasonCode == right.RemovedReasonCode
}
func equalLimitRecords(left policy.LimitRecord, right policy.LimitRecord) bool {
return left.RecordID == right.RecordID &&
left.UserID == right.UserID &&
left.LimitCode == right.LimitCode &&
left.Value == right.Value &&
left.ReasonCode == right.ReasonCode &&
left.Actor == right.Actor &&
left.AppliedAt.Equal(right.AppliedAt) &&
equalOptionalTime(left.ExpiresAt, right.ExpiresAt) &&
equalOptionalTime(left.RemovedAt, right.RemovedAt) &&
left.RemovedBy == right.RemovedBy &&
left.RemovedReasonCode == right.RemovedReasonCode
}
// PolicyLifecycleStore adapts Store to the existing PolicyLifecycleStore
// port.
type PolicyLifecycleStore struct {
store *Store
}
// PolicyLifecycle returns one adapter that exposes the atomic policy-lifecycle
// store port over Store.
func (store *Store) PolicyLifecycle() *PolicyLifecycleStore {
if store == nil {
return nil
}
return &PolicyLifecycleStore{store: store}
}
// ApplySanction atomically creates one new active sanction record.
func (adapter *PolicyLifecycleStore) ApplySanction(ctx context.Context, input ports.ApplySanctionInput) error {
return adapter.store.ApplySanction(ctx, input)
}
// RemoveSanction atomically removes one active sanction record.
func (adapter *PolicyLifecycleStore) RemoveSanction(ctx context.Context, input ports.RemoveSanctionInput) error {
return adapter.store.RemoveSanction(ctx, input)
}
// SetLimit atomically creates or replaces one active limit record.
func (adapter *PolicyLifecycleStore) SetLimit(ctx context.Context, input ports.SetLimitInput) error {
return adapter.store.SetLimit(ctx, input)
}
// RemoveLimit atomically removes one active limit record.
func (adapter *PolicyLifecycleStore) RemoveLimit(ctx context.Context, input ports.RemoveLimitInput) error {
return adapter.store.RemoveLimit(ctx, input)
}
var _ ports.PolicyLifecycleStore = (*PolicyLifecycleStore)(nil)
File diff suppressed because it is too large Load Diff
@@ -1,879 +0,0 @@
package userstore
import (
"context"
"testing"
"time"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/authblock"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/require"
)
func TestAccountStoreCreateAndLookups(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.Equal(t, record, byUserID)
byEmail, err := accountStore.GetByEmail(context.Background(), record.Email)
require.NoError(t, err)
require.Equal(t, record, byEmail)
byUserName, err := accountStore.GetByUserName(context.Background(), record.UserName)
require.NoError(t, err)
require.Equal(t, record, byUserName)
exists, err := accountStore.ExistsByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.True(t, exists)
}
func TestBlockedEmailStoreUpsertAndGet(t *testing.T) {
t.Parallel()
store := newTestStore(t)
blockedEmailStore := store.BlockedEmails()
record := authblock.BlockedEmailSubject{
Email: common.Email("blocked@example.com"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: time.Unix(1_775_240_100, 0).UTC(),
ResolvedUserID: common.UserID("user-123"),
}
require.NoError(t, blockedEmailStore.Upsert(context.Background(), record))
got, err := blockedEmailStore.GetByEmail(context.Background(), record.Email)
require.NoError(t, err)
require.Equal(t, record, got)
}
func TestEnsureResolveAndBlockFlows(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
accountRecord := validAccountRecord()
entitlementSnapshot := validEntitlementSnapshot(accountRecord.UserID, now)
created, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: accountRecord.Email,
Account: accountRecord,
Entitlement: entitlementSnapshot,
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeCreated, created.Outcome)
byUserName, err := store.GetByUserName(context.Background(), accountRecord.UserName)
require.NoError(t, err)
require.Equal(t, accountRecord.UserID, byUserName.UserID)
entitlementHistory, err := store.ListEntitlementRecordsByUserID(context.Background(), accountRecord.UserID)
require.NoError(t, err)
require.Len(t, entitlementHistory, 1)
require.Equal(t, validEntitlementRecord(accountRecord.UserID, now), entitlementHistory[0])
resolved, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindExisting, resolved.Kind)
blockedByUserID, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: accountRecord.UserID,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now.Add(time.Minute),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, blockedByUserID.Outcome)
repeatedBlock, err := store.BlockByEmail(context.Background(), ports.BlockByEmailInput{
Email: accountRecord.Email,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now.Add(2 * time.Minute),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeAlreadyBlocked, repeatedBlock.Outcome)
require.Equal(t, accountRecord.UserID, repeatedBlock.UserID)
blockedResolution, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindBlocked, blockedResolution.Kind)
ensureBlocked, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: accountRecord.Email,
Account: accountRecord,
Entitlement: entitlementSnapshot,
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, ensureBlocked.Outcome)
}
func TestBlockedEmailWithoutUserPreventsEnsureCreate(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
accountRecord := validAccountRecord()
entitlementSnapshot := validEntitlementSnapshot(accountRecord.UserID, now)
blocked, err := store.BlockByEmail(context.Background(), ports.BlockByEmailInput{
Email: accountRecord.Email,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now,
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, blocked.Outcome)
require.True(t, blocked.UserID.IsZero())
resolved, err := store.ResolveByEmail(context.Background(), accountRecord.Email)
require.NoError(t, err)
require.Equal(t, ports.AuthResolutionKindBlocked, resolved.Kind)
ensured, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: accountRecord.Email,
Account: accountRecord,
Entitlement: entitlementSnapshot,
EntitlementRecord: validEntitlementRecord(accountRecord.UserID, now),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeBlocked, ensured.Outcome)
exists, err := store.ExistsByUserID(context.Background(), accountRecord.UserID)
require.NoError(t, err)
require.False(t, exists)
}
func TestEnsureByEmailExistingDoesNotOverwriteStoredSettings(t *testing.T) {
t.Parallel()
store := newTestStore(t)
createdAt := time.Unix(1_775_240_000, 0).UTC()
existingAccount := account.UserAccount{
UserID: common.UserID("user-existing"),
Email: common.Email("pilot@example.com"),
UserName: common.UserName("player-abcdefgh"),
PreferredLanguage: common.LanguageTag("en"),
TimeZone: common.TimeZoneName("Europe/Kaliningrad"),
CreatedAt: createdAt,
UpdatedAt: createdAt,
}
require.NoError(t, store.Create(context.Background(), createAccountInput(existingAccount)))
result, err := store.EnsureByEmail(context.Background(), ports.EnsureByEmailInput{
Email: existingAccount.Email,
Account: account.UserAccount{
UserID: common.UserID("user-created"),
Email: existingAccount.Email,
UserName: common.UserName("player-newabcde"),
PreferredLanguage: common.LanguageTag("fr-FR"),
TimeZone: common.TimeZoneName("UTC"),
CreatedAt: createdAt.Add(time.Minute),
UpdatedAt: createdAt.Add(time.Minute),
},
Entitlement: validEntitlementSnapshot(common.UserID("user-created"), createdAt.Add(time.Minute)),
EntitlementRecord: validEntitlementRecord(common.UserID("user-created"), createdAt.Add(time.Minute)),
})
require.NoError(t, err)
require.Equal(t, ports.EnsureByEmailOutcomeExisting, result.Outcome)
require.Equal(t, existingAccount.UserID, result.UserID)
storedAccount, err := store.GetByEmail(context.Background(), existingAccount.Email)
require.NoError(t, err)
require.Equal(t, existingAccount, storedAccount)
}
func TestAccountStoreUpdateDisplayNamePreservesImmutableFields(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
updated := record
updated.DisplayName = common.DisplayName("NovaPrime")
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
require.NoError(t, accountStore.Update(context.Background(), updated))
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.Equal(t, updated, byUserID)
byEmail, err := accountStore.GetByEmail(context.Background(), record.Email)
require.NoError(t, err)
require.Equal(t, updated, byEmail)
byUserName, err := accountStore.GetByUserName(context.Background(), record.UserName)
require.NoError(t, err)
require.Equal(t, updated, byUserName)
}
func TestAccountStoreUpdateRejectsUserNameMutation(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
attempted := record
attempted.UserName = common.UserName("player-changed")
attempted.UpdatedAt = record.UpdatedAt.Add(time.Minute)
err := accountStore.Update(context.Background(), attempted)
require.ErrorIs(t, err, ports.ErrConflict)
}
func TestAccountStoreUpdateDeclaredCountryPreservesLookups(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
updated := record
updated.DeclaredCountry = common.CountryCode("FR")
updated.UpdatedAt = record.UpdatedAt.Add(time.Minute)
require.NoError(t, accountStore.Update(context.Background(), updated))
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.Equal(t, updated, byUserID)
byEmail, err := accountStore.GetByEmail(context.Background(), record.Email)
require.NoError(t, err)
require.Equal(t, updated, byEmail)
byUserName, err := accountStore.GetByUserName(context.Background(), record.UserName)
require.NoError(t, err)
require.Equal(t, updated, byUserName)
}
func TestAccountStorePersistsSoftDeleteMarker(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
record := validAccountRecord()
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(record)))
deletedAt := record.UpdatedAt.Add(time.Hour)
updated := record
updated.UpdatedAt = deletedAt
updated.DeletedAt = &deletedAt
require.NoError(t, accountStore.Update(context.Background(), updated))
byUserID, err := accountStore.GetByUserID(context.Background(), record.UserID)
require.NoError(t, err)
require.NotNil(t, byUserID.DeletedAt)
require.True(t, byUserID.DeletedAt.Equal(deletedAt))
require.True(t, byUserID.IsDeleted())
}
func TestAccountStoreCreateReturnsUserNameConflict(t *testing.T) {
t.Parallel()
store := newTestStore(t)
accountStore := store.Accounts()
first := validAccountRecord()
second := validAccountRecord()
second.UserID = common.UserID("user-456")
second.Email = common.Email("other@example.com")
require.NoError(t, accountStore.Create(context.Background(), createAccountInput(first)))
err := accountStore.Create(context.Background(), createAccountInput(second))
require.ErrorIs(t, err, ports.ErrUserNameConflict)
}
func TestBlockByUserIDRepeatedCallsStayIdempotent(t *testing.T) {
t.Parallel()
store := newTestStore(t)
now := time.Unix(1_775_240_000, 0).UTC()
accountRecord := validAccountRecord()
require.NoError(t, store.Create(context.Background(), createAccountInput(accountRecord)))
first, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: accountRecord.UserID,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now,
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeBlocked, first.Outcome)
second, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: accountRecord.UserID,
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: now.Add(time.Minute),
})
require.NoError(t, err)
require.Equal(t, ports.AuthBlockOutcomeAlreadyBlocked, second.Outcome)
require.Equal(t, accountRecord.UserID, second.UserID)
}
func TestBlockByUserIDUnknownUserReturnsNotFound(t *testing.T) {
t.Parallel()
store := newTestStore(t)
_, err := store.BlockByUserID(context.Background(), ports.BlockByUserIDInput{
UserID: common.UserID("user-missing"),
ReasonCode: common.ReasonCode("policy_blocked"),
BlockedAt: time.Unix(1_775_240_000, 0).UTC(),
})
require.ErrorIs(t, err, ports.ErrNotFound)
}
func TestSanctionAndLimitStoresRoundTrip(t *testing.T) {
t.Parallel()
store := newTestStore(t)
sanctionStore := store.Sanctions()
limitStore := store.Limits()
now := time.Unix(1_775_240_000, 0).UTC()
sanctionRecord := policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-1"),
UserID: common.UserID("user-123"),
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("self_service"),
ReasonCode: common.ReasonCode("policy_enforced"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
AppliedAt: now,
}
require.NoError(t, sanctionStore.Create(context.Background(), sanctionRecord))
gotSanction, err := sanctionStore.GetByRecordID(context.Background(), sanctionRecord.RecordID)
require.NoError(t, err)
require.Equal(t, sanctionRecord, gotSanction)
sanctions, err := sanctionStore.ListByUserID(context.Background(), sanctionRecord.UserID)
require.NoError(t, err)
require.Len(t, sanctions, 1)
expiresAt := now.Add(time.Hour)
sanctionRecord.ExpiresAt = &expiresAt
require.NoError(t, sanctionStore.Update(context.Background(), sanctionRecord))
gotSanction, err = sanctionStore.GetByRecordID(context.Background(), sanctionRecord.RecordID)
require.NoError(t, err)
require.Equal(t, sanctionRecord.RecordID, gotSanction.RecordID)
require.Equal(t, sanctionRecord.UserID, gotSanction.UserID)
require.Equal(t, sanctionRecord.SanctionCode, gotSanction.SanctionCode)
require.Equal(t, sanctionRecord.Scope, gotSanction.Scope)
require.Equal(t, sanctionRecord.ReasonCode, gotSanction.ReasonCode)
require.Equal(t, sanctionRecord.Actor, gotSanction.Actor)
require.True(t, gotSanction.AppliedAt.Equal(sanctionRecord.AppliedAt))
require.NotNil(t, gotSanction.ExpiresAt)
require.True(t, gotSanction.ExpiresAt.Equal(*sanctionRecord.ExpiresAt))
limitRecord := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-1"),
UserID: common.UserID("user-123"),
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 3,
ReasonCode: common.ReasonCode("policy_enforced"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
AppliedAt: now,
}
require.NoError(t, limitStore.Create(context.Background(), limitRecord))
gotLimit, err := limitStore.GetByRecordID(context.Background(), limitRecord.RecordID)
require.NoError(t, err)
require.Equal(t, limitRecord, gotLimit)
limits, err := limitStore.ListByUserID(context.Background(), limitRecord.UserID)
require.NoError(t, err)
require.Len(t, limits, 1)
limitRecord.Value = 5
require.NoError(t, limitStore.Update(context.Background(), limitRecord))
gotLimit, err = limitStore.GetByRecordID(context.Background(), limitRecord.RecordID)
require.NoError(t, err)
require.Equal(t, limitRecord, gotLimit)
}
func TestPolicyLifecycleApplyAndRemoveSanction(t *testing.T) {
t.Parallel()
store := newTestStore(t)
lifecycleStore := store.PolicyLifecycle()
sanctionStore := store.Sanctions()
snapshotStore := store.EntitlementSnapshots()
now := time.Unix(1_775_240_000, 0).UTC()
userID := common.UserID("user-123")
require.NoError(t, snapshotStore.Put(context.Background(), validEntitlementSnapshot(userID, now)))
record := policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-1"),
UserID: userID,
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("auth"),
ReasonCode: common.ReasonCode("manual_block"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now,
}
require.NoError(t, lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
NewRecord: record,
}))
activeRecordID, err := store.loadActiveSanctionRecordID(
context.Background(),
store.client,
store.keyspace.ActiveSanction(userID, policy.SanctionCodeLoginBlock),
)
require.NoError(t, err)
require.Equal(t, record.RecordID, activeRecordID)
err = lifecycleStore.ApplySanction(context.Background(), ports.ApplySanctionInput{
NewRecord: policy.SanctionRecord{
RecordID: policy.SanctionRecordID("sanction-2"),
UserID: userID,
SanctionCode: policy.SanctionCodeLoginBlock,
Scope: common.Scope("auth"),
ReasonCode: common.ReasonCode("manual_block"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")},
AppliedAt: now.Add(time.Minute),
},
})
require.ErrorIs(t, err, ports.ErrConflict)
removed := record
removedAt := now.Add(30 * time.Minute)
removed.RemovedAt = &removedAt
removed.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")}
removed.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveSanction(context.Background(), ports.RemoveSanctionInput{
ExpectedActiveRecord: record,
UpdatedRecord: removed,
}))
stored, err := sanctionStore.GetByRecordID(context.Background(), record.RecordID)
require.NoError(t, err)
require.Equal(t, removed, stored)
_, err = store.loadActiveSanctionRecordID(
context.Background(),
store.client,
store.keyspace.ActiveSanction(userID, policy.SanctionCodeLoginBlock),
)
require.ErrorIs(t, err, ports.ErrNotFound)
}
func TestPolicyLifecycleSetAndRemoveLimit(t *testing.T) {
t.Parallel()
store := newTestStore(t)
lifecycleStore := store.PolicyLifecycle()
limitStore := store.Limits()
now := time.Unix(1_775_240_000, 0).UTC()
userID := common.UserID("user-123")
first := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-1"),
UserID: userID,
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 3,
ReasonCode: common.ReasonCode("manual_override"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
AppliedAt: now,
}
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
NewRecord: first,
}))
activeRecordID, err := store.loadActiveLimitRecordID(
context.Background(),
store.client,
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
)
require.NoError(t, err)
require.Equal(t, first.RecordID, activeRecordID)
second := policy.LimitRecord{
RecordID: policy.LimitRecordID("limit-2"),
UserID: userID,
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
Value: 5,
ReasonCode: common.ReasonCode("manual_override"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-2")},
AppliedAt: now.Add(time.Hour),
}
updatedFirst := first
removedAt := second.AppliedAt
updatedFirst.RemovedAt = &removedAt
updatedFirst.RemovedBy = second.Actor
updatedFirst.RemovedReasonCode = second.ReasonCode
require.NoError(t, lifecycleStore.SetLimit(context.Background(), ports.SetLimitInput{
ExpectedActiveRecord: &first,
UpdatedActiveRecord: &updatedFirst,
NewRecord: second,
}))
storedFirst, err := limitStore.GetByRecordID(context.Background(), first.RecordID)
require.NoError(t, err)
require.Equal(t, updatedFirst, storedFirst)
activeRecordID, err = store.loadActiveLimitRecordID(
context.Background(),
store.client,
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
)
require.NoError(t, err)
require.Equal(t, second.RecordID, activeRecordID)
removedSecond := second
removeAt := now.Add(90 * time.Minute)
removedSecond.RemovedAt = &removeAt
removedSecond.RemovedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-3")}
removedSecond.RemovedReasonCode = common.ReasonCode("manual_remove")
require.NoError(t, lifecycleStore.RemoveLimit(context.Background(), ports.RemoveLimitInput{
ExpectedActiveRecord: second,
UpdatedRecord: removedSecond,
}))
storedSecond, err := limitStore.GetByRecordID(context.Background(), second.RecordID)
require.NoError(t, err)
require.Equal(t, removedSecond, storedSecond)
_, err = store.loadActiveLimitRecordID(
context.Background(),
store.client,
store.keyspace.ActiveLimit(userID, policy.LimitCodeMaxOwnedPrivateGames),
)
require.ErrorIs(t, err, ports.ErrNotFound)
}
func TestEntitlementLifecycleTransitions(t *testing.T) {
t.Parallel()
store := newTestStore(t)
historyStore := store.EntitlementHistory()
snapshotStore := store.EntitlementSnapshots()
lifecycleStore := store.EntitlementLifecycle()
userID := common.UserID("user-123")
startedFreeAt := time.Unix(1_775_240_000, 0).UTC()
freeRecord := validEntitlementRecord(userID, startedFreeAt)
freeSnapshot := validEntitlementSnapshot(userID, startedFreeAt)
require.NoError(t, historyStore.Create(context.Background(), freeRecord))
require.NoError(t, snapshotStore.Put(context.Background(), freeSnapshot))
grantStartsAt := startedFreeAt.Add(24 * time.Hour)
grantEndsAt := grantStartsAt.Add(30 * 24 * time.Hour)
grantedRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-1"),
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
grantedSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
grantEndsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
closedFreeRecord := freeRecord
closedFreeRecord.ClosedAt = timePointer(grantStartsAt)
closedFreeRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
closedFreeRecord.ClosedReasonCode = common.ReasonCode("manual_grant")
require.NoError(t, lifecycleStore.Grant(context.Background(), ports.GrantEntitlementInput{
ExpectedCurrentSnapshot: freeSnapshot,
ExpectedCurrentRecord: freeRecord,
UpdatedCurrentRecord: closedFreeRecord,
NewRecord: grantedRecord,
NewSnapshot: grantedSnapshot,
}))
storedSnapshot, err := snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, grantedSnapshot, storedSnapshot)
storedFreeRecord, err := historyStore.GetByRecordID(context.Background(), freeRecord.RecordID)
require.NoError(t, err)
require.Equal(t, closedFreeRecord, storedFreeRecord)
extendedEndsAt := grantEndsAt.Add(30 * 24 * time.Hour)
extensionRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-2"),
userID,
entitlement.PlanCodePaidMonthly,
grantEndsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
extendedSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
grantStartsAt,
extendedEndsAt,
common.Source("admin"),
common.ReasonCode("manual_extend"),
)
require.NoError(t, lifecycleStore.Extend(context.Background(), ports.ExtendEntitlementInput{
ExpectedCurrentSnapshot: grantedSnapshot,
NewRecord: extensionRecord,
NewSnapshot: extendedSnapshot,
}))
storedSnapshot, err = snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, extendedSnapshot, storedSnapshot)
revokeAt := grantEndsAt.Add(12 * time.Hour)
revokedCurrentRecord := extensionRecord
revokedCurrentRecord.ClosedAt = timePointer(revokeAt)
revokedCurrentRecord.ClosedBy = common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")}
revokedCurrentRecord.ClosedReasonCode = common.ReasonCode("manual_revoke")
freeAfterRevokeRecord := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-free-2"),
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
StartsAt: revokeAt,
CreatedAt: revokeAt,
}
freeAfterRevokeSnapshot := entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: revokeAt,
Source: common.Source("admin"),
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: common.ReasonCode("manual_revoke"),
UpdatedAt: revokeAt,
}
require.NoError(t, lifecycleStore.Revoke(context.Background(), ports.RevokeEntitlementInput{
ExpectedCurrentSnapshot: extendedSnapshot,
ExpectedCurrentRecord: extensionRecord,
UpdatedCurrentRecord: revokedCurrentRecord,
NewRecord: freeAfterRevokeRecord,
NewSnapshot: freeAfterRevokeSnapshot,
}))
storedSnapshot, err = snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, freeAfterRevokeSnapshot, storedSnapshot)
historyRecords, err := historyStore.ListByUserID(context.Background(), userID)
require.NoError(t, err)
require.Len(t, historyRecords, 4)
}
func TestRepairExpiredEntitlementMaterializesFreeSnapshot(t *testing.T) {
t.Parallel()
store := newTestStore(t)
historyStore := store.EntitlementHistory()
snapshotStore := store.EntitlementSnapshots()
lifecycleStore := store.EntitlementLifecycle()
userID := common.UserID("user-123")
startsAt := time.Unix(1_775_240_000, 0).UTC()
endsAt := startsAt.Add(24 * time.Hour)
expiredSnapshot := paidEntitlementSnapshot(
userID,
entitlement.PlanCodePaidMonthly,
startsAt,
endsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
expiredSnapshot.UpdatedAt = endsAt.Add(24 * time.Hour)
expiredRecord := paidEntitlementRecord(
entitlement.EntitlementRecordID("entitlement-paid-1"),
userID,
entitlement.PlanCodePaidMonthly,
startsAt,
endsAt,
common.Source("admin"),
common.ReasonCode("manual_grant"),
)
require.NoError(t, historyStore.Create(context.Background(), expiredRecord))
require.NoError(t, snapshotStore.Put(context.Background(), expiredSnapshot))
repairedAt := endsAt.Add(2 * time.Hour)
freeRecord := entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-free-after-expiry"),
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("entitlement_expiry_repair"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("paid_entitlement_expired"),
StartsAt: endsAt,
CreatedAt: repairedAt,
}
freeSnapshot := entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: endsAt,
Source: common.Source("entitlement_expiry_repair"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("paid_entitlement_expired"),
UpdatedAt: repairedAt,
}
require.NoError(t, lifecycleStore.RepairExpired(context.Background(), ports.RepairExpiredEntitlementInput{
ExpectedExpiredSnapshot: expiredSnapshot,
NewRecord: freeRecord,
NewSnapshot: freeSnapshot,
}))
storedSnapshot, err := snapshotStore.GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, freeSnapshot, storedSnapshot)
historyRecords, err := historyStore.ListByUserID(context.Background(), userID)
require.NoError(t, err)
require.Len(t, historyRecords, 2)
require.Equal(t, freeRecord, historyRecords[1])
}
func newTestStore(t *testing.T) *Store {
t.Helper()
server := miniredis.RunT(t)
store, err := New(Config{
Addr: server.Addr(),
DB: 0,
KeyspacePrefix: "user:test:",
OperationTimeout: 250 * time.Millisecond,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = store.Close()
})
return store
}
func validAccountRecord() account.UserAccount {
createdAt := time.Unix(1_775_240_000, 0).UTC()
return account.UserAccount{
UserID: common.UserID("user-123"),
Email: common.Email("pilot@example.com"),
UserName: common.UserName("player-abcdefgh"),
PreferredLanguage: common.LanguageTag("en"),
TimeZone: common.TimeZoneName("Europe/Kaliningrad"),
CreatedAt: createdAt,
UpdatedAt: createdAt,
}
}
func validEntitlementSnapshot(userID common.UserID, now time.Time) entitlement.CurrentSnapshot {
return entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
IsPaid: false,
StartsAt: now,
Source: common.Source("auth_registration"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("initial_free_entitlement"),
UpdatedAt: now,
}
}
func validEntitlementRecord(userID common.UserID, now time.Time) entitlement.PeriodRecord {
return entitlement.PeriodRecord{
RecordID: entitlement.EntitlementRecordID("entitlement-" + userID.String()),
UserID: userID,
PlanCode: entitlement.PlanCodeFree,
Source: common.Source("auth_registration"),
Actor: common.ActorRef{Type: common.ActorType("service"), ID: common.ActorID("user-service")},
ReasonCode: common.ReasonCode("initial_free_entitlement"),
StartsAt: now,
CreatedAt: now,
}
}
func paidEntitlementRecord(
recordID entitlement.EntitlementRecordID,
userID common.UserID,
planCode entitlement.PlanCode,
startsAt time.Time,
endsAt time.Time,
source common.Source,
reasonCode common.ReasonCode,
) entitlement.PeriodRecord {
return entitlement.PeriodRecord{
RecordID: recordID,
UserID: userID,
PlanCode: planCode,
Source: source,
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: reasonCode,
StartsAt: startsAt,
EndsAt: timePointer(endsAt),
CreatedAt: startsAt,
}
}
func paidEntitlementSnapshot(
userID common.UserID,
planCode entitlement.PlanCode,
startsAt time.Time,
endsAt time.Time,
source common.Source,
reasonCode common.ReasonCode,
) entitlement.CurrentSnapshot {
return entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: planCode,
IsPaid: true,
StartsAt: startsAt,
EndsAt: timePointer(endsAt),
Source: source,
Actor: common.ActorRef{Type: common.ActorType("admin"), ID: common.ActorID("admin-1")},
ReasonCode: reasonCode,
UpdatedAt: startsAt,
}
}
func timePointer(value time.Time) *time.Time {
utcValue := value.UTC()
return &utcValue
}
func createAccountInput(record account.UserAccount) ports.CreateAccountInput {
return ports.CreateAccountInput{
Account: record,
}
}
@@ -1,193 +0,0 @@
// Package redisstate defines the frozen Redis logical keyspace and pagination
// helpers used by future User Service storage adapters.
package redisstate
import (
"encoding/base64"
"fmt"
"strings"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
)
const defaultPrefix = "user:"
// Keyspace builds the frozen Redis logical keys used by future storage
// adapters. The package intentionally exposes key construction only and does
// not depend on any Redis client.
type Keyspace struct {
// Prefix stores the namespace prefix applied to every key. The zero value
// uses `user:`.
Prefix string
}
// Account returns the primary user-account key for userID.
func (k Keyspace) Account(userID common.UserID) string {
return k.prefix() + "account:" + encodeKeyComponent(userID.String())
}
// EmailLookup returns the exact normalized e-mail lookup key.
func (k Keyspace) EmailLookup(email common.Email) string {
return k.prefix() + "lookup:email:" + encodeKeyComponent(email.String())
}
// UserNameLookup returns the exact stored user-name lookup key.
func (k Keyspace) UserNameLookup(userName common.UserName) string {
return k.prefix() + "lookup:user-name:" + encodeKeyComponent(userName.String())
}
// BlockedEmailSubject returns the dedicated blocked-email-subject key.
func (k Keyspace) BlockedEmailSubject(email common.Email) string {
return k.prefix() + "blocked-email:" + encodeKeyComponent(email.String())
}
// EntitlementRecord returns the primary entitlement history-record key.
func (k Keyspace) EntitlementRecord(recordID entitlement.EntitlementRecordID) string {
return k.prefix() + "entitlement:record:" + encodeKeyComponent(recordID.String())
}
// EntitlementHistory returns the per-user entitlement-history index key.
func (k Keyspace) EntitlementHistory(userID common.UserID) string {
return k.prefix() + "entitlement:history:" + encodeKeyComponent(userID.String())
}
// EntitlementSnapshot returns the current entitlement-snapshot key.
func (k Keyspace) EntitlementSnapshot(userID common.UserID) string {
return k.prefix() + "entitlement:snapshot:" + encodeKeyComponent(userID.String())
}
// SanctionRecord returns the primary sanction history-record key.
func (k Keyspace) SanctionRecord(recordID policy.SanctionRecordID) string {
return k.prefix() + "sanction:record:" + encodeKeyComponent(recordID.String())
}
// SanctionHistory returns the per-user sanction-history index key.
func (k Keyspace) SanctionHistory(userID common.UserID) string {
return k.prefix() + "sanction:history:" + encodeKeyComponent(userID.String())
}
// ActiveSanction returns the per-user active-sanction slot for one sanction
// code. The slot guarantees at most one active sanction per `user_id +
// sanction_code`.
func (k Keyspace) ActiveSanction(userID common.UserID, code policy.SanctionCode) string {
return k.prefix() + "sanction:active:" + encodeKeyComponent(userID.String()) + ":" + encodeKeyComponent(string(code))
}
// LimitRecord returns the primary limit history-record key.
func (k Keyspace) LimitRecord(recordID policy.LimitRecordID) string {
return k.prefix() + "limit:record:" + encodeKeyComponent(recordID.String())
}
// LimitHistory returns the per-user limit-history index key.
func (k Keyspace) LimitHistory(userID common.UserID) string {
return k.prefix() + "limit:history:" + encodeKeyComponent(userID.String())
}
// ActiveLimit returns the per-user active-limit slot for one limit code. The
// slot guarantees at most one active limit per `user_id + limit_code`.
func (k Keyspace) ActiveLimit(userID common.UserID, code policy.LimitCode) string {
return k.prefix() + "limit:active:" + encodeKeyComponent(userID.String()) + ":" + encodeKeyComponent(string(code))
}
// CreatedAtIndex returns the deterministic newest-first user-ordering index.
func (k Keyspace) CreatedAtIndex() string {
return k.prefix() + "index:created-at"
}
// PaidStateIndex returns the coarse free-versus-paid index key.
func (k Keyspace) PaidStateIndex(state entitlement.PaidState) string {
return k.prefix() + "index:paid-state:" + encodeKeyComponent(string(state))
}
// FinitePaidExpiryIndex returns the finite paid-expiry index key. Lifetime
// plans intentionally do not participate in this index.
func (k Keyspace) FinitePaidExpiryIndex() string {
return k.prefix() + "index:paid-expiry:finite"
}
// DeclaredCountryIndex returns the current declared-country reverse-lookup
// index key.
func (k Keyspace) DeclaredCountryIndex(code common.CountryCode) string {
return k.prefix() + "index:declared-country:" + encodeKeyComponent(code.String())
}
// ActiveSanctionCodeIndex returns the reverse-lookup index key for users with
// an active sanction code.
func (k Keyspace) ActiveSanctionCodeIndex(code policy.SanctionCode) string {
return k.prefix() + "index:active-sanction:" + encodeKeyComponent(string(code))
}
// ActiveLimitCodeIndex returns the reverse-lookup index key for users with an
// active limit code.
func (k Keyspace) ActiveLimitCodeIndex(code policy.LimitCode) string {
return k.prefix() + "index:active-limit:" + encodeKeyComponent(string(code))
}
// EligibilityMarkerIndex returns the reverse-lookup index key for one derived
// eligibility marker boolean.
func (k Keyspace) EligibilityMarkerIndex(marker policy.EligibilityMarker, value bool) string {
return fmt.Sprintf("%sindex:eligibility:%s:%t", k.prefix(), encodeKeyComponent(string(marker)), value)
}
// CreatedAtScore returns the frozen ZSET score representation for created-at
// ordering and deterministic pagination.
func CreatedAtScore(createdAt time.Time) float64 {
return float64(createdAt.UTC().UnixMicro())
}
// ExpiryScore returns the frozen ZSET score representation for finite paid
// expiry ordering.
func ExpiryScore(expiresAt time.Time) float64 {
return float64(expiresAt.UTC().UnixMicro())
}
// PageCursor identifies the last seen `(created_at, user_id)` tuple used by
// deterministic newest-first pagination.
type PageCursor struct {
// CreatedAt stores the created-at component of the last seen row.
CreatedAt time.Time
// UserID stores the user-id tiebreaker component of the last seen row.
UserID common.UserID
}
// Validate reports whether PageCursor contains a complete cursor tuple.
func (cursor PageCursor) Validate() error {
if err := common.ValidateTimestamp("page cursor created at", cursor.CreatedAt); err != nil {
return err
}
if err := cursor.UserID.Validate(); err != nil {
return fmt.Errorf("page cursor user id: %w", err)
}
return nil
}
// ComparePageOrder compares two listing positions using the frozen ordering:
// `created_at desc`, then `user_id desc`.
func ComparePageOrder(left PageCursor, right PageCursor) int {
switch {
case left.CreatedAt.After(right.CreatedAt):
return -1
case left.CreatedAt.Before(right.CreatedAt):
return 1
default:
return -strings.Compare(left.UserID.String(), right.UserID.String())
}
}
func (k Keyspace) prefix() string {
prefix := strings.TrimSpace(k.Prefix)
if prefix == "" {
return defaultPrefix
}
return prefix
}
func encodeKeyComponent(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
@@ -1,57 +0,0 @@
package redisstate
import (
"testing"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"github.com/stretchr/testify/require"
)
func TestKeyspaceBuildsStableKeys(t *testing.T) {
t.Parallel()
keyspace := Keyspace{Prefix: "custom:"}
require.Equal(t, "custom:account:dXNlci0xMjM", keyspace.Account(common.UserID("user-123")))
require.Equal(t, "custom:lookup:email:cGlsb3RAZXhhbXBsZS5jb20", keyspace.EmailLookup(common.Email("pilot@example.com")))
require.Equal(t, "custom:lookup:user-name:cGxheWVyLWFiY2RlZmdo", keyspace.UserNameLookup(common.UserName("player-abcdefgh")))
require.Equal(t, "custom:blocked-email:cGlsb3RAZXhhbXBsZS5jb20", keyspace.BlockedEmailSubject(common.Email("pilot@example.com")))
require.Equal(t, "custom:entitlement:record:ZW50aXRsZW1lbnQtMTIz", keyspace.EntitlementRecord(entitlement.EntitlementRecordID("entitlement-123")))
require.Equal(t, "custom:sanction:record:c2FuY3Rpb24tMQ", keyspace.SanctionRecord(policy.SanctionRecordID("sanction-1")))
require.Equal(t, "custom:limit:record:bGltaXQtMQ", keyspace.LimitRecord(policy.LimitRecordID("limit-1")))
require.Equal(t, "custom:sanction:active:dXNlci0xMjM:bG9naW5fYmxvY2s", keyspace.ActiveSanction(common.UserID("user-123"), policy.SanctionCodeLoginBlock))
require.Equal(t, "custom:limit:active:dXNlci0xMjM:bWF4X293bmVkX3ByaXZhdGVfZ2FtZXM", keyspace.ActiveLimit(common.UserID("user-123"), policy.LimitCodeMaxOwnedPrivateGames))
require.Equal(t, "custom:index:created-at", keyspace.CreatedAtIndex())
require.Equal(t, "custom:index:paid-state:cGFpZA", keyspace.PaidStateIndex(entitlement.PaidStatePaid))
require.Equal(t, "custom:index:paid-expiry:finite", keyspace.FinitePaidExpiryIndex())
require.Equal(t, "custom:index:declared-country:REU", keyspace.DeclaredCountryIndex(common.CountryCode("DE")))
require.Equal(t, "custom:index:active-sanction:bG9naW5fYmxvY2s", keyspace.ActiveSanctionCodeIndex(policy.SanctionCodeLoginBlock))
require.Equal(t, "custom:index:active-limit:bWF4X293bmVkX3ByaXZhdGVfZ2FtZXM", keyspace.ActiveLimitCodeIndex(policy.LimitCodeMaxOwnedPrivateGames))
require.Equal(t, "custom:index:eligibility:Y2FuX2xvZ2lu:true", keyspace.EligibilityMarkerIndex(policy.EligibilityMarkerCanLogin, true))
}
func TestComparePageOrder(t *testing.T) {
t.Parallel()
newer := PageCursor{CreatedAt: time.Unix(20, 0).UTC(), UserID: common.UserID("user-200")}
older := PageCursor{CreatedAt: time.Unix(10, 0).UTC(), UserID: common.UserID("user-100")}
sameTimeHigherUserID := PageCursor{CreatedAt: time.Unix(20, 0).UTC(), UserID: common.UserID("user-300")}
require.Negative(t, ComparePageOrder(newer, older))
require.Positive(t, ComparePageOrder(older, newer))
require.Negative(t, ComparePageOrder(sameTimeHigherUserID, newer))
}
func TestScoresUseUnixMicro(t *testing.T) {
t.Parallel()
value := time.Unix(1_775_240_000, 123_000).UTC()
want := float64(value.UnixMicro())
require.Equal(t, want, CreatedAtScore(value))
require.Equal(t, want, ExpiryScore(value))
}
@@ -1,191 +0,0 @@
package redisstate
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
)
var (
// ErrPageTokenFiltersMismatch reports that a supplied page token was created
// for a different normalized filter set.
ErrPageTokenFiltersMismatch = errors.New("page token filters do not match current filters")
)
// UserListFilters stores the frozen admin-listing filter set that becomes part
// of the opaque page token fingerprint.
type UserListFilters struct {
// PaidState stores the coarse free-versus-paid filter.
PaidState entitlement.PaidState
// PaidExpiresBefore stores the optional finite-paid expiry upper bound.
PaidExpiresBefore *time.Time
// PaidExpiresAfter stores the optional finite-paid expiry lower bound.
PaidExpiresAfter *time.Time
// DeclaredCountry stores the optional declared-country filter.
DeclaredCountry common.CountryCode
// SanctionCode stores the optional active-sanction filter.
SanctionCode policy.SanctionCode
// LimitCode stores the optional active-limit filter.
LimitCode policy.LimitCode
// CanLogin stores the optional login-eligibility filter.
CanLogin *bool
// CanCreatePrivateGame stores the optional private-game-create eligibility
// filter.
CanCreatePrivateGame *bool
// CanJoinGame stores the optional join-game eligibility filter.
CanJoinGame *bool
}
// Validate reports whether UserListFilters is structurally valid.
func (filters UserListFilters) Validate() error {
if !filters.PaidState.IsKnown() {
return fmt.Errorf("paid state %q is unsupported", filters.PaidState)
}
if filters.PaidExpiresBefore != nil && filters.PaidExpiresBefore.IsZero() {
return fmt.Errorf("paid expires before must not be zero")
}
if filters.PaidExpiresAfter != nil && filters.PaidExpiresAfter.IsZero() {
return fmt.Errorf("paid expires after must not be zero")
}
if !filters.DeclaredCountry.IsZero() {
if err := filters.DeclaredCountry.Validate(); err != nil {
return fmt.Errorf("declared country: %w", err)
}
}
if filters.SanctionCode != "" && !filters.SanctionCode.IsKnown() {
return fmt.Errorf("sanction code %q is unsupported", filters.SanctionCode)
}
if filters.LimitCode != "" && !filters.LimitCode.IsKnown() {
return fmt.Errorf("limit code %q is unsupported", filters.LimitCode)
}
return nil
}
// EncodePageToken encodes cursor and filters into the frozen opaque page token
// format.
func EncodePageToken(cursor PageCursor, filters UserListFilters) (string, error) {
if err := cursor.Validate(); err != nil {
return "", fmt.Errorf("encode page token: %w", err)
}
fingerprint, err := normalizeFilters(filters)
if err != nil {
return "", fmt.Errorf("encode page token: %w", err)
}
payload, err := json.Marshal(pageTokenPayload{
CreatedAt: cursor.CreatedAt.UTC().Format(time.RFC3339Nano),
UserID: cursor.UserID.String(),
Filters: fingerprint,
})
if err != nil {
return "", fmt.Errorf("encode page token: %w", err)
}
return base64.RawURLEncoding.EncodeToString(payload), nil
}
// DecodePageToken decodes raw into the frozen page cursor and verifies that
// the embedded normalized filter set matches expectedFilters.
func DecodePageToken(raw string, expectedFilters UserListFilters) (PageCursor, error) {
fingerprint, err := normalizeFilters(expectedFilters)
if err != nil {
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
}
payload, err := base64.RawURLEncoding.DecodeString(raw)
if err != nil {
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
}
var token pageTokenPayload
if err := json.Unmarshal(payload, &token); err != nil {
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
}
if token.Filters != fingerprint {
return PageCursor{}, ErrPageTokenFiltersMismatch
}
createdAt, err := time.Parse(time.RFC3339Nano, token.CreatedAt)
if err != nil {
return PageCursor{}, fmt.Errorf("decode page token: parse created_at: %w", err)
}
cursor := PageCursor{
CreatedAt: createdAt.UTC(),
UserID: common.UserID(token.UserID),
}
if err := cursor.Validate(); err != nil {
return PageCursor{}, fmt.Errorf("decode page token: %w", err)
}
return cursor, nil
}
type pageTokenPayload struct {
CreatedAt string `json:"created_at"`
UserID string `json:"user_id"`
Filters normalizedFilterPayload `json:"filters"`
}
type normalizedFilterPayload struct {
PaidState string `json:"paid_state,omitempty"`
PaidExpiresBeforeUTC string `json:"paid_expires_before_utc,omitempty"`
PaidExpiresAfterUTC string `json:"paid_expires_after_utc,omitempty"`
DeclaredCountry string `json:"declared_country,omitempty"`
SanctionCode string `json:"sanction_code,omitempty"`
LimitCode string `json:"limit_code,omitempty"`
CanLogin string `json:"can_login,omitempty"`
CanCreatePrivateGame string `json:"can_create_private_game,omitempty"`
CanJoinGame string `json:"can_join_game,omitempty"`
}
func normalizeFilters(filters UserListFilters) (normalizedFilterPayload, error) {
if err := filters.Validate(); err != nil {
return normalizedFilterPayload{}, err
}
return normalizedFilterPayload{
PaidState: string(filters.PaidState),
PaidExpiresBeforeUTC: formatOptionalTime(filters.PaidExpiresBefore),
PaidExpiresAfterUTC: formatOptionalTime(filters.PaidExpiresAfter),
DeclaredCountry: filters.DeclaredCountry.String(),
SanctionCode: string(filters.SanctionCode),
LimitCode: string(filters.LimitCode),
CanLogin: formatOptionalBool(filters.CanLogin),
CanCreatePrivateGame: formatOptionalBool(filters.CanCreatePrivateGame),
CanJoinGame: formatOptionalBool(filters.CanJoinGame),
}, nil
}
func formatOptionalTime(value *time.Time) string {
if value == nil {
return ""
}
return value.UTC().Format(time.RFC3339Nano)
}
func formatOptionalBool(value *bool) string {
if value == nil {
return ""
}
if *value {
return "true"
}
return "false"
}
@@ -1,70 +0,0 @@
package redisstate
import (
"testing"
"time"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"github.com/stretchr/testify/require"
)
func TestEncodeDecodePageToken(t *testing.T) {
t.Parallel()
before := time.Unix(1_775_250_000, 0).UTC()
after := time.Unix(1_775_240_000, 0).UTC()
canLogin := true
canCreate := false
canJoin := true
filters := UserListFilters{
PaidState: entitlement.PaidStatePaid,
PaidExpiresBefore: &before,
PaidExpiresAfter: &after,
DeclaredCountry: common.CountryCode("DE"),
SanctionCode: policy.SanctionCodeLoginBlock,
LimitCode: policy.LimitCodeMaxOwnedPrivateGames,
CanLogin: &canLogin,
CanCreatePrivateGame: &canCreate,
CanJoinGame: &canJoin,
}
cursor := PageCursor{
CreatedAt: time.Unix(1_775_240_100, 987_000_000).UTC(),
UserID: common.UserID("user-123"),
}
token, err := EncodePageToken(cursor, filters)
require.NoError(t, err)
decoded, err := DecodePageToken(token, filters)
require.NoError(t, err)
require.Equal(t, cursor, decoded)
}
func TestDecodePageTokenFilterMismatch(t *testing.T) {
t.Parallel()
cursor := PageCursor{
CreatedAt: time.Unix(1_775_240_100, 0).UTC(),
UserID: common.UserID("user-123"),
}
filters := UserListFilters{
PaidState: entitlement.PaidStatePaid,
}
token, err := EncodePageToken(cursor, filters)
require.NoError(t, err)
_, err = DecodePageToken(token, UserListFilters{PaidState: entitlement.PaidStateFree})
require.ErrorIs(t, err, ErrPageTokenFiltersMismatch)
}
func TestDecodePageTokenRejectsInvalidInput(t *testing.T) {
t.Parallel()
_, err := DecodePageToken("%%%not-base64%%%", UserListFilters{})
require.Error(t, err)
}
+80 -44
View File
@@ -3,16 +3,20 @@ package app
import (
"context"
"database/sql"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
"galaxy/postgres"
"galaxy/redisconn"
"galaxy/user/internal/adapters/local"
"galaxy/user/internal/adapters/postgres/migrations"
pguserstore "galaxy/user/internal/adapters/postgres/userstore"
"galaxy/user/internal/adapters/redis/domainevents"
"galaxy/user/internal/adapters/redis/lifecycleevents"
"galaxy/user/internal/adapters/redis/userstore"
"galaxy/user/internal/adminapi"
"galaxy/user/internal/api/internalhttp"
"galaxy/user/internal/config"
@@ -25,16 +29,14 @@ import (
"galaxy/user/internal/service/policysvc"
"galaxy/user/internal/service/selfservice"
"galaxy/user/internal/telemetry"
goredis "github.com/redis/go-redis/v9"
)
type pinger interface {
Ping(context.Context) error
}
type closer interface {
Close() error
}
// Runtime owns the runnable user-service process plus the cleanup functions
// that release runtime resources after shutdown.
type Runtime struct {
@@ -93,61 +95,75 @@ func NewRuntime(ctx context.Context, cfg config.Config, logger *slog.Logger) (*R
return telemetryRuntime.Shutdown(shutdownCtx)
})
store, err := userstore.New(userstore.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
KeyspacePrefix: cfg.Redis.KeyspacePrefix,
OperationTimeout: cfg.Redis.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: redis user store: %w", err))
// Open the shared Redis master client for both stream publishers. The
// client is owned by the runtime; publishers borrow it through their
// New(client, cfg) constructors.
redisClient := redisconn.NewMasterClient(cfg.Redis.Conn)
if err := redisconn.Instrument(redisClient,
redisconn.WithTracerProvider(telemetryRuntime.TracerProvider()),
redisconn.WithMeterProvider(telemetryRuntime.MeterProvider()),
); err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: instrument redis client: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, store.Close)
if err := pingDependency(ctx, "redis user store", store); err != nil {
runtime.cleanupFns = append(runtime.cleanupFns, redisClient.Close)
if err := pingRedisClient(ctx, redisClient, cfg.Redis.Conn); err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: %w", err))
}
domainEventPublisher, err := domainevents.New(domainevents.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
// Open the PostgreSQL pool, attach instrumentation, ping it, and apply
// embedded migrations strictly before any HTTP listener opens. A failure
// at any of these steps is fatal: the service exits with non-zero status.
pgPool, err := postgres.OpenPrimary(ctx, cfg.Postgres.Conn,
postgres.WithTracerProvider(telemetryRuntime.TracerProvider()),
postgres.WithMeterProvider(telemetryRuntime.MeterProvider()),
)
if err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: open postgres primary: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, pgPool.Close)
unregisterDBStats, err := postgres.InstrumentDBStats(pgPool,
postgres.WithMeterProvider(telemetryRuntime.MeterProvider()),
)
if err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: instrument postgres db stats: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, unregisterDBStats)
if err := postgres.Ping(ctx, pgPool, cfg.Postgres.Conn.OperationTimeout); err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: %w", err))
}
migrationsFS := migrations.FS()
if err := postgres.RunMigrations(ctx, pgPool, migrationsFS, "."); err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: run postgres migrations: %w", err))
}
store, err := pguserstore.New(pguserstore.Config{
DB: pgPool,
OperationTimeout: cfg.Postgres.Conn.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: postgres user store: %w", err))
}
if err := pingDependency(ctx, "postgres user store", store); err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: %w", err))
}
domainEventPublisher, err := domainevents.New(redisClient, domainevents.Config{
Stream: cfg.Redis.DomainEventsStream,
StreamMaxLen: cfg.Redis.DomainEventsStreamMaxLen,
OperationTimeout: cfg.Redis.OperationTimeout,
OperationTimeout: cfg.Redis.Conn.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: redis domain-event publisher: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, domainEventPublisher.Close)
if err := pingDependency(ctx, "redis domain-event publisher", domainEventPublisher); err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: %w", err))
}
lifecycleEventPublisher, err := lifecycleevents.New(lifecycleevents.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
lifecycleEventPublisher, err := lifecycleevents.New(redisClient, lifecycleevents.Config{
Stream: cfg.Redis.LifecycleEventsStream,
StreamMaxLen: cfg.Redis.LifecycleEventsStreamMaxLen,
OperationTimeout: cfg.Redis.OperationTimeout,
OperationTimeout: cfg.Redis.Conn.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: redis lifecycle-event publisher: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, lifecycleEventPublisher.Close)
if err := pingDependency(ctx, "redis lifecycle-event publisher", lifecycleEventPublisher); err != nil {
return cleanupOnError(fmt.Errorf("new user-service runtime: %w", err))
}
clock := local.Clock{}
idGenerator := local.IDGenerator{}
@@ -517,4 +533,24 @@ func pingDependency(ctx context.Context, name string, dependency pinger) error {
return nil
}
var _ closer = (*userstore.Store)(nil)
func pingRedisClient(ctx context.Context, client *goredis.Client, cfg redisconn.Config) error {
pingCtx, cancel := context.WithTimeout(ctx, cfg.OperationTimeout)
defer cancel()
if err := client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis master: %w", err)
}
return nil
}
// Compile-time guard that the postgres-backed user store implements the
// closer pattern relied on by cleanupFns. Close is a no-op on the postgres
// store; the underlying *sql.DB is closed via cleanupFns appended above.
var _ interface{ Close() error } = (*pguserstore.Store)(nil)
// Compile-time guard that the postgres-backed user store also satisfies the
// pinger contract used by pingDependency.
var _ pinger = (*pguserstore.Store)(nil)
// Compile-time guard kept from the previous implementation so future readers
// can trust the *sql.DB life cycle remains consistent with cleanupFns.
var _ *sql.DB = (*sql.DB)(nil)
+71 -98
View File
@@ -3,16 +3,20 @@
package config
import (
"crypto/tls"
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
"galaxy/postgres"
"galaxy/redisconn"
)
const (
envPrefix = "USERSERVICE"
shutdownTimeoutEnvVar = "USERSERVICE_SHUTDOWN_TIMEOUT"
logLevelEnvVar = "USERSERVICE_LOG_LEVEL"
@@ -27,13 +31,6 @@ const (
adminHTTPReadTimeoutEnvVar = "USERSERVICE_ADMIN_HTTP_READ_TIMEOUT"
adminHTTPIdleTimeoutEnvVar = "USERSERVICE_ADMIN_HTTP_IDLE_TIMEOUT"
redisAddrEnvVar = "USERSERVICE_REDIS_ADDR"
redisUsernameEnvVar = "USERSERVICE_REDIS_USERNAME"
redisPasswordEnvVar = "USERSERVICE_REDIS_PASSWORD"
redisDBEnvVar = "USERSERVICE_REDIS_DB"
redisTLSEnabledEnvVar = "USERSERVICE_REDIS_TLS_ENABLED"
redisOperationTimeoutEnvVar = "USERSERVICE_REDIS_OPERATION_TIMEOUT"
redisKeyspacePrefixEnvVar = "USERSERVICE_REDIS_KEYSPACE_PREFIX"
redisDomainEventsStreamEnvVar = "USERSERVICE_REDIS_DOMAIN_EVENTS_STREAM"
redisDomainEventsStreamMaxLenEnvVar = "USERSERVICE_REDIS_DOMAIN_EVENTS_STREAM_MAX_LEN"
redisLifecycleEventsStreamEnvVar = "USERSERVICE_REDIS_LIFECYCLE_EVENTS_STREAM"
@@ -48,26 +45,23 @@ const (
otelStdoutTracesEnabledEnvVar = "USERSERVICE_OTEL_STDOUT_TRACES_ENABLED"
otelStdoutMetricsEnabledEnvVar = "USERSERVICE_OTEL_STDOUT_METRICS_ENABLED"
defaultShutdownTimeout = 5 * time.Second
defaultLogLevel = "info"
defaultInternalHTTPAddr = ":8091"
defaultAdminHTTPAddr = ""
defaultReadHeaderTimeout = 2 * time.Second
defaultReadTimeout = 10 * time.Second
defaultIdleTimeout = time.Minute
defaultRequestTimeout = 3 * time.Second
defaultRedisDB = 0
defaultRedisOperationTimeout = 250 * time.Millisecond
defaultRedisKeyspacePrefix = "user:"
defaultShutdownTimeout = 5 * time.Second
defaultLogLevel = "info"
defaultInternalHTTPAddr = ":8091"
defaultAdminHTTPAddr = ""
defaultReadHeaderTimeout = 2 * time.Second
defaultReadTimeout = 10 * time.Second
defaultIdleTimeout = time.Minute
defaultRequestTimeout = 3 * time.Second
defaultDomainEventsStream = "user:domain_events"
defaultDomainEventsStreamMaxLen = 1024
defaultLifecycleEventsStream = "user:lifecycle_events"
defaultLifecycleEventsStreamMaxLen = 1024
defaultOTelServiceName = "galaxy-user"
otelExporterNone = "none"
otelExporterOTLP = "otlp"
otelProtocolHTTPProtobuf = "http/protobuf"
otelProtocolGRPC = "grpc"
defaultOTelServiceName = "galaxy-user"
otelExporterNone = "none"
otelExporterOTLP = "otlp"
otelProtocolHTTPProtobuf = "http/protobuf"
otelProtocolGRPC = "grpc"
)
// Config stores the full user-service process configuration.
@@ -85,9 +79,14 @@ type Config struct {
// AdminHTTP configures the optional private admin HTTP listener.
AdminHTTP AdminHTTPConfig
// Redis configures the Redis-backed user store and domain-event publisher.
// Redis configures the Redis-backed event publishers (domain + lifecycle
// streams) plus the connection topology consumed via `pkg/redisconn`.
Redis RedisConfig
// Postgres configures the PostgreSQL-backed durable store consumed via
// `pkg/postgres`.
Postgres PostgresConfig
// Telemetry configures the process-wide OpenTelemetry runtime.
Telemetry TelemetryConfig
}
@@ -171,28 +170,12 @@ func (cfg AdminHTTPConfig) Validate() error {
}
}
// RedisConfig configures the Redis-backed store and domain-event publisher.
// RedisConfig configures the Redis-backed event publishers and the connection
// topology shared with `pkg/redisconn`.
type RedisConfig struct {
// Addr stores the Redis network address.
Addr string
// Username stores the optional Redis ACL username.
Username string
// Password stores the optional Redis ACL password.
Password string
// DB stores the Redis logical database index.
DB int
// TLSEnabled reports whether TLS must be used for Redis connections.
TLSEnabled bool
// OperationTimeout bounds one Redis round trip.
OperationTimeout time.Duration
// KeyspacePrefix stores the root prefix of the service-owned Redis keyspace.
KeyspacePrefix string
// Conn carries the connection topology (master, replicas, password, db,
// per-call timeout). Loaded via redisconn.LoadFromEnv("USERSERVICE").
Conn redisconn.Config
// DomainEventsStream stores the Redis Stream key used for auxiliary
// post-commit domain events.
@@ -203,8 +186,8 @@ type RedisConfig struct {
DomainEventsStreamMaxLen int64
// LifecycleEventsStream stores the Redis Stream key used for trusted
// user-lifecycle events (permanent_block, delete) consumed by
// `Game Lobby` for Race Name Directory cascade release.
// user-lifecycle events (permanent_block, delete) consumed by `Game
// Lobby` for Race Name Directory cascade release.
LifecycleEventsStream string
// LifecycleEventsStreamMaxLen bounds the lifecycle-events Redis Stream
@@ -212,27 +195,12 @@ type RedisConfig struct {
LifecycleEventsStreamMaxLen int64
}
// TLSConfig returns the conservative TLS configuration used by Redis adapters
// when TLSEnabled is true.
func (cfg RedisConfig) TLSConfig() *tls.Config {
if !cfg.TLSEnabled {
return nil
}
return &tls.Config{MinVersion: tls.VersionTLS12}
}
// Validate reports whether cfg stores a usable Redis configuration.
func (cfg RedisConfig) Validate() error {
if err := cfg.Conn.Validate(); err != nil {
return err
}
switch {
case strings.TrimSpace(cfg.Addr) == "":
return fmt.Errorf("redis addr must not be empty")
case cfg.DB < 0:
return fmt.Errorf("redis db must not be negative")
case cfg.OperationTimeout <= 0:
return fmt.Errorf("redis operation timeout must be positive")
case strings.TrimSpace(cfg.KeyspacePrefix) == "":
return fmt.Errorf("redis keyspace prefix must not be empty")
case strings.TrimSpace(cfg.DomainEventsStream) == "":
return fmt.Errorf("redis domain events stream must not be empty")
case cfg.DomainEventsStreamMaxLen <= 0:
@@ -246,6 +214,20 @@ func (cfg RedisConfig) Validate() error {
}
}
// PostgresConfig configures the PostgreSQL-backed durable store. It wraps
// the shared `pkg/postgres.Config` so callers receive the same struct shape
// across services.
type PostgresConfig struct {
// Conn stores the primary plus replica DSN topology and pool tuning.
// Loaded via postgres.LoadFromEnv("USERSERVICE").
Conn postgres.Config
}
// Validate reports whether cfg stores a usable PostgreSQL configuration.
func (cfg PostgresConfig) Validate() error {
return cfg.Conn.Validate()
}
// TelemetryConfig configures the user-service OpenTelemetry runtime.
type TelemetryConfig struct {
// ServiceName overrides the default OpenTelemetry service name.
@@ -313,7 +295,9 @@ func DefaultAdminHTTPConfig() AdminHTTPConfig {
}
// DefaultConfig returns the default process configuration with all optional
// values filled.
// values filled. Required connection coordinates (Redis master/password,
// Postgres primary DSN) remain zero-valued and must be supplied via
// LoadFromEnv.
func DefaultConfig() Config {
return Config{
ShutdownTimeout: defaultShutdownTimeout,
@@ -329,14 +313,15 @@ func DefaultConfig() Config {
},
AdminHTTP: DefaultAdminHTTPConfig(),
Redis: RedisConfig{
DB: defaultRedisDB,
OperationTimeout: defaultRedisOperationTimeout,
KeyspacePrefix: defaultRedisKeyspacePrefix,
Conn: redisconn.DefaultConfig(),
DomainEventsStream: defaultDomainEventsStream,
DomainEventsStreamMaxLen: defaultDomainEventsStreamMaxLen,
LifecycleEventsStream: defaultLifecycleEventsStream,
LifecycleEventsStreamMaxLen: defaultLifecycleEventsStreamMaxLen,
},
Postgres: PostgresConfig{
Conn: postgres.DefaultConfig(),
},
Telemetry: TelemetryConfig{
ServiceName: defaultOTelServiceName,
TracesExporter: otelExporterNone,
@@ -360,6 +345,9 @@ func (cfg Config) Validate() error {
if err := cfg.Redis.Validate(); err != nil {
return fmt.Errorf("redis config: %w", err)
}
if err := cfg.Postgres.Validate(); err != nil {
return fmt.Errorf("postgres config: %w", err)
}
if _, err := parseLogLevel(cfg.Logging.Level); err != nil {
return fmt.Errorf("logging config: %w", err)
}
@@ -370,7 +358,11 @@ func (cfg Config) Validate() error {
return nil
}
// LoadFromEnv loads Config from the process environment.
// LoadFromEnv loads Config from the process environment. Connection topology
// for Redis and PostgreSQL is delegated to the shared `pkg/redisconn` and
// `pkg/postgres` LoadFromEnv helpers, which enforce the architectural rules
// (mandatory Redis password, deprecated TLS/USERNAME variables hard-fail,
// required Postgres primary DSN).
func LoadFromEnv() (Config, error) {
cfg := DefaultConfig()
@@ -413,22 +405,11 @@ func LoadFromEnv() (Config, error) {
return Config{}, err
}
cfg.Redis.Addr = loadString(redisAddrEnvVar, cfg.Redis.Addr)
cfg.Redis.Username = loadString(redisUsernameEnvVar, cfg.Redis.Username)
cfg.Redis.Password = loadString(redisPasswordEnvVar, cfg.Redis.Password)
cfg.Redis.DB, err = loadInt(redisDBEnvVar, cfg.Redis.DB)
redisConn, err := redisconn.LoadFromEnv(envPrefix)
if err != nil {
return Config{}, err
}
cfg.Redis.TLSEnabled, err = loadBool(redisTLSEnabledEnvVar, cfg.Redis.TLSEnabled)
if err != nil {
return Config{}, err
}
cfg.Redis.OperationTimeout, err = loadDuration(redisOperationTimeoutEnvVar, cfg.Redis.OperationTimeout)
if err != nil {
return Config{}, err
}
cfg.Redis.KeyspacePrefix = loadString(redisKeyspacePrefixEnvVar, cfg.Redis.KeyspacePrefix)
cfg.Redis.Conn = redisConn
cfg.Redis.DomainEventsStream = loadString(redisDomainEventsStreamEnvVar, cfg.Redis.DomainEventsStream)
cfg.Redis.DomainEventsStreamMaxLen, err = loadInt64(redisDomainEventsStreamMaxLenEnvVar, cfg.Redis.DomainEventsStreamMaxLen)
if err != nil {
@@ -440,6 +421,12 @@ func LoadFromEnv() (Config, error) {
return Config{}, err
}
pgConn, err := postgres.LoadFromEnv(envPrefix)
if err != nil {
return Config{}, err
}
cfg.Postgres.Conn = pgConn
cfg.Telemetry.ServiceName = loadString(otelServiceNameEnvVar, cfg.Telemetry.ServiceName)
cfg.Telemetry.TracesExporter = normalizeExporterValue(loadString(otelTracesExporterEnvVar, cfg.Telemetry.TracesExporter))
cfg.Telemetry.MetricsExporter = normalizeExporterValue(loadString(otelMetricsExporterEnvVar, cfg.Telemetry.MetricsExporter))
@@ -492,20 +479,6 @@ func loadDuration(envName string, defaultValue time.Duration) (time.Duration, er
return duration, nil
}
func loadInt(envName string, defaultValue int) (int, error) {
value, ok := os.LookupEnv(envName)
if !ok {
return defaultValue, nil
}
parsedValue, err := strconv.Atoi(strings.TrimSpace(value))
if err != nil {
return 0, fmt.Errorf("%s: parse int: %w", envName, err)
}
return parsedValue, nil
}
func loadInt64(envName string, defaultValue int64) (int64, error) {
value, ok := os.LookupEnv(envName)
if !ok {
+130 -23
View File
@@ -1,14 +1,37 @@
package config
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const (
redisMasterAddrEnvVar = "USERSERVICE_REDIS_MASTER_ADDR"
redisReplicaAddrsEnvVar = "USERSERVICE_REDIS_REPLICA_ADDRS"
redisPasswordEnvVar = "USERSERVICE_REDIS_PASSWORD"
redisDBEnvVar = "USERSERVICE_REDIS_DB"
redisOperationTimeoutEnvVar = "USERSERVICE_REDIS_OPERATION_TIMEOUT"
redisLegacyAddrEnvVar = "USERSERVICE_REDIS_ADDR"
redisLegacyUsernameEnvVar = "USERSERVICE_REDIS_USERNAME"
redisLegacyTLSEnabledEnvVar = "USERSERVICE_REDIS_TLS_ENABLED"
redisLegacyKeyspacePrefixEnv = "USERSERVICE_REDIS_KEYSPACE_PREFIX"
postgresPrimaryDSNEnvVar = "USERSERVICE_POSTGRES_PRIMARY_DSN"
postgresReplicaDSNsEnvVar = "USERSERVICE_POSTGRES_REPLICA_DSNS"
postgresOperationTimeoutEnvVar = "USERSERVICE_POSTGRES_OPERATION_TIMEOUT"
postgresMaxOpenConnsEnvVar = "USERSERVICE_POSTGRES_MAX_OPEN_CONNS"
postgresMaxIdleConnsEnvVar = "USERSERVICE_POSTGRES_MAX_IDLE_CONNS"
postgresConnMaxLifetimeEnvVar = "USERSERVICE_POSTGRES_CONN_MAX_LIFETIME"
defaultPostgresDSN = "postgres://userservice:secret@127.0.0.1:5432/galaxy?search_path=user&sslmode=disable"
)
func TestLoadFromEnvUsesDefaults(t *testing.T) {
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
t.Setenv(redisMasterAddrEnvVar, "127.0.0.1:6379")
t.Setenv(redisPasswordEnvVar, "secret")
t.Setenv(postgresPrimaryDSNEnvVar, defaultPostgresDSN)
cfg, err := LoadFromEnv()
require.NoError(t, err)
@@ -18,10 +41,18 @@ func TestLoadFromEnvUsesDefaults(t *testing.T) {
require.Equal(t, defaults.Logging.Level, cfg.Logging.Level)
require.Equal(t, defaults.InternalHTTP, cfg.InternalHTTP)
require.Equal(t, defaults.AdminHTTP, cfg.AdminHTTP)
require.Equal(t, "127.0.0.1:6379", cfg.Redis.Addr)
require.Equal(t, defaults.Redis.DB, cfg.Redis.DB)
require.Equal(t, "127.0.0.1:6379", cfg.Redis.Conn.MasterAddr)
require.Equal(t, "secret", cfg.Redis.Conn.Password)
require.Equal(t, defaults.Redis.Conn.DB, cfg.Redis.Conn.DB)
require.Equal(t, defaults.Redis.DomainEventsStream, cfg.Redis.DomainEventsStream)
require.Equal(t, defaults.Redis.DomainEventsStreamMaxLen, cfg.Redis.DomainEventsStreamMaxLen)
require.Equal(t, defaults.Redis.LifecycleEventsStream, cfg.Redis.LifecycleEventsStream)
require.Equal(t, defaults.Redis.LifecycleEventsStreamMaxLen, cfg.Redis.LifecycleEventsStreamMaxLen)
require.Equal(t, defaultPostgresDSN, cfg.Postgres.Conn.PrimaryDSN)
require.Equal(t, defaults.Postgres.Conn.OperationTimeout, cfg.Postgres.Conn.OperationTimeout)
require.Equal(t, defaults.Postgres.Conn.MaxOpenConns, cfg.Postgres.Conn.MaxOpenConns)
require.Equal(t, defaults.Postgres.Conn.MaxIdleConns, cfg.Postgres.Conn.MaxIdleConns)
require.Equal(t, defaults.Postgres.Conn.ConnMaxLifetime, cfg.Postgres.Conn.ConnMaxLifetime)
require.Equal(t, defaults.Telemetry, cfg.Telemetry)
}
@@ -33,15 +64,21 @@ func TestLoadFromEnvAppliesOverrides(t *testing.T) {
t.Setenv(internalHTTPRequestTimeoutEnvVar, "750ms")
t.Setenv(adminHTTPAddrEnvVar, "127.0.0.1:19091")
t.Setenv(adminHTTPIdleTimeoutEnvVar, "90s")
t.Setenv(redisAddrEnvVar, "127.0.0.1:6380")
t.Setenv(redisUsernameEnvVar, "alice")
t.Setenv(redisPasswordEnvVar, "secret")
t.Setenv(redisMasterAddrEnvVar, "127.0.0.1:6380")
t.Setenv(redisReplicaAddrsEnvVar, "127.0.0.1:6381,127.0.0.1:6382")
t.Setenv(redisPasswordEnvVar, "redis-secret")
t.Setenv(redisDBEnvVar, "3")
t.Setenv(redisTLSEnabledEnvVar, "true")
t.Setenv(redisOperationTimeoutEnvVar, "900ms")
t.Setenv(redisKeyspacePrefixEnvVar, "user:custom:")
t.Setenv(redisDomainEventsStreamEnvVar, "user:test_events")
t.Setenv(redisDomainEventsStreamMaxLenEnvVar, "2048")
t.Setenv(redisLifecycleEventsStreamEnvVar, "user:test_lifecycle")
t.Setenv(redisLifecycleEventsStreamMaxLenEnvVar, "512")
t.Setenv(postgresPrimaryDSNEnvVar, defaultPostgresDSN)
t.Setenv(postgresReplicaDSNsEnvVar, "postgres://userservice:secret@replica-a/galaxy?sslmode=disable,postgres://userservice:secret@replica-b/galaxy?sslmode=disable")
t.Setenv(postgresOperationTimeoutEnvVar, "2s")
t.Setenv(postgresMaxOpenConnsEnvVar, "40")
t.Setenv(postgresMaxIdleConnsEnvVar, "8")
t.Setenv(postgresConnMaxLifetimeEnvVar, "45m")
t.Setenv(otelServiceNameEnvVar, "galaxy-user-stage12")
t.Setenv(otelTracesExporterEnvVar, "otlp")
t.Setenv(otelMetricsExporterEnvVar, "otlp")
@@ -60,15 +97,24 @@ func TestLoadFromEnvAppliesOverrides(t *testing.T) {
require.Equal(t, 750*time.Millisecond, cfg.InternalHTTP.RequestTimeout)
require.Equal(t, "127.0.0.1:19091", cfg.AdminHTTP.Addr)
require.Equal(t, 90*time.Second, cfg.AdminHTTP.IdleTimeout)
require.Equal(t, "127.0.0.1:6380", cfg.Redis.Addr)
require.Equal(t, "alice", cfg.Redis.Username)
require.Equal(t, "secret", cfg.Redis.Password)
require.Equal(t, 3, cfg.Redis.DB)
require.True(t, cfg.Redis.TLSEnabled)
require.Equal(t, 900*time.Millisecond, cfg.Redis.OperationTimeout)
require.Equal(t, "user:custom:", cfg.Redis.KeyspacePrefix)
require.Equal(t, "127.0.0.1:6380", cfg.Redis.Conn.MasterAddr)
require.Equal(t, []string{"127.0.0.1:6381", "127.0.0.1:6382"}, cfg.Redis.Conn.ReplicaAddrs)
require.Equal(t, "redis-secret", cfg.Redis.Conn.Password)
require.Equal(t, 3, cfg.Redis.Conn.DB)
require.Equal(t, 900*time.Millisecond, cfg.Redis.Conn.OperationTimeout)
require.Equal(t, "user:test_events", cfg.Redis.DomainEventsStream)
require.Equal(t, int64(2048), cfg.Redis.DomainEventsStreamMaxLen)
require.Equal(t, "user:test_lifecycle", cfg.Redis.LifecycleEventsStream)
require.Equal(t, int64(512), cfg.Redis.LifecycleEventsStreamMaxLen)
require.Equal(t, defaultPostgresDSN, cfg.Postgres.Conn.PrimaryDSN)
require.Equal(t, []string{
"postgres://userservice:secret@replica-a/galaxy?sslmode=disable",
"postgres://userservice:secret@replica-b/galaxy?sslmode=disable",
}, cfg.Postgres.Conn.ReplicaDSNs)
require.Equal(t, 2*time.Second, cfg.Postgres.Conn.OperationTimeout)
require.Equal(t, 40, cfg.Postgres.Conn.MaxOpenConns)
require.Equal(t, 8, cfg.Postgres.Conn.MaxIdleConns)
require.Equal(t, 45*time.Minute, cfg.Postgres.Conn.ConnMaxLifetime)
require.Equal(t, "galaxy-user-stage12", cfg.Telemetry.ServiceName)
require.Equal(t, "otlp", cfg.Telemetry.TracesExporter)
require.Equal(t, "otlp", cfg.Telemetry.MetricsExporter)
@@ -78,29 +124,90 @@ func TestLoadFromEnvAppliesOverrides(t *testing.T) {
require.True(t, cfg.Telemetry.StdoutMetricsEnabled)
}
// TestLoadFromEnvRejectsLegacyRedisVars verifies the architectural rule from
// PG_PLAN.md §3 / ARCHITECTURE.md §Persistence Backends: legacy
// USERSERVICE_REDIS_TLS_ENABLED and USERSERVICE_REDIS_USERNAME variables must
// produce a startup error from `pkg/redisconn` so operators see the breaking
// rename immediately.
func TestLoadFromEnvRejectsLegacyRedisVars(t *testing.T) {
cases := []struct {
name string
envName string
}{
{name: "tls_enabled deprecated", envName: redisLegacyTLSEnabledEnvVar},
{name: "username deprecated", envName: redisLegacyUsernameEnvVar},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(redisMasterAddrEnvVar, "127.0.0.1:6379")
t.Setenv(redisPasswordEnvVar, "secret")
t.Setenv(postgresPrimaryDSNEnvVar, defaultPostgresDSN)
t.Setenv(tc.envName, "true")
_, err := LoadFromEnv()
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "no longer supported"))
})
}
}
// TestLoadFromEnvRequiresMandatoryFields covers the architectural rule that
// Redis password, master address and Postgres primary DSN are mandatory;
// missing any one returns a startup error.
func TestLoadFromEnvRequiresMandatoryFields(t *testing.T) {
t.Run("missing redis password", func(t *testing.T) {
t.Setenv(redisMasterAddrEnvVar, "127.0.0.1:6379")
t.Setenv(postgresPrimaryDSNEnvVar, defaultPostgresDSN)
_, err := LoadFromEnv()
require.Error(t, err)
})
t.Run("missing redis master addr", func(t *testing.T) {
t.Setenv(redisPasswordEnvVar, "secret")
t.Setenv(postgresPrimaryDSNEnvVar, defaultPostgresDSN)
_, err := LoadFromEnv()
require.Error(t, err)
})
t.Run("missing postgres dsn", func(t *testing.T) {
t.Setenv(redisMasterAddrEnvVar, "127.0.0.1:6379")
t.Setenv(redisPasswordEnvVar, "secret")
_, err := LoadFromEnv()
require.Error(t, err)
})
}
func TestLoadFromEnvRejectsInvalidValues(t *testing.T) {
tests := []struct {
cases := []struct {
name string
envName string
envVal string
}{
{name: "invalid duration", envName: shutdownTimeoutEnvVar, envVal: "later"},
{name: "invalid bool", envName: redisTLSEnabledEnvVar, envVal: "sometimes"},
{name: "invalid log level", envName: logLevelEnvVar, envVal: "verbose"},
{name: "invalid int", envName: redisDBEnvVar, envVal: "db-three"},
{name: "invalid redis db", envName: redisDBEnvVar, envVal: "db-three"},
{name: "invalid stream max len", envName: redisDomainEventsStreamMaxLenEnvVar, envVal: "many"},
{name: "invalid traces exporter", envName: otelTracesExporterEnvVar, envVal: "zipkin"},
{name: "invalid metrics protocol", envName: otelExporterOTLPMetricsProtocolEnvVar, envVal: "udp"},
{name: "invalid postgres operation timeout", envName: postgresOperationTimeoutEnvVar, envVal: "soon"},
{name: "invalid postgres max open conns", envName: postgresMaxOpenConnsEnvVar, envVal: "none"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
t.Setenv(tt.envName, tt.envVal)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(redisMasterAddrEnvVar, "127.0.0.1:6379")
t.Setenv(redisPasswordEnvVar, "secret")
t.Setenv(postgresPrimaryDSNEnvVar, defaultPostgresDSN)
t.Setenv(tc.envName, tc.envVal)
_, err := LoadFromEnv()
require.Error(t, err)
})
}
}
// Suppress unused-warning for legacy keyspace prefix env reference: keep the
// constant in test scope for documentation, though no current code uses it.
var _ = redisLegacyAddrEnvVar
var _ = redisLegacyKeyspacePrefixEnv
@@ -5,15 +5,12 @@ import (
"testing"
"time"
"galaxy/user/internal/adapters/redis/userstore"
"galaxy/user/internal/domain/account"
"galaxy/user/internal/domain/common"
"galaxy/user/internal/domain/entitlement"
"galaxy/user/internal/domain/policy"
"galaxy/user/internal/ports"
"galaxy/user/internal/service/entitlementsvc"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/require"
)
@@ -249,66 +246,14 @@ func TestSnapshotReaderExecutePermanentBlockCollapsesMarkers(t *testing.T) {
}
}
func TestSnapshotReaderExecuteRepairsExpiredPaidSnapshotWithStore(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_240_500, 0).UTC()
store := newRedisStore(t)
userID := common.UserID("user-123")
accountRecord := validAccountRecord()
require.NoError(t, store.Accounts().Create(context.Background(), ports.CreateAccountInput{
Account: accountRecord,
}))
expiredEndsAt := now.Add(-time.Minute)
require.NoError(t, store.EntitlementSnapshots().Put(context.Background(), entitlement.CurrentSnapshot{
UserID: userID,
PlanCode: entitlement.PlanCodePaidMonthly,
IsPaid: true,
StartsAt: now.Add(-30 * 24 * time.Hour),
EndsAt: timePointer(expiredEndsAt),
Source: common.Source("billing"),
Actor: common.ActorRef{Type: common.ActorType("billing"), ID: common.ActorID("invoice-1")},
ReasonCode: common.ReasonCode("renewal"),
UpdatedAt: now.Add(-2 * time.Hour),
}))
entitlementReader, err := entitlementsvc.NewReader(
store.EntitlementSnapshots(),
store.EntitlementLifecycle(),
fixedClock{now: now},
fixedIDGenerator{entitlementRecordID: entitlement.EntitlementRecordID("entitlement-expiry-repair")},
)
require.NoError(t, err)
service, err := NewSnapshotReader(
store.Accounts(),
entitlementReader,
store.Sanctions(),
store.Limits(),
fixedClock{now: now},
)
require.NoError(t, err)
result, err := service.Execute(context.Background(), GetUserEligibilityInput{UserID: userID.String()})
require.NoError(t, err)
require.True(t, result.Exists)
require.NotNil(t, result.Entitlement)
require.Equal(t, "free", result.Entitlement.PlanCode)
require.False(t, result.Entitlement.IsPaid)
require.Equal(t, expiredEndsAt, result.Entitlement.StartsAt)
require.Equal(t, []EffectiveLimitView{
{LimitCode: "max_pending_public_applications", Value: 3},
{LimitCode: "max_active_game_memberships", Value: 3},
{LimitCode: "max_registered_race_names", Value: 1},
}, result.EffectiveLimits)
storedSnapshot, err := store.EntitlementSnapshots().GetByUserID(context.Background(), userID)
require.NoError(t, err)
require.Equal(t, entitlement.PlanCodeFree, storedSnapshot.PlanCode)
require.False(t, storedSnapshot.IsPaid)
}
// The expired-snapshot repair is exercised end-to-end through the
// runtime-contract test (`runtime_contract_test.go`), which boots a real
// PostgreSQL container and the full runtime. The original miniredis-based
// version of this test was removed in PG_PLAN.md §3 because the
// adapter-level RepairExpired path no longer exists in this package; the
// in-memory fake stores below cover the service-layer logic for every other
// scenario in the file.
var _ = entitlement.EntitlementRecordID("")
type fakeAccountStore struct {
existsByUserID map[common.UserID]bool
@@ -553,24 +498,6 @@ func validAccountRecord() account.UserAccount {
}
}
func newRedisStore(t *testing.T) *userstore.Store {
t.Helper()
server := miniredis.RunT(t)
store, err := userstore.New(userstore.Config{
Addr: server.Addr(),
DB: 0,
KeyspacePrefix: "user:test:",
OperationTimeout: 250 * time.Millisecond,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = store.Close()
})
return store
}
func timePointer(value time.Time) *time.Time {
utcValue := value.UTC()
return &utcValue