package sessionstore import ( "context" "crypto/ed25519" "encoding/json" "testing" "time" "galaxy/authsession/internal/adapters/contracttest" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/domain/devicesession" "galaxy/authsession/internal/ports" "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestStoreContract(t *testing.T) { t.Parallel() contracttest.RunSessionStoreContractTests(t, func(t *testing.T) ports.SessionStore { t.Helper() server := miniredis.RunT(t) return newTestStore(t, server, Config{}) }) } func TestNew(t *testing.T) { t.Parallel() server := miniredis.RunT(t) tests := []struct { name string cfg Config wantErr string }{ { name: "valid config", cfg: Config{ Addr: server.Addr(), DB: 1, SessionKeyPrefix: "authsession:session:", UserSessionsKeyPrefix: "authsession:user-sessions:", UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", OperationTimeout: 250 * time.Millisecond, }, }, { name: "empty addr", cfg: Config{ SessionKeyPrefix: "authsession:session:", UserSessionsKeyPrefix: "authsession:user-sessions:", UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", OperationTimeout: 250 * time.Millisecond, }, wantErr: "redis addr must not be empty", }, { name: "negative db", cfg: Config{ Addr: server.Addr(), DB: -1, SessionKeyPrefix: "authsession:session:", UserSessionsKeyPrefix: "authsession:user-sessions:", UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", OperationTimeout: 250 * time.Millisecond, }, wantErr: "redis db must not be negative", }, { name: "empty session prefix", cfg: Config{ Addr: server.Addr(), UserSessionsKeyPrefix: "authsession:user-sessions:", UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", OperationTimeout: 250 * time.Millisecond, }, wantErr: "session key prefix must not be empty", }, { name: "empty all sessions prefix", cfg: Config{ Addr: server.Addr(), SessionKeyPrefix: "authsession:session:", UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", OperationTimeout: 250 * time.Millisecond, }, wantErr: "user sessions key prefix must not be empty", }, { name: "empty active sessions prefix", cfg: Config{ Addr: server.Addr(), SessionKeyPrefix: "authsession:session:", UserSessionsKeyPrefix: "authsession:user-sessions:", OperationTimeout: 250 * time.Millisecond, }, wantErr: "user active sessions key prefix must not be empty", }, { name: "non positive timeout", cfg: Config{ Addr: server.Addr(), SessionKeyPrefix: "authsession:session:", UserSessionsKeyPrefix: "authsession:user-sessions:", UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:", }, wantErr: "operation timeout must be positive", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() store, err := New(tt.cfg) if tt.wantErr != "" { require.Error(t, err) assert.ErrorContains(t, err, tt.wantErr) return } require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, store.Close()) }) }) } } func TestStorePing(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) require.NoError(t, store.Ping(context.Background())) } func TestStoreCreateAndGetActive(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, record, got) got.Revocation = &devicesession.Revocation{ At: got.CreatedAt.Add(time.Minute), ReasonCode: devicesession.RevokeReasonAdminRevoke, ActorType: common.RevokeActorType("admin"), } again, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Nil(t, again.Revocation) assert.Equal(t, record, again) } func TestStoreCreateAndGetRevoked(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := revokedSessionFixture("device-session-2", "user-1", time.Unix(1_775_240_100, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, record, got) count, err := store.CountActiveByUserID(context.Background(), record.UserID) require.NoError(t, err) assert.Zero(t, count) } func TestStoreGetNotFound(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) _, err := store.Get(context.Background(), common.DeviceSessionID("missing-session")) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrNotFound) } func TestStoreCreateConflict(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_200, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) err := store.Create(context.Background(), record) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrConflict) } func TestStoreIndexesAndOrdering(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) older := activeSessionFixture("device-session-old", "user-1", time.Unix(10, 0).UTC()) newer := activeSessionFixture("device-session-new", "user-1", time.Unix(20, 0).UTC()) revoked := revokedSessionFixture("device-session-revoked", "user-1", time.Unix(15, 0).UTC()) otherUser := activeSessionFixture("device-session-other", "user-2", time.Unix(30, 0).UTC()) for _, record := range []devicesession.Session{older, newer, revoked, otherUser} { require.NoError(t, store.Create(context.Background(), record)) } got, err := store.ListByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) require.Len(t, got, 3) assert.Equal(t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID}) count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) assert.Equal(t, 2, count) unknown, err := store.ListByUserID(context.Background(), common.UserID("unknown-user")) require.NoError(t, err) assert.Empty(t, unknown) } func TestStoreKeyPrefixesAndEncodedPrimaryKey(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{ SessionKeyPrefix: "custom:session:", UserSessionsKeyPrefix: "custom:user-sessions:", UserActiveSessionsKeyPrefix: "custom:user-active-sessions:", }) record := activeSessionFixture("device/session:opaque?1", "user/opaque:1", time.Unix(40, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) primaryKey := store.sessionKey(record.ID) assert.Equal(t, "custom:session:"+encodeKeyComponent(record.ID.String()), primaryKey) assert.True(t, server.Exists(primaryKey)) allSessionsKey := store.userSessionsKey(record.UserID) activeSessionsKey := store.userActiveSessionsKey(record.UserID) assert.Equal(t, "custom:user-sessions:"+encodeKeyComponent(record.UserID.String()), allSessionsKey) assert.Equal(t, "custom:user-active-sessions:"+encodeKeyComponent(record.UserID.String()), activeSessionsKey) allMembers, err := server.ZMembers(allSessionsKey) require.NoError(t, err) assert.Equal(t, []string{record.ID.String()}, allMembers) activeMembers, err := server.ZMembers(activeSessionsKey) require.NoError(t, err) assert.Equal(t, []string{record.ID.String()}, activeMembers) } func TestStoreRevoke(t *testing.T) { t.Parallel() t.Run("active session", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) revocation := devicesession.Revocation{ At: time.Unix(200, 0).UTC(), ReasonCode: devicesession.RevokeReasonLogoutAll, ActorType: common.RevokeActorType("system"), } result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ DeviceSessionID: record.ID, Revocation: revocation, }) require.NoError(t, err) assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome) require.NotNil(t, result.Session.Revocation) assert.Equal(t, revocation, *result.Session.Revocation) count, err := store.CountActiveByUserID(context.Background(), record.UserID) require.NoError(t, err) assert.Zero(t, count) }) t.Run("already revoked keeps stored revocation", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := revokedSessionFixture("device-session-2", "user-1", time.Unix(100, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ DeviceSessionID: record.ID, Revocation: devicesession.Revocation{ At: time.Unix(300, 0).UTC(), ReasonCode: devicesession.RevokeReasonAdminRevoke, ActorType: common.RevokeActorType("admin"), ActorID: "admin-1", }, }) require.NoError(t, err) assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome) require.NotNil(t, result.Session.Revocation) assert.Equal(t, *record.Revocation, *result.Session.Revocation) }) t.Run("unknown session", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) _, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ DeviceSessionID: common.DeviceSessionID("missing-session"), Revocation: devicesession.Revocation{ At: time.Unix(200, 0).UTC(), ReasonCode: devicesession.RevokeReasonLogoutAll, ActorType: common.RevokeActorType("system"), }, }) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrNotFound) }) } func TestStoreRevokeAllByUserID(t *testing.T) { t.Parallel() t.Run("revokes active sessions newest first and clears active index", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) older := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC()) newer := activeSessionFixture("device-session-2", "user-1", time.Unix(200, 0).UTC()) alreadyRevoked := revokedSessionFixture("device-session-3", "user-1", time.Unix(150, 0).UTC()) otherUser := activeSessionFixture("device-session-4", "user-2", time.Unix(250, 0).UTC()) for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} { require.NoError(t, store.Create(context.Background(), record)) } revocation := devicesession.Revocation{ At: time.Unix(300, 0).UTC(), ReasonCode: devicesession.RevokeReasonAdminRevoke, ActorType: common.RevokeActorType("admin"), ActorID: "admin-1", } result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ UserID: common.UserID("user-1"), Revocation: revocation, }) require.NoError(t, err) assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome) require.Len(t, result.Sessions, 2) assert.Equal(t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID}) assert.Equal(t, revocation, *result.Sessions[0].Revocation) assert.Equal(t, revocation, *result.Sessions[1].Revocation) count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) assert.Zero(t, count) otherCount, err := store.CountActiveByUserID(context.Background(), common.UserID("user-2")) require.NoError(t, err) assert.Equal(t, 1, otherCount) }) t.Run("no active sessions", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := revokedSessionFixture("device-session-5", "user-1", time.Unix(100, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ UserID: common.UserID("user-1"), Revocation: devicesession.Revocation{ At: time.Unix(400, 0).UTC(), ReasonCode: devicesession.RevokeReasonAdminRevoke, ActorType: common.RevokeActorType("admin"), }, }) require.NoError(t, err) assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome) assert.Empty(t, result.Sessions) }) } func TestStoreStrictDecodeCorruption(t *testing.T) { t.Parallel() now := time.Unix(1_775_240_300, 0).UTC() baseRecord := revokedSessionFixture("device-session-corrupt", "user-1", now) stored, err := redisRecordFromSession(baseRecord) require.NoError(t, err) tests := []struct { name string mutate func(redisRecord) string wantErrText string }{ { name: "malformed json", mutate: func(_ redisRecord) string { return "{" }, wantErrText: "decode redis session record", }, { name: "trailing json input", mutate: func(record redisRecord) string { return mustMarshalJSON(t, record) + "{}" }, wantErrText: "unexpected trailing JSON input", }, { name: "unknown field", mutate: func(record redisRecord) string { payload := map[string]any{ "device_session_id": record.DeviceSessionID, "user_id": record.UserID, "client_public_key_base64": record.ClientPublicKeyBase64, "status": record.Status, "created_at": record.CreatedAt, "revoked_at": record.RevokedAt, "revoke_reason_code": record.RevokeReasonCode, "revoke_actor_type": record.RevokeActorType, "revoke_actor_id": record.RevokeActorID, "unexpected": true, } return mustMarshalJSON(t, payload) }, wantErrText: "unknown field", }, { name: "unsupported status", mutate: func(record redisRecord) string { record.Status = devicesession.Status("paused") return mustMarshalJSON(t, record) }, wantErrText: `status "paused" is unsupported`, }, { name: "non canonical timestamp", mutate: func(record redisRecord) string { record.CreatedAt = "2026-04-04T12:00:00+03:00" return mustMarshalJSON(t, record) }, wantErrText: "canonical UTC RFC3339Nano timestamp", }, { name: "incomplete revocation metadata", mutate: func(record redisRecord) string { record.RevokeActorType = "" return mustMarshalJSON(t, record) }, wantErrText: "revocation metadata must be either fully present or fully absent", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) server.Set(store.sessionKey(baseRecord.ID), tt.mutate(stored)) _, err := store.Get(context.Background(), baseRecord.ID) require.Error(t, err) assert.ErrorContains(t, err, tt.wantErrText) }) } } func TestStoreListByUserIDDetectsCorruptIndexes(t *testing.T) { t.Parallel() t.Run("missing primary record", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) userID := common.UserID("user-1") _, err := server.ZAdd(store.userSessionsKey(userID), 100, "missing-session") require.NoError(t, err) _, err = store.ListByUserID(context.Background(), userID) require.Error(t, err) assert.ErrorContains(t, err, "references missing session") }) t.Run("wrong user id in primary record", func(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := activeSessionFixture("device-session-1", "user-2", time.Unix(100, 0).UTC()) require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record)) _, err := server.ZAdd(store.userSessionsKey(common.UserID("user-1")), createdAtScore(record.CreatedAt), record.ID.String()) require.NoError(t, err) _, err = store.ListByUserID(context.Background(), common.UserID("user-1")) require.Error(t, err) assert.ErrorContains(t, err, `belongs to "user-2"`) }) } func TestStoreRevokeAllByUserIDDetectsCorruptActiveIndex(t *testing.T) { t.Parallel() server := miniredis.RunT(t) store := newTestStore(t, server, Config{}) record := revokedSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC()) require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record)) _, err := server.ZAdd(store.userActiveSessionsKey(record.UserID), createdAtScore(record.CreatedAt), record.ID.String()) require.NoError(t, err) _, err = store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ UserID: record.UserID, Revocation: devicesession.Revocation{ At: time.Unix(200, 0).UTC(), ReasonCode: devicesession.RevokeReasonAdminRevoke, ActorType: common.RevokeActorType("admin"), }, }) require.Error(t, err) assert.ErrorContains(t, err, `is "revoked"`) } func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store { t.Helper() if cfg.Addr == "" { cfg.Addr = server.Addr() } if cfg.SessionKeyPrefix == "" { cfg.SessionKeyPrefix = "authsession:session:" } if cfg.UserSessionsKeyPrefix == "" { cfg.UserSessionsKeyPrefix = "authsession:user-sessions:" } if cfg.UserActiveSessionsKeyPrefix == "" { cfg.UserActiveSessionsKeyPrefix = "authsession:user-active-sessions:" } if cfg.OperationTimeout == 0 { cfg.OperationTimeout = 250 * time.Millisecond } store, err := New(cfg) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, store.Close()) }) return store } func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, }) if err != nil { panic(err) } return devicesession.Session{ ID: common.DeviceSessionID(deviceSessionID), UserID: common.UserID(userID), ClientPublicKey: clientPublicKey, Status: devicesession.StatusActive, CreatedAt: createdAt, } } func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { record := activeSessionFixture(deviceSessionID, userID, createdAt) record.Status = devicesession.StatusRevoked record.Revocation = &devicesession.Revocation{ At: createdAt.Add(time.Minute), ReasonCode: devicesession.RevokeReasonDeviceLogout, ActorType: common.RevokeActorType("user"), ActorID: "user-actor", } return record } func seedSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record devicesession.Session) error { t.Helper() stored, err := redisRecordFromSession(record) require.NoError(t, err) server.Set(key, mustMarshalJSON(t, stored)) return nil } func mustMarshalJSON(t *testing.T, value any) string { t.Helper() payload, err := json.Marshal(value) require.NoError(t, err) return string(payload) }