package redisstate import ( "context" "errors" "fmt" "strings" "galaxy/lobby/internal/domain/common" "galaxy/lobby/internal/domain/invite" "galaxy/lobby/internal/ports" "github.com/redis/go-redis/v9" ) // InviteStore provides Redis-backed durable storage for invite records. type InviteStore struct { client *redis.Client keys Keyspace } // NewInviteStore constructs one Redis-backed invite store. It returns an // error when client is nil. func NewInviteStore(client *redis.Client) (*InviteStore, error) { if client == nil { return nil, errors.New("new invite store: nil redis client") } return &InviteStore{ client: client, keys: Keyspace{}, }, nil } // Save persists a new created invite record. Save is create-only; a // second save against the same invite id returns invite.ErrConflict. func (store *InviteStore) Save(ctx context.Context, record invite.Invite) error { if store == nil || store.client == nil { return errors.New("save invite: nil store") } if ctx == nil { return errors.New("save invite: nil context") } if err := record.Validate(); err != nil { return fmt.Errorf("save invite: %w", err) } if record.Status != invite.StatusCreated { return fmt.Errorf( "save invite: status must be %q, got %q", invite.StatusCreated, record.Status, ) } payload, err := MarshalInvite(record) if err != nil { return fmt.Errorf("save invite: %w", err) } primaryKey := store.keys.Invite(record.InviteID) gameIndexKey := store.keys.InvitesByGame(record.GameID) userIndexKey := store.keys.InvitesByUser(record.InviteeUserID) inviterIndexKey := store.keys.InvitesByInviter(record.InviterUserID) member := record.InviteID.String() watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error { existing, getErr := tx.Exists(ctx, primaryKey).Result() if getErr != nil { return fmt.Errorf("save invite: %w", getErr) } if existing != 0 { return fmt.Errorf("save invite: %w", invite.ErrConflict) } _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Set(ctx, primaryKey, payload, InviteRecordTTL) pipe.SAdd(ctx, gameIndexKey, member) pipe.SAdd(ctx, userIndexKey, member) pipe.SAdd(ctx, inviterIndexKey, member) return nil }) return err }, primaryKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("save invite: %w", invite.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // Get returns the record identified by inviteID. func (store *InviteStore) Get(ctx context.Context, inviteID common.InviteID) (invite.Invite, error) { if store == nil || store.client == nil { return invite.Invite{}, errors.New("get invite: nil store") } if ctx == nil { return invite.Invite{}, errors.New("get invite: nil context") } if err := inviteID.Validate(); err != nil { return invite.Invite{}, fmt.Errorf("get invite: %w", err) } payload, err := store.client.Get(ctx, store.keys.Invite(inviteID)).Bytes() switch { case errors.Is(err, redis.Nil): return invite.Invite{}, invite.ErrNotFound case err != nil: return invite.Invite{}, fmt.Errorf("get invite: %w", err) } record, err := UnmarshalInvite(payload) if err != nil { return invite.Invite{}, fmt.Errorf("get invite: %w", err) } return record, nil } // GetByGame returns every invite attached to gameID. func (store *InviteStore) GetByGame(ctx context.Context, gameID common.GameID) ([]invite.Invite, error) { if store == nil || store.client == nil { return nil, errors.New("get invites by game: nil store") } if ctx == nil { return nil, errors.New("get invites by game: nil context") } if err := gameID.Validate(); err != nil { return nil, fmt.Errorf("get invites by game: %w", err) } return store.loadInvitesBySet(ctx, "get invites by game", store.keys.InvitesByGame(gameID), ) } // GetByUser returns every invite addressed to inviteeUserID. func (store *InviteStore) GetByUser(ctx context.Context, inviteeUserID string) ([]invite.Invite, error) { if store == nil || store.client == nil { return nil, errors.New("get invites by user: nil store") } if ctx == nil { return nil, errors.New("get invites by user: nil context") } trimmed := strings.TrimSpace(inviteeUserID) if trimmed == "" { return nil, fmt.Errorf("get invites by user: invitee user id must not be empty") } return store.loadInvitesBySet(ctx, "get invites by user", store.keys.InvitesByUser(trimmed), ) } // GetByInviter returns every invite created by inviterUserID. func (store *InviteStore) GetByInviter(ctx context.Context, inviterUserID string) ([]invite.Invite, error) { if store == nil || store.client == nil { return nil, errors.New("get invites by inviter: nil store") } if ctx == nil { return nil, errors.New("get invites by inviter: nil context") } trimmed := strings.TrimSpace(inviterUserID) if trimmed == "" { return nil, fmt.Errorf("get invites by inviter: inviter user id must not be empty") } return store.loadInvitesBySet(ctx, "get invites by inviter", store.keys.InvitesByInviter(trimmed), ) } // loadInvitesBySet materializes invites whose ids are stored in setKey. // Stale set members (primary key removed out-of-band) are dropped silently. func (store *InviteStore) loadInvitesBySet(ctx context.Context, operation, setKey string) ([]invite.Invite, error) { members, err := store.client.SMembers(ctx, setKey).Result() if err != nil { return nil, fmt.Errorf("%s: %w", operation, err) } if len(members) == 0 { return nil, nil } primaryKeys := make([]string, len(members)) for index, member := range members { primaryKeys[index] = store.keys.Invite(common.InviteID(member)) } payloads, err := store.client.MGet(ctx, primaryKeys...).Result() if err != nil { return nil, fmt.Errorf("%s: %w", operation, err) } records := make([]invite.Invite, 0, len(payloads)) for _, entry := range payloads { if entry == nil { continue } raw, ok := entry.(string) if !ok { return nil, fmt.Errorf("%s: unexpected payload type %T", operation, entry) } record, err := UnmarshalInvite([]byte(raw)) if err != nil { return nil, fmt.Errorf("%s: %w", operation, err) } records = append(records, record) } return records, nil } // UpdateStatus applies one status transition in a compare-and-swap fashion. func (store *InviteStore) UpdateStatus(ctx context.Context, input ports.UpdateInviteStatusInput) error { if store == nil || store.client == nil { return errors.New("update invite status: nil store") } if ctx == nil { return errors.New("update invite status: nil context") } if err := input.Validate(); err != nil { return fmt.Errorf("update invite status: %w", err) } if err := invite.Transition(input.ExpectedFrom, input.To); err != nil { return err } primaryKey := store.keys.Invite(input.InviteID) at := input.At.UTC() watchErr := store.client.Watch(ctx, func(tx *redis.Tx) error { payload, getErr := tx.Get(ctx, primaryKey).Bytes() switch { case errors.Is(getErr, redis.Nil): return invite.ErrNotFound case getErr != nil: return fmt.Errorf("update invite status: %w", getErr) } existing, err := UnmarshalInvite(payload) if err != nil { return fmt.Errorf("update invite status: %w", err) } if existing.Status != input.ExpectedFrom { return fmt.Errorf("update invite status: %w", invite.ErrConflict) } existing.Status = input.To decidedAt := at existing.DecidedAt = &decidedAt if input.To == invite.StatusRedeemed { existing.RaceName = strings.TrimSpace(input.RaceName) } encoded, err := MarshalInvite(existing) if err != nil { return fmt.Errorf("update invite status: %w", err) } _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Set(ctx, primaryKey, encoded, InviteRecordTTL) return nil }) return err }, primaryKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("update invite status: %w", invite.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // Ensure InviteStore satisfies the ports.InviteStore interface at // compile time. var _ ports.InviteStore = (*InviteStore)(nil)