package redisstate import ( "context" "errors" "fmt" "strings" "galaxy/lobby/internal/domain/common" "galaxy/lobby/internal/domain/membership" "galaxy/lobby/internal/ports" "github.com/redis/go-redis/v9" ) // MembershipStore provides Redis-backed durable storage for membership // records. type MembershipStore struct { client *redis.Client keys Keyspace } // NewMembershipStore constructs one Redis-backed membership store. It // returns an error when client is nil. func NewMembershipStore(client *redis.Client) (*MembershipStore, error) { if client == nil { return nil, errors.New("new membership store: nil redis client") } return &MembershipStore{ client: client, keys: Keyspace{}, }, nil } // Save persists a new active membership record. Save is create-only; a // second save against the same membership id returns // membership.ErrConflict. func (store *MembershipStore) Save(ctx context.Context, record membership.Membership) error { if store == nil || store.client == nil { return errors.New("save membership: nil store") } if ctx == nil { return errors.New("save membership: nil context") } if err := record.Validate(); err != nil { return fmt.Errorf("save membership: %w", err) } if record.Status != membership.StatusActive { return fmt.Errorf( "save membership: status must be %q, got %q", membership.StatusActive, record.Status, ) } payload, err := MarshalMembership(record) if err != nil { return fmt.Errorf("save membership: %w", err) } primaryKey := store.keys.Membership(record.MembershipID) gameIndexKey := store.keys.MembershipsByGame(record.GameID) userIndexKey := store.keys.MembershipsByUser(record.UserID) member := record.MembershipID.String() watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error { existing, getErr := tx.Exists(ctx, primaryKey).Result() if getErr != nil { return fmt.Errorf("save membership: %w", getErr) } if existing != 0 { return fmt.Errorf("save membership: %w", membership.ErrConflict) } _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Set(ctx, primaryKey, payload, MembershipRecordTTL) pipe.SAdd(ctx, gameIndexKey, member) pipe.SAdd(ctx, userIndexKey, member) return nil }) return err }, primaryKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("save membership: %w", membership.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // Get returns the record identified by membershipID. func (store *MembershipStore) Get(ctx context.Context, membershipID common.MembershipID) (membership.Membership, error) { if store == nil || store.client == nil { return membership.Membership{}, errors.New("get membership: nil store") } if ctx == nil { return membership.Membership{}, errors.New("get membership: nil context") } if err := membershipID.Validate(); err != nil { return membership.Membership{}, fmt.Errorf("get membership: %w", err) } payload, err := store.client.Get(ctx, store.keys.Membership(membershipID)).Bytes() switch { case errors.Is(err, redis.Nil): return membership.Membership{}, membership.ErrNotFound case err != nil: return membership.Membership{}, fmt.Errorf("get membership: %w", err) } record, err := UnmarshalMembership(payload) if err != nil { return membership.Membership{}, fmt.Errorf("get membership: %w", err) } return record, nil } // GetByGame returns every membership attached to gameID. func (store *MembershipStore) GetByGame(ctx context.Context, gameID common.GameID) ([]membership.Membership, error) { if store == nil || store.client == nil { return nil, errors.New("get memberships by game: nil store") } if ctx == nil { return nil, errors.New("get memberships by game: nil context") } if err := gameID.Validate(); err != nil { return nil, fmt.Errorf("get memberships by game: %w", err) } return store.loadMembershipsBySet(ctx, "get memberships by game", store.keys.MembershipsByGame(gameID), ) } // GetByUser returns every membership held by userID. func (store *MembershipStore) GetByUser(ctx context.Context, userID string) ([]membership.Membership, error) { if store == nil || store.client == nil { return nil, errors.New("get memberships by user: nil store") } if ctx == nil { return nil, errors.New("get memberships by user: nil context") } trimmed := strings.TrimSpace(userID) if trimmed == "" { return nil, fmt.Errorf("get memberships by user: user id must not be empty") } return store.loadMembershipsBySet(ctx, "get memberships by user", store.keys.MembershipsByUser(trimmed), ) } // loadMembershipsBySet materializes memberships whose ids are stored in // setKey. Stale set members are dropped silently. func (store *MembershipStore) loadMembershipsBySet(ctx context.Context, operation, setKey string) ([]membership.Membership, error) { members, err := store.client.SMembers(ctx, setKey).Result() if err != nil { return nil, fmt.Errorf("%s: %w", operation, err) } if len(members) == 0 { return nil, nil } primaryKeys := make([]string, len(members)) for index, member := range members { primaryKeys[index] = store.keys.Membership(common.MembershipID(member)) } payloads, err := store.client.MGet(ctx, primaryKeys...).Result() if err != nil { return nil, fmt.Errorf("%s: %w", operation, err) } records := make([]membership.Membership, 0, len(payloads)) for _, entry := range payloads { if entry == nil { continue } raw, ok := entry.(string) if !ok { return nil, fmt.Errorf("%s: unexpected payload type %T", operation, entry) } record, err := UnmarshalMembership([]byte(raw)) if err != nil { return nil, fmt.Errorf("%s: %w", operation, err) } records = append(records, record) } return records, nil } // UpdateStatus applies one status transition in a compare-and-swap fashion. func (store *MembershipStore) UpdateStatus(ctx context.Context, input ports.UpdateMembershipStatusInput) error { if store == nil || store.client == nil { return errors.New("update membership status: nil store") } if ctx == nil { return errors.New("update membership status: nil context") } if err := input.Validate(); err != nil { return fmt.Errorf("update membership status: %w", err) } if err := membership.Transition(input.ExpectedFrom, input.To); err != nil { return err } primaryKey := store.keys.Membership(input.MembershipID) at := input.At.UTC() watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error { payload, getErr := tx.Get(ctx, primaryKey).Bytes() switch { case errors.Is(getErr, redis.Nil): return membership.ErrNotFound case getErr != nil: return fmt.Errorf("update membership status: %w", getErr) } existing, err := UnmarshalMembership(payload) if err != nil { return fmt.Errorf("update membership status: %w", err) } if existing.Status != input.ExpectedFrom { return fmt.Errorf("update membership status: %w", membership.ErrConflict) } existing.Status = input.To removedAt := at existing.RemovedAt = &removedAt encoded, err := MarshalMembership(existing) if err != nil { return fmt.Errorf("update membership status: %w", err) } _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Set(ctx, primaryKey, encoded, MembershipRecordTTL) return nil }) return err }, primaryKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("update membership status: %w", membership.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // Delete removes the membership record identified by membershipID from // the primary store and from the per-game and per-user index sets in // one transaction. It returns membership.ErrNotFound when no record // exists for the id and membership.ErrConflict when a concurrent // mutation invalidates the watched key. func (store *MembershipStore) Delete(ctx context.Context, membershipID common.MembershipID) error { if store == nil || store.client == nil { return errors.New("delete membership: nil store") } if ctx == nil { return errors.New("delete membership: nil context") } if err := membershipID.Validate(); err != nil { return fmt.Errorf("delete membership: %w", err) } primaryKey := store.keys.Membership(membershipID) member := membershipID.String() watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error { payload, getErr := tx.Get(ctx, primaryKey).Bytes() switch { case errors.Is(getErr, redis.Nil): return membership.ErrNotFound case getErr != nil: return fmt.Errorf("delete membership: %w", getErr) } existing, err := UnmarshalMembership(payload) if err != nil { return fmt.Errorf("delete membership: %w", err) } gameIndexKey := store.keys.MembershipsByGame(existing.GameID) userIndexKey := store.keys.MembershipsByUser(existing.UserID) _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Del(ctx, primaryKey) pipe.SRem(ctx, gameIndexKey, member) pipe.SRem(ctx, userIndexKey, member) return nil }) return err }, primaryKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("delete membership: %w", membership.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // Ensure MembershipStore satisfies the ports.MembershipStore interface at // compile time. var _ ports.MembershipStore = (*MembershipStore)(nil)