package projectionpublisher import ( "bytes" "context" "crypto/ed25519" "encoding/base64" "encoding/json" "fmt" "io" "testing" "time" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/domain/gatewayprojection" "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNew(t *testing.T) { t.Parallel() server := miniredis.RunT(t) tests := []struct { name string cfg Config wantErr string }{ { name: "valid config", cfg: Config{ Addr: server.Addr(), DB: 3, SessionCacheKeyPrefix: "gateway:session:", SessionEventsStream: "gateway:session_events", StreamMaxLen: 1024, OperationTimeout: 250 * time.Millisecond, }, }, { name: "empty addr", cfg: Config{ SessionCacheKeyPrefix: "gateway:session:", SessionEventsStream: "gateway:session_events", StreamMaxLen: 1024, OperationTimeout: 250 * time.Millisecond, }, wantErr: "redis addr must not be empty", }, { name: "negative db", cfg: Config{ Addr: server.Addr(), DB: -1, SessionCacheKeyPrefix: "gateway:session:", SessionEventsStream: "gateway:session_events", StreamMaxLen: 1024, OperationTimeout: 250 * time.Millisecond, }, wantErr: "redis db must not be negative", }, { name: "empty session cache key prefix", cfg: Config{ Addr: server.Addr(), SessionEventsStream: "gateway:session_events", StreamMaxLen: 1024, OperationTimeout: 250 * time.Millisecond, }, wantErr: "session cache key prefix must not be empty", }, { name: "empty session events stream", cfg: Config{ Addr: server.Addr(), SessionCacheKeyPrefix: "gateway:session:", StreamMaxLen: 1024, OperationTimeout: 250 * time.Millisecond, }, wantErr: "session events stream must not be empty", }, { name: "non positive stream max len", cfg: Config{ Addr: server.Addr(), SessionCacheKeyPrefix: "gateway:session:", SessionEventsStream: "gateway:session_events", OperationTimeout: 250 * time.Millisecond, }, wantErr: "stream max len must be positive", }, { name: "non positive timeout", cfg: Config{ Addr: server.Addr(), SessionCacheKeyPrefix: "gateway:session:", SessionEventsStream: "gateway:session_events", StreamMaxLen: 1024, }, wantErr: "operation timeout must be positive", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() publisher, err := New(tt.cfg) if tt.wantErr != "" { require.Error(t, err) assert.ErrorContains(t, err, tt.wantErr) return } require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, publisher.Close()) }) }) } } func TestPublisherPing(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{}) require.NoError(t, publisher.Ping(context.Background())) } func TestPublisherPublishSessionActive(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{}) snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil) require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) key := publisher.sessionCacheKey(snapshot.DeviceSessionID) assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key) assert.True(t, server.Exists(key)) assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String()))) payload, err := server.Get(key) require.NoError(t, err) record := decodeCachePayload(t, payload) assert.Equal(t, cacheRecord{ DeviceSessionID: snapshot.DeviceSessionID.String(), UserID: snapshot.UserID.String(), ClientPublicKey: snapshot.ClientPublicKey, Status: gatewayprojection.StatusActive, }, record) assert.Zero(t, server.TTL(key)) entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, map[string]string{ "device_session_id": snapshot.DeviceSessionID.String(), "user_id": snapshot.UserID.String(), "client_public_key": snapshot.ClientPublicKey, "status": string(gatewayprojection.StatusActive), }, stringifyValues(entries[0].Values)) } func TestPublisherPublishSessionRevoked(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{}) revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC() snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt) require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) key := publisher.sessionCacheKey(snapshot.DeviceSessionID) payload, err := server.Get(key) require.NoError(t, err) record := decodeCachePayload(t, payload) require.NotNil(t, record.RevokedAtMS) assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS) assert.Equal(t, cacheRecord{ DeviceSessionID: snapshot.DeviceSessionID.String(), UserID: snapshot.UserID.String(), ClientPublicKey: snapshot.ClientPublicKey, Status: gatewayprojection.StatusRevoked, RevokedAtMS: int64Pointer(revokedAt.UnixMilli()), }, record) entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, map[string]string{ "device_session_id": snapshot.DeviceSessionID.String(), "user_id": snapshot.UserID.String(), "client_public_key": snapshot.ClientPublicKey, "status": string(gatewayprojection.StatusRevoked), "revoked_at_ms": "1776000123456", }, stringifyValues(entries[0].Values)) } func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8}) deviceSessionID := "device-session-456" active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil) revokedAt := time.Unix(1_776_010_000, 0).UTC() revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt) require.NoError(t, publisher.PublishSession(context.Background(), active)) require.NoError(t, publisher.PublishSession(context.Background(), revoked)) payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID)) require.NoError(t, err) record := decodeCachePayload(t, payload) require.NotNil(t, record.RevokedAtMS) assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS) assert.Equal(t, gatewayprojection.StatusRevoked, record.Status) entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() require.NoError(t, err) require.Len(t, entries, 2) assert.Equal(t, map[string]string{ "device_session_id": active.DeviceSessionID.String(), "user_id": active.UserID.String(), "client_public_key": active.ClientPublicKey, "status": string(gatewayprojection.StatusActive), }, stringifyValues(entries[0].Values)) assert.Equal(t, map[string]string{ "device_session_id": revoked.DeviceSessionID.String(), "user_id": revoked.UserID.String(), "client_public_key": revoked.ClientPublicKey, "status": string(gatewayprojection.StatusRevoked), "revoked_at_ms": "1776010000000", }, stringifyValues(entries[1].Values)) } func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8}) snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil) require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID)) require.NoError(t, err) record := decodeCachePayload(t, payload) assert.Equal(t, cacheRecord{ DeviceSessionID: snapshot.DeviceSessionID.String(), UserID: snapshot.UserID.String(), ClientPublicKey: snapshot.ClientPublicKey, Status: gatewayprojection.StatusActive, }, record) entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result() require.NoError(t, err) require.Len(t, entries, 2) assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values)) } func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2}) for index := range 6 { snapshot := testSnapshot( common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(), gatewayprojection.StatusActive, nil, ) require.NoError(t, publisher.PublishSession(context.Background(), snapshot)) } streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result() require.NoError(t, err) assert.LessOrEqual(t, streamLength, int64(2)) } func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{}) snapshot := gatewayprojection.Snapshot{ DeviceSessionID: common.DeviceSessionID("device-session-123"), UserID: common.UserID("user-123"), Status: gatewayprojection.StatusActive, } err := publisher.PublishSession(context.Background(), snapshot) require.Error(t, err) assert.ErrorContains(t, err, "gateway projection client public key") assert.Empty(t, server.Keys()) } func TestPublisherPublishSessionNilContext(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{}) err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil)) require.Error(t, err) assert.ErrorContains(t, err, "nil context") } func TestPublisherPublishSessionBackendFailure(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{}) server.Close() err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil)) require.Error(t, err) assert.ErrorContains(t, err, "publish session projection") } func TestPublisherPingNilContext(t *testing.T) { t.Parallel() server := miniredis.RunT(t) publisher := newTestPublisher(t, server, Config{}) err := publisher.Ping(nil) require.Error(t, err) assert.ErrorContains(t, err, "nil context") } func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher { t.Helper() if cfg.Addr == "" { cfg.Addr = server.Addr() } if cfg.SessionCacheKeyPrefix == "" { cfg.SessionCacheKeyPrefix = "gateway:session:" } if cfg.SessionEventsStream == "" { cfg.SessionEventsStream = "gateway:session_events" } if cfg.StreamMaxLen == 0 { cfg.StreamMaxLen = 1024 } if cfg.OperationTimeout == 0 { cfg.OperationTimeout = 250 * time.Millisecond } publisher, err := New(cfg) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, publisher.Close()) }) return publisher } func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot { raw := make(ed25519.PublicKey, ed25519.PublicKeySize) for index := range raw { raw[index] = byte(index + 1) } snapshot := gatewayprojection.Snapshot{ DeviceSessionID: common.DeviceSessionID(deviceSessionID), UserID: common.UserID("user-123"), ClientPublicKey: base64.StdEncoding.EncodeToString(raw), Status: status, RevokedAt: revokedAt, } if status == gatewayprojection.StatusRevoked { snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked") snapshot.RevokeActorType = common.RevokeActorType("system") } return snapshot } func decodeCachePayload(t *testing.T, payload string) cacheRecord { t.Helper() decoder := json.NewDecoder(bytes.NewReader([]byte(payload))) decoder.DisallowUnknownFields() var record cacheRecord require.NoError(t, decoder.Decode(&record)) err := decoder.Decode(&struct{}{}) if err == nil { require.FailNow(t, "expected cache payload EOF after first JSON value") } require.ErrorIs(t, err, io.EOF) var fieldSet map[string]json.RawMessage require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet)) expectedFields := map[string]struct{}{ "device_session_id": {}, "user_id": {}, "client_public_key": {}, "status": {}, } if record.RevokedAtMS != nil { expectedFields["revoked_at_ms"] = struct{}{} } assert.Equal(t, len(expectedFields), len(fieldSet)) for field := range fieldSet { _, ok := expectedFields[field] assert.Truef(t, ok, "unexpected cache payload field %q", field) } return record } func stringifyValues(values map[string]any) map[string]string { stringified := make(map[string]string, len(values)) for key, value := range values { stringified[key] = fmt.Sprint(value) } return stringified } func encodeBase64URL(value string) string { return base64.RawURLEncoding.EncodeToString([]byte(value)) } func int64Pointer(value int64) *int64 { return &value }