feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,416 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/grpcapi"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
targetOneCtx, cancelTargetOne := context.WithCancel(context.Background())
|
||||
defer cancelTargetOne()
|
||||
targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1")
|
||||
|
||||
targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background())
|
||||
defer cancelTargetTwo()
|
||||
targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2")
|
||||
|
||||
unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background())
|
||||
defer cancelUnrelated()
|
||||
unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3")
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-123",
|
||||
"payload_bytes": []byte("payload-123"),
|
||||
"request_id": "request-123",
|
||||
"trace_id": "trace-123",
|
||||
})
|
||||
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetOne), push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetTwo), push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertNoPushEvent(t, unrelated, cancelUnrelated)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
otherCtx, cancelOther := context.WithCancel(context.Background())
|
||||
defer cancelOther()
|
||||
otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1")
|
||||
|
||||
targetCtx, cancelTarget := context.WithCancel(context.Background())
|
||||
defer cancelTarget()
|
||||
targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2")
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"device_session_id": "device-session-2",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-456",
|
||||
"payload_bytes": []byte("payload-456"),
|
||||
})
|
||||
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-456",
|
||||
PayloadBytes: []byte("payload-456"),
|
||||
})
|
||||
assertNoPushEvent(t, otherStream, cancelOther)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
select {
|
||||
case <-sessionSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "session subscriber did not start")
|
||||
}
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
streamCtx, cancelStream := context.WithCancel(context.Background())
|
||||
defer cancelStream()
|
||||
|
||||
stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-1",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": pushClientPublicKeyBase64(),
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1")
|
||||
return lookupErr == nil && record.Status == session.StatusRevoked
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after revoke")
|
||||
}
|
||||
|
||||
reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2"))
|
||||
if err == nil {
|
||||
_, err = reopened.Recv()
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
running.cancel()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(recvErr))
|
||||
assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after gateway shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) {
|
||||
t.Helper()
|
||||
|
||||
addr := unusedTCPAddr(t)
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = addr
|
||||
grpcCfg.FreshnessWindow = 5 * time.Minute
|
||||
|
||||
responseSigner := newTestResponseSigner(t)
|
||||
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
||||
Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()),
|
||||
ResponseSigner: responseSigner,
|
||||
SessionCache: sessionCache,
|
||||
ReplayStore: staticReplayStore{},
|
||||
Clock: fixedClock{now: testNow},
|
||||
PushHub: pushHub,
|
||||
})
|
||||
|
||||
components := []app.Component{gateway, clientSubscriber}
|
||||
components = append(components, extraComponents...)
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
},
|
||||
components...,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-clientSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "client event subscriber did not start")
|
||||
}
|
||||
|
||||
return addr, runningAuthenticatedGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: pushClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
|
||||
payloadHash := sha256.Sum256(nil)
|
||||
traceID := "trace-" + deviceSessionID
|
||||
|
||||
req := &gatewayv1.SubscribeEventsRequest{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "gateway.subscribe",
|
||||
TimestampMs: testNow.UnixMilli(),
|
||||
RequestId: requestID,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: traceID,
|
||||
}
|
||||
req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
PayloadHash: req.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
t.Helper()
|
||||
|
||||
event, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
return event
|
||||
}
|
||||
|
||||
func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, "gateway.server_time", event.GetEventType())
|
||||
assert.Equal(t, wantRequestID, event.GetEventId())
|
||||
assert.Equal(t, wantRequestID, event.GetRequestId())
|
||||
assert.Equal(t, wantTraceID, event.GetTraceId())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, want.EventType, event.GetEventType())
|
||||
assert.Equal(t, want.EventID, event.GetEventId())
|
||||
assert.Equal(t, want.RequestID, event.GetRequestId())
|
||||
assert.Equal(t, want.TraceID, event.GetTraceId())
|
||||
assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) {
|
||||
t.Helper()
|
||||
|
||||
recvCh := make(chan *gatewayv1.GatewayEvent, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
recvCh <- event
|
||||
}()
|
||||
|
||||
select {
|
||||
case event := <-recvCh:
|
||||
require.FailNowf(t, "unexpected push event delivered", "%+v", event)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
cancel()
|
||||
case err := <-errCh:
|
||||
require.FailNowf(t, "stream closed unexpectedly", "%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func pushClientPrivateKey() ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-push-grpc-test-client"))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func pushClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func pushResponseSignerPublicKey() ed25519.PublicKey {
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
||||
return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey)
|
||||
}
|
||||
Reference in New Issue
Block a user