package challengestore import ( "context" "crypto/ed25519" "encoding/json" "testing" "time" "galaxy/authsession/internal/adapters/contracttest" "galaxy/authsession/internal/domain/challenge" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/ports" "github.com/alicebob/miniredis/v2" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func newRedisClient(t *testing.T, server *miniredis.Miniredis) *redis.Client { t.Helper() client := redis.NewClient(&redis.Options{ Addr: server.Addr(), Protocol: 2, DisableIdentity: true, }) t.Cleanup(func() { assert.NoError(t, client.Close()) }) return client } func TestStoreContract(t *testing.T) { t.Parallel() contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore { t.Helper() server := miniredis.RunT(t) return newTestStore(t, server, Config{}) }) } func TestNew(t *testing.T) { t.Parallel() server := miniredis.RunT(t) client := newRedisClient(t, server) tests := []struct { name string client *redis.Client cfg Config wantErr string }{ { name: "valid config", client: client, cfg: Config{KeyPrefix: "authsession:challenge:", OperationTimeout: 250 * time.Millisecond}, }, { name: "nil client", client: nil, cfg: Config{KeyPrefix: "authsession:challenge:", OperationTimeout: 250 * time.Millisecond}, wantErr: "nil redis client", }, { name: "empty key prefix", client: client, cfg: Config{OperationTimeout: 250 * time.Millisecond}, wantErr: "redis key prefix must not be empty", }, { name: "non-positive operation timeout", client: client, cfg: Config{KeyPrefix: "authsession:challenge:"}, wantErr: "operation timeout must be positive", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() store, err := New(tt.client, tt.cfg) if tt.wantErr != "" { require.Error(t, err) assert.ErrorContains(t, err, tt.wantErr) return } require.NoError(t, err) require.NotNil(t, store) }) } } func TestStoreCreateAndGet(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) now := time.Unix(1_775_130_000, 0).UTC() record := testChallenge(now) require.NoError(t, store.Create(context.Background(), record)) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, record, got) got.CodeHash[0] = 0xFF keyBytes := got.Confirmation.ClientPublicKey.PublicKey() keyBytes[0] = 0xFE again, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, record.CodeHash, again.CodeHash) require.NotNil(t, again.Confirmation) assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String()) } func TestStoreCreateAndGetPendingChallenge(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) now := time.Unix(1_775_130_100, 0).UTC() record := testPendingChallenge(now) require.NoError(t, store.Create(context.Background(), record)) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, record, got) } func TestStoreCreateAndGetThrottledChallenge(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) now := time.Unix(1_775_130_150, 0).UTC() record := testPendingChallenge(now) record.Status = challenge.StatusDeliveryThrottled record.DeliveryState = challenge.DeliveryThrottled record.Attempts.Send = 1 record.Abuse.LastAttemptAt = timePointer(now) require.NoError(t, record.Validate()) require.NoError(t, store.Create(context.Background(), record)) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, record, got) } func TestStoreGetStrictDecode(t *testing.T) { t.Parallel() now := time.Unix(1_775_130_200, 0).UTC() baseRecord := testChallenge(now) baseStored, err := redisRecordFromChallenge(baseRecord) require.NoError(t, err) tests := []struct { name string mutate func(redisRecord) string wantErrText string }{ { name: "malformed json", mutate: func(_ redisRecord) string { return "{" }, wantErrText: "decode redis challenge record", }, { name: "trailing json input", mutate: func(record redisRecord) string { return mustMarshalJSON(t, record) + "{}" }, wantErrText: "unexpected trailing JSON input", }, { name: "unknown field", mutate: func(record redisRecord) string { payload := map[string]any{ "challenge_id": record.ChallengeID, "email": record.Email, "code_hash_base64": record.CodeHashBase64, "status": record.Status, "delivery_state": record.DeliveryState, "created_at": record.CreatedAt, "expires_at": record.ExpiresAt, "send_attempt_count": record.SendAttemptCount, "confirm_attempt_count": record.ConfirmAttemptCount, "last_attempt_at": record.LastAttemptAt, "confirmed_session_id": record.ConfirmedSessionID, "confirmed_client_public_key": record.ConfirmedClientPublicKey, "confirmed_at": record.ConfirmedAt, "unexpected": true, } return mustMarshalJSON(t, payload) }, wantErrText: "unknown field", }, { name: "unsupported status", mutate: func(record redisRecord) string { record.Status = challenge.Status("paused") return mustMarshalJSON(t, record) }, wantErrText: `status "paused" is unsupported`, }, { name: "unsupported delivery state", mutate: func(record redisRecord) string { record.DeliveryState = challenge.DeliveryState("queued") return mustMarshalJSON(t, record) }, wantErrText: `delivery state "queued" is unsupported`, }, { name: "missing required email", mutate: func(record redisRecord) string { record.Email = "" return mustMarshalJSON(t, record) }, wantErrText: "challenge email", }, { name: "challenge id mismatch", mutate: func(record redisRecord) string { record.ChallengeID = "other-challenge" return mustMarshalJSON(t, record) }, wantErrText: `does not match requested`, }, { name: "non canonical utc timestamp", mutate: func(record redisRecord) string { record.CreatedAt = "2026-04-04T12:00:00+03:00" return mustMarshalJSON(t, record) }, wantErrText: "canonical UTC RFC3339Nano timestamp", }, { name: "partial confirmation metadata", mutate: func(record redisRecord) string { record.ConfirmedAt = nil return mustMarshalJSON(t, record) }, wantErrText: "confirmation metadata must be either fully present or fully absent", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored)) _, err := store.Get(context.Background(), baseRecord.ID) require.Error(t, err) assert.ErrorContains(t, err, tt.wantErrText) }) } } func TestStoreKeySchemeAndTTL(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"}) now := time.Now().UTC() prefixed := testPendingChallenge(now) prefixed.ID = common.ChallengeID("challenge:opaque/id?value") require.NoError(t, store.Create(context.Background(), prefixed)) key := store.lookupKey(prefixed.ID) assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key) assert.True(t, server.Exists(key)) freshTTL := server.TTL(key) assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod) assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second) expired := testPendingChallenge(now.Add(-10 * time.Minute)) expired.ID = common.ChallengeID("expired-challenge") expired.CreatedAt = now.Add(-20 * time.Minute) expired.ExpiresAt = now.Add(-1 * time.Minute) require.NoError(t, store.Create(context.Background(), expired)) expiredTTL := server.TTL(store.lookupKey(expired.ID)) assert.LessOrEqual(t, expiredTTL, expirationGracePeriod) assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second) } func TestStoreCreateConflict(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) err := store.Create(context.Background(), record) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrConflict) } func TestStoreGetNotFound(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) _, err := store.Get(context.Background(), common.ChallengeID("missing-challenge")) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrNotFound) } func TestStoreCompareAndSwap(t *testing.T) { t.Parallel() now := time.Unix(1_775_130_400, 0).UTC() t.Run("success", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) previous := testPendingChallenge(now) next := previous next.Status = challenge.StatusSent next.DeliveryState = challenge.DeliverySent next.Attempts.Send = 1 next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute)) require.NoError(t, store.Create(context.Background(), previous)) require.NoError(t, store.CompareAndSwap(context.Background(), previous, next)) got, err := store.Get(context.Background(), previous.ID) require.NoError(t, err) assert.Equal(t, next, got) }) t.Run("conflict when stored record differs", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) stored := testPendingChallenge(now) previous := stored previous.Attempts.Send = 99 next := stored next.Status = challenge.StatusSent next.DeliveryState = challenge.DeliverySent require.NoError(t, store.Create(context.Background(), stored)) err := store.CompareAndSwap(context.Background(), previous, next) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrConflict) }) t.Run("not found", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) previous := testPendingChallenge(now) next := previous next.Status = challenge.StatusSent next.DeliveryState = challenge.DeliverySent err := store.CompareAndSwap(context.Background(), previous, next) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrNotFound) }) t.Run("corrupt stored record returns adapter error", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) previous := testPendingChallenge(now) next := previous next.Status = challenge.StatusSent next.DeliveryState = challenge.DeliverySent server.Set(store.lookupKey(previous.ID), "{") err := store.CompareAndSwap(context.Background(), previous, next) require.Error(t, err) assert.NotErrorIs(t, err, ports.ErrConflict) assert.ErrorContains(t, err, "decode redis challenge record") }) } func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store { t.Helper() if cfg.KeyPrefix == "" { cfg.KeyPrefix = "authsession:challenge:" } if cfg.OperationTimeout == 0 { cfg.OperationTimeout = 250 * time.Millisecond } store, err := New(newRedisClient(t, server), cfg) require.NoError(t, err) return store } func testPendingChallenge(now time.Time) challenge.Challenge { return challenge.Challenge{ ID: common.ChallengeID("challenge-pending"), Email: common.Email("pilot@example.com"), CodeHash: []byte("hashed-pending-code"), PreferredLanguage: "en", Status: challenge.StatusPendingSend, DeliveryState: challenge.DeliveryPending, CreatedAt: now, ExpiresAt: now.Add(challenge.InitialTTL), } } func testChallenge(now time.Time) challenge.Challenge { clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, }) if err != nil { panic(err) } return challenge.Challenge{ ID: common.ChallengeID("challenge-confirmed"), Email: common.Email("pilot@example.com"), CodeHash: []byte("hashed-code"), PreferredLanguage: "en", Status: challenge.StatusConfirmedPendingExpire, DeliveryState: challenge.DeliverySent, CreatedAt: now, ExpiresAt: now.Add(challenge.ConfirmedRetention), Attempts: challenge.AttemptCounters{ Send: 1, Confirm: 2, }, Abuse: challenge.AbuseMetadata{ LastAttemptAt: timePointer(now.Add(30 * time.Second)), }, Confirmation: &challenge.Confirmation{ SessionID: common.DeviceSessionID("device-session-1"), ClientPublicKey: clientPublicKey, ConfirmedAt: now.Add(1 * time.Minute), }, } } func TestStoreGetDefaultsMissingPreferredLanguageToEnglish(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) now := time.Unix(1_775_130_250, 0).UTC() record := testPendingChallenge(now) stored, err := redisRecordFromChallenge(record) require.NoError(t, err) stored.PreferredLanguage = "" payload := mustMarshalJSON(t, map[string]any{ "challenge_id": stored.ChallengeID, "email": stored.Email, "code_hash_base64": stored.CodeHashBase64, "status": stored.Status, "delivery_state": stored.DeliveryState, "created_at": stored.CreatedAt, "expires_at": stored.ExpiresAt, "send_attempt_count": stored.SendAttemptCount, "confirm_attempt_count": stored.ConfirmAttemptCount, }) server.Set(store.lookupKey(record.ID), payload) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, "en", got.PreferredLanguage) } func timePointer(value time.Time) *time.Time { return &value } func mustMarshalJSON(t *testing.T, value any) string { t.Helper() payload, err := json.Marshal(value) require.NoError(t, err) return string(payload) } func TestStoreGetNilContext(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) _, err := store.Get(nil, common.ChallengeID("challenge")) require.Error(t, err) assert.ErrorContains(t, err, "nil context") }