package notificationgateway_test import ( "bytes" "context" "crypto/ed25519" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "io" "net/http" "path/filepath" "testing" "time" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" contractsgatewayv1 "galaxy/integration/internal/contracts/gatewayv1" "galaxy/integration/internal/harness" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) const ( notificationGatewayClientEventsStream = "gateway:client_events" notificationGatewayIntentsStream = "notification:intents" ) func TestNotificationGatewayFanOutsAllUserPushTypesToAllUserSessions(t *testing.T) { h := newNotificationGatewayHarness(t) recipient := h.ensureUser(t, "pilot@example.com", "fr-FR") firstPrivateKey := newClientPrivateKey("first") secondPrivateKey := newClientPrivateKey("second") unrelatedPrivateKey := newClientPrivateKey("unrelated") h.seedGatewaySession(t, "device-session-1", recipient.UserID, firstPrivateKey) h.seedGatewaySession(t, "device-session-2", recipient.UserID, secondPrivateKey) h.seedGatewaySession(t, "device-session-3", "user-unrelated", unrelatedPrivateKey) conn := h.dialGateway(t) client := gatewayv1.NewEdgeGatewayClient(conn) firstCtx, cancelFirst := context.WithCancel(context.Background()) defer cancelFirst() firstStream, err := client.SubscribeEvents(firstCtx, newSubscribeEventsRequest("device-session-1", "request-1", firstPrivateKey)) require.NoError(t, err) assertBootstrapEvent(t, recvGatewayEvent(t, firstStream), h.responseSignerPublicKey, "request-1") secondCtx, cancelSecond := context.WithCancel(context.Background()) defer cancelSecond() secondStream, err := client.SubscribeEvents(secondCtx, newSubscribeEventsRequest("device-session-2", "request-2", secondPrivateKey)) require.NoError(t, err) assertBootstrapEvent(t, recvGatewayEvent(t, secondStream), h.responseSignerPublicKey, "request-2") unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background()) defer cancelUnrelated() unrelatedStream, err := client.SubscribeEvents(unrelatedCtx, newSubscribeEventsRequest("device-session-3", "request-3", unrelatedPrivateKey)) require.NoError(t, err) assertBootstrapEvent(t, recvGatewayEvent(t, unrelatedStream), h.responseSignerPublicKey, "request-3") cases := []pushIntentCase{ { notificationType: "game.turn.ready", producer: "game_master", payloadJSON: `{"game_id":"game-123","game_name":"Nebula Clash","turn_number":54}`, }, { notificationType: "game.finished", producer: "game_master", payloadJSON: `{"game_id":"game-123","game_name":"Nebula Clash","final_turn_number":55}`, }, { notificationType: "lobby.application.submitted", producer: "game_lobby", payloadJSON: `{"game_id":"game-123","game_name":"Nebula Clash","applicant_user_id":"applicant-1","applicant_name":"Nova Pilot"}`, }, { notificationType: "lobby.membership.approved", producer: "game_lobby", payloadJSON: `{"game_id":"game-123","game_name":"Nebula Clash"}`, }, { notificationType: "lobby.membership.rejected", producer: "game_lobby", payloadJSON: `{"game_id":"game-123","game_name":"Nebula Clash"}`, }, { notificationType: "lobby.invite.created", producer: "game_lobby", payloadJSON: `{"game_id":"game-123","game_name":"Nebula Clash","inviter_user_id":"owner-1","inviter_name":"Owner Pilot"}`, }, { notificationType: "lobby.invite.redeemed", producer: "game_lobby", payloadJSON: `{"game_id":"game-123","game_name":"Nebula Clash","invitee_user_id":"invitee-1","invitee_name":"Nova Pilot"}`, }, } for index, tc := range cases { messageID := h.publishPushIntent(t, tc, recipient.UserID, index) firstEvent := recvGatewayEvent(t, firstStream) assertNotificationPushEvent(t, firstEvent, h.responseSignerPublicKey, tc.notificationType, messageID, recipient.UserID, index) secondEvent := recvGatewayEvent(t, secondStream) assertNotificationPushEvent(t, secondEvent, h.responseSignerPublicKey, tc.notificationType, messageID, recipient.UserID, index) } assertNoGatewayEvent(t, unrelatedStream, cancelUnrelated) messages, err := h.redis.XRange(context.Background(), notificationGatewayClientEventsStream, "-", "+").Result() require.NoError(t, err) require.Len(t, messages, len(cases)) for index, message := range messages { require.Equal(t, recipient.UserID, message.Values["user_id"]) require.Equal(t, cases[index].notificationType, message.Values["event_type"]) require.NotContains(t, message.Values, "device_session_id") } } type notificationGatewayHarness struct { redis *redis.Client userServiceURL string gatewayGRPCAddr string responseSignerPublicKey ed25519.PublicKey notificationProcess *harness.Process gatewayProcess *harness.Process userServiceProcess *harness.Process } type pushIntentCase struct { notificationType string producer string payloadJSON string } type ensureByEmailResponse struct { Outcome string `json:"outcome"` UserID string `json:"user_id"` } func newNotificationGatewayHarness(t *testing.T) *notificationGatewayHarness { t.Helper() redisRuntime := harness.StartRedisContainer(t) redisClient := redis.NewClient(&redis.Options{ Addr: redisRuntime.Addr, Protocol: 2, DisableIdentity: true, }) t.Cleanup(func() { require.NoError(t, redisClient.Close()) }) responseSignerPath, responseSignerPublicKey := harness.WriteResponseSignerPEM(t, t.Name()) userServiceAddr := harness.FreeTCPAddress(t) notificationInternalAddr := harness.FreeTCPAddress(t) gatewayPublicAddr := harness.FreeTCPAddress(t) gatewayGRPCAddr := harness.FreeTCPAddress(t) userServiceBinary := harness.BuildBinary(t, "userservice", "./user/cmd/userservice") notificationBinary := harness.BuildBinary(t, "notification", "./notification/cmd/notification") gatewayBinary := harness.BuildBinary(t, "gateway", "./gateway/cmd/gateway") userServiceEnv := harness.StartUserServicePersistence(t, redisRuntime.Addr).Env userServiceEnv["USERSERVICE_LOG_LEVEL"] = "info" userServiceEnv["USERSERVICE_INTERNAL_HTTP_ADDR"] = userServiceAddr userServiceEnv["OTEL_TRACES_EXPORTER"] = "none" userServiceEnv["OTEL_METRICS_EXPORTER"] = "none" userServiceProcess := harness.StartProcess(t, "userservice", userServiceBinary, userServiceEnv) waitForUserServiceReady(t, userServiceProcess, "http://"+userServiceAddr) notificationEnv := harness.StartNotificationServicePersistence(t, redisRuntime.Addr).Env notificationEnv["NOTIFICATION_LOG_LEVEL"] = "info" notificationEnv["NOTIFICATION_INTERNAL_HTTP_ADDR"] = notificationInternalAddr notificationEnv["NOTIFICATION_USER_SERVICE_BASE_URL"] = "http://" + userServiceAddr notificationEnv["NOTIFICATION_USER_SERVICE_TIMEOUT"] = time.Second.String() notificationEnv["NOTIFICATION_INTENTS_READ_BLOCK_TIMEOUT"] = "100ms" notificationEnv["NOTIFICATION_ROUTE_BACKOFF_MIN"] = "100ms" notificationEnv["NOTIFICATION_ROUTE_BACKOFF_MAX"] = "100ms" notificationEnv["NOTIFICATION_GATEWAY_CLIENT_EVENTS_STREAM"] = notificationGatewayClientEventsStream notificationEnv["OTEL_TRACES_EXPORTER"] = "none" notificationEnv["OTEL_METRICS_EXPORTER"] = "none" notificationProcess := harness.StartProcess(t, "notification", notificationBinary, notificationEnv) harness.WaitForHTTPStatus(t, notificationProcess, "http://"+notificationInternalAddr+"/readyz", http.StatusOK) gatewayProcess := harness.StartProcess(t, "gateway", gatewayBinary, map[string]string{ "GATEWAY_LOG_LEVEL": "info", "GATEWAY_PUBLIC_HTTP_ADDR": gatewayPublicAddr, "GATEWAY_AUTHENTICATED_GRPC_ADDR": gatewayGRPCAddr, "GATEWAY_REDIS_MASTER_ADDR": redisRuntime.Addr, "GATEWAY_REDIS_PASSWORD": "integration", "GATEWAY_SESSION_CACHE_REDIS_KEY_PREFIX": "gateway:session:", "GATEWAY_SESSION_EVENTS_REDIS_STREAM": "gateway:session_events", "GATEWAY_CLIENT_EVENTS_REDIS_STREAM": notificationGatewayClientEventsStream, "GATEWAY_CLIENT_EVENTS_REDIS_READ_BLOCK_TIMEOUT": "100ms", "GATEWAY_REPLAY_REDIS_KEY_PREFIX": "gateway:replay:", "GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH": filepath.Clean(responseSignerPath), "OTEL_TRACES_EXPORTER": "none", "OTEL_METRICS_EXPORTER": "none", }) harness.WaitForHTTPStatus(t, gatewayProcess, "http://"+gatewayPublicAddr+"/healthz", http.StatusOK) harness.WaitForTCP(t, gatewayProcess, gatewayGRPCAddr) return ¬ificationGatewayHarness{ redis: redisClient, userServiceURL: "http://" + userServiceAddr, gatewayGRPCAddr: gatewayGRPCAddr, responseSignerPublicKey: responseSignerPublicKey, notificationProcess: notificationProcess, gatewayProcess: gatewayProcess, userServiceProcess: userServiceProcess, } } func (h *notificationGatewayHarness) ensureUser(t *testing.T, email string, preferredLanguage string) ensureByEmailResponse { t.Helper() response := postJSONValue(t, h.userServiceURL+"/api/v1/internal/users/ensure-by-email", map[string]any{ "email": email, "registration_context": map[string]string{ "preferred_language": preferredLanguage, "time_zone": "Europe/Kaliningrad", }, }) var body ensureByEmailResponse requireJSONStatus(t, response, http.StatusOK, &body) require.Equal(t, "created", body.Outcome) require.NotEmpty(t, body.UserID) return body } func (h *notificationGatewayHarness) dialGateway(t *testing.T) *grpc.ClientConn { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := grpc.DialContext( ctx, h.gatewayGRPCAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, conn.Close()) }) return conn } func (h *notificationGatewayHarness) seedGatewaySession(t *testing.T, deviceSessionID string, userID string, clientPrivateKey ed25519.PrivateKey) { t.Helper() record := gatewaySessionRecord{ DeviceSessionID: deviceSessionID, UserID: userID, ClientPublicKey: base64.StdEncoding.EncodeToString(clientPrivateKey.Public().(ed25519.PublicKey)), Status: "active", } payload, err := json.Marshal(record) require.NoError(t, err) require.NoError(t, h.redis.Set(context.Background(), "gateway:session:"+deviceSessionID, payload, 0).Err()) } func (h *notificationGatewayHarness) publishPushIntent(t *testing.T, tc pushIntentCase, recipientUserID string, index int) string { t.Helper() messageID, err := h.redis.XAdd(context.Background(), &redis.XAddArgs{ Stream: notificationGatewayIntentsStream, Values: map[string]any{ "notification_type": tc.notificationType, "producer": tc.producer, "audience_kind": "user", "recipient_user_ids_json": `["` + recipientUserID + `"]`, "idempotency_key": tc.notificationType + ":gateway:" + string(rune('a'+index)), "occurred_at_ms": "1775121700000", "request_id": pushRequestID(index), "trace_id": pushTraceID(index), "payload_json": tc.payloadJSON, }, }).Result() require.NoError(t, err) return messageID } type gatewaySessionRecord struct { DeviceSessionID string `json:"device_session_id"` UserID string `json:"user_id"` ClientPublicKey string `json:"client_public_key"` Status string `json:"status"` RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"` } type httpResponse struct { StatusCode int Body string Header http.Header } func postJSONValue(t *testing.T, targetURL string, body any) httpResponse { t.Helper() payload, err := json.Marshal(body) require.NoError(t, err) request, err := http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(payload)) require.NoError(t, err) request.Header.Set("Content-Type", "application/json") return doRequest(t, request) } func requireJSONStatus(t *testing.T, response httpResponse, wantStatus int, target any) { t.Helper() require.Equal(t, wantStatus, response.StatusCode, "response body: %s", response.Body) require.NoError(t, decodeStrictJSONPayload([]byte(response.Body), target)) } func doRequest(t *testing.T, request *http.Request) httpResponse { t.Helper() client := &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ DisableKeepAlives: true, }, } t.Cleanup(client.CloseIdleConnections) response, err := client.Do(request) require.NoError(t, err) defer response.Body.Close() payload, err := io.ReadAll(response.Body) require.NoError(t, err) return httpResponse{ StatusCode: response.StatusCode, Body: string(payload), Header: response.Header.Clone(), } } func decodeStrictJSONPayload(payload []byte, target any) error { decoder := json.NewDecoder(bytes.NewReader(payload)) decoder.DisallowUnknownFields() if err := decoder.Decode(target); err != nil { return err } if err := decoder.Decode(&struct{}{}); err != io.EOF { if err == nil { return errors.New("unexpected trailing JSON input") } return err } return nil } func waitForUserServiceReady(t *testing.T, process *harness.Process, baseURL string) { t.Helper() client := &http.Client{Timeout: 250 * time.Millisecond} t.Cleanup(client.CloseIdleConnections) deadline := time.Now().Add(10 * time.Second) for time.Now().Before(deadline) { request, err := http.NewRequest(http.MethodGet, baseURL+"/api/v1/internal/users/user-missing/exists", nil) require.NoError(t, err) response, err := client.Do(request) if err == nil { _, _ = io.Copy(io.Discard, response.Body) response.Body.Close() if response.StatusCode == http.StatusOK { return } } time.Sleep(25 * time.Millisecond) } t.Fatalf("wait for userservice readiness: timeout\n%s", process.Logs()) } func newClientPrivateKey(label string) ed25519.PrivateKey { seed := sha256.Sum256([]byte("galaxy-integration-notification-gateway-client-" + label)) return ed25519.NewKeyFromSeed(seed[:]) } func newSubscribeEventsRequest(deviceSessionID string, requestID string, clientPrivateKey ed25519.PrivateKey) *gatewayv1.SubscribeEventsRequest { payloadHash := contractsgatewayv1.ComputePayloadHash(nil) request := &gatewayv1.SubscribeEventsRequest{ ProtocolVersion: contractsgatewayv1.ProtocolVersionV1, DeviceSessionId: deviceSessionID, MessageType: contractsgatewayv1.SubscribeMessageType, TimestampMs: time.Now().UnixMilli(), RequestId: requestID, PayloadHash: payloadHash, TraceId: "trace-" + requestID, } request.Signature = contractsgatewayv1.SignRequest(clientPrivateKey, contractsgatewayv1.RequestSigningFields{ ProtocolVersion: request.GetProtocolVersion(), DeviceSessionID: request.GetDeviceSessionId(), MessageType: request.GetMessageType(), TimestampMS: request.GetTimestampMs(), RequestID: request.GetRequestId(), PayloadHash: request.GetPayloadHash(), }) return request } func recvGatewayEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent { t.Helper() eventCh := make(chan *gatewayv1.GatewayEvent, 1) errCh := make(chan error, 1) go func() { event, err := stream.Recv() if err != nil { errCh <- err return } eventCh <- event }() select { case event := <-eventCh: return event case err := <-errCh: require.NoError(t, err) case <-time.After(5 * time.Second): require.FailNow(t, "timed out waiting for gateway event") } return nil } func assertBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, responseSignerPublicKey ed25519.PublicKey, wantRequestID string) { t.Helper() require.Equal(t, contractsgatewayv1.ServerTimeEventType, event.GetEventType()) require.Equal(t, wantRequestID, event.GetEventId()) require.Equal(t, wantRequestID, event.GetRequestId()) require.NoError(t, contractsgatewayv1.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) require.NoError(t, contractsgatewayv1.VerifyEventSignature(responseSignerPublicKey, event.GetSignature(), contractsgatewayv1.EventSigningFields{ EventType: event.GetEventType(), EventID: event.GetEventId(), TimestampMS: event.GetTimestampMs(), RequestID: event.GetRequestId(), TraceID: event.GetTraceId(), PayloadHash: event.GetPayloadHash(), })) } func assertNotificationPushEvent( t *testing.T, event *gatewayv1.GatewayEvent, responseSignerPublicKey ed25519.PublicKey, notificationType string, notificationID string, userID string, index int, ) { t.Helper() require.Equal(t, notificationType, event.GetEventType()) require.Equal(t, notificationID+"/push:user:"+userID, event.GetEventId()) require.Equal(t, pushRequestID(index), event.GetRequestId()) require.Equal(t, pushTraceID(index), event.GetTraceId()) require.NotEmpty(t, event.GetPayloadBytes()) require.NoError(t, contractsgatewayv1.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) require.NoError(t, contractsgatewayv1.VerifyEventSignature(responseSignerPublicKey, event.GetSignature(), contractsgatewayv1.EventSigningFields{ EventType: event.GetEventType(), EventID: event.GetEventId(), TimestampMS: event.GetTimestampMs(), RequestID: event.GetRequestId(), TraceID: event.GetTraceId(), PayloadHash: event.GetPayloadHash(), })) } func assertNoGatewayEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) { t.Helper() eventCh := make(chan *gatewayv1.GatewayEvent, 1) errCh := make(chan error, 1) go func() { event, err := stream.Recv() if err != nil { errCh <- err return } eventCh <- event }() select { case event := <-eventCh: require.FailNowf(t, "unexpected gateway event delivered", "%+v", event) case <-time.After(200 * time.Millisecond): cancel() case err := <-errCh: require.FailNowf(t, "stream closed unexpectedly", "%v", err) } } func pushRequestID(index int) string { return "notification-request-" + string(rune('a'+index)) } func pushTraceID(index int) string { return "notification-trace-" + string(rune('a'+index)) }