feat: authsession service
This commit is contained in:
@@ -0,0 +1,484 @@
|
||||
// Package challengestore implements ports.ChallengeStore with Redis-backed
|
||||
// strict JSON challenge records.
|
||||
package challengestore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const expirationGracePeriod = 5 * time.Minute
|
||||
|
||||
// Config configures one Redis-backed challenge store instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// KeyPrefix is the namespace prefix applied to every challenge key.
|
||||
KeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store persists challenges as one strict JSON value per Redis key.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
Email string `json:"email"`
|
||||
CodeHashBase64 string `json:"code_hash_base64"`
|
||||
Status challenge.Status `json:"status"`
|
||||
DeliveryState challenge.DeliveryState `json:"delivery_state"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
SendAttemptCount int `json:"send_attempt_count"`
|
||||
ConfirmAttemptCount int `json:"confirm_attempt_count"`
|
||||
LastAttemptAt *string `json:"last_attempt_at,omitempty"`
|
||||
ConfirmedSessionID string `json:"confirmed_session_id,omitempty"`
|
||||
ConfirmedClientPublicKey string `json:"confirmed_client_public_key,omitempty"`
|
||||
ConfirmedAt *string `json:"confirmed_at,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed challenge store from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
if strings.TrimSpace(cfg.Addr) == "" {
|
||||
return nil, errors.New("new redis challenge store: redis addr must not be empty")
|
||||
}
|
||||
if cfg.DB < 0 {
|
||||
return nil, errors.New("new redis challenge store: redis db must not be negative")
|
||||
}
|
||||
if strings.TrimSpace(cfg.KeyPrefix) == "" {
|
||||
return nil, errors.New("new redis challenge store: redis key prefix must not be empty")
|
||||
}
|
||||
if cfg.OperationTimeout <= 0 {
|
||||
return nil, errors.New("new redis challenge store: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis challenge store")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis challenge store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the stored challenge for challengeID.
|
||||
func (s *Store) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
|
||||
if err := challengeID.Validate(); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "get challenge from redis")
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
payload, err := s.client.Get(operationCtx, s.lookupKey(challengeID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
|
||||
}
|
||||
|
||||
record, err := decodeChallengeRecord(challengeID, payload)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Create persists record as a new challenge.
|
||||
func (s *Store) Create(ctx context.Context, record challenge.Challenge) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalChallengeRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "create challenge in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
created, err := s.client.SetNX(operationCtx, s.lookupKey(record.ID), payload, redisTTL(record.ExpiresAt)).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create challenge %q in redis: %w", record.ID, err)
|
||||
}
|
||||
if !created {
|
||||
return fmt.Errorf("create challenge %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompareAndSwap replaces previous with next when the currently stored
|
||||
// challenge matches previous exactly in canonical Redis representation.
|
||||
func (s *Store) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
|
||||
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
|
||||
return fmt.Errorf("compare and swap challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
nextPayload, err := marshalChallengeRecord(next)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "compare and swap challenge in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
key := s.lookupKey(previous.ID)
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
payload, err := tx.Get(operationCtx, key).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
current, err := decodeChallengeRecord(previous.ID, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
matches, err := equalStoredChallenges(current, previous)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
if !matches {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, key, nextPayload, redisTTL(next.ExpiresAt))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, key)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (s *Store) lookupKey(challengeID common.ChallengeID) string {
|
||||
return s.keyPrefix + encodeKeyComponent(challengeID.String())
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func marshalChallengeRecord(record challenge.Challenge) ([]byte, error) {
|
||||
stored, err := redisRecordFromChallenge(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(stored)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func redisRecordFromChallenge(record challenge.Challenge) (redisRecord, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return redisRecord{}, fmt.Errorf("encode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
stored := redisRecord{
|
||||
ChallengeID: record.ID.String(),
|
||||
Email: record.Email.String(),
|
||||
CodeHashBase64: base64.StdEncoding.EncodeToString(record.CodeHash),
|
||||
Status: record.Status,
|
||||
DeliveryState: record.DeliveryState,
|
||||
CreatedAt: formatTimestamp(record.CreatedAt),
|
||||
ExpiresAt: formatTimestamp(record.ExpiresAt),
|
||||
SendAttemptCount: record.Attempts.Send,
|
||||
ConfirmAttemptCount: record.Attempts.Confirm,
|
||||
LastAttemptAt: formatOptionalTimestamp(record.Abuse.LastAttemptAt),
|
||||
}
|
||||
if record.Confirmation != nil {
|
||||
stored.ConfirmedSessionID = record.Confirmation.SessionID.String()
|
||||
stored.ConfirmedClientPublicKey = record.Confirmation.ClientPublicKey.String()
|
||||
stored.ConfirmedAt = formatOptionalTimestamp(&record.Confirmation.ConfirmedAt)
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func decodeChallengeRecord(expectedChallengeID common.ChallengeID, payload []byte) (challenge.Challenge, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return challenge.Challenge{}, errors.New("decode redis challenge record: unexpected trailing JSON input")
|
||||
}
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
record, err := challengeFromRedisRecord(stored)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
if record.ID != expectedChallengeID {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: challenge_id %q does not match requested %q", record.ID, expectedChallengeID)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func challengeFromRedisRecord(stored redisRecord) (challenge.Challenge, error) {
|
||||
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
expiresAt, err := parseTimestamp("expires_at", stored.ExpiresAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
lastAttemptAt, err := parseOptionalTimestamp("last_attempt_at", stored.LastAttemptAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
|
||||
codeHash, err := base64.StdEncoding.Strict().DecodeString(stored.CodeHashBase64)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: code_hash_base64: %w", err)
|
||||
}
|
||||
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID(stored.ChallengeID),
|
||||
Email: common.Email(stored.Email),
|
||||
CodeHash: codeHash,
|
||||
Status: stored.Status,
|
||||
DeliveryState: stored.DeliveryState,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: stored.SendAttemptCount,
|
||||
Confirm: stored.ConfirmAttemptCount,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: lastAttemptAt,
|
||||
},
|
||||
}
|
||||
|
||||
confirmation, err := parseConfirmation(stored)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
record.Confirmation = confirmation
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func parseConfirmation(stored redisRecord) (*challenge.Confirmation, error) {
|
||||
hasSessionID := strings.TrimSpace(stored.ConfirmedSessionID) != ""
|
||||
hasClientPublicKey := strings.TrimSpace(stored.ConfirmedClientPublicKey) != ""
|
||||
hasConfirmedAt := stored.ConfirmedAt != nil
|
||||
|
||||
if !hasSessionID && !hasClientPublicKey && !hasConfirmedAt {
|
||||
return nil, nil
|
||||
}
|
||||
if !hasSessionID || !hasClientPublicKey || !hasConfirmedAt {
|
||||
return nil, errors.New("decode redis challenge record: confirmation metadata must be either fully present or fully absent")
|
||||
}
|
||||
|
||||
confirmedAt, err := parseTimestamp("confirmed_at", *stored.ConfirmedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ConfirmedClientPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
|
||||
}
|
||||
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
|
||||
}
|
||||
|
||||
return &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID(stored.ConfirmedSessionID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: confirmedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseOptionalTimestamp(fieldName string, value *string) (*time.Time, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
parsed, err := parseTimestamp(fieldName, *value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(fieldName string, value string) (time.Time, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must not be empty", fieldName)
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s: %w", fieldName, err)
|
||||
}
|
||||
|
||||
canonical := parsed.UTC().Format(time.RFC3339Nano)
|
||||
if value != canonical {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
|
||||
}
|
||||
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
|
||||
func formatTimestamp(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func formatOptionalTimestamp(value *time.Time) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
formatted := formatTimestamp(*value)
|
||||
return &formatted
|
||||
}
|
||||
|
||||
func redisTTL(expiresAt time.Time) time.Duration {
|
||||
ttl := time.Until(expiresAt.UTC())
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
return ttl + expirationGracePeriod
|
||||
}
|
||||
|
||||
func equalStoredChallenges(left challenge.Challenge, right challenge.Challenge) (bool, error) {
|
||||
leftRecord, err := redisRecordFromChallenge(left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
rightRecord, err := redisRecordFromChallenge(right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(leftRecord, rightRecord), nil
|
||||
}
|
||||
|
||||
var _ ports.ChallengeStore = (*Store)(nil)
|
||||
@@ -0,0 +1,531 @@
|
||||
package challengestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
return newTestStore(t, server, Config{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive operation timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_000, 0).UTC()
|
||||
|
||||
record := testChallenge(now)
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
got.CodeHash[0] = 0xFF
|
||||
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
|
||||
keyBytes[0] = 0xFE
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.CodeHash, again.CodeHash)
|
||||
require.NotNil(t, again.Confirmation)
|
||||
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetPendingChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_100, 0).UTC()
|
||||
|
||||
record := testPendingChallenge(now)
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetThrottledChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_150, 0).UTC()
|
||||
|
||||
record := testPendingChallenge(now)
|
||||
record.Status = challenge.StatusDeliveryThrottled
|
||||
record.DeliveryState = challenge.DeliveryThrottled
|
||||
record.Attempts.Send = 1
|
||||
record.Abuse.LastAttemptAt = timePointer(now)
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
}
|
||||
|
||||
func TestStoreGetStrictDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_130_200, 0).UTC()
|
||||
baseRecord := testChallenge(now)
|
||||
baseStored, err := redisRecordFromChallenge(baseRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(redisRecord) string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "malformed json",
|
||||
mutate: func(_ redisRecord) string {
|
||||
return "{"
|
||||
},
|
||||
wantErrText: "decode redis challenge record",
|
||||
},
|
||||
{
|
||||
name: "trailing json input",
|
||||
mutate: func(record redisRecord) string {
|
||||
return mustMarshalJSON(t, record) + "{}"
|
||||
},
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
mutate: func(record redisRecord) string {
|
||||
payload := map[string]any{
|
||||
"challenge_id": record.ChallengeID,
|
||||
"email": record.Email,
|
||||
"code_hash_base64": record.CodeHashBase64,
|
||||
"status": record.Status,
|
||||
"delivery_state": record.DeliveryState,
|
||||
"created_at": record.CreatedAt,
|
||||
"expires_at": record.ExpiresAt,
|
||||
"send_attempt_count": record.SendAttemptCount,
|
||||
"confirm_attempt_count": record.ConfirmAttemptCount,
|
||||
"last_attempt_at": record.LastAttemptAt,
|
||||
"confirmed_session_id": record.ConfirmedSessionID,
|
||||
"confirmed_client_public_key": record.ConfirmedClientPublicKey,
|
||||
"confirmed_at": record.ConfirmedAt,
|
||||
"unexpected": true,
|
||||
}
|
||||
return mustMarshalJSON(t, payload)
|
||||
},
|
||||
wantErrText: "unknown field",
|
||||
},
|
||||
{
|
||||
name: "unsupported status",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Status = challenge.Status("paused")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "unsupported delivery state",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.DeliveryState = challenge.DeliveryState("queued")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `delivery state "queued" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "missing required email",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Email = ""
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "challenge email",
|
||||
},
|
||||
{
|
||||
name: "challenge id mismatch",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.ChallengeID = "other-challenge"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `does not match requested`,
|
||||
},
|
||||
{
|
||||
name: "non canonical utc timestamp",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.CreatedAt = "2026-04-04T12:00:00+03:00"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "canonical UTC RFC3339Nano timestamp",
|
||||
},
|
||||
{
|
||||
name: "partial confirmation metadata",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.ConfirmedAt = nil
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "confirmation metadata must be either fully present or fully absent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored))
|
||||
|
||||
_, err := store.Get(context.Background(), baseRecord.ID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreKeySchemeAndTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"})
|
||||
now := time.Now().UTC()
|
||||
|
||||
prefixed := testPendingChallenge(now)
|
||||
prefixed.ID = common.ChallengeID("challenge:opaque/id?value")
|
||||
require.NoError(t, store.Create(context.Background(), prefixed))
|
||||
|
||||
key := store.lookupKey(prefixed.ID)
|
||||
assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key)
|
||||
assert.True(t, server.Exists(key))
|
||||
|
||||
freshTTL := server.TTL(key)
|
||||
assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod)
|
||||
assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second)
|
||||
|
||||
expired := testPendingChallenge(now.Add(-10 * time.Minute))
|
||||
expired.ID = common.ChallengeID("expired-challenge")
|
||||
expired.CreatedAt = now.Add(-20 * time.Minute)
|
||||
expired.ExpiresAt = now.Add(-1 * time.Minute)
|
||||
require.NoError(t, store.Create(context.Background(), expired))
|
||||
|
||||
expiredTTL := server.TTL(store.lookupKey(expired.ID))
|
||||
assert.LessOrEqual(t, expiredTTL, expirationGracePeriod)
|
||||
assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second)
|
||||
}
|
||||
|
||||
func TestStoreCreateConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
}
|
||||
|
||||
func TestStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreCompareAndSwap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_130_400, 0).UTC()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
next.Attempts.Send = 1
|
||||
next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute))
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), previous))
|
||||
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
|
||||
|
||||
got, err := store.Get(context.Background(), previous.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, next, got)
|
||||
})
|
||||
|
||||
t.Run("conflict when stored record differs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
stored := testPendingChallenge(now)
|
||||
previous := stored
|
||||
previous.Attempts.Send = 99
|
||||
next := stored
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), stored))
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("corrupt stored record returns adapter error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
server.Set(store.lookupKey(previous.ID), "{")
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.NotErrorIs(t, err, ports.ErrConflict)
|
||||
assert.ErrorContains(t, err, "decode redis challenge record")
|
||||
})
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "authsession:challenge:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func testPendingChallenge(now time.Time) challenge.Challenge {
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-pending"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-pending-code"),
|
||||
Status: challenge.StatusPendingSend,
|
||||
DeliveryState: challenge.DeliveryPending,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.InitialTTL),
|
||||
}
|
||||
}
|
||||
|
||||
func testChallenge(now time.Time) challenge.Challenge {
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-confirmed"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-code"),
|
||||
Status: challenge.StatusConfirmedPendingExpire,
|
||||
DeliveryState: challenge.DeliverySent,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.ConfirmedRetention),
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: 1,
|
||||
Confirm: 2,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: timePointer(now.Add(30 * time.Second)),
|
||||
},
|
||||
Confirmation: &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID("device-session-1"),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: now.Add(1 * time.Minute),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func timePointer(value time.Time) *time.Time {
|
||||
return &value
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func TestStorePingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
err := store.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestStoreGetNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(nil, common.ChallengeID("challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
Reference in New Issue
Block a user