531 lines
15 KiB
Go
531 lines
15 KiB
Go
package challengestore
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"encoding/json"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/authsession/internal/adapters/contracttest"
|
|
"galaxy/authsession/internal/domain/challenge"
|
|
"galaxy/authsession/internal/domain/common"
|
|
"galaxy/authsession/internal/ports"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/redis/go-redis/v9"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func newRedisClient(t *testing.T, server *miniredis.Miniredis) *redis.Client {
|
|
t.Helper()
|
|
|
|
client := redis.NewClient(&redis.Options{
|
|
Addr: server.Addr(),
|
|
Protocol: 2,
|
|
DisableIdentity: true,
|
|
})
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, client.Close())
|
|
})
|
|
|
|
return client
|
|
}
|
|
|
|
func TestStoreContract(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore {
|
|
t.Helper()
|
|
|
|
server := miniredis.RunT(t)
|
|
return newTestStore(t, server, Config{})
|
|
})
|
|
}
|
|
|
|
func TestNew(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
client := newRedisClient(t, server)
|
|
|
|
tests := []struct {
|
|
name string
|
|
client *redis.Client
|
|
cfg Config
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
client: client,
|
|
cfg: Config{KeyPrefix: "authsession:challenge:", OperationTimeout: 250 * time.Millisecond},
|
|
},
|
|
{
|
|
name: "nil client",
|
|
client: nil,
|
|
cfg: Config{KeyPrefix: "authsession:challenge:", OperationTimeout: 250 * time.Millisecond},
|
|
wantErr: "nil redis client",
|
|
},
|
|
{
|
|
name: "empty key prefix",
|
|
client: client,
|
|
cfg: Config{OperationTimeout: 250 * time.Millisecond},
|
|
wantErr: "redis key prefix must not be empty",
|
|
},
|
|
{
|
|
name: "non-positive operation timeout",
|
|
client: client,
|
|
cfg: Config{KeyPrefix: "authsession:challenge:"},
|
|
wantErr: "operation timeout must be positive",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, err := New(tt.client, tt.cfg)
|
|
if tt.wantErr != "" {
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, tt.wantErr)
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, store)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStoreCreateAndGet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
now := time.Unix(1_775_130_000, 0).UTC()
|
|
|
|
record := testChallenge(now)
|
|
|
|
require.NoError(t, store.Create(context.Background(), record))
|
|
|
|
got, err := store.Get(context.Background(), record.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, record, got)
|
|
|
|
got.CodeHash[0] = 0xFF
|
|
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
|
|
keyBytes[0] = 0xFE
|
|
|
|
again, err := store.Get(context.Background(), record.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, record.CodeHash, again.CodeHash)
|
|
require.NotNil(t, again.Confirmation)
|
|
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
|
|
}
|
|
|
|
func TestStoreCreateAndGetPendingChallenge(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
now := time.Unix(1_775_130_100, 0).UTC()
|
|
|
|
record := testPendingChallenge(now)
|
|
|
|
require.NoError(t, store.Create(context.Background(), record))
|
|
|
|
got, err := store.Get(context.Background(), record.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, record, got)
|
|
}
|
|
|
|
func TestStoreCreateAndGetThrottledChallenge(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
now := time.Unix(1_775_130_150, 0).UTC()
|
|
|
|
record := testPendingChallenge(now)
|
|
record.Status = challenge.StatusDeliveryThrottled
|
|
record.DeliveryState = challenge.DeliveryThrottled
|
|
record.Attempts.Send = 1
|
|
record.Abuse.LastAttemptAt = timePointer(now)
|
|
require.NoError(t, record.Validate())
|
|
|
|
require.NoError(t, store.Create(context.Background(), record))
|
|
|
|
got, err := store.Get(context.Background(), record.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, record, got)
|
|
}
|
|
|
|
func TestStoreGetStrictDecode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := time.Unix(1_775_130_200, 0).UTC()
|
|
baseRecord := testChallenge(now)
|
|
baseStored, err := redisRecordFromChallenge(baseRecord)
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
mutate func(redisRecord) string
|
|
wantErrText string
|
|
}{
|
|
{
|
|
name: "malformed json",
|
|
mutate: func(_ redisRecord) string {
|
|
return "{"
|
|
},
|
|
wantErrText: "decode redis challenge record",
|
|
},
|
|
{
|
|
name: "trailing json input",
|
|
mutate: func(record redisRecord) string {
|
|
return mustMarshalJSON(t, record) + "{}"
|
|
},
|
|
wantErrText: "unexpected trailing JSON input",
|
|
},
|
|
{
|
|
name: "unknown field",
|
|
mutate: func(record redisRecord) string {
|
|
payload := map[string]any{
|
|
"challenge_id": record.ChallengeID,
|
|
"email": record.Email,
|
|
"code_hash_base64": record.CodeHashBase64,
|
|
"status": record.Status,
|
|
"delivery_state": record.DeliveryState,
|
|
"created_at": record.CreatedAt,
|
|
"expires_at": record.ExpiresAt,
|
|
"send_attempt_count": record.SendAttemptCount,
|
|
"confirm_attempt_count": record.ConfirmAttemptCount,
|
|
"last_attempt_at": record.LastAttemptAt,
|
|
"confirmed_session_id": record.ConfirmedSessionID,
|
|
"confirmed_client_public_key": record.ConfirmedClientPublicKey,
|
|
"confirmed_at": record.ConfirmedAt,
|
|
"unexpected": true,
|
|
}
|
|
return mustMarshalJSON(t, payload)
|
|
},
|
|
wantErrText: "unknown field",
|
|
},
|
|
{
|
|
name: "unsupported status",
|
|
mutate: func(record redisRecord) string {
|
|
record.Status = challenge.Status("paused")
|
|
return mustMarshalJSON(t, record)
|
|
},
|
|
wantErrText: `status "paused" is unsupported`,
|
|
},
|
|
{
|
|
name: "unsupported delivery state",
|
|
mutate: func(record redisRecord) string {
|
|
record.DeliveryState = challenge.DeliveryState("queued")
|
|
return mustMarshalJSON(t, record)
|
|
},
|
|
wantErrText: `delivery state "queued" is unsupported`,
|
|
},
|
|
{
|
|
name: "missing required email",
|
|
mutate: func(record redisRecord) string {
|
|
record.Email = ""
|
|
return mustMarshalJSON(t, record)
|
|
},
|
|
wantErrText: "challenge email",
|
|
},
|
|
{
|
|
name: "challenge id mismatch",
|
|
mutate: func(record redisRecord) string {
|
|
record.ChallengeID = "other-challenge"
|
|
return mustMarshalJSON(t, record)
|
|
},
|
|
wantErrText: `does not match requested`,
|
|
},
|
|
{
|
|
name: "non canonical utc timestamp",
|
|
mutate: func(record redisRecord) string {
|
|
record.CreatedAt = "2026-04-04T12:00:00+03:00"
|
|
return mustMarshalJSON(t, record)
|
|
},
|
|
wantErrText: "canonical UTC RFC3339Nano timestamp",
|
|
},
|
|
{
|
|
name: "partial confirmation metadata",
|
|
mutate: func(record redisRecord) string {
|
|
record.ConfirmedAt = nil
|
|
return mustMarshalJSON(t, record)
|
|
},
|
|
wantErrText: "confirmation metadata must be either fully present or fully absent",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored))
|
|
|
|
_, err := store.Get(context.Background(), baseRecord.ID)
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, tt.wantErrText)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStoreKeySchemeAndTTL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"})
|
|
now := time.Now().UTC()
|
|
|
|
prefixed := testPendingChallenge(now)
|
|
prefixed.ID = common.ChallengeID("challenge:opaque/id?value")
|
|
require.NoError(t, store.Create(context.Background(), prefixed))
|
|
|
|
key := store.lookupKey(prefixed.ID)
|
|
assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key)
|
|
assert.True(t, server.Exists(key))
|
|
|
|
freshTTL := server.TTL(key)
|
|
assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod)
|
|
assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second)
|
|
|
|
expired := testPendingChallenge(now.Add(-10 * time.Minute))
|
|
expired.ID = common.ChallengeID("expired-challenge")
|
|
expired.CreatedAt = now.Add(-20 * time.Minute)
|
|
expired.ExpiresAt = now.Add(-1 * time.Minute)
|
|
require.NoError(t, store.Create(context.Background(), expired))
|
|
|
|
expiredTTL := server.TTL(store.lookupKey(expired.ID))
|
|
assert.LessOrEqual(t, expiredTTL, expirationGracePeriod)
|
|
assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second)
|
|
}
|
|
|
|
func TestStoreCreateConflict(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC())
|
|
|
|
require.NoError(t, store.Create(context.Background(), record))
|
|
|
|
err := store.Create(context.Background(), record)
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ports.ErrConflict)
|
|
}
|
|
|
|
func TestStoreGetNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
|
|
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ports.ErrNotFound)
|
|
}
|
|
|
|
func TestStoreCompareAndSwap(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := time.Unix(1_775_130_400, 0).UTC()
|
|
|
|
t.Run("success", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
previous := testPendingChallenge(now)
|
|
next := previous
|
|
next.Status = challenge.StatusSent
|
|
next.DeliveryState = challenge.DeliverySent
|
|
next.Attempts.Send = 1
|
|
next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute))
|
|
|
|
require.NoError(t, store.Create(context.Background(), previous))
|
|
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
|
|
|
|
got, err := store.Get(context.Background(), previous.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, next, got)
|
|
})
|
|
|
|
t.Run("conflict when stored record differs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
stored := testPendingChallenge(now)
|
|
previous := stored
|
|
previous.Attempts.Send = 99
|
|
next := stored
|
|
next.Status = challenge.StatusSent
|
|
next.DeliveryState = challenge.DeliverySent
|
|
|
|
require.NoError(t, store.Create(context.Background(), stored))
|
|
|
|
err := store.CompareAndSwap(context.Background(), previous, next)
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ports.ErrConflict)
|
|
})
|
|
|
|
t.Run("not found", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
previous := testPendingChallenge(now)
|
|
next := previous
|
|
next.Status = challenge.StatusSent
|
|
next.DeliveryState = challenge.DeliverySent
|
|
|
|
err := store.CompareAndSwap(context.Background(), previous, next)
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ports.ErrNotFound)
|
|
})
|
|
|
|
t.Run("corrupt stored record returns adapter error", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
previous := testPendingChallenge(now)
|
|
next := previous
|
|
next.Status = challenge.StatusSent
|
|
next.DeliveryState = challenge.DeliverySent
|
|
|
|
server.Set(store.lookupKey(previous.ID), "{")
|
|
|
|
err := store.CompareAndSwap(context.Background(), previous, next)
|
|
require.Error(t, err)
|
|
assert.NotErrorIs(t, err, ports.ErrConflict)
|
|
assert.ErrorContains(t, err, "decode redis challenge record")
|
|
})
|
|
}
|
|
|
|
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
|
t.Helper()
|
|
|
|
if cfg.KeyPrefix == "" {
|
|
cfg.KeyPrefix = "authsession:challenge:"
|
|
}
|
|
if cfg.OperationTimeout == 0 {
|
|
cfg.OperationTimeout = 250 * time.Millisecond
|
|
}
|
|
|
|
store, err := New(newRedisClient(t, server), cfg)
|
|
require.NoError(t, err)
|
|
|
|
return store
|
|
}
|
|
|
|
func testPendingChallenge(now time.Time) challenge.Challenge {
|
|
return challenge.Challenge{
|
|
ID: common.ChallengeID("challenge-pending"),
|
|
Email: common.Email("pilot@example.com"),
|
|
CodeHash: []byte("hashed-pending-code"),
|
|
PreferredLanguage: "en",
|
|
Status: challenge.StatusPendingSend,
|
|
DeliveryState: challenge.DeliveryPending,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(challenge.InitialTTL),
|
|
}
|
|
}
|
|
|
|
func testChallenge(now time.Time) challenge.Challenge {
|
|
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
|
0, 1, 2, 3, 4, 5, 6, 7,
|
|
8, 9, 10, 11, 12, 13, 14, 15,
|
|
16, 17, 18, 19, 20, 21, 22, 23,
|
|
24, 25, 26, 27, 28, 29, 30, 31,
|
|
})
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return challenge.Challenge{
|
|
ID: common.ChallengeID("challenge-confirmed"),
|
|
Email: common.Email("pilot@example.com"),
|
|
CodeHash: []byte("hashed-code"),
|
|
PreferredLanguage: "en",
|
|
Status: challenge.StatusConfirmedPendingExpire,
|
|
DeliveryState: challenge.DeliverySent,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(challenge.ConfirmedRetention),
|
|
Attempts: challenge.AttemptCounters{
|
|
Send: 1,
|
|
Confirm: 2,
|
|
},
|
|
Abuse: challenge.AbuseMetadata{
|
|
LastAttemptAt: timePointer(now.Add(30 * time.Second)),
|
|
},
|
|
Confirmation: &challenge.Confirmation{
|
|
SessionID: common.DeviceSessionID("device-session-1"),
|
|
ClientPublicKey: clientPublicKey,
|
|
ConfirmedAt: now.Add(1 * time.Minute),
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestStoreGetDefaultsMissingPreferredLanguageToEnglish(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
now := time.Unix(1_775_130_250, 0).UTC()
|
|
|
|
record := testPendingChallenge(now)
|
|
stored, err := redisRecordFromChallenge(record)
|
|
require.NoError(t, err)
|
|
stored.PreferredLanguage = ""
|
|
|
|
payload := mustMarshalJSON(t, map[string]any{
|
|
"challenge_id": stored.ChallengeID,
|
|
"email": stored.Email,
|
|
"code_hash_base64": stored.CodeHashBase64,
|
|
"status": stored.Status,
|
|
"delivery_state": stored.DeliveryState,
|
|
"created_at": stored.CreatedAt,
|
|
"expires_at": stored.ExpiresAt,
|
|
"send_attempt_count": stored.SendAttemptCount,
|
|
"confirm_attempt_count": stored.ConfirmAttemptCount,
|
|
})
|
|
server.Set(store.lookupKey(record.ID), payload)
|
|
|
|
got, err := store.Get(context.Background(), record.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "en", got.PreferredLanguage)
|
|
}
|
|
|
|
func timePointer(value time.Time) *time.Time {
|
|
return &value
|
|
}
|
|
|
|
func mustMarshalJSON(t *testing.T, value any) string {
|
|
t.Helper()
|
|
|
|
payload, err := json.Marshal(value)
|
|
require.NoError(t, err)
|
|
|
|
return string(payload)
|
|
}
|
|
|
|
func TestStoreGetNilContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := newTestStore(t, server, Config{})
|
|
|
|
_, err := store.Get(nil, common.ChallengeID("challenge"))
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "nil context")
|
|
}
|