feat: authsession service
This commit is contained in:
@@ -0,0 +1,723 @@
|
||||
// Package sessionstore implements ports.SessionStore with Redis-backed strict
|
||||
// JSON source-of-truth session records and per-user indexes.
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const mutationRetryLimit = 3
|
||||
|
||||
// Config configures one Redis-backed session store instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// SessionKeyPrefix is the namespace prefix applied to primary session keys.
|
||||
SessionKeyPrefix string
|
||||
|
||||
// UserSessionsKeyPrefix is the namespace prefix applied to all-session user
|
||||
// indexes.
|
||||
UserSessionsKeyPrefix string
|
||||
|
||||
// UserActiveSessionsKeyPrefix is the namespace prefix applied to active
|
||||
// session user indexes.
|
||||
UserActiveSessionsKeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store persists source-of-truth sessions in Redis and maintains user-scoped
|
||||
// indexes for list and count operations.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
sessionKeyPrefix string
|
||||
userSessionsKeyPrefix string
|
||||
userActiveSessionsKeyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKeyBase64 string `json:"client_public_key_base64"`
|
||||
Status devicesession.Status `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
RevokedAt *string `json:"revoked_at,omitempty"`
|
||||
RevokeReasonCode string `json:"revoke_reason_code,omitempty"`
|
||||
RevokeActorType string `json:"revoke_actor_type,omitempty"`
|
||||
RevokeActorID string `json:"revoke_actor_id,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed session store from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis session store: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis session store: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.SessionKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: session key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: user sessions key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: user active sessions key prefix must not be empty")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis session store: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
sessionKeyPrefix: cfg.SessionKeyPrefix,
|
||||
userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix,
|
||||
userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis session store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the stored session for deviceSessionID.
|
||||
func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
||||
if err := deviceSessionID.Validate(); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "get session from redis")
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
record, err := s.loadSession(operationCtx, deviceSessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound)
|
||||
default:
|
||||
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// ListByUserID returns every stored session for userID in newest-first order.
|
||||
func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("list sessions by user id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
if len(deviceSessionIDs) == 0 {
|
||||
return []devicesession.Session{}, nil
|
||||
}
|
||||
|
||||
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
|
||||
for _, rawDeviceSessionID := range deviceSessionIDs {
|
||||
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
|
||||
record, err := s.loadSession(operationCtx, deviceSessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID)
|
||||
default:
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
if record.UserID != userID {
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
sortSessionsNewestFirst(records)
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// CountActiveByUserID returns the number of active sessions currently stored
|
||||
// for userID.
|
||||
func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return 0, fmt.Errorf("count active sessions by user id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// Create persists record as a new device session.
|
||||
func (s *Store) Create(ctx context.Context, record devicesession.Session) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create session in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalSessionRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session in redis: %w", err)
|
||||
}
|
||||
|
||||
deviceSessionKey := s.sessionKey(record.ID)
|
||||
allSessionsKey := s.userSessionsKey(record.UserID)
|
||||
activeSessionsKey := s.userActiveSessionsKey(record.UserID)
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "create session in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
_, err := tx.Get(operationCtx, deviceSessionKey).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
case err != nil:
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
|
||||
default:
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
|
||||
pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{
|
||||
Score: createdAtScore(record.CreatedAt),
|
||||
Member: record.ID.String(),
|
||||
})
|
||||
if record.Status == devicesession.StatusActive {
|
||||
pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{
|
||||
Score: createdAtScore(record.CreatedAt),
|
||||
Member: record.ID.String(),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, deviceSessionKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke stores a revoked view of one target session.
|
||||
func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err)
|
||||
}
|
||||
|
||||
var result ports.RevokeSessionResult
|
||||
err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error {
|
||||
deviceSessionKey := s.sessionKey(input.DeviceSessionID)
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound)
|
||||
default:
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Status == devicesession.StatusRevoked {
|
||||
result = ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
|
||||
Session: current,
|
||||
}
|
||||
return result.Validate()
|
||||
}
|
||||
|
||||
next := current
|
||||
next.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
next.Revocation = &revocation
|
||||
if err := next.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
payload, err := marshalSessionRecord(next)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
|
||||
pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String())
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
result = ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeRevoked,
|
||||
Session: next,
|
||||
}
|
||||
return result.Validate()
|
||||
}, deviceSessionKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return errRetryMutation
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RevokeAllByUserID stores revoked views for all currently active sessions
|
||||
// owned by input.UserID.
|
||||
func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err)
|
||||
}
|
||||
|
||||
var result ports.RevokeUserSessionsResult
|
||||
err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error {
|
||||
activeSessionsKey := s.userActiveSessionsKey(input.UserID)
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
if len(deviceSessionIDs) == 0 {
|
||||
// Force EXEC so WATCH observes concurrent active-index changes even
|
||||
// for the no-op path.
|
||||
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.ZCard(operationCtx, activeSessionsKey)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
|
||||
result = ports.RevokeUserSessionsResult{
|
||||
Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions,
|
||||
UserID: input.UserID,
|
||||
Sessions: []devicesession.Session{},
|
||||
}
|
||||
return result.Validate()
|
||||
}
|
||||
|
||||
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
|
||||
for _, rawDeviceSessionID := range deviceSessionIDs {
|
||||
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
|
||||
record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID)
|
||||
default:
|
||||
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
if record.UserID != input.UserID {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID)
|
||||
}
|
||||
if record.Status != devicesession.StatusActive {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status)
|
||||
}
|
||||
|
||||
next := record
|
||||
next.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
next.Revocation = &revocation
|
||||
if err := next.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
|
||||
}
|
||||
records = append(records, next)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
for _, record := range records {
|
||||
payload, err := marshalSessionRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("session %q: %w", record.ID, err)
|
||||
}
|
||||
pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0)
|
||||
pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
|
||||
sortSessionsNewestFirst(records)
|
||||
result = ports.RevokeUserSessionsResult{
|
||||
Outcome: ports.RevokeUserSessionsOutcomeRevoked,
|
||||
UserID: input.UserID,
|
||||
Sessions: records,
|
||||
}
|
||||
return result.Validate()
|
||||
}, activeSessionsKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return errRetryMutation
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var errRetryMutation = errors.New("redis session store: retry mutation")
|
||||
|
||||
func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error {
|
||||
for attempt := 0; attempt < mutationRetryLimit; attempt++ {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, operation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = execute(operationCtx)
|
||||
cancel()
|
||||
|
||||
switch {
|
||||
case errors.Is(err, errRetryMutation):
|
||||
if attempt == mutationRetryLimit-1 {
|
||||
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
|
||||
}
|
||||
continue
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
||||
return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get)
|
||||
}
|
||||
|
||||
func (s *Store) loadSessionWithGetter(
|
||||
ctx context.Context,
|
||||
deviceSessionID common.DeviceSessionID,
|
||||
getter func(context.Context, string) *redis.StringCmd,
|
||||
) (devicesession.Session, error) {
|
||||
payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return devicesession.Session{}, ports.ErrNotFound
|
||||
case err != nil:
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
record, err := decodeSessionRecord(deviceSessionID, payload)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string {
|
||||
return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String())
|
||||
}
|
||||
|
||||
func (s *Store) userSessionsKey(userID common.UserID) string {
|
||||
return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
func (s *Store) userActiveSessionsKey(userID common.UserID) string {
|
||||
return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func marshalSessionRecord(record devicesession.Session) ([]byte, error) {
|
||||
stored, err := redisRecordFromSession(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(stored)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode redis session record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func redisRecordFromSession(record devicesession.Session) (redisRecord, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return redisRecord{}, fmt.Errorf("encode redis session record: %w", err)
|
||||
}
|
||||
|
||||
stored := redisRecord{
|
||||
DeviceSessionID: record.ID.String(),
|
||||
UserID: record.UserID.String(),
|
||||
ClientPublicKeyBase64: record.ClientPublicKey.String(),
|
||||
Status: record.Status,
|
||||
CreatedAt: formatTimestamp(record.CreatedAt),
|
||||
}
|
||||
if record.Revocation != nil {
|
||||
stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At)
|
||||
stored.RevokeReasonCode = record.Revocation.ReasonCode.String()
|
||||
stored.RevokeActorType = record.Revocation.ActorType.String()
|
||||
stored.RevokeActorID = record.Revocation.ActorID
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
||||
}
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
record, err := sessionFromRedisRecord(stored)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
if record.ID != expectedDeviceSessionID {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) {
|
||||
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
|
||||
}
|
||||
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
|
||||
}
|
||||
|
||||
record := devicesession.Session{
|
||||
ID: common.DeviceSessionID(stored.DeviceSessionID),
|
||||
UserID: common.UserID(stored.UserID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: stored.Status,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
revocation, err := parseRevocation(stored)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
record.Revocation = revocation
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) {
|
||||
hasRevokedAt := stored.RevokedAt != nil
|
||||
hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != ""
|
||||
hasActorType := strings.TrimSpace(stored.RevokeActorType) != ""
|
||||
hasActorID := strings.TrimSpace(stored.RevokeActorID) != ""
|
||||
|
||||
if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID {
|
||||
return nil, nil
|
||||
}
|
||||
if !hasRevokedAt || !hasReasonCode || !hasActorType {
|
||||
return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent")
|
||||
}
|
||||
|
||||
revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &devicesession.Revocation{
|
||||
At: revokedAt,
|
||||
ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode),
|
||||
ActorType: common.RevokeActorType(stored.RevokeActorType),
|
||||
ActorID: stored.RevokeActorID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(fieldName string, value string) (time.Time, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName)
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err)
|
||||
}
|
||||
|
||||
canonical := parsed.UTC().Format(time.RFC3339Nano)
|
||||
if value != canonical {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
|
||||
}
|
||||
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
|
||||
func formatTimestamp(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func formatOptionalTimestamp(value *time.Time) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
formatted := formatTimestamp(*value)
|
||||
return &formatted
|
||||
}
|
||||
|
||||
func createdAtScore(createdAt time.Time) float64 {
|
||||
return float64(createdAt.UTC().UnixMicro())
|
||||
}
|
||||
|
||||
func sortSessionsNewestFirst(records []devicesession.Session) {
|
||||
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
|
||||
switch {
|
||||
case left.CreatedAt.Equal(right.CreatedAt):
|
||||
return strings.Compare(left.ID.String(), right.ID.String())
|
||||
case left.CreatedAt.After(right.CreatedAt):
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var _ ports.SessionStore = (*Store)(nil)
|
||||
@@ -0,0 +1,635 @@
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunSessionStoreContractTests(t, func(t *testing.T) ports.SessionStore {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
return newTestStore(t, server, Config{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 1,
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty all sessions prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "user sessions key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty active sessions prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "user active sessions key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
got.Revocation = &devicesession.Revocation{
|
||||
At: got.CreatedAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
}
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, again.Revocation)
|
||||
assert.Equal(t, record, again)
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetRevoked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(1_775_240_100, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
}
|
||||
|
||||
func TestStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreCreateConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_200, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
}
|
||||
|
||||
func TestStoreIndexesAndOrdering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
older := activeSessionFixture("device-session-old", "user-1", time.Unix(10, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-new", "user-1", time.Unix(20, 0).UTC())
|
||||
revoked := revokedSessionFixture("device-session-revoked", "user-1", time.Unix(15, 0).UTC())
|
||||
otherUser := activeSessionFixture("device-session-other", "user-2", time.Unix(30, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got, 3)
|
||||
assert.Equal(t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID})
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
unknown, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unknown)
|
||||
}
|
||||
|
||||
func TestStoreKeyPrefixesAndEncodedPrimaryKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{
|
||||
SessionKeyPrefix: "custom:session:",
|
||||
UserSessionsKeyPrefix: "custom:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "custom:user-active-sessions:",
|
||||
})
|
||||
|
||||
record := activeSessionFixture("device/session:opaque?1", "user/opaque:1", time.Unix(40, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
primaryKey := store.sessionKey(record.ID)
|
||||
assert.Equal(t, "custom:session:"+encodeKeyComponent(record.ID.String()), primaryKey)
|
||||
assert.True(t, server.Exists(primaryKey))
|
||||
|
||||
allSessionsKey := store.userSessionsKey(record.UserID)
|
||||
activeSessionsKey := store.userActiveSessionsKey(record.UserID)
|
||||
assert.Equal(t, "custom:user-sessions:"+encodeKeyComponent(record.UserID.String()), allSessionsKey)
|
||||
assert.Equal(t, "custom:user-active-sessions:"+encodeKeyComponent(record.UserID.String()), activeSessionsKey)
|
||||
|
||||
allMembers, err := server.ZMembers(allSessionsKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{record.ID.String()}, allMembers)
|
||||
|
||||
activeMembers, err := server.ZMembers(activeSessionsKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{record.ID.String()}, activeMembers)
|
||||
}
|
||||
|
||||
func TestStoreRevoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("active session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
revocation := devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
}
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, revocation, *result.Session.Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
})
|
||||
|
||||
t.Run("already revoked keeps stored revocation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(300, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
ActorID: "admin-1",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
|
||||
})
|
||||
|
||||
t.Run("unknown session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: common.DeviceSessionID("missing-session"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRevokeAllByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("revokes active sessions newest first and clears active index", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
older := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(200, 0).UTC())
|
||||
alreadyRevoked := revokedSessionFixture("device-session-3", "user-1", time.Unix(150, 0).UTC())
|
||||
otherUser := activeSessionFixture("device-session-4", "user-2", time.Unix(250, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
revocation := devicesession.Revocation{
|
||||
At: time.Unix(300, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
ActorID: "admin-1",
|
||||
}
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
|
||||
require.Len(t, result.Sessions, 2)
|
||||
assert.Equal(t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID})
|
||||
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
|
||||
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
|
||||
otherCount, err := store.CountActiveByUserID(context.Background(), common.UserID("user-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, otherCount)
|
||||
})
|
||||
|
||||
t.Run("no active sessions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-5", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(400, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
|
||||
assert.Empty(t, result.Sessions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreStrictDecodeCorruption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_240_300, 0).UTC()
|
||||
baseRecord := revokedSessionFixture("device-session-corrupt", "user-1", now)
|
||||
stored, err := redisRecordFromSession(baseRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(redisRecord) string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "malformed json",
|
||||
mutate: func(_ redisRecord) string {
|
||||
return "{"
|
||||
},
|
||||
wantErrText: "decode redis session record",
|
||||
},
|
||||
{
|
||||
name: "trailing json input",
|
||||
mutate: func(record redisRecord) string {
|
||||
return mustMarshalJSON(t, record) + "{}"
|
||||
},
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
mutate: func(record redisRecord) string {
|
||||
payload := map[string]any{
|
||||
"device_session_id": record.DeviceSessionID,
|
||||
"user_id": record.UserID,
|
||||
"client_public_key_base64": record.ClientPublicKeyBase64,
|
||||
"status": record.Status,
|
||||
"created_at": record.CreatedAt,
|
||||
"revoked_at": record.RevokedAt,
|
||||
"revoke_reason_code": record.RevokeReasonCode,
|
||||
"revoke_actor_type": record.RevokeActorType,
|
||||
"revoke_actor_id": record.RevokeActorID,
|
||||
"unexpected": true,
|
||||
}
|
||||
return mustMarshalJSON(t, payload)
|
||||
},
|
||||
wantErrText: "unknown field",
|
||||
},
|
||||
{
|
||||
name: "unsupported status",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Status = devicesession.Status("paused")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "non canonical timestamp",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.CreatedAt = "2026-04-04T12:00:00+03:00"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "canonical UTC RFC3339Nano timestamp",
|
||||
},
|
||||
{
|
||||
name: "incomplete revocation metadata",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.RevokeActorType = ""
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "revocation metadata must be either fully present or fully absent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Set(store.sessionKey(baseRecord.ID), tt.mutate(stored))
|
||||
|
||||
_, err := store.Get(context.Background(), baseRecord.ID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreListByUserIDDetectsCorruptIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("missing primary record", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
userID := common.UserID("user-1")
|
||||
_, err := server.ZAdd(store.userSessionsKey(userID), 100, "missing-session")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.ListByUserID(context.Background(), userID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "references missing session")
|
||||
})
|
||||
|
||||
t.Run("wrong user id in primary record", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-2", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
|
||||
_, err := server.ZAdd(store.userSessionsKey(common.UserID("user-1")), createdAtScore(record.CreatedAt), record.ID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `belongs to "user-2"`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRevokeAllByUserIDDetectsCorruptActiveIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
|
||||
_, err := server.ZAdd(store.userActiveSessionsKey(record.UserID), createdAtScore(record.CreatedAt), record.ID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: record.UserID,
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `is "revoked"`)
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionKeyPrefix == "" {
|
||||
cfg.SessionKeyPrefix = "authsession:session:"
|
||||
}
|
||||
if cfg.UserSessionsKeyPrefix == "" {
|
||||
cfg.UserSessionsKeyPrefix = "authsession:user-sessions:"
|
||||
}
|
||||
if cfg.UserActiveSessionsKeyPrefix == "" {
|
||||
cfg.UserActiveSessionsKeyPrefix = "authsession:user-active-sessions:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
record := activeSessionFixture(deviceSessionID, userID, createdAt)
|
||||
record.Status = devicesession.StatusRevoked
|
||||
record.Revocation = &devicesession.Revocation{
|
||||
At: createdAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonDeviceLogout,
|
||||
ActorType: common.RevokeActorType("user"),
|
||||
ActorID: "user-actor",
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
func seedSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record devicesession.Session) error {
|
||||
t.Helper()
|
||||
|
||||
stored, err := redisRecordFromSession(record)
|
||||
require.NoError(t, err)
|
||||
server.Set(key, mustMarshalJSON(t, stored))
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(payload)
|
||||
}
|
||||
Reference in New Issue
Block a user