318 lines
9.2 KiB
Go
318 lines
9.2 KiB
Go
package redisstate
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"galaxy/lobby/internal/domain/common"
|
|
"galaxy/lobby/internal/domain/membership"
|
|
"galaxy/lobby/internal/ports"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// MembershipStore provides Redis-backed durable storage for membership
|
|
// records.
|
|
type MembershipStore struct {
|
|
client *redis.Client
|
|
keys Keyspace
|
|
}
|
|
|
|
// NewMembershipStore constructs one Redis-backed membership store. It
|
|
// returns an error when client is nil.
|
|
func NewMembershipStore(client *redis.Client) (*MembershipStore, error) {
|
|
if client == nil {
|
|
return nil, errors.New("new membership store: nil redis client")
|
|
}
|
|
|
|
return &MembershipStore{
|
|
client: client,
|
|
keys: Keyspace{},
|
|
}, nil
|
|
}
|
|
|
|
// Save persists a new active membership record. Save is create-only; a
|
|
// second save against the same membership id returns
|
|
// membership.ErrConflict.
|
|
func (store *MembershipStore) Save(ctx context.Context, record membership.Membership) error {
|
|
if store == nil || store.client == nil {
|
|
return errors.New("save membership: nil store")
|
|
}
|
|
if ctx == nil {
|
|
return errors.New("save membership: nil context")
|
|
}
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("save membership: %w", err)
|
|
}
|
|
if record.Status != membership.StatusActive {
|
|
return fmt.Errorf(
|
|
"save membership: status must be %q, got %q",
|
|
membership.StatusActive, record.Status,
|
|
)
|
|
}
|
|
|
|
payload, err := MarshalMembership(record)
|
|
if err != nil {
|
|
return fmt.Errorf("save membership: %w", err)
|
|
}
|
|
|
|
primaryKey := store.keys.Membership(record.MembershipID)
|
|
gameIndexKey := store.keys.MembershipsByGame(record.GameID)
|
|
userIndexKey := store.keys.MembershipsByUser(record.UserID)
|
|
member := record.MembershipID.String()
|
|
|
|
watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error {
|
|
existing, getErr := tx.Exists(ctx, primaryKey).Result()
|
|
if getErr != nil {
|
|
return fmt.Errorf("save membership: %w", getErr)
|
|
}
|
|
if existing != 0 {
|
|
return fmt.Errorf("save membership: %w", membership.ErrConflict)
|
|
}
|
|
|
|
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
|
|
pipe.Set(ctx, primaryKey, payload, MembershipRecordTTL)
|
|
pipe.SAdd(ctx, gameIndexKey, member)
|
|
pipe.SAdd(ctx, userIndexKey, member)
|
|
return nil
|
|
})
|
|
return err
|
|
}, primaryKey)
|
|
|
|
switch {
|
|
case errors.Is(watchErr, redis.TxFailedErr):
|
|
return fmt.Errorf("save membership: %w", membership.ErrConflict)
|
|
case watchErr != nil:
|
|
return watchErr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Get returns the record identified by membershipID.
|
|
func (store *MembershipStore) Get(ctx context.Context, membershipID common.MembershipID) (membership.Membership, error) {
|
|
if store == nil || store.client == nil {
|
|
return membership.Membership{}, errors.New("get membership: nil store")
|
|
}
|
|
if ctx == nil {
|
|
return membership.Membership{}, errors.New("get membership: nil context")
|
|
}
|
|
if err := membershipID.Validate(); err != nil {
|
|
return membership.Membership{}, fmt.Errorf("get membership: %w", err)
|
|
}
|
|
|
|
payload, err := store.client.Get(ctx, store.keys.Membership(membershipID)).Bytes()
|
|
switch {
|
|
case errors.Is(err, redis.Nil):
|
|
return membership.Membership{}, membership.ErrNotFound
|
|
case err != nil:
|
|
return membership.Membership{}, fmt.Errorf("get membership: %w", err)
|
|
}
|
|
|
|
record, err := UnmarshalMembership(payload)
|
|
if err != nil {
|
|
return membership.Membership{}, fmt.Errorf("get membership: %w", err)
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
// GetByGame returns every membership attached to gameID.
|
|
func (store *MembershipStore) GetByGame(ctx context.Context, gameID common.GameID) ([]membership.Membership, error) {
|
|
if store == nil || store.client == nil {
|
|
return nil, errors.New("get memberships by game: nil store")
|
|
}
|
|
if ctx == nil {
|
|
return nil, errors.New("get memberships by game: nil context")
|
|
}
|
|
if err := gameID.Validate(); err != nil {
|
|
return nil, fmt.Errorf("get memberships by game: %w", err)
|
|
}
|
|
|
|
return store.loadMembershipsBySet(ctx,
|
|
"get memberships by game",
|
|
store.keys.MembershipsByGame(gameID),
|
|
)
|
|
}
|
|
|
|
// GetByUser returns every membership held by userID.
|
|
func (store *MembershipStore) GetByUser(ctx context.Context, userID string) ([]membership.Membership, error) {
|
|
if store == nil || store.client == nil {
|
|
return nil, errors.New("get memberships by user: nil store")
|
|
}
|
|
if ctx == nil {
|
|
return nil, errors.New("get memberships by user: nil context")
|
|
}
|
|
trimmed := strings.TrimSpace(userID)
|
|
if trimmed == "" {
|
|
return nil, fmt.Errorf("get memberships by user: user id must not be empty")
|
|
}
|
|
|
|
return store.loadMembershipsBySet(ctx,
|
|
"get memberships by user",
|
|
store.keys.MembershipsByUser(trimmed),
|
|
)
|
|
}
|
|
|
|
// loadMembershipsBySet materializes memberships whose ids are stored in
|
|
// setKey. Stale set members are dropped silently.
|
|
func (store *MembershipStore) loadMembershipsBySet(ctx context.Context, operation, setKey string) ([]membership.Membership, error) {
|
|
members, err := store.client.SMembers(ctx, setKey).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: %w", operation, err)
|
|
}
|
|
if len(members) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
primaryKeys := make([]string, len(members))
|
|
for index, member := range members {
|
|
primaryKeys[index] = store.keys.Membership(common.MembershipID(member))
|
|
}
|
|
|
|
payloads, err := store.client.MGet(ctx, primaryKeys...).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: %w", operation, err)
|
|
}
|
|
|
|
records := make([]membership.Membership, 0, len(payloads))
|
|
for _, entry := range payloads {
|
|
if entry == nil {
|
|
continue
|
|
}
|
|
raw, ok := entry.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s: unexpected payload type %T", operation, entry)
|
|
}
|
|
record, err := UnmarshalMembership([]byte(raw))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: %w", operation, err)
|
|
}
|
|
records = append(records, record)
|
|
}
|
|
|
|
return records, nil
|
|
}
|
|
|
|
// UpdateStatus applies one status transition in a compare-and-swap fashion.
|
|
func (store *MembershipStore) UpdateStatus(ctx context.Context, input ports.UpdateMembershipStatusInput) error {
|
|
if store == nil || store.client == nil {
|
|
return errors.New("update membership status: nil store")
|
|
}
|
|
if ctx == nil {
|
|
return errors.New("update membership status: nil context")
|
|
}
|
|
if err := input.Validate(); err != nil {
|
|
return fmt.Errorf("update membership status: %w", err)
|
|
}
|
|
|
|
if err := membership.Transition(input.ExpectedFrom, input.To); err != nil {
|
|
return err
|
|
}
|
|
|
|
primaryKey := store.keys.Membership(input.MembershipID)
|
|
at := input.At.UTC()
|
|
|
|
watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error {
|
|
payload, getErr := tx.Get(ctx, primaryKey).Bytes()
|
|
switch {
|
|
case errors.Is(getErr, redis.Nil):
|
|
return membership.ErrNotFound
|
|
case getErr != nil:
|
|
return fmt.Errorf("update membership status: %w", getErr)
|
|
}
|
|
|
|
existing, err := UnmarshalMembership(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("update membership status: %w", err)
|
|
}
|
|
if existing.Status != input.ExpectedFrom {
|
|
return fmt.Errorf("update membership status: %w", membership.ErrConflict)
|
|
}
|
|
|
|
existing.Status = input.To
|
|
removedAt := at
|
|
existing.RemovedAt = &removedAt
|
|
|
|
encoded, err := MarshalMembership(existing)
|
|
if err != nil {
|
|
return fmt.Errorf("update membership status: %w", err)
|
|
}
|
|
|
|
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
|
|
pipe.Set(ctx, primaryKey, encoded, MembershipRecordTTL)
|
|
return nil
|
|
})
|
|
return err
|
|
}, primaryKey)
|
|
|
|
switch {
|
|
case errors.Is(watchErr, redis.TxFailedErr):
|
|
return fmt.Errorf("update membership status: %w", membership.ErrConflict)
|
|
case watchErr != nil:
|
|
return watchErr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Delete removes the membership record identified by membershipID from
|
|
// the primary store and from the per-game and per-user index sets in
|
|
// one transaction. It returns membership.ErrNotFound when no record
|
|
// exists for the id and membership.ErrConflict when a concurrent
|
|
// mutation invalidates the watched key.
|
|
func (store *MembershipStore) Delete(ctx context.Context, membershipID common.MembershipID) error {
|
|
if store == nil || store.client == nil {
|
|
return errors.New("delete membership: nil store")
|
|
}
|
|
if ctx == nil {
|
|
return errors.New("delete membership: nil context")
|
|
}
|
|
if err := membershipID.Validate(); err != nil {
|
|
return fmt.Errorf("delete membership: %w", err)
|
|
}
|
|
|
|
primaryKey := store.keys.Membership(membershipID)
|
|
member := membershipID.String()
|
|
|
|
watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error {
|
|
payload, getErr := tx.Get(ctx, primaryKey).Bytes()
|
|
switch {
|
|
case errors.Is(getErr, redis.Nil):
|
|
return membership.ErrNotFound
|
|
case getErr != nil:
|
|
return fmt.Errorf("delete membership: %w", getErr)
|
|
}
|
|
|
|
existing, err := UnmarshalMembership(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("delete membership: %w", err)
|
|
}
|
|
|
|
gameIndexKey := store.keys.MembershipsByGame(existing.GameID)
|
|
userIndexKey := store.keys.MembershipsByUser(existing.UserID)
|
|
|
|
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
|
|
pipe.Del(ctx, primaryKey)
|
|
pipe.SRem(ctx, gameIndexKey, member)
|
|
pipe.SRem(ctx, userIndexKey, member)
|
|
return nil
|
|
})
|
|
return err
|
|
}, primaryKey)
|
|
|
|
switch {
|
|
case errors.Is(watchErr, redis.TxFailedErr):
|
|
return fmt.Errorf("delete membership: %w", membership.ErrConflict)
|
|
case watchErr != nil:
|
|
return watchErr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Ensure MembershipStore satisfies the ports.MembershipStore interface at
|
|
// compile time.
|
|
var _ ports.MembershipStore = (*MembershipStore)(nil)
|