feat: authsession service

This commit is contained in:
Ilia Denisov
2026-04-08 16:23:07 +02:00
committed by GitHub
parent 28f04916af
commit 86a68ed9d0
174 changed files with 31732 additions and 112 deletions
@@ -0,0 +1,484 @@
// Package challengestore implements ports.ChallengeStore with Redis-backed
// strict JSON challenge records.
package challengestore
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
"strings"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
const expirationGracePeriod = 5 * time.Minute
// Config configures one Redis-backed challenge store instance.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// KeyPrefix is the namespace prefix applied to every challenge key.
KeyPrefix string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Store persists challenges as one strict JSON value per Redis key.
type Store struct {
client *redis.Client
keyPrefix string
operationTimeout time.Duration
}
type redisRecord struct {
ChallengeID string `json:"challenge_id"`
Email string `json:"email"`
CodeHashBase64 string `json:"code_hash_base64"`
Status challenge.Status `json:"status"`
DeliveryState challenge.DeliveryState `json:"delivery_state"`
CreatedAt string `json:"created_at"`
ExpiresAt string `json:"expires_at"`
SendAttemptCount int `json:"send_attempt_count"`
ConfirmAttemptCount int `json:"confirm_attempt_count"`
LastAttemptAt *string `json:"last_attempt_at,omitempty"`
ConfirmedSessionID string `json:"confirmed_session_id,omitempty"`
ConfirmedClientPublicKey string `json:"confirmed_client_public_key,omitempty"`
ConfirmedAt *string `json:"confirmed_at,omitempty"`
}
// New constructs a Redis-backed challenge store from cfg.
func New(cfg Config) (*Store, error) {
if strings.TrimSpace(cfg.Addr) == "" {
return nil, errors.New("new redis challenge store: redis addr must not be empty")
}
if cfg.DB < 0 {
return nil, errors.New("new redis challenge store: redis db must not be negative")
}
if strings.TrimSpace(cfg.KeyPrefix) == "" {
return nil, errors.New("new redis challenge store: redis key prefix must not be empty")
}
if cfg.OperationTimeout <= 0 {
return nil, errors.New("new redis challenge store: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Store{
client: redis.NewClient(options),
keyPrefix: cfg.KeyPrefix,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *Store) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (s *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := s.operationContext(ctx, "ping redis challenge store")
if err != nil {
return err
}
defer cancel()
if err := s.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis challenge store: %w", err)
}
return nil
}
// Get returns the stored challenge for challengeID.
func (s *Store) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
if err := challengeID.Validate(); err != nil {
return challenge.Challenge{}, fmt.Errorf("get challenge from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "get challenge from redis")
if err != nil {
return challenge.Challenge{}, err
}
defer cancel()
payload, err := s.client.Get(operationCtx, s.lookupKey(challengeID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, ports.ErrNotFound)
case err != nil:
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
}
record, err := decodeChallengeRecord(challengeID, payload)
if err != nil {
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
}
return record, nil
}
// Create persists record as a new challenge.
func (s *Store) Create(ctx context.Context, record challenge.Challenge) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create challenge in redis: %w", err)
}
payload, err := marshalChallengeRecord(record)
if err != nil {
return fmt.Errorf("create challenge in redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "create challenge in redis")
if err != nil {
return err
}
defer cancel()
created, err := s.client.SetNX(operationCtx, s.lookupKey(record.ID), payload, redisTTL(record.ExpiresAt)).Result()
if err != nil {
return fmt.Errorf("create challenge %q in redis: %w", record.ID, err)
}
if !created {
return fmt.Errorf("create challenge %q in redis: %w", record.ID, ports.ErrConflict)
}
return nil
}
// CompareAndSwap replaces previous with next when the currently stored
// challenge matches previous exactly in canonical Redis representation.
func (s *Store) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
return fmt.Errorf("compare and swap challenge in redis: %w", err)
}
nextPayload, err := marshalChallengeRecord(next)
if err != nil {
return fmt.Errorf("compare and swap challenge in redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "compare and swap challenge in redis")
if err != nil {
return err
}
defer cancel()
key := s.lookupKey(previous.ID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
payload, err := tx.Get(operationCtx, key).Bytes()
switch {
case errors.Is(err, redis.Nil):
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrNotFound)
case err != nil:
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
current, err := decodeChallengeRecord(previous.ID, payload)
if err != nil {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
matches, err := equalStoredChallenges(current, previous)
if err != nil {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
if !matches {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, key, nextPayload, redisTTL(next.ExpiresAt))
return nil
})
if err != nil {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
return nil
}, key)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if s == nil || s.client == nil {
return nil, nil, fmt.Errorf("%s: nil store", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
return operationCtx, cancel, nil
}
func (s *Store) lookupKey(challengeID common.ChallengeID) string {
return s.keyPrefix + encodeKeyComponent(challengeID.String())
}
func encodeKeyComponent(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func marshalChallengeRecord(record challenge.Challenge) ([]byte, error) {
stored, err := redisRecordFromChallenge(record)
if err != nil {
return nil, err
}
payload, err := json.Marshal(stored)
if err != nil {
return nil, fmt.Errorf("encode redis challenge record: %w", err)
}
return payload, nil
}
func redisRecordFromChallenge(record challenge.Challenge) (redisRecord, error) {
if err := record.Validate(); err != nil {
return redisRecord{}, fmt.Errorf("encode redis challenge record: %w", err)
}
stored := redisRecord{
ChallengeID: record.ID.String(),
Email: record.Email.String(),
CodeHashBase64: base64.StdEncoding.EncodeToString(record.CodeHash),
Status: record.Status,
DeliveryState: record.DeliveryState,
CreatedAt: formatTimestamp(record.CreatedAt),
ExpiresAt: formatTimestamp(record.ExpiresAt),
SendAttemptCount: record.Attempts.Send,
ConfirmAttemptCount: record.Attempts.Confirm,
LastAttemptAt: formatOptionalTimestamp(record.Abuse.LastAttemptAt),
}
if record.Confirmation != nil {
stored.ConfirmedSessionID = record.Confirmation.SessionID.String()
stored.ConfirmedClientPublicKey = record.Confirmation.ClientPublicKey.String()
stored.ConfirmedAt = formatOptionalTimestamp(&record.Confirmation.ConfirmedAt)
}
return stored, nil
}
func decodeChallengeRecord(expectedChallengeID common.ChallengeID, payload []byte) (challenge.Challenge, error) {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
var stored redisRecord
if err := decoder.Decode(&stored); err != nil {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return challenge.Challenge{}, errors.New("decode redis challenge record: unexpected trailing JSON input")
}
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
}
record, err := challengeFromRedisRecord(stored)
if err != nil {
return challenge.Challenge{}, err
}
if record.ID != expectedChallengeID {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: challenge_id %q does not match requested %q", record.ID, expectedChallengeID)
}
return record, nil
}
func challengeFromRedisRecord(stored redisRecord) (challenge.Challenge, error) {
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
if err != nil {
return challenge.Challenge{}, err
}
expiresAt, err := parseTimestamp("expires_at", stored.ExpiresAt)
if err != nil {
return challenge.Challenge{}, err
}
lastAttemptAt, err := parseOptionalTimestamp("last_attempt_at", stored.LastAttemptAt)
if err != nil {
return challenge.Challenge{}, err
}
codeHash, err := base64.StdEncoding.Strict().DecodeString(stored.CodeHashBase64)
if err != nil {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: code_hash_base64: %w", err)
}
record := challenge.Challenge{
ID: common.ChallengeID(stored.ChallengeID),
Email: common.Email(stored.Email),
CodeHash: codeHash,
Status: stored.Status,
DeliveryState: stored.DeliveryState,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
Attempts: challenge.AttemptCounters{
Send: stored.SendAttemptCount,
Confirm: stored.ConfirmAttemptCount,
},
Abuse: challenge.AbuseMetadata{
LastAttemptAt: lastAttemptAt,
},
}
confirmation, err := parseConfirmation(stored)
if err != nil {
return challenge.Challenge{}, err
}
record.Confirmation = confirmation
if err := record.Validate(); err != nil {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
}
return record, nil
}
func parseConfirmation(stored redisRecord) (*challenge.Confirmation, error) {
hasSessionID := strings.TrimSpace(stored.ConfirmedSessionID) != ""
hasClientPublicKey := strings.TrimSpace(stored.ConfirmedClientPublicKey) != ""
hasConfirmedAt := stored.ConfirmedAt != nil
if !hasSessionID && !hasClientPublicKey && !hasConfirmedAt {
return nil, nil
}
if !hasSessionID || !hasClientPublicKey || !hasConfirmedAt {
return nil, errors.New("decode redis challenge record: confirmation metadata must be either fully present or fully absent")
}
confirmedAt, err := parseTimestamp("confirmed_at", *stored.ConfirmedAt)
if err != nil {
return nil, err
}
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ConfirmedClientPublicKey)
if err != nil {
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
}
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
if err != nil {
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
}
return &challenge.Confirmation{
SessionID: common.DeviceSessionID(stored.ConfirmedSessionID),
ClientPublicKey: clientPublicKey,
ConfirmedAt: confirmedAt,
}, nil
}
func parseOptionalTimestamp(fieldName string, value *string) (*time.Time, error) {
if value == nil {
return nil, nil
}
parsed, err := parseTimestamp(fieldName, *value)
if err != nil {
return nil, err
}
return &parsed, nil
}
func parseTimestamp(fieldName string, value string) (time.Time, error) {
if strings.TrimSpace(value) == "" {
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must not be empty", fieldName)
}
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return time.Time{}, fmt.Errorf("decode redis challenge record: %s: %w", fieldName, err)
}
canonical := parsed.UTC().Format(time.RFC3339Nano)
if value != canonical {
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
}
return parsed.UTC(), nil
}
func formatTimestamp(value time.Time) string {
return value.UTC().Format(time.RFC3339Nano)
}
func formatOptionalTimestamp(value *time.Time) *string {
if value == nil {
return nil
}
formatted := formatTimestamp(*value)
return &formatted
}
func redisTTL(expiresAt time.Time) time.Duration {
ttl := time.Until(expiresAt.UTC())
if ttl < 0 {
ttl = 0
}
return ttl + expirationGracePeriod
}
func equalStoredChallenges(left challenge.Challenge, right challenge.Challenge) (bool, error) {
leftRecord, err := redisRecordFromChallenge(left)
if err != nil {
return false, err
}
rightRecord, err := redisRecordFromChallenge(right)
if err != nil {
return false, err
}
return reflect.DeepEqual(leftRecord, rightRecord), nil
}
var _ ports.ChallengeStore = (*Store)(nil)
@@ -0,0 +1,531 @@
package challengestore
import (
"context"
"crypto/ed25519"
"encoding/json"
"testing"
"time"
"galaxy/authsession/internal/adapters/contracttest"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStoreContract(t *testing.T) {
t.Parallel()
contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore {
t.Helper()
server := miniredis.RunT(t)
return newTestStore(t, server, Config{})
})
}
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 2,
KeyPrefix: "authsession:challenge:",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
KeyPrefix: "authsession:challenge:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
KeyPrefix: "authsession:challenge:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty key prefix",
cfg: Config{
Addr: server.Addr(),
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis key prefix must not be empty",
},
{
name: "non-positive operation timeout",
cfg: Config{
Addr: server.Addr(),
KeyPrefix: "authsession:challenge:",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
store, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
})
}
}
func TestStorePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
require.NoError(t, store.Ping(context.Background()))
}
func TestStoreCreateAndGet(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
now := time.Unix(1_775_130_000, 0).UTC()
record := testChallenge(now)
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
got.CodeHash[0] = 0xFF
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
keyBytes[0] = 0xFE
again, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record.CodeHash, again.CodeHash)
require.NotNil(t, again.Confirmation)
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
}
func TestStoreCreateAndGetPendingChallenge(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
now := time.Unix(1_775_130_100, 0).UTC()
record := testPendingChallenge(now)
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
}
func TestStoreCreateAndGetThrottledChallenge(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
now := time.Unix(1_775_130_150, 0).UTC()
record := testPendingChallenge(now)
record.Status = challenge.StatusDeliveryThrottled
record.DeliveryState = challenge.DeliveryThrottled
record.Attempts.Send = 1
record.Abuse.LastAttemptAt = timePointer(now)
require.NoError(t, record.Validate())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
}
func TestStoreGetStrictDecode(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_130_200, 0).UTC()
baseRecord := testChallenge(now)
baseStored, err := redisRecordFromChallenge(baseRecord)
require.NoError(t, err)
tests := []struct {
name string
mutate func(redisRecord) string
wantErrText string
}{
{
name: "malformed json",
mutate: func(_ redisRecord) string {
return "{"
},
wantErrText: "decode redis challenge record",
},
{
name: "trailing json input",
mutate: func(record redisRecord) string {
return mustMarshalJSON(t, record) + "{}"
},
wantErrText: "unexpected trailing JSON input",
},
{
name: "unknown field",
mutate: func(record redisRecord) string {
payload := map[string]any{
"challenge_id": record.ChallengeID,
"email": record.Email,
"code_hash_base64": record.CodeHashBase64,
"status": record.Status,
"delivery_state": record.DeliveryState,
"created_at": record.CreatedAt,
"expires_at": record.ExpiresAt,
"send_attempt_count": record.SendAttemptCount,
"confirm_attempt_count": record.ConfirmAttemptCount,
"last_attempt_at": record.LastAttemptAt,
"confirmed_session_id": record.ConfirmedSessionID,
"confirmed_client_public_key": record.ConfirmedClientPublicKey,
"confirmed_at": record.ConfirmedAt,
"unexpected": true,
}
return mustMarshalJSON(t, payload)
},
wantErrText: "unknown field",
},
{
name: "unsupported status",
mutate: func(record redisRecord) string {
record.Status = challenge.Status("paused")
return mustMarshalJSON(t, record)
},
wantErrText: `status "paused" is unsupported`,
},
{
name: "unsupported delivery state",
mutate: func(record redisRecord) string {
record.DeliveryState = challenge.DeliveryState("queued")
return mustMarshalJSON(t, record)
},
wantErrText: `delivery state "queued" is unsupported`,
},
{
name: "missing required email",
mutate: func(record redisRecord) string {
record.Email = ""
return mustMarshalJSON(t, record)
},
wantErrText: "challenge email",
},
{
name: "challenge id mismatch",
mutate: func(record redisRecord) string {
record.ChallengeID = "other-challenge"
return mustMarshalJSON(t, record)
},
wantErrText: `does not match requested`,
},
{
name: "non canonical utc timestamp",
mutate: func(record redisRecord) string {
record.CreatedAt = "2026-04-04T12:00:00+03:00"
return mustMarshalJSON(t, record)
},
wantErrText: "canonical UTC RFC3339Nano timestamp",
},
{
name: "partial confirmation metadata",
mutate: func(record redisRecord) string {
record.ConfirmedAt = nil
return mustMarshalJSON(t, record)
},
wantErrText: "confirmation metadata must be either fully present or fully absent",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored))
_, err := store.Get(context.Background(), baseRecord.ID)
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErrText)
})
}
}
func TestStoreKeySchemeAndTTL(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"})
now := time.Now().UTC()
prefixed := testPendingChallenge(now)
prefixed.ID = common.ChallengeID("challenge:opaque/id?value")
require.NoError(t, store.Create(context.Background(), prefixed))
key := store.lookupKey(prefixed.ID)
assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key)
assert.True(t, server.Exists(key))
freshTTL := server.TTL(key)
assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod)
assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second)
expired := testPendingChallenge(now.Add(-10 * time.Minute))
expired.ID = common.ChallengeID("expired-challenge")
expired.CreatedAt = now.Add(-20 * time.Minute)
expired.ExpiresAt = now.Add(-1 * time.Minute)
require.NoError(t, store.Create(context.Background(), expired))
expiredTTL := server.TTL(store.lookupKey(expired.ID))
assert.LessOrEqual(t, expiredTTL, expirationGracePeriod)
assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second)
}
func TestStoreCreateConflict(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
err := store.Create(context.Background(), record)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
}
func TestStoreGetNotFound(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
}
func TestStoreCompareAndSwap(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_130_400, 0).UTC()
t.Run("success", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
previous := testPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
next.Attempts.Send = 1
next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute))
require.NoError(t, store.Create(context.Background(), previous))
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
got, err := store.Get(context.Background(), previous.ID)
require.NoError(t, err)
assert.Equal(t, next, got)
})
t.Run("conflict when stored record differs", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
stored := testPendingChallenge(now)
previous := stored
previous.Attempts.Send = 99
next := stored
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
require.NoError(t, store.Create(context.Background(), stored))
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
})
t.Run("not found", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
previous := testPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
t.Run("corrupt stored record returns adapter error", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
previous := testPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
server.Set(store.lookupKey(previous.ID), "{")
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.NotErrorIs(t, err, ports.ErrConflict)
assert.ErrorContains(t, err, "decode redis challenge record")
})
}
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.KeyPrefix == "" {
cfg.KeyPrefix = "authsession:challenge:"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
store, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return store
}
func testPendingChallenge(now time.Time) challenge.Challenge {
return challenge.Challenge{
ID: common.ChallengeID("challenge-pending"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hashed-pending-code"),
Status: challenge.StatusPendingSend,
DeliveryState: challenge.DeliveryPending,
CreatedAt: now,
ExpiresAt: now.Add(challenge.InitialTTL),
}
}
func testChallenge(now time.Time) challenge.Challenge {
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31,
})
if err != nil {
panic(err)
}
return challenge.Challenge{
ID: common.ChallengeID("challenge-confirmed"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hashed-code"),
Status: challenge.StatusConfirmedPendingExpire,
DeliveryState: challenge.DeliverySent,
CreatedAt: now,
ExpiresAt: now.Add(challenge.ConfirmedRetention),
Attempts: challenge.AttemptCounters{
Send: 1,
Confirm: 2,
},
Abuse: challenge.AbuseMetadata{
LastAttemptAt: timePointer(now.Add(30 * time.Second)),
},
Confirmation: &challenge.Confirmation{
SessionID: common.DeviceSessionID("device-session-1"),
ClientPublicKey: clientPublicKey,
ConfirmedAt: now.Add(1 * time.Minute),
},
}
}
func timePointer(value time.Time) *time.Time {
return &value
}
func mustMarshalJSON(t *testing.T, value any) string {
t.Helper()
payload, err := json.Marshal(value)
require.NoError(t, err)
return string(payload)
}
func TestStorePingNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
err := store.Ping(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func TestStoreGetNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Get(nil, common.ChallengeID("challenge"))
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
@@ -0,0 +1,169 @@
// Package configprovider implements ports.ConfigProvider with Redis-backed
// dynamic auth/session configuration.
package configprovider
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
"strings"
"time"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
// Config configures one Redis-backed config provider instance.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// SessionLimitKey identifies the single Redis string key that stores the
// active-session-limit configuration value.
SessionLimitKey string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Store reads dynamic auth/session configuration from Redis.
type Store struct {
client *redis.Client
sessionLimitKey string
operationTimeout time.Duration
}
// New constructs a Redis-backed config provider from cfg.
func New(cfg Config) (*Store, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis config provider: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis config provider: redis db must not be negative")
case strings.TrimSpace(cfg.SessionLimitKey) == "":
return nil, errors.New("new redis config provider: session limit key must not be empty")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis config provider: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Store{
client: redis.NewClient(options),
sessionLimitKey: cfg.SessionLimitKey,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *Store) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (s *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := s.operationContext(ctx, "ping redis config provider")
if err != nil {
return err
}
defer cancel()
if err := s.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis config provider: %w", err)
}
return nil
}
// LoadSessionLimit returns the current active-session-limit configuration.
// Missing or invalid Redis values are treated as “limit absent” by policy.
func (s *Store) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) {
operationCtx, cancel, err := s.operationContext(ctx, "load session limit from redis")
if err != nil {
return ports.SessionLimitConfig{}, err
}
defer cancel()
value, err := s.client.Get(operationCtx, s.sessionLimitKey).Result()
switch {
case errors.Is(err, redis.Nil):
return ports.SessionLimitConfig{}, nil
case err != nil:
return ports.SessionLimitConfig{}, fmt.Errorf("load session limit from redis: %w", err)
}
config, valid := parseSessionLimitConfig(value)
if !valid {
return ports.SessionLimitConfig{}, nil
}
if err := config.Validate(); err != nil {
return ports.SessionLimitConfig{}, nil
}
return config, nil
}
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if s == nil || s.client == nil {
return nil, nil, fmt.Errorf("%s: nil store", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
return operationCtx, cancel, nil
}
func parseSessionLimitConfig(raw string) (ports.SessionLimitConfig, bool) {
if strings.TrimSpace(raw) == "" || strings.TrimSpace(raw) != raw {
return ports.SessionLimitConfig{}, false
}
for _, symbol := range raw {
if symbol < '0' || symbol > '9' {
return ports.SessionLimitConfig{}, false
}
}
parsed, err := strconv.ParseInt(raw, 10, strconv.IntSize)
if err != nil || parsed <= 0 {
return ports.SessionLimitConfig{}, false
}
limit := int(parsed)
return ports.SessionLimitConfig{
ActiveSessionLimit: &limit,
}, true
}
var _ ports.ConfigProvider = (*Store)(nil)
@@ -0,0 +1,283 @@
package configprovider
import (
"context"
"strconv"
"testing"
"time"
"galaxy/authsession/internal/adapters/contracttest"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStoreContract(t *testing.T) {
t.Parallel()
contracttest.RunConfigProviderContractTests(t, func(t *testing.T) contracttest.ConfigProviderHarness {
t.Helper()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
return contracttest.ConfigProviderHarness{
Provider: store,
SeedDisabled: func(t *testing.T) {
t.Helper()
server.Del(store.sessionLimitKey)
},
SeedLimit: func(t *testing.T, limit int) {
t.Helper()
server.Set(store.sessionLimitKey, strconv.Itoa(limit))
},
}
})
}
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 2,
SessionLimitKey: "authsession:config:active-session-limit",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
SessionLimitKey: "authsession:config:active-session-limit",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
SessionLimitKey: "authsession:config:active-session-limit",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty session limit key",
cfg: Config{
Addr: server.Addr(),
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session limit key must not be empty",
},
{
name: "non positive timeout",
cfg: Config{
Addr: server.Addr(),
SessionLimitKey: "authsession:config:active-session-limit",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
store, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
})
}
}
func TestStorePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
require.NoError(t, store.Ping(context.Background()))
}
func TestStoreLoadSessionLimit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
seed func(*testing.T, *miniredis.Miniredis, *Store)
wantConfig ports.SessionLimitConfig
}{
{
name: "missing key means disabled",
wantConfig: ports.SessionLimitConfig{},
},
{
name: "valid positive integer",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "5")
},
wantConfig: configWithLimit(5),
},
{
name: "empty string is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "whitespace only is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, " ")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "whitespace padded integer is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, " 5 ")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "non integer text is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "five")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "zero is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "0")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "negative integer is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "-3")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "overflow is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "999999999999999999999999999999")
},
wantConfig: ports.SessionLimitConfig{},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
if tt.seed != nil {
tt.seed(t, server, store)
}
got, err := store.LoadSessionLimit(context.Background())
require.NoError(t, err)
assert.Equal(t, tt.wantConfig, got)
})
}
}
func TestStoreLoadSessionLimitBackendFailure(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
server.Close()
_, err := store.LoadSessionLimit(context.Background())
require.Error(t, err)
assert.ErrorContains(t, err, "load session limit from redis")
}
func TestStoreLoadSessionLimitNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.LoadSessionLimit(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func TestStorePingNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
err := store.Ping(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.SessionLimitKey == "" {
cfg.SessionLimitKey = "authsession:config:active-session-limit"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
store, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return store
}
func configWithLimit(limit int) ports.SessionLimitConfig {
return ports.SessionLimitConfig{
ActiveSessionLimit: &limit,
}
}
@@ -0,0 +1,223 @@
// Package projectionpublisher implements
// ports.GatewaySessionProjectionPublisher with Redis-backed gateway-compatible
// cache snapshots and session lifecycle events.
package projectionpublisher
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
// Config configures one Redis-backed gateway session projection publisher.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// SessionCacheKeyPrefix is the namespace prefix applied to gateway session
// cache keys. The raw device session identifier is appended directly.
SessionCacheKeyPrefix string
// SessionEventsStream identifies the gateway session lifecycle Redis Stream.
SessionEventsStream string
// StreamMaxLen bounds the session lifecycle stream with approximate
// trimming via XADD MAXLEN ~.
StreamMaxLen int64
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Publisher publishes gateway-compatible session projections into Redis cache
// and stream namespaces.
type Publisher struct {
client *redis.Client
sessionCacheKeyPrefix string
sessionEventsStream string
streamMaxLen int64
operationTimeout time.Duration
}
type cacheRecord struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKey string `json:"client_public_key"`
Status gatewayprojection.Status `json:"status"`
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
}
// New constructs a Redis-backed gateway session projection publisher from
// cfg.
func New(cfg Config) (*Publisher, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis projection publisher: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis projection publisher: redis db must not be negative")
case strings.TrimSpace(cfg.SessionCacheKeyPrefix) == "":
return nil, errors.New("new redis projection publisher: session cache key prefix must not be empty")
case strings.TrimSpace(cfg.SessionEventsStream) == "":
return nil, errors.New("new redis projection publisher: session events stream must not be empty")
case cfg.StreamMaxLen <= 0:
return nil, errors.New("new redis projection publisher: stream max len must be positive")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis projection publisher: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Publisher{
client: redis.NewClient(options),
sessionCacheKeyPrefix: cfg.SessionCacheKeyPrefix,
sessionEventsStream: cfg.SessionEventsStream,
streamMaxLen: cfg.StreamMaxLen,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (p *Publisher) Close() error {
if p == nil || p.client == nil {
return nil
}
return p.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (p *Publisher) Ping(ctx context.Context) error {
operationCtx, cancel, err := p.operationContext(ctx, "ping redis projection publisher")
if err != nil {
return err
}
defer cancel()
if err := p.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis projection publisher: %w", err)
}
return nil
}
// PublishSession writes one gateway-compatible session snapshot into the
// gateway cache namespace and appends the same snapshot to the gateway session
// event stream within one Redis transaction.
func (p *Publisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error {
if err := snapshot.Validate(); err != nil {
return fmt.Errorf("publish session projection to redis: %w", err)
}
payload, err := marshalCacheRecord(snapshot)
if err != nil {
return fmt.Errorf("publish session projection to redis: %w", err)
}
values := buildStreamValues(snapshot)
operationCtx, cancel, err := p.operationContext(ctx, "publish session projection to redis")
if err != nil {
return err
}
defer cancel()
key := p.sessionCacheKey(snapshot.DeviceSessionID)
_, err = p.client.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, key, payload, 0)
pipe.XAdd(operationCtx, &redis.XAddArgs{
Stream: p.sessionEventsStream,
MaxLen: p.streamMaxLen,
Approx: true,
Values: values,
})
return nil
})
if err != nil {
return fmt.Errorf("publish session projection %q to redis: %w", snapshot.DeviceSessionID, err)
}
return nil
}
func (p *Publisher) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if p == nil || p.client == nil {
return nil, nil, fmt.Errorf("%s: nil publisher", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
return operationCtx, cancel, nil
}
func (p *Publisher) sessionCacheKey(deviceSessionID interface{ String() string }) string {
return p.sessionCacheKeyPrefix + deviceSessionID.String()
}
func marshalCacheRecord(snapshot gatewayprojection.Snapshot) ([]byte, error) {
record := cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: snapshot.Status,
}
if snapshot.RevokedAt != nil {
revokedAtMS := snapshot.RevokedAt.UTC().UnixMilli()
record.RevokedAtMS = &revokedAtMS
}
payload, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("marshal gateway session cache record: %w", err)
}
return payload, nil
}
func buildStreamValues(snapshot gatewayprojection.Snapshot) map[string]any {
values := map[string]any{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(snapshot.Status),
}
if snapshot.RevokedAt != nil {
values["revoked_at_ms"] = fmt.Sprint(snapshot.RevokedAt.UTC().UnixMilli())
}
return values
}
var _ ports.GatewaySessionProjectionPublisher = (*Publisher)(nil)
@@ -0,0 +1,442 @@
package projectionpublisher
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/gatewayprojection"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 3,
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty session cache key prefix",
cfg: Config{
Addr: server.Addr(),
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session cache key prefix must not be empty",
},
{
name: "empty session events stream",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session events stream must not be empty",
},
{
name: "non positive stream max len",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "stream max len must be positive",
},
{
name: "non positive timeout",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
publisher, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, publisher.Close())
})
})
}
}
func TestPublisherPing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
require.NoError(t, publisher.Ping(context.Background()))
}
func TestPublisherPublishSessionActive(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key)
assert.True(t, server.Exists(key))
assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String())))
payload, err := server.Get(key)
require.NoError(t, err)
record := decodeCachePayload(t, payload)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusActive,
}, record)
assert.Zero(t, server.TTL(key))
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, map[string]string{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(gatewayprojection.StatusActive),
}, stringifyValues(entries[0].Values))
}
func TestPublisherPublishSessionRevoked(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC()
snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
payload, err := server.Get(key)
require.NoError(t, err)
record := decodeCachePayload(t, payload)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusRevoked,
RevokedAtMS: int64Pointer(revokedAt.UnixMilli()),
}, record)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, map[string]string{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(gatewayprojection.StatusRevoked),
"revoked_at_ms": "1776000123456",
}, stringifyValues(entries[0].Values))
}
func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
deviceSessionID := "device-session-456"
active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil)
revokedAt := time.Unix(1_776_010_000, 0).UTC()
revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt)
require.NoError(t, publisher.PublishSession(context.Background(), active))
require.NoError(t, publisher.PublishSession(context.Background(), revoked))
payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID))
require.NoError(t, err)
record := decodeCachePayload(t, payload)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
assert.Equal(t, gatewayprojection.StatusRevoked, record.Status)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 2)
assert.Equal(t, map[string]string{
"device_session_id": active.DeviceSessionID.String(),
"user_id": active.UserID.String(),
"client_public_key": active.ClientPublicKey,
"status": string(gatewayprojection.StatusActive),
}, stringifyValues(entries[0].Values))
assert.Equal(t, map[string]string{
"device_session_id": revoked.DeviceSessionID.String(),
"user_id": revoked.UserID.String(),
"client_public_key": revoked.ClientPublicKey,
"status": string(gatewayprojection.StatusRevoked),
"revoked_at_ms": "1776010000000",
}, stringifyValues(entries[1].Values))
}
func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID))
require.NoError(t, err)
record := decodeCachePayload(t, payload)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusActive,
}, record)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 2)
assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values))
}
func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2})
for index := range 6 {
snapshot := testSnapshot(
common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(),
gatewayprojection.StatusActive,
nil,
)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
}
streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result()
require.NoError(t, err)
assert.LessOrEqual(t, streamLength, int64(2))
}
func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID("device-session-123"),
UserID: common.UserID("user-123"),
Status: gatewayprojection.StatusActive,
}
err := publisher.PublishSession(context.Background(), snapshot)
require.Error(t, err)
assert.ErrorContains(t, err, "gateway projection client public key")
assert.Empty(t, server.Keys())
}
func TestPublisherPublishSessionNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func TestPublisherPublishSessionBackendFailure(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
server.Close()
err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
require.Error(t, err)
assert.ErrorContains(t, err, "publish session projection")
}
func TestPublisherPingNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
err := publisher.Ping(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.SessionCacheKeyPrefix == "" {
cfg.SessionCacheKeyPrefix = "gateway:session:"
}
if cfg.SessionEventsStream == "" {
cfg.SessionEventsStream = "gateway:session_events"
}
if cfg.StreamMaxLen == 0 {
cfg.StreamMaxLen = 1024
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
publisher, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, publisher.Close())
})
return publisher
}
func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot {
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
for index := range raw {
raw[index] = byte(index + 1)
}
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID("user-123"),
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
Status: status,
RevokedAt: revokedAt,
}
if status == gatewayprojection.StatusRevoked {
snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked")
snapshot.RevokeActorType = common.RevokeActorType("system")
}
return snapshot
}
func decodeCachePayload(t *testing.T, payload string) cacheRecord {
t.Helper()
decoder := json.NewDecoder(bytes.NewReader([]byte(payload)))
decoder.DisallowUnknownFields()
var record cacheRecord
require.NoError(t, decoder.Decode(&record))
err := decoder.Decode(&struct{}{})
if err == nil {
require.FailNow(t, "expected cache payload EOF after first JSON value")
}
require.ErrorIs(t, err, io.EOF)
var fieldSet map[string]json.RawMessage
require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet))
expectedFields := map[string]struct{}{
"device_session_id": {},
"user_id": {},
"client_public_key": {},
"status": {},
}
if record.RevokedAtMS != nil {
expectedFields["revoked_at_ms"] = struct{}{}
}
assert.Equal(t, len(expectedFields), len(fieldSet))
for field := range fieldSet {
_, ok := expectedFields[field]
assert.Truef(t, ok, "unexpected cache payload field %q", field)
}
return record
}
func stringifyValues(values map[string]any) map[string]string {
stringified := make(map[string]string, len(values))
for key, value := range values {
stringified[key] = fmt.Sprint(value)
}
return stringified
}
func encodeBase64URL(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func int64Pointer(value int64) *int64 {
return &value
}
@@ -0,0 +1,152 @@
// Package sendemailcodeabuse implements ports.SendEmailCodeAbuseProtector with
// one Redis TTL key per normalized e-mail address.
package sendemailcodeabuse
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
// Config configures one Redis-backed send-email-code abuse protector.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// KeyPrefix is the namespace prefix applied to every resend-throttle key.
KeyPrefix string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Protector applies the fixed resend cooldown with one Redis key per
// normalized e-mail address.
type Protector struct {
client *redis.Client
keyPrefix string
operationTimeout time.Duration
}
// New constructs a Redis-backed resend-throttle protector from cfg.
func New(cfg Config) (*Protector, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis send email code abuse protector: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis send email code abuse protector: redis db must not be negative")
case strings.TrimSpace(cfg.KeyPrefix) == "":
return nil, errors.New("new redis send email code abuse protector: redis key prefix must not be empty")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis send email code abuse protector: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Protector{
client: redis.NewClient(options),
keyPrefix: cfg.KeyPrefix,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (p *Protector) Close() error {
if p == nil || p.client == nil {
return nil
}
return p.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (p *Protector) Ping(ctx context.Context) error {
operationCtx, cancel, err := p.operationContext(ctx, "ping redis send email code abuse protector")
if err != nil {
return err
}
defer cancel()
if err := p.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis send email code abuse protector: %w", err)
}
return nil
}
// CheckAndReserve applies the fixed resend cooldown using one TTL key per
// normalized e-mail address.
func (p *Protector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
if err := input.Validate(); err != nil {
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err)
}
operationCtx, cancel, err := p.operationContext(ctx, "check and reserve send email code abuse")
if err != nil {
return ports.SendEmailCodeAbuseResult{}, err
}
defer cancel()
key := p.lookupKey(input.Email)
value := input.Now.UTC().Add(challenge.ResendThrottleCooldown).Format(time.RFC3339Nano)
created, err := p.client.SetNX(operationCtx, key, value, challenge.ResendThrottleCooldown).Result()
if err != nil {
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse for %q: %w", input.Email, err)
}
if created {
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeAllowed}, nil
}
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeThrottled}, nil
}
func (p *Protector) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if p == nil || p.client == nil {
return nil, nil, fmt.Errorf("%s: nil protector", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
return operationCtx, cancel, nil
}
func (p *Protector) lookupKey(email common.Email) string {
return p.keyPrefix + base64.RawURLEncoding.EncodeToString([]byte(email.String()))
}
var _ ports.SendEmailCodeAbuseProtector = (*Protector)(nil)
@@ -0,0 +1,176 @@
package sendemailcodeabuse
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 1,
KeyPrefix: "authsession:send-email-code-throttle:",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
KeyPrefix: "authsession:send-email-code-throttle:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
KeyPrefix: "authsession:send-email-code-throttle:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty key prefix",
cfg: Config{
Addr: server.Addr(),
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis key prefix must not be empty",
},
{
name: "non-positive timeout",
cfg: Config{
Addr: server.Addr(),
KeyPrefix: "authsession:send-email-code-throttle:",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
protector, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, protector.Close())
})
})
}
}
func TestProtectorPing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
protector := newTestProtector(t, server, Config{})
require.NoError(t, protector.Ping(context.Background()))
}
func TestProtectorCheckAndReserve(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
protector := newTestProtector(t, server, Config{})
email := common.Email("pilot@example.com")
now := time.Unix(10, 0).UTC()
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now,
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
key := protector.lookupKey(email)
assert.True(t, server.Exists(key))
ttl := server.TTL(key)
assert.LessOrEqual(t, ttl, challenge.ResendThrottleCooldown)
assert.GreaterOrEqual(t, ttl, challenge.ResendThrottleCooldown-2*time.Second)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(30 * time.Second),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
ttlAfterThrottle := server.TTL(key)
assert.LessOrEqual(t, ttlAfterThrottle, ttl)
server.FastForward(challenge.ResendThrottleCooldown)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(challenge.ResendThrottleCooldown),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
}
func TestProtectorNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
protector := newTestProtector(t, server, Config{})
_, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{
Email: common.Email("pilot@example.com"),
Now: time.Unix(10, 0).UTC(),
})
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestProtector(t *testing.T, server *miniredis.Miniredis, cfg Config) *Protector {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.KeyPrefix == "" {
cfg.KeyPrefix = "authsession:send-email-code-throttle:"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
protector, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, protector.Close())
})
return protector
}
@@ -0,0 +1,723 @@
// Package sessionstore implements ports.SessionStore with Redis-backed strict
// JSON source-of-truth session records and per-user indexes.
package sessionstore
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
const mutationRetryLimit = 3
// Config configures one Redis-backed session store instance.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// SessionKeyPrefix is the namespace prefix applied to primary session keys.
SessionKeyPrefix string
// UserSessionsKeyPrefix is the namespace prefix applied to all-session user
// indexes.
UserSessionsKeyPrefix string
// UserActiveSessionsKeyPrefix is the namespace prefix applied to active
// session user indexes.
UserActiveSessionsKeyPrefix string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Store persists source-of-truth sessions in Redis and maintains user-scoped
// indexes for list and count operations.
type Store struct {
client *redis.Client
sessionKeyPrefix string
userSessionsKeyPrefix string
userActiveSessionsKeyPrefix string
operationTimeout time.Duration
}
type redisRecord struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKeyBase64 string `json:"client_public_key_base64"`
Status devicesession.Status `json:"status"`
CreatedAt string `json:"created_at"`
RevokedAt *string `json:"revoked_at,omitempty"`
RevokeReasonCode string `json:"revoke_reason_code,omitempty"`
RevokeActorType string `json:"revoke_actor_type,omitempty"`
RevokeActorID string `json:"revoke_actor_id,omitempty"`
}
// New constructs a Redis-backed session store from cfg.
func New(cfg Config) (*Store, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis session store: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis session store: redis db must not be negative")
case strings.TrimSpace(cfg.SessionKeyPrefix) == "":
return nil, errors.New("new redis session store: session key prefix must not be empty")
case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "":
return nil, errors.New("new redis session store: user sessions key prefix must not be empty")
case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "":
return nil, errors.New("new redis session store: user active sessions key prefix must not be empty")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis session store: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Store{
client: redis.NewClient(options),
sessionKeyPrefix: cfg.SessionKeyPrefix,
userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix,
userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *Store) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (s *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store")
if err != nil {
return err
}
defer cancel()
if err := s.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis session store: %w", err)
}
return nil
}
// Get returns the stored session for deviceSessionID.
func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
if err := deviceSessionID.Validate(); err != nil {
return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "get session from redis")
if err != nil {
return devicesession.Session{}, err
}
defer cancel()
record, err := s.loadSession(operationCtx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound)
default:
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err)
}
}
return record, nil
}
// ListByUserID returns every stored session for userID in newest-first order.
func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list sessions by user id from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis")
if err != nil {
return nil, err
}
defer cancel()
deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result()
if err != nil {
return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err)
}
if len(deviceSessionIDs) == 0 {
return []devicesession.Session{}, nil
}
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
for _, rawDeviceSessionID := range deviceSessionIDs {
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
record, err := s.loadSession(operationCtx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID)
default:
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err)
}
}
if record.UserID != userID {
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID)
}
records = append(records, record)
}
sortSessionsNewestFirst(records)
return records, nil
}
// CountActiveByUserID returns the number of active sessions currently stored
// for userID.
func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
if err := userID.Validate(); err != nil {
return 0, fmt.Errorf("count active sessions by user id from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis")
if err != nil {
return 0, err
}
defer cancel()
count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result()
if err != nil {
return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err)
}
return int(count), nil
}
// Create persists record as a new device session.
func (s *Store) Create(ctx context.Context, record devicesession.Session) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create session in redis: %w", err)
}
payload, err := marshalSessionRecord(record)
if err != nil {
return fmt.Errorf("create session in redis: %w", err)
}
deviceSessionKey := s.sessionKey(record.ID)
allSessionsKey := s.userSessionsKey(record.UserID)
activeSessionsKey := s.userActiveSessionsKey(record.UserID)
operationCtx, cancel, err := s.operationContext(ctx, "create session in redis")
if err != nil {
return err
}
defer cancel()
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
_, err := tx.Get(operationCtx, deviceSessionKey).Bytes()
switch {
case errors.Is(err, redis.Nil):
case err != nil:
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
default:
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{
Score: createdAtScore(record.CreatedAt),
Member: record.ID.String(),
})
if record.Status == devicesession.StatusActive {
pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{
Score: createdAtScore(record.CreatedAt),
Member: record.ID.String(),
})
}
return nil
})
if err != nil {
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
}
return nil
}, deviceSessionKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// Revoke stores a revoked view of one target session.
func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
if err := input.Validate(); err != nil {
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err)
}
var result ports.RevokeSessionResult
err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error {
deviceSessionKey := s.sessionKey(input.DeviceSessionID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound)
default:
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
}
if current.Status == devicesession.StatusRevoked {
result = ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
Session: current,
}
return result.Validate()
}
next := current
next.Status = devicesession.StatusRevoked
revocation := input.Revocation
next.Revocation = &revocation
if err := next.Validate(); err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
payload, err := marshalSessionRecord(next)
if err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String())
return nil
})
if err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
result = ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeRevoked,
Session: next,
}
return result.Validate()
}, deviceSessionKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return errRetryMutation
case watchErr != nil:
return watchErr
default:
return nil
}
})
if err != nil {
return ports.RevokeSessionResult{}, err
}
return result, nil
}
// RevokeAllByUserID stores revoked views for all currently active sessions
// owned by input.UserID.
func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
if err := input.Validate(); err != nil {
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err)
}
var result ports.RevokeUserSessionsResult
err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error {
activeSessionsKey := s.userActiveSessionsKey(input.UserID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result()
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
if len(deviceSessionIDs) == 0 {
// Force EXEC so WATCH observes concurrent active-index changes even
// for the no-op path.
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.ZCard(operationCtx, activeSessionsKey)
return nil
})
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
result = ports.RevokeUserSessionsResult{
Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions,
UserID: input.UserID,
Sessions: []devicesession.Session{},
}
return result.Validate()
}
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
for _, rawDeviceSessionID := range deviceSessionIDs {
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID)
default:
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
}
}
if record.UserID != input.UserID {
return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID)
}
if record.Status != devicesession.StatusActive {
return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status)
}
next := record
next.Status = devicesession.StatusRevoked
revocation := input.Revocation
next.Revocation = &revocation
if err := next.Validate(); err != nil {
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
}
records = append(records, next)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
for _, record := range records {
payload, err := marshalSessionRecord(record)
if err != nil {
return fmt.Errorf("session %q: %w", record.ID, err)
}
pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0)
pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String())
}
return nil
})
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
sortSessionsNewestFirst(records)
result = ports.RevokeUserSessionsResult{
Outcome: ports.RevokeUserSessionsOutcomeRevoked,
UserID: input.UserID,
Sessions: records,
}
return result.Validate()
}, activeSessionsKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return errRetryMutation
case watchErr != nil:
return watchErr
default:
return nil
}
})
if err != nil {
return ports.RevokeUserSessionsResult{}, err
}
return result, nil
}
var errRetryMutation = errors.New("redis session store: retry mutation")
func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error {
for attempt := 0; attempt < mutationRetryLimit; attempt++ {
operationCtx, cancel, err := s.operationContext(ctx, operation)
if err != nil {
return err
}
err = execute(operationCtx)
cancel()
switch {
case errors.Is(err, errRetryMutation):
if attempt == mutationRetryLimit-1 {
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
}
continue
default:
return err
}
}
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
}
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if s == nil || s.client == nil {
return nil, nil, fmt.Errorf("%s: nil store", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
return operationCtx, cancel, nil
}
func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get)
}
func (s *Store) loadSessionWithGetter(
ctx context.Context,
deviceSessionID common.DeviceSessionID,
getter func(context.Context, string) *redis.StringCmd,
) (devicesession.Session, error) {
payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return devicesession.Session{}, ports.ErrNotFound
case err != nil:
return devicesession.Session{}, err
}
record, err := decodeSessionRecord(deviceSessionID, payload)
if err != nil {
return devicesession.Session{}, err
}
return record, nil
}
func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string {
return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String())
}
func (s *Store) userSessionsKey(userID common.UserID) string {
return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String())
}
func (s *Store) userActiveSessionsKey(userID common.UserID) string {
return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String())
}
func encodeKeyComponent(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func marshalSessionRecord(record devicesession.Session) ([]byte, error) {
stored, err := redisRecordFromSession(record)
if err != nil {
return nil, err
}
payload, err := json.Marshal(stored)
if err != nil {
return nil, fmt.Errorf("encode redis session record: %w", err)
}
return payload, nil
}
func redisRecordFromSession(record devicesession.Session) (redisRecord, error) {
if err := record.Validate(); err != nil {
return redisRecord{}, fmt.Errorf("encode redis session record: %w", err)
}
stored := redisRecord{
DeviceSessionID: record.ID.String(),
UserID: record.UserID.String(),
ClientPublicKeyBase64: record.ClientPublicKey.String(),
Status: record.Status,
CreatedAt: formatTimestamp(record.CreatedAt),
}
if record.Revocation != nil {
stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At)
stored.RevokeReasonCode = record.Revocation.ReasonCode.String()
stored.RevokeActorType = record.Revocation.ActorType.String()
stored.RevokeActorID = record.Revocation.ActorID
}
return stored, nil
}
func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
var stored redisRecord
if err := decoder.Decode(&stored); err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input")
}
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
record, err := sessionFromRedisRecord(stored)
if err != nil {
return devicesession.Session{}, err
}
if record.ID != expectedDeviceSessionID {
return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID)
}
return record, nil
}
func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) {
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
if err != nil {
return devicesession.Session{}, err
}
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64)
if err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
}
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
if err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
}
record := devicesession.Session{
ID: common.DeviceSessionID(stored.DeviceSessionID),
UserID: common.UserID(stored.UserID),
ClientPublicKey: clientPublicKey,
Status: stored.Status,
CreatedAt: createdAt,
}
revocation, err := parseRevocation(stored)
if err != nil {
return devicesession.Session{}, err
}
record.Revocation = revocation
if err := record.Validate(); err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
return record, nil
}
func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) {
hasRevokedAt := stored.RevokedAt != nil
hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != ""
hasActorType := strings.TrimSpace(stored.RevokeActorType) != ""
hasActorID := strings.TrimSpace(stored.RevokeActorID) != ""
if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID {
return nil, nil
}
if !hasRevokedAt || !hasReasonCode || !hasActorType {
return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent")
}
revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt)
if err != nil {
return nil, err
}
return &devicesession.Revocation{
At: revokedAt,
ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode),
ActorType: common.RevokeActorType(stored.RevokeActorType),
ActorID: stored.RevokeActorID,
}, nil
}
func parseTimestamp(fieldName string, value string) (time.Time, error) {
if strings.TrimSpace(value) == "" {
return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName)
}
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err)
}
canonical := parsed.UTC().Format(time.RFC3339Nano)
if value != canonical {
return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
}
return parsed.UTC(), nil
}
func formatTimestamp(value time.Time) string {
return value.UTC().Format(time.RFC3339Nano)
}
func formatOptionalTimestamp(value *time.Time) *string {
if value == nil {
return nil
}
formatted := formatTimestamp(*value)
return &formatted
}
func createdAtScore(createdAt time.Time) float64 {
return float64(createdAt.UTC().UnixMicro())
}
func sortSessionsNewestFirst(records []devicesession.Session) {
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
switch {
case left.CreatedAt.Equal(right.CreatedAt):
return strings.Compare(left.ID.String(), right.ID.String())
case left.CreatedAt.After(right.CreatedAt):
return -1
default:
return 1
}
})
}
var _ ports.SessionStore = (*Store)(nil)
@@ -0,0 +1,635 @@
package sessionstore
import (
"context"
"crypto/ed25519"
"encoding/json"
"testing"
"time"
"galaxy/authsession/internal/adapters/contracttest"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStoreContract(t *testing.T) {
t.Parallel()
contracttest.RunSessionStoreContractTests(t, func(t *testing.T) ports.SessionStore {
t.Helper()
server := miniredis.RunT(t)
return newTestStore(t, server, Config{})
})
}
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 1,
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty session prefix",
cfg: Config{
Addr: server.Addr(),
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session key prefix must not be empty",
},
{
name: "empty all sessions prefix",
cfg: Config{
Addr: server.Addr(),
SessionKeyPrefix: "authsession:session:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "user sessions key prefix must not be empty",
},
{
name: "empty active sessions prefix",
cfg: Config{
Addr: server.Addr(),
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "user active sessions key prefix must not be empty",
},
{
name: "non positive timeout",
cfg: Config{
Addr: server.Addr(),
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
store, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
})
}
}
func TestStorePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
require.NoError(t, store.Ping(context.Background()))
}
func TestStoreCreateAndGetActive(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
got.Revocation = &devicesession.Revocation{
At: got.CreatedAt.Add(time.Minute),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
}
again, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Nil(t, again.Revocation)
assert.Equal(t, record, again)
}
func TestStoreCreateAndGetRevoked(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(1_775_240_100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
require.NoError(t, err)
assert.Zero(t, count)
}
func TestStoreGetNotFound(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
}
func TestStoreCreateConflict(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_200, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
err := store.Create(context.Background(), record)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
}
func TestStoreIndexesAndOrdering(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
older := activeSessionFixture("device-session-old", "user-1", time.Unix(10, 0).UTC())
newer := activeSessionFixture("device-session-new", "user-1", time.Unix(20, 0).UTC())
revoked := revokedSessionFixture("device-session-revoked", "user-1", time.Unix(15, 0).UTC())
otherUser := activeSessionFixture("device-session-other", "user-2", time.Unix(30, 0).UTC())
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
require.NoError(t, store.Create(context.Background(), record))
}
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
require.Len(t, got, 3)
assert.Equal(t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID})
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
assert.Equal(t, 2, count)
unknown, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
require.NoError(t, err)
assert.Empty(t, unknown)
}
func TestStoreKeyPrefixesAndEncodedPrimaryKey(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{
SessionKeyPrefix: "custom:session:",
UserSessionsKeyPrefix: "custom:user-sessions:",
UserActiveSessionsKeyPrefix: "custom:user-active-sessions:",
})
record := activeSessionFixture("device/session:opaque?1", "user/opaque:1", time.Unix(40, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
primaryKey := store.sessionKey(record.ID)
assert.Equal(t, "custom:session:"+encodeKeyComponent(record.ID.String()), primaryKey)
assert.True(t, server.Exists(primaryKey))
allSessionsKey := store.userSessionsKey(record.UserID)
activeSessionsKey := store.userActiveSessionsKey(record.UserID)
assert.Equal(t, "custom:user-sessions:"+encodeKeyComponent(record.UserID.String()), allSessionsKey)
assert.Equal(t, "custom:user-active-sessions:"+encodeKeyComponent(record.UserID.String()), activeSessionsKey)
allMembers, err := server.ZMembers(allSessionsKey)
require.NoError(t, err)
assert.Equal(t, []string{record.ID.String()}, allMembers)
activeMembers, err := server.ZMembers(activeSessionsKey)
require.NoError(t, err)
assert.Equal(t, []string{record.ID.String()}, activeMembers)
}
func TestStoreRevoke(t *testing.T) {
t.Parallel()
t.Run("active session", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
revocation := devicesession.Revocation{
At: time.Unix(200, 0).UTC(),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
}
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: revocation,
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
require.NotNil(t, result.Session.Revocation)
assert.Equal(t, revocation, *result.Session.Revocation)
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
require.NoError(t, err)
assert.Zero(t, count)
})
t.Run("already revoked keeps stored revocation", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: devicesession.Revocation{
At: time.Unix(300, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
ActorID: "admin-1",
},
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
require.NotNil(t, result.Session.Revocation)
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
})
t.Run("unknown session", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: common.DeviceSessionID("missing-session"),
Revocation: devicesession.Revocation{
At: time.Unix(200, 0).UTC(),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
},
})
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
}
func TestStoreRevokeAllByUserID(t *testing.T) {
t.Parallel()
t.Run("revokes active sessions newest first and clears active index", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
older := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(200, 0).UTC())
alreadyRevoked := revokedSessionFixture("device-session-3", "user-1", time.Unix(150, 0).UTC())
otherUser := activeSessionFixture("device-session-4", "user-2", time.Unix(250, 0).UTC())
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
require.NoError(t, store.Create(context.Background(), record))
}
revocation := devicesession.Revocation{
At: time.Unix(300, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
ActorID: "admin-1",
}
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: common.UserID("user-1"),
Revocation: revocation,
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
require.Len(t, result.Sessions, 2)
assert.Equal(t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID})
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
assert.Zero(t, count)
otherCount, err := store.CountActiveByUserID(context.Background(), common.UserID("user-2"))
require.NoError(t, err)
assert.Equal(t, 1, otherCount)
})
t.Run("no active sessions", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-5", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: common.UserID("user-1"),
Revocation: devicesession.Revocation{
At: time.Unix(400, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
},
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
assert.Empty(t, result.Sessions)
})
}
func TestStoreStrictDecodeCorruption(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_240_300, 0).UTC()
baseRecord := revokedSessionFixture("device-session-corrupt", "user-1", now)
stored, err := redisRecordFromSession(baseRecord)
require.NoError(t, err)
tests := []struct {
name string
mutate func(redisRecord) string
wantErrText string
}{
{
name: "malformed json",
mutate: func(_ redisRecord) string {
return "{"
},
wantErrText: "decode redis session record",
},
{
name: "trailing json input",
mutate: func(record redisRecord) string {
return mustMarshalJSON(t, record) + "{}"
},
wantErrText: "unexpected trailing JSON input",
},
{
name: "unknown field",
mutate: func(record redisRecord) string {
payload := map[string]any{
"device_session_id": record.DeviceSessionID,
"user_id": record.UserID,
"client_public_key_base64": record.ClientPublicKeyBase64,
"status": record.Status,
"created_at": record.CreatedAt,
"revoked_at": record.RevokedAt,
"revoke_reason_code": record.RevokeReasonCode,
"revoke_actor_type": record.RevokeActorType,
"revoke_actor_id": record.RevokeActorID,
"unexpected": true,
}
return mustMarshalJSON(t, payload)
},
wantErrText: "unknown field",
},
{
name: "unsupported status",
mutate: func(record redisRecord) string {
record.Status = devicesession.Status("paused")
return mustMarshalJSON(t, record)
},
wantErrText: `status "paused" is unsupported`,
},
{
name: "non canonical timestamp",
mutate: func(record redisRecord) string {
record.CreatedAt = "2026-04-04T12:00:00+03:00"
return mustMarshalJSON(t, record)
},
wantErrText: "canonical UTC RFC3339Nano timestamp",
},
{
name: "incomplete revocation metadata",
mutate: func(record redisRecord) string {
record.RevokeActorType = ""
return mustMarshalJSON(t, record)
},
wantErrText: "revocation metadata must be either fully present or fully absent",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
server.Set(store.sessionKey(baseRecord.ID), tt.mutate(stored))
_, err := store.Get(context.Background(), baseRecord.ID)
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErrText)
})
}
}
func TestStoreListByUserIDDetectsCorruptIndexes(t *testing.T) {
t.Parallel()
t.Run("missing primary record", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
userID := common.UserID("user-1")
_, err := server.ZAdd(store.userSessionsKey(userID), 100, "missing-session")
require.NoError(t, err)
_, err = store.ListByUserID(context.Background(), userID)
require.Error(t, err)
assert.ErrorContains(t, err, "references missing session")
})
t.Run("wrong user id in primary record", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-2", time.Unix(100, 0).UTC())
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
_, err := server.ZAdd(store.userSessionsKey(common.UserID("user-1")), createdAtScore(record.CreatedAt), record.ID.String())
require.NoError(t, err)
_, err = store.ListByUserID(context.Background(), common.UserID("user-1"))
require.Error(t, err)
assert.ErrorContains(t, err, `belongs to "user-2"`)
})
}
func TestStoreRevokeAllByUserIDDetectsCorruptActiveIndex(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
_, err := server.ZAdd(store.userActiveSessionsKey(record.UserID), createdAtScore(record.CreatedAt), record.ID.String())
require.NoError(t, err)
_, err = store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: record.UserID,
Revocation: devicesession.Revocation{
At: time.Unix(200, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
},
})
require.Error(t, err)
assert.ErrorContains(t, err, `is "revoked"`)
}
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.SessionKeyPrefix == "" {
cfg.SessionKeyPrefix = "authsession:session:"
}
if cfg.UserSessionsKeyPrefix == "" {
cfg.UserSessionsKeyPrefix = "authsession:user-sessions:"
}
if cfg.UserActiveSessionsKeyPrefix == "" {
cfg.UserActiveSessionsKeyPrefix = "authsession:user-active-sessions:"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
store, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return store
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31,
})
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: clientPublicKey,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
record := activeSessionFixture(deviceSessionID, userID, createdAt)
record.Status = devicesession.StatusRevoked
record.Revocation = &devicesession.Revocation{
At: createdAt.Add(time.Minute),
ReasonCode: devicesession.RevokeReasonDeviceLogout,
ActorType: common.RevokeActorType("user"),
ActorID: "user-actor",
}
return record
}
func seedSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record devicesession.Session) error {
t.Helper()
stored, err := redisRecordFromSession(record)
require.NoError(t, err)
server.Set(key, mustMarshalJSON(t, stored))
return nil
}
func mustMarshalJSON(t *testing.T, value any) string {
t.Helper()
payload, err := json.Marshal(value)
require.NoError(t, err)
return string(payload)
}