// Package sessionstore implements ports.SessionStore with Redis-backed strict // JSON source-of-truth session records and per-user indexes. package sessionstore import ( "bytes" "context" "crypto/tls" "encoding/base64" "encoding/json" "errors" "fmt" "io" "slices" "strings" "time" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/domain/devicesession" "galaxy/authsession/internal/ports" "github.com/redis/go-redis/v9" ) const mutationRetryLimit = 3 // Config configures one Redis-backed session store instance. type Config struct { // Addr is the Redis network address in host:port form. Addr string // Username is the optional Redis ACL username. Username string // Password is the optional Redis ACL password. Password string // DB is the Redis logical database index. DB int // TLSEnabled enables TLS with a conservative minimum protocol version. TLSEnabled bool // SessionKeyPrefix is the namespace prefix applied to primary session keys. SessionKeyPrefix string // UserSessionsKeyPrefix is the namespace prefix applied to all-session user // indexes. UserSessionsKeyPrefix string // UserActiveSessionsKeyPrefix is the namespace prefix applied to active // session user indexes. UserActiveSessionsKeyPrefix string // OperationTimeout bounds each Redis round trip performed by the adapter. OperationTimeout time.Duration } // Store persists source-of-truth sessions in Redis and maintains user-scoped // indexes for list and count operations. type Store struct { client *redis.Client sessionKeyPrefix string userSessionsKeyPrefix string userActiveSessionsKeyPrefix string operationTimeout time.Duration } type redisRecord struct { DeviceSessionID string `json:"device_session_id"` UserID string `json:"user_id"` ClientPublicKeyBase64 string `json:"client_public_key_base64"` Status devicesession.Status `json:"status"` CreatedAt string `json:"created_at"` RevokedAt *string `json:"revoked_at,omitempty"` RevokeReasonCode string `json:"revoke_reason_code,omitempty"` RevokeActorType string `json:"revoke_actor_type,omitempty"` RevokeActorID string `json:"revoke_actor_id,omitempty"` } // New constructs a Redis-backed session store from cfg. func New(cfg Config) (*Store, error) { switch { case strings.TrimSpace(cfg.Addr) == "": return nil, errors.New("new redis session store: redis addr must not be empty") case cfg.DB < 0: return nil, errors.New("new redis session store: redis db must not be negative") case strings.TrimSpace(cfg.SessionKeyPrefix) == "": return nil, errors.New("new redis session store: session key prefix must not be empty") case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "": return nil, errors.New("new redis session store: user sessions key prefix must not be empty") case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "": return nil, errors.New("new redis session store: user active sessions key prefix must not be empty") case cfg.OperationTimeout <= 0: return nil, errors.New("new redis session store: operation timeout must be positive") } options := &redis.Options{ Addr: cfg.Addr, Username: cfg.Username, Password: cfg.Password, DB: cfg.DB, Protocol: 2, DisableIdentity: true, } if cfg.TLSEnabled { options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} } return &Store{ client: redis.NewClient(options), sessionKeyPrefix: cfg.SessionKeyPrefix, userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix, userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix, operationTimeout: cfg.OperationTimeout, }, nil } // Close releases the underlying Redis client resources. func (s *Store) Close() error { if s == nil || s.client == nil { return nil } return s.client.Close() } // Ping verifies that the configured Redis backend is reachable within the // adapter operation timeout budget. func (s *Store) Ping(ctx context.Context) error { operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store") if err != nil { return err } defer cancel() if err := s.client.Ping(operationCtx).Err(); err != nil { return fmt.Errorf("ping redis session store: %w", err) } return nil } // Get returns the stored session for deviceSessionID. func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { if err := deviceSessionID.Validate(); err != nil { return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err) } operationCtx, cancel, err := s.operationContext(ctx, "get session from redis") if err != nil { return devicesession.Session{}, err } defer cancel() record, err := s.loadSession(operationCtx, deviceSessionID) if err != nil { switch { case errors.Is(err, ports.ErrNotFound): return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound) default: return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err) } } return record, nil } // ListByUserID returns every stored session for userID in newest-first order. func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) { if err := userID.Validate(); err != nil { return nil, fmt.Errorf("list sessions by user id from redis: %w", err) } operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis") if err != nil { return nil, err } defer cancel() deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result() if err != nil { return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err) } if len(deviceSessionIDs) == 0 { return []devicesession.Session{}, nil } records := make([]devicesession.Session, 0, len(deviceSessionIDs)) for _, rawDeviceSessionID := range deviceSessionIDs { deviceSessionID := common.DeviceSessionID(rawDeviceSessionID) record, err := s.loadSession(operationCtx, deviceSessionID) if err != nil { switch { case errors.Is(err, ports.ErrNotFound): return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID) default: return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err) } } if record.UserID != userID { return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID) } records = append(records, record) } sortSessionsNewestFirst(records) return records, nil } // CountActiveByUserID returns the number of active sessions currently stored // for userID. func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) { if err := userID.Validate(); err != nil { return 0, fmt.Errorf("count active sessions by user id from redis: %w", err) } operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis") if err != nil { return 0, err } defer cancel() count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result() if err != nil { return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err) } return int(count), nil } // Create persists record as a new device session. func (s *Store) Create(ctx context.Context, record devicesession.Session) error { if err := record.Validate(); err != nil { return fmt.Errorf("create session in redis: %w", err) } payload, err := marshalSessionRecord(record) if err != nil { return fmt.Errorf("create session in redis: %w", err) } deviceSessionKey := s.sessionKey(record.ID) allSessionsKey := s.userSessionsKey(record.UserID) activeSessionsKey := s.userActiveSessionsKey(record.UserID) operationCtx, cancel, err := s.operationContext(ctx, "create session in redis") if err != nil { return err } defer cancel() watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { _, err := tx.Get(operationCtx, deviceSessionKey).Bytes() switch { case errors.Is(err, redis.Nil): case err != nil: return fmt.Errorf("create session %q in redis: %w", record.ID, err) default: return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict) } _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { pipe.Set(operationCtx, deviceSessionKey, payload, 0) pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{ Score: createdAtScore(record.CreatedAt), Member: record.ID.String(), }) if record.Status == devicesession.StatusActive { pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{ Score: createdAtScore(record.CreatedAt), Member: record.ID.String(), }) } return nil }) if err != nil { return fmt.Errorf("create session %q in redis: %w", record.ID, err) } return nil }, deviceSessionKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict) case watchErr != nil: return watchErr default: return nil } } // Revoke stores a revoked view of one target session. func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) { if err := input.Validate(); err != nil { return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err) } var result ports.RevokeSessionResult err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error { deviceSessionKey := s.sessionKey(input.DeviceSessionID) watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get) if err != nil { switch { case errors.Is(err, ports.ErrNotFound): return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound) default: return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) } } if current.Status == devicesession.StatusRevoked { result = ports.RevokeSessionResult{ Outcome: ports.RevokeSessionOutcomeAlreadyRevoked, Session: current, } return result.Validate() } next := current next.Status = devicesession.StatusRevoked revocation := input.Revocation next.Revocation = &revocation if err := next.Validate(); err != nil { return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) } payload, err := marshalSessionRecord(next) if err != nil { return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) } _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { pipe.Set(operationCtx, deviceSessionKey, payload, 0) pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String()) return nil }) if err != nil { return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err) } result = ports.RevokeSessionResult{ Outcome: ports.RevokeSessionOutcomeRevoked, Session: next, } return result.Validate() }, deviceSessionKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return errRetryMutation case watchErr != nil: return watchErr default: return nil } }) if err != nil { return ports.RevokeSessionResult{}, err } return result, nil } // RevokeAllByUserID stores revoked views for all currently active sessions // owned by input.UserID. func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) { if err := input.Validate(); err != nil { return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err) } var result ports.RevokeUserSessionsResult err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error { activeSessionsKey := s.userActiveSessionsKey(input.UserID) watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result() if err != nil { return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err) } if len(deviceSessionIDs) == 0 { // Force EXEC so WATCH observes concurrent active-index changes even // for the no-op path. _, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { pipe.ZCard(operationCtx, activeSessionsKey) return nil }) if err != nil { return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err) } result = ports.RevokeUserSessionsResult{ Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions, UserID: input.UserID, Sessions: []devicesession.Session{}, } return result.Validate() } records := make([]devicesession.Session, 0, len(deviceSessionIDs)) for _, rawDeviceSessionID := range deviceSessionIDs { deviceSessionID := common.DeviceSessionID(rawDeviceSessionID) record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get) if err != nil { switch { case errors.Is(err, ports.ErrNotFound): return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID) default: return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err) } } if record.UserID != input.UserID { return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID) } if record.Status != devicesession.StatusActive { return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status) } next := record next.Status = devicesession.StatusRevoked revocation := input.Revocation next.Revocation = &revocation if err := next.Validate(); err != nil { return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err) } records = append(records, next) } _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { for _, record := range records { payload, err := marshalSessionRecord(record) if err != nil { return fmt.Errorf("session %q: %w", record.ID, err) } pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0) pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String()) } return nil }) if err != nil { return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err) } sortSessionsNewestFirst(records) result = ports.RevokeUserSessionsResult{ Outcome: ports.RevokeUserSessionsOutcomeRevoked, UserID: input.UserID, Sessions: records, } return result.Validate() }, activeSessionsKey) switch { case errors.Is(watchErr, redis.TxFailedErr): return errRetryMutation case watchErr != nil: return watchErr default: return nil } }) if err != nil { return ports.RevokeUserSessionsResult{}, err } return result, nil } var errRetryMutation = errors.New("redis session store: retry mutation") func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error { for attempt := 0; attempt < mutationRetryLimit; attempt++ { operationCtx, cancel, err := s.operationContext(ctx, operation) if err != nil { return err } err = execute(operationCtx) cancel() switch { case errors.Is(err, errRetryMutation): if attempt == mutationRetryLimit-1 { return fmt.Errorf("%s: mutation retry limit exceeded", operation) } continue default: return err } } return fmt.Errorf("%s: mutation retry limit exceeded", operation) } func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) { if s == nil || s.client == nil { return nil, nil, fmt.Errorf("%s: nil store", operation) } if ctx == nil { return nil, nil, fmt.Errorf("%s: nil context", operation) } operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout) return operationCtx, cancel, nil } func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get) } func (s *Store) loadSessionWithGetter( ctx context.Context, deviceSessionID common.DeviceSessionID, getter func(context.Context, string) *redis.StringCmd, ) (devicesession.Session, error) { payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes() switch { case errors.Is(err, redis.Nil): return devicesession.Session{}, ports.ErrNotFound case err != nil: return devicesession.Session{}, err } record, err := decodeSessionRecord(deviceSessionID, payload) if err != nil { return devicesession.Session{}, err } return record, nil } func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string { return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String()) } func (s *Store) userSessionsKey(userID common.UserID) string { return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String()) } func (s *Store) userActiveSessionsKey(userID common.UserID) string { return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String()) } func encodeKeyComponent(value string) string { return base64.RawURLEncoding.EncodeToString([]byte(value)) } func marshalSessionRecord(record devicesession.Session) ([]byte, error) { stored, err := redisRecordFromSession(record) if err != nil { return nil, err } payload, err := json.Marshal(stored) if err != nil { return nil, fmt.Errorf("encode redis session record: %w", err) } return payload, nil } func redisRecordFromSession(record devicesession.Session) (redisRecord, error) { if err := record.Validate(); err != nil { return redisRecord{}, fmt.Errorf("encode redis session record: %w", err) } stored := redisRecord{ DeviceSessionID: record.ID.String(), UserID: record.UserID.String(), ClientPublicKeyBase64: record.ClientPublicKey.String(), Status: record.Status, CreatedAt: formatTimestamp(record.CreatedAt), } if record.Revocation != nil { stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At) stored.RevokeReasonCode = record.Revocation.ReasonCode.String() stored.RevokeActorType = record.Revocation.ActorType.String() stored.RevokeActorID = record.Revocation.ActorID } return stored, nil } func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) { decoder := json.NewDecoder(bytes.NewReader(payload)) decoder.DisallowUnknownFields() var stored redisRecord if err := decoder.Decode(&stored); err != nil { return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err) } if err := decoder.Decode(&struct{}{}); err != io.EOF { if err == nil { return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input") } return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err) } record, err := sessionFromRedisRecord(stored) if err != nil { return devicesession.Session{}, err } if record.ID != expectedDeviceSessionID { return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID) } return record, nil } func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) { createdAt, err := parseTimestamp("created_at", stored.CreatedAt) if err != nil { return devicesession.Session{}, err } rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64) if err != nil { return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err) } clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey) if err != nil { return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err) } record := devicesession.Session{ ID: common.DeviceSessionID(stored.DeviceSessionID), UserID: common.UserID(stored.UserID), ClientPublicKey: clientPublicKey, Status: stored.Status, CreatedAt: createdAt, } revocation, err := parseRevocation(stored) if err != nil { return devicesession.Session{}, err } record.Revocation = revocation if err := record.Validate(); err != nil { return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err) } return record, nil } func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) { hasRevokedAt := stored.RevokedAt != nil hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != "" hasActorType := strings.TrimSpace(stored.RevokeActorType) != "" hasActorID := strings.TrimSpace(stored.RevokeActorID) != "" if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID { return nil, nil } if !hasRevokedAt || !hasReasonCode || !hasActorType { return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent") } revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt) if err != nil { return nil, err } return &devicesession.Revocation{ At: revokedAt, ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode), ActorType: common.RevokeActorType(stored.RevokeActorType), ActorID: stored.RevokeActorID, }, nil } func parseTimestamp(fieldName string, value string) (time.Time, error) { if strings.TrimSpace(value) == "" { return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName) } parsed, err := time.Parse(time.RFC3339Nano, value) if err != nil { return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err) } canonical := parsed.UTC().Format(time.RFC3339Nano) if value != canonical { return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName) } return parsed.UTC(), nil } func formatTimestamp(value time.Time) string { return value.UTC().Format(time.RFC3339Nano) } func formatOptionalTimestamp(value *time.Time) *string { if value == nil { return nil } formatted := formatTimestamp(*value) return &formatted } func createdAtScore(createdAt time.Time) float64 { return float64(createdAt.UTC().UnixMicro()) } func sortSessionsNewestFirst(records []devicesession.Session) { slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int { switch { case left.CreatedAt.Equal(right.CreatedAt): return strings.Compare(left.ID.String(), right.ID.String()) case left.CreatedAt.After(right.CreatedAt): return -1 default: return 1 } }) } var _ ports.SessionStore = (*Store)(nil)