Files
galaxy-game/authsession/internal/adapters/redis/sessionstore/store.go
T
2026-04-08 16:23:07 +02:00

724 lines
23 KiB
Go

// Package sessionstore implements ports.SessionStore with Redis-backed strict
// JSON source-of-truth session records and per-user indexes.
package sessionstore
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
const mutationRetryLimit = 3
// Config configures one Redis-backed session store instance.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// SessionKeyPrefix is the namespace prefix applied to primary session keys.
SessionKeyPrefix string
// UserSessionsKeyPrefix is the namespace prefix applied to all-session user
// indexes.
UserSessionsKeyPrefix string
// UserActiveSessionsKeyPrefix is the namespace prefix applied to active
// session user indexes.
UserActiveSessionsKeyPrefix string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Store persists source-of-truth sessions in Redis and maintains user-scoped
// indexes for list and count operations.
type Store struct {
client *redis.Client
sessionKeyPrefix string
userSessionsKeyPrefix string
userActiveSessionsKeyPrefix string
operationTimeout time.Duration
}
type redisRecord struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKeyBase64 string `json:"client_public_key_base64"`
Status devicesession.Status `json:"status"`
CreatedAt string `json:"created_at"`
RevokedAt *string `json:"revoked_at,omitempty"`
RevokeReasonCode string `json:"revoke_reason_code,omitempty"`
RevokeActorType string `json:"revoke_actor_type,omitempty"`
RevokeActorID string `json:"revoke_actor_id,omitempty"`
}
// New constructs a Redis-backed session store from cfg.
func New(cfg Config) (*Store, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis session store: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis session store: redis db must not be negative")
case strings.TrimSpace(cfg.SessionKeyPrefix) == "":
return nil, errors.New("new redis session store: session key prefix must not be empty")
case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "":
return nil, errors.New("new redis session store: user sessions key prefix must not be empty")
case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "":
return nil, errors.New("new redis session store: user active sessions key prefix must not be empty")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis session store: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Store{
client: redis.NewClient(options),
sessionKeyPrefix: cfg.SessionKeyPrefix,
userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix,
userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *Store) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (s *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store")
if err != nil {
return err
}
defer cancel()
if err := s.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis session store: %w", err)
}
return nil
}
// Get returns the stored session for deviceSessionID.
func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
if err := deviceSessionID.Validate(); err != nil {
return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "get session from redis")
if err != nil {
return devicesession.Session{}, err
}
defer cancel()
record, err := s.loadSession(operationCtx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound)
default:
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err)
}
}
return record, nil
}
// ListByUserID returns every stored session for userID in newest-first order.
func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list sessions by user id from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis")
if err != nil {
return nil, err
}
defer cancel()
deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result()
if err != nil {
return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err)
}
if len(deviceSessionIDs) == 0 {
return []devicesession.Session{}, nil
}
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
for _, rawDeviceSessionID := range deviceSessionIDs {
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
record, err := s.loadSession(operationCtx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID)
default:
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err)
}
}
if record.UserID != userID {
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID)
}
records = append(records, record)
}
sortSessionsNewestFirst(records)
return records, nil
}
// CountActiveByUserID returns the number of active sessions currently stored
// for userID.
func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
if err := userID.Validate(); err != nil {
return 0, fmt.Errorf("count active sessions by user id from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis")
if err != nil {
return 0, err
}
defer cancel()
count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result()
if err != nil {
return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err)
}
return int(count), nil
}
// Create persists record as a new device session.
func (s *Store) Create(ctx context.Context, record devicesession.Session) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create session in redis: %w", err)
}
payload, err := marshalSessionRecord(record)
if err != nil {
return fmt.Errorf("create session in redis: %w", err)
}
deviceSessionKey := s.sessionKey(record.ID)
allSessionsKey := s.userSessionsKey(record.UserID)
activeSessionsKey := s.userActiveSessionsKey(record.UserID)
operationCtx, cancel, err := s.operationContext(ctx, "create session in redis")
if err != nil {
return err
}
defer cancel()
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
_, err := tx.Get(operationCtx, deviceSessionKey).Bytes()
switch {
case errors.Is(err, redis.Nil):
case err != nil:
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
default:
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{
Score: createdAtScore(record.CreatedAt),
Member: record.ID.String(),
})
if record.Status == devicesession.StatusActive {
pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{
Score: createdAtScore(record.CreatedAt),
Member: record.ID.String(),
})
}
return nil
})
if err != nil {
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
}
return nil
}, deviceSessionKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// Revoke stores a revoked view of one target session.
func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
if err := input.Validate(); err != nil {
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err)
}
var result ports.RevokeSessionResult
err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error {
deviceSessionKey := s.sessionKey(input.DeviceSessionID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound)
default:
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
}
if current.Status == devicesession.StatusRevoked {
result = ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
Session: current,
}
return result.Validate()
}
next := current
next.Status = devicesession.StatusRevoked
revocation := input.Revocation
next.Revocation = &revocation
if err := next.Validate(); err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
payload, err := marshalSessionRecord(next)
if err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String())
return nil
})
if err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
result = ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeRevoked,
Session: next,
}
return result.Validate()
}, deviceSessionKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return errRetryMutation
case watchErr != nil:
return watchErr
default:
return nil
}
})
if err != nil {
return ports.RevokeSessionResult{}, err
}
return result, nil
}
// RevokeAllByUserID stores revoked views for all currently active sessions
// owned by input.UserID.
func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
if err := input.Validate(); err != nil {
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err)
}
var result ports.RevokeUserSessionsResult
err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error {
activeSessionsKey := s.userActiveSessionsKey(input.UserID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result()
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
if len(deviceSessionIDs) == 0 {
// Force EXEC so WATCH observes concurrent active-index changes even
// for the no-op path.
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.ZCard(operationCtx, activeSessionsKey)
return nil
})
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
result = ports.RevokeUserSessionsResult{
Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions,
UserID: input.UserID,
Sessions: []devicesession.Session{},
}
return result.Validate()
}
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
for _, rawDeviceSessionID := range deviceSessionIDs {
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID)
default:
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
}
}
if record.UserID != input.UserID {
return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID)
}
if record.Status != devicesession.StatusActive {
return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status)
}
next := record
next.Status = devicesession.StatusRevoked
revocation := input.Revocation
next.Revocation = &revocation
if err := next.Validate(); err != nil {
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
}
records = append(records, next)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
for _, record := range records {
payload, err := marshalSessionRecord(record)
if err != nil {
return fmt.Errorf("session %q: %w", record.ID, err)
}
pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0)
pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String())
}
return nil
})
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
sortSessionsNewestFirst(records)
result = ports.RevokeUserSessionsResult{
Outcome: ports.RevokeUserSessionsOutcomeRevoked,
UserID: input.UserID,
Sessions: records,
}
return result.Validate()
}, activeSessionsKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return errRetryMutation
case watchErr != nil:
return watchErr
default:
return nil
}
})
if err != nil {
return ports.RevokeUserSessionsResult{}, err
}
return result, nil
}
var errRetryMutation = errors.New("redis session store: retry mutation")
func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error {
for attempt := 0; attempt < mutationRetryLimit; attempt++ {
operationCtx, cancel, err := s.operationContext(ctx, operation)
if err != nil {
return err
}
err = execute(operationCtx)
cancel()
switch {
case errors.Is(err, errRetryMutation):
if attempt == mutationRetryLimit-1 {
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
}
continue
default:
return err
}
}
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
}
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if s == nil || s.client == nil {
return nil, nil, fmt.Errorf("%s: nil store", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
return operationCtx, cancel, nil
}
func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get)
}
func (s *Store) loadSessionWithGetter(
ctx context.Context,
deviceSessionID common.DeviceSessionID,
getter func(context.Context, string) *redis.StringCmd,
) (devicesession.Session, error) {
payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return devicesession.Session{}, ports.ErrNotFound
case err != nil:
return devicesession.Session{}, err
}
record, err := decodeSessionRecord(deviceSessionID, payload)
if err != nil {
return devicesession.Session{}, err
}
return record, nil
}
func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string {
return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String())
}
func (s *Store) userSessionsKey(userID common.UserID) string {
return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String())
}
func (s *Store) userActiveSessionsKey(userID common.UserID) string {
return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String())
}
func encodeKeyComponent(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func marshalSessionRecord(record devicesession.Session) ([]byte, error) {
stored, err := redisRecordFromSession(record)
if err != nil {
return nil, err
}
payload, err := json.Marshal(stored)
if err != nil {
return nil, fmt.Errorf("encode redis session record: %w", err)
}
return payload, nil
}
func redisRecordFromSession(record devicesession.Session) (redisRecord, error) {
if err := record.Validate(); err != nil {
return redisRecord{}, fmt.Errorf("encode redis session record: %w", err)
}
stored := redisRecord{
DeviceSessionID: record.ID.String(),
UserID: record.UserID.String(),
ClientPublicKeyBase64: record.ClientPublicKey.String(),
Status: record.Status,
CreatedAt: formatTimestamp(record.CreatedAt),
}
if record.Revocation != nil {
stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At)
stored.RevokeReasonCode = record.Revocation.ReasonCode.String()
stored.RevokeActorType = record.Revocation.ActorType.String()
stored.RevokeActorID = record.Revocation.ActorID
}
return stored, nil
}
func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
var stored redisRecord
if err := decoder.Decode(&stored); err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input")
}
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
record, err := sessionFromRedisRecord(stored)
if err != nil {
return devicesession.Session{}, err
}
if record.ID != expectedDeviceSessionID {
return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID)
}
return record, nil
}
func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) {
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
if err != nil {
return devicesession.Session{}, err
}
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64)
if err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
}
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
if err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
}
record := devicesession.Session{
ID: common.DeviceSessionID(stored.DeviceSessionID),
UserID: common.UserID(stored.UserID),
ClientPublicKey: clientPublicKey,
Status: stored.Status,
CreatedAt: createdAt,
}
revocation, err := parseRevocation(stored)
if err != nil {
return devicesession.Session{}, err
}
record.Revocation = revocation
if err := record.Validate(); err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
return record, nil
}
func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) {
hasRevokedAt := stored.RevokedAt != nil
hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != ""
hasActorType := strings.TrimSpace(stored.RevokeActorType) != ""
hasActorID := strings.TrimSpace(stored.RevokeActorID) != ""
if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID {
return nil, nil
}
if !hasRevokedAt || !hasReasonCode || !hasActorType {
return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent")
}
revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt)
if err != nil {
return nil, err
}
return &devicesession.Revocation{
At: revokedAt,
ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode),
ActorType: common.RevokeActorType(stored.RevokeActorType),
ActorID: stored.RevokeActorID,
}, nil
}
func parseTimestamp(fieldName string, value string) (time.Time, error) {
if strings.TrimSpace(value) == "" {
return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName)
}
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err)
}
canonical := parsed.UTC().Format(time.RFC3339Nano)
if value != canonical {
return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
}
return parsed.UTC(), nil
}
func formatTimestamp(value time.Time) string {
return value.UTC().Format(time.RFC3339Nano)
}
func formatOptionalTimestamp(value *time.Time) *string {
if value == nil {
return nil
}
formatted := formatTimestamp(*value)
return &formatted
}
func createdAtScore(createdAt time.Time) float64 {
return float64(createdAt.UTC().UnixMicro())
}
func sortSessionsNewestFirst(records []devicesession.Session) {
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
switch {
case left.CreatedAt.Equal(right.CreatedAt):
return strings.Compare(left.ID.String(), right.ID.String())
case left.CreatedAt.After(right.CreatedAt):
return -1
default:
return 1
}
})
}
var _ ports.SessionStore = (*Store)(nil)