367 lines
10 KiB
Go
367 lines
10 KiB
Go
package events
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/config"
|
|
"galaxy/gateway/internal/session"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := session.NewMemoryCache()
|
|
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
|
running := runTestSubscriber(t, subscriber)
|
|
defer running.stop(t)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": "public-key-123",
|
|
"status": string(session.StatusActive),
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
record, err := store.Lookup(context.Background(), "device-session-123")
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return record.UserID == "user-123" && record.Status == session.StatusActive
|
|
}, time.Second, 10*time.Millisecond)
|
|
}
|
|
|
|
func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := session.NewMemoryCache()
|
|
require.NoError(t, store.Upsert(session.Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: session.StatusActive,
|
|
}))
|
|
|
|
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
|
running := runTestSubscriber(t, subscriber)
|
|
defer running.stop(t)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": "public-key-123",
|
|
"status": string(session.StatusRevoked),
|
|
"revoked_at_ms": "123456789",
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
record, err := store.Lookup(context.Background(), "device-session-123")
|
|
if err != nil || record.RevokedAtMS == nil {
|
|
return false
|
|
}
|
|
|
|
return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789
|
|
}, time.Second, 10*time.Millisecond)
|
|
}
|
|
|
|
func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := session.NewMemoryCache()
|
|
handler := &recordingSessionRevocationHandler{}
|
|
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
|
running := runTestSubscriber(t, subscriber)
|
|
defer running.stop(t)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": "public-key-123",
|
|
"status": string(session.StatusRevoked),
|
|
"revoked_at_ms": "123456789",
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
record, err := store.Lookup(context.Background(), "device-session-123")
|
|
if err != nil || record.Status != session.StatusRevoked {
|
|
return false
|
|
}
|
|
|
|
return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations())
|
|
}, time.Second, 10*time.Millisecond)
|
|
}
|
|
|
|
func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := session.NewMemoryCache()
|
|
handler := &recordingSessionRevocationHandler{}
|
|
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
|
running := runTestSubscriber(t, subscriber)
|
|
defer running.stop(t)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": "public-key-123",
|
|
"status": string(session.StatusActive),
|
|
})
|
|
|
|
assert.Never(t, func() bool {
|
|
return len(handler.revocations()) != 0
|
|
}, 100*time.Millisecond, 10*time.Millisecond)
|
|
}
|
|
|
|
func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
handler := &recordingSessionRevocationHandler{}
|
|
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler)
|
|
running := runTestSubscriber(t, subscriber)
|
|
defer running.stop(t)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": "public-key-123",
|
|
"status": string(session.StatusRevoked),
|
|
"revoked_at_ms": "123456789",
|
|
})
|
|
|
|
assert.Never(t, func() bool {
|
|
return len(handler.revocations()) != 0
|
|
}, 100*time.Millisecond, 10*time.Millisecond)
|
|
}
|
|
|
|
func TestRedisSessionSubscriberLaterEventWins(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := session.NewMemoryCache()
|
|
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
|
running := runTestSubscriber(t, subscriber)
|
|
defer running.stop(t)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": "public-key-123",
|
|
"status": string(session.StatusActive),
|
|
})
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-456",
|
|
"client_public_key": "public-key-456",
|
|
"status": string(session.StatusActive),
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
record, err := store.Lookup(context.Background(), "device-session-123")
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456"
|
|
}, time.Second, 10*time.Millisecond)
|
|
}
|
|
|
|
func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := session.NewMemoryCache()
|
|
require.NoError(t, store.Upsert(session.Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: session.StatusActive,
|
|
}))
|
|
|
|
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
|
running := runTestSubscriber(t, subscriber)
|
|
defer running.stop(t)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": "public-key-123",
|
|
"status": "paused",
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
_, err := store.Lookup(context.Background(), "device-session-123")
|
|
return err != nil
|
|
}, time.Second, 10*time.Millisecond)
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-456",
|
|
"client_public_key": "public-key-456",
|
|
"status": string(session.StatusActive),
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
record, err := store.Lookup(context.Background(), "device-session-123")
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return record.UserID == "user-456" && record.Status == session.StatusActive
|
|
}, time.Second, 10*time.Millisecond)
|
|
}
|
|
|
|
func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
store := session.NewMemoryCache()
|
|
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
resultCh := make(chan error, 1)
|
|
go func() {
|
|
resultCh <- subscriber.Run(ctx)
|
|
}()
|
|
|
|
select {
|
|
case <-subscriber.started:
|
|
case <-time.After(time.Second):
|
|
require.FailNow(t, "subscriber did not start")
|
|
}
|
|
|
|
cancel()
|
|
require.NoError(t, subscriber.Shutdown(context.Background()))
|
|
|
|
select {
|
|
case err := <-resultCh:
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
case <-time.After(time.Second):
|
|
require.FailNow(t, "subscriber did not stop after shutdown")
|
|
}
|
|
}
|
|
|
|
func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber {
|
|
t.Helper()
|
|
|
|
return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil)
|
|
}
|
|
|
|
func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber {
|
|
t.Helper()
|
|
|
|
subscriber, err := NewRedisSessionSubscriberWithRevocationHandler(
|
|
config.SessionCacheRedisConfig{
|
|
Addr: server.Addr(),
|
|
LookupTimeout: 250 * time.Millisecond,
|
|
},
|
|
config.SessionEventsRedisConfig{
|
|
Stream: "gateway:session_events",
|
|
ReadBlockTimeout: 25 * time.Millisecond,
|
|
},
|
|
store,
|
|
revocationHandler,
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, subscriber.Close())
|
|
})
|
|
|
|
return subscriber
|
|
}
|
|
|
|
type recordingSessionRevocationHandler struct {
|
|
mu sync.Mutex
|
|
revokedIDs []string
|
|
}
|
|
|
|
func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) {
|
|
h.mu.Lock()
|
|
h.revokedIDs = append(h.revokedIDs, deviceSessionID)
|
|
h.mu.Unlock()
|
|
}
|
|
|
|
func (h *recordingSessionRevocationHandler) revocations() []string {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
return append([]string(nil), h.revokedIDs...)
|
|
}
|
|
|
|
type failingSnapshotStore struct{}
|
|
|
|
func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) {
|
|
return session.Record{}, session.ErrNotFound
|
|
}
|
|
|
|
func (failingSnapshotStore) Upsert(session.Record) error {
|
|
return context.DeadlineExceeded
|
|
}
|
|
|
|
func (failingSnapshotStore) Delete(string) {}
|
|
|
|
func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) {
|
|
t.Helper()
|
|
|
|
values := make([]string, 0, len(fields)*2)
|
|
for key, value := range fields {
|
|
values = append(values, key, value)
|
|
}
|
|
|
|
_, err := server.XAdd(stream, "*", values)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type runningSubscriber struct {
|
|
cancel context.CancelFunc
|
|
resultCh chan error
|
|
stopOnce bool
|
|
}
|
|
|
|
func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber {
|
|
t.Helper()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
resultCh := make(chan error, 1)
|
|
go func() {
|
|
resultCh <- subscriber.Run(ctx)
|
|
}()
|
|
|
|
select {
|
|
case <-subscriber.started:
|
|
case <-time.After(time.Second):
|
|
require.FailNow(t, "subscriber did not start")
|
|
}
|
|
|
|
return runningSubscriber{
|
|
cancel: cancel,
|
|
resultCh: resultCh,
|
|
}
|
|
}
|
|
|
|
func (r runningSubscriber) stop(t *testing.T) {
|
|
t.Helper()
|
|
|
|
r.cancel()
|
|
|
|
select {
|
|
case err := <-r.resultCh:
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
case <-time.After(time.Second):
|
|
require.FailNow(t, "subscriber did not stop")
|
|
}
|
|
}
|