package events import ( "context" "crypto/ed25519" "crypto/sha256" "encoding/base64" "testing" "time" "galaxy/gateway/internal/app" "galaxy/gateway/internal/authn" "galaxy/gateway/internal/config" "galaxy/gateway/internal/grpcapi" "galaxy/gateway/internal/push" "galaxy/gateway/internal/session" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) { t.Parallel() server := miniredis.RunT(t) sessionCache := session.NewMemoryCache() require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123"))) require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999"))) pushHub := push.NewHub(4) clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber) defer running.stop(t) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) targetOneCtx, cancelTargetOne := context.WithCancel(context.Background()) defer cancelTargetOne() targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1")) require.NoError(t, err) assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1") targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background()) defer cancelTargetTwo() targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2")) require.NoError(t, err) assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2") unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background()) defer cancelUnrelated() unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3")) require.NoError(t, err) assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3") addClientEvent(t, server, "gateway:client_events", map[string]any{ "user_id": "user-123", "event_type": "fleet.updated", "event_id": "event-123", "payload_bytes": []byte("payload-123"), "request_id": "request-123", "trace_id": "trace-123", }) assertSignedPushEvent(t, recvPushEvent(t, targetOne), push.Event{ UserID: "user-123", EventType: "fleet.updated", EventID: "event-123", PayloadBytes: []byte("payload-123"), RequestID: "request-123", TraceID: "trace-123", }) assertSignedPushEvent(t, recvPushEvent(t, targetTwo), push.Event{ UserID: "user-123", EventType: "fleet.updated", EventID: "event-123", PayloadBytes: []byte("payload-123"), RequestID: "request-123", TraceID: "trace-123", }) assertNoPushEvent(t, unrelated, cancelUnrelated) } func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) { t.Parallel() server := miniredis.RunT(t) sessionCache := session.NewMemoryCache() require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123"))) pushHub := push.NewHub(4) clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber) defer running.stop(t) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) otherCtx, cancelOther := context.WithCancel(context.Background()) defer cancelOther() otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1")) require.NoError(t, err) assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1") targetCtx, cancelTarget := context.WithCancel(context.Background()) defer cancelTarget() targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2")) require.NoError(t, err) assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2") addClientEvent(t, server, "gateway:client_events", map[string]any{ "user_id": "user-123", "device_session_id": "device-session-2", "event_type": "fleet.updated", "event_id": "event-456", "payload_bytes": []byte("payload-456"), }) assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{ UserID: "user-123", DeviceSessionID: "device-session-2", EventType: "fleet.updated", EventID: "event-456", PayloadBytes: []byte("payload-456"), }) assertNoPushEvent(t, otherStream, cancelOther) } func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) { t.Parallel() server := miniredis.RunT(t) sessionCache := session.NewMemoryCache() require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) pushHub := push.NewHub(4) clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub) addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber) defer running.stop(t) select { case <-sessionSubscriber.started: case <-time.After(time.Second): require.FailNow(t, "session subscriber did not start") } conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) streamCtx, cancelStream := context.WithCancel(context.Background()) defer cancelStream() stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1")) require.NoError(t, err) assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1") addSessionEvent(t, server, "gateway:session_events", map[string]string{ "device_session_id": "device-session-1", "user_id": "user-123", "client_public_key": pushClientPublicKeyBase64(), "status": string(session.StatusRevoked), "revoked_at_ms": "123456789", }) require.Eventually(t, func() bool { record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1") return lookupErr == nil && record.Status == session.StatusRevoked }, time.Second, 10*time.Millisecond) recvErrCh := make(chan error, 1) go func() { _, recvErr := stream.Recv() recvErrCh <- recvErr }() select { case recvErr := <-recvErrCh: require.Error(t, recvErr) assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr)) assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message()) case <-time.After(time.Second): require.FailNow(t, "stream did not close after revoke") } reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2")) if err == nil { _, err = reopened.Recv() } require.Error(t, err) assert.Equal(t, codes.FailedPrecondition, status.Code(err)) assert.Equal(t, "device session is revoked", status.Convert(err).Message()) } func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) { t.Parallel() server := miniredis.RunT(t) sessionCache := session.NewMemoryCache() require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123"))) pushHub := push.NewHub(4) clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub) addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber) defer running.stop(t) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1")) require.NoError(t, err) assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1") recvErrCh := make(chan error, 1) go func() { _, recvErr := stream.Recv() recvErrCh <- recvErr }() running.cancel() select { case recvErr := <-recvErrCh: require.Error(t, recvErr) assert.Equal(t, codes.Unavailable, status.Code(recvErr)) assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message()) case <-time.After(time.Second): require.FailNow(t, "stream did not close after gateway shutdown") } } func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) { t.Helper() addr := unusedTCPAddr(t) grpcCfg := config.DefaultAuthenticatedGRPCConfig() grpcCfg.Addr = addr grpcCfg.FreshnessWindow = 5 * time.Minute responseSigner := newTestResponseSigner(t) gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{ Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()), ResponseSigner: responseSigner, SessionCache: sessionCache, ReplayStore: staticReplayStore{}, Clock: fixedClock{now: testNow}, PushHub: pushHub, }) components := []app.Component{gateway, clientSubscriber} components = append(components, extraComponents...) application := app.New( config.Config{ ShutdownTimeout: time.Second, AuthenticatedGRPC: grpcCfg, }, components..., ) ctx, cancel := context.WithCancel(context.Background()) resultCh := make(chan error, 1) go func() { resultCh <- application.Run(ctx) }() select { case <-clientSubscriber.started: case <-time.After(time.Second): require.FailNow(t, "client event subscriber did not start") } return addr, runningAuthenticatedGateway{ cancel: cancel, resultCh: resultCh, } } func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record { return session.Record{ DeviceSessionID: deviceSessionID, UserID: userID, ClientPublicKey: pushClientPublicKeyBase64(), Status: session.StatusActive, } } func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest { payloadHash := sha256.Sum256(nil) traceID := "trace-" + deviceSessionID req := &gatewayv1.SubscribeEventsRequest{ ProtocolVersion: "v1", DeviceSessionId: deviceSessionID, MessageType: "gateway.subscribe", TimestampMs: testNow.UnixMilli(), RequestId: requestID, PayloadHash: payloadHash[:], TraceId: traceID, } req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{ ProtocolVersion: req.GetProtocolVersion(), DeviceSessionID: req.GetDeviceSessionId(), MessageType: req.GetMessageType(), TimestampMS: req.GetTimestampMs(), RequestID: req.GetRequestId(), PayloadHash: req.GetPayloadHash(), })) return req } func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent { t.Helper() event, err := stream.Recv() require.NoError(t, err) return event } func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) { t.Helper() require.NotNil(t, event) assert.Equal(t, "gateway.server_time", event.GetEventType()) assert.Equal(t, wantRequestID, event.GetEventId()) assert.Equal(t, wantRequestID, event.GetRequestId()) assert.Equal(t, wantTraceID, event.GetTraceId()) require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{ EventType: event.GetEventType(), EventID: event.GetEventId(), TimestampMS: event.GetTimestampMs(), RequestID: event.GetRequestId(), TraceID: event.GetTraceId(), PayloadHash: event.GetPayloadHash(), })) } func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) { t.Helper() require.NotNil(t, event) assert.Equal(t, want.EventType, event.GetEventType()) assert.Equal(t, want.EventID, event.GetEventId()) assert.Equal(t, want.RequestID, event.GetRequestId()) assert.Equal(t, want.TraceID, event.GetTraceId()) assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes()) require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{ EventType: event.GetEventType(), EventID: event.GetEventId(), TimestampMS: event.GetTimestampMs(), RequestID: event.GetRequestId(), TraceID: event.GetTraceId(), PayloadHash: event.GetPayloadHash(), })) } func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) { t.Helper() recvCh := make(chan *gatewayv1.GatewayEvent, 1) errCh := make(chan error, 1) go func() { event, err := stream.Recv() if err != nil { errCh <- err return } recvCh <- event }() select { case event := <-recvCh: require.FailNowf(t, "unexpected push event delivered", "%+v", event) case <-time.After(100 * time.Millisecond): cancel() case err := <-errCh: require.FailNowf(t, "stream closed unexpectedly", "%v", err) } } func pushClientPrivateKey() ed25519.PrivateKey { seed := sha256.Sum256([]byte("gateway-push-grpc-test-client")) return ed25519.NewKeyFromSeed(seed[:]) } func pushClientPublicKeyBase64() string { return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey)) } func pushResponseSignerPublicKey() ed25519.PublicKey { seed := sha256.Sum256([]byte("gateway-events-grpc-test-response")) return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey) }