package contracttest import ( "context" "crypto/ed25519" "testing" "time" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/domain/devicesession" "galaxy/authsession/internal/ports" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // SessionStoreFactory constructs a fresh SessionStore instance suitable for // one isolated contract subtest. type SessionStoreFactory func(t *testing.T) ports.SessionStore // RunSessionStoreContractTests executes the backend-agnostic SessionStore // contract suite against newStore. func RunSessionStoreContractTests(t *testing.T, newStore SessionStoreFactory) { t.Helper() t.Run("create and get", func(t *testing.T) { t.Parallel() store := newStore(t) record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) assert.Equal(t, record, got) }) t.Run("create conflict", func(t *testing.T) { t.Parallel() store := newStore(t) record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_050, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) err := store.Create(context.Background(), record) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrConflict) }) t.Run("get not found", func(t *testing.T) { t.Parallel() store := newStore(t) _, err := store.Get(context.Background(), common.DeviceSessionID("missing-session")) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrNotFound) }) t.Run("list by user id returns newest first", func(t *testing.T) { t.Parallel() store := newStore(t) older := contractActiveSession(t, "device-session-old", "user-1", time.Unix(10, 0).UTC()) newer := contractActiveSession(t, "device-session-new", "user-1", time.Unix(20, 0).UTC()) revoked := contractRevokedSession(t, "device-session-revoked", "user-1", time.Unix(15, 0).UTC()) otherUser := contractActiveSession(t, "device-session-other", "user-2", time.Unix(30, 0).UTC()) for _, record := range []devicesession.Session{older, newer, revoked, otherUser} { require.NoError(t, store.Create(context.Background(), record)) } got, err := store.ListByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) require.Len(t, got, 3) assert.Equal( t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID}, ) }) t.Run("list by user id returns empty slice for unknown user", func(t *testing.T) { t.Parallel() store := newStore(t) got, err := store.ListByUserID(context.Background(), common.UserID("unknown-user")) require.NoError(t, err) require.NotNil(t, got) assert.Empty(t, got) }) t.Run("count active by user id", func(t *testing.T) { t.Parallel() store := newStore(t) activeOne := contractActiveSession(t, "device-session-1", "user-1", time.Unix(40, 0).UTC()) activeTwo := contractActiveSession(t, "device-session-2", "user-1", time.Unix(50, 0).UTC()) revoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(60, 0).UTC()) otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(70, 0).UTC()) for _, record := range []devicesession.Session{activeOne, activeTwo, revoked, otherUser} { require.NoError(t, store.Create(context.Background(), record)) } count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) assert.Equal(t, 2, count) }) t.Run("revoke active session", func(t *testing.T) { t.Parallel() store := newStore(t) record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) revocation := contractRevocation(time.Unix(200, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", "") result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ DeviceSessionID: record.ID, Revocation: revocation, }) require.NoError(t, err) assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome) require.NotNil(t, result.Session.Revocation) assert.Equal(t, revocation, *result.Session.Revocation) count, err := store.CountActiveByUserID(context.Background(), record.UserID) require.NoError(t, err) assert.Zero(t, count) }) t.Run("revoke already revoked preserves stored revocation", func(t *testing.T) { t.Parallel() store := newStore(t) record := contractRevokedSession(t, "device-session-2", "user-1", time.Unix(110, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ DeviceSessionID: record.ID, Revocation: contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1"), }) require.NoError(t, err) assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome) require.NotNil(t, result.Session.Revocation) assert.Equal(t, *record.Revocation, *result.Session.Revocation) }) t.Run("revoke not found", func(t *testing.T) { t.Parallel() store := newStore(t) _, err := store.Revoke(context.Background(), ports.RevokeSessionInput{ DeviceSessionID: common.DeviceSessionID("missing-session"), Revocation: contractRevocation(time.Unix(210, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", ""), }) require.Error(t, err) assert.ErrorIs(t, err, ports.ErrNotFound) }) t.Run("revoke all by user id revokes active sessions newest first", func(t *testing.T) { t.Parallel() store := newStore(t) older := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC()) newer := contractActiveSession(t, "device-session-2", "user-1", time.Unix(200, 0).UTC()) alreadyRevoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(150, 0).UTC()) otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(250, 0).UTC()) for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} { require.NoError(t, store.Create(context.Background(), record)) } revocation := contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1") result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ UserID: common.UserID("user-1"), Revocation: revocation, }) require.NoError(t, err) assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome) require.Len(t, result.Sessions, 2) assert.Equal( t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID}, ) assert.Equal(t, revocation, *result.Sessions[0].Revocation) assert.Equal(t, revocation, *result.Sessions[1].Revocation) count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) assert.Zero(t, count) }) t.Run("revoke all by user id reports no active sessions", func(t *testing.T) { t.Parallel() store := newStore(t) record := contractRevokedSession(t, "device-session-5", "user-1", time.Unix(120, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{ UserID: common.UserID("user-1"), Revocation: contractRevocation(time.Unix(400, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", ""), }) require.NoError(t, err) assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome) require.NotNil(t, result.Sessions) assert.Empty(t, result.Sessions) }) t.Run("get returns defensive copies", func(t *testing.T) { t.Parallel() store := newStore(t) record := contractRevokedSession(t, "device-session-copy", "user-1", time.Unix(130, 0).UTC()) require.NoError(t, store.Create(context.Background(), record)) got, err := store.Get(context.Background(), record.ID) require.NoError(t, err) require.NotNil(t, got.Revocation) got.Revocation.ActorID = "mutated" again, err := store.Get(context.Background(), record.ID) require.NoError(t, err) require.NotNil(t, again.Revocation) assert.Equal(t, record, again) }) } func contractActiveSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { t.Helper() clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, }) require.NoError(t, err) record := devicesession.Session{ ID: common.DeviceSessionID(deviceSessionID), UserID: common.UserID(userID), ClientPublicKey: clientPublicKey, Status: devicesession.StatusActive, CreatedAt: createdAt, } require.NoError(t, record.Validate()) return record } func contractRevokedSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session { t.Helper() record := contractActiveSession(t, deviceSessionID, userID, createdAt) revocation := contractRevocation(createdAt.Add(time.Minute), devicesession.RevokeReasonDeviceLogout, "user", "user-actor") record.Status = devicesession.StatusRevoked record.Revocation = &revocation require.NoError(t, record.Validate()) return record } func contractRevocation(at time.Time, reasonCode common.RevokeReasonCode, actorType string, actorID string) devicesession.Revocation { record := devicesession.Revocation{ At: at, ReasonCode: reasonCode, ActorType: common.RevokeActorType(actorType), ActorID: actorID, } if err := record.Validate(); err != nil { panic(err) } return record }