package grpcapi import ( "context" "errors" "io" "testing" "galaxy/gateway/internal/session" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func TestExecuteCommandRejectsUnknownSession(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return session.Record{}, session.ErrNotFound }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unauthenticated, status.Code(err)) assert.Equal(t, "unknown device session", status.Convert(err).Message()) assert.Zero(t, delegate.executeCalls) } func TestSubscribeEventsRejectsUnknownSession(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return session.Record{}, session.ErrNotFound }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) require.Error(t, err) assert.Equal(t, codes.Unauthenticated, status.Code(err)) assert.Equal(t, "unknown device session", status.Convert(err).Message()) assert.Zero(t, delegate.subscribeCalls) } func TestExecuteCommandRejectsRevokedSession(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.FailedPrecondition, status.Code(err)) assert.Equal(t, "device session is revoked", status.Convert(err).Message()) assert.Zero(t, delegate.executeCalls) } func TestSubscribeEventsRejectsRevokedSession(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) require.Error(t, err) assert.Equal(t, codes.FailedPrecondition, status.Code(err)) assert.Equal(t, "device session is revoked", status.Convert(err).Message()) assert.Zero(t, delegate.subscribeCalls) } func TestExecuteCommandRejectsSessionCacheUnavailable(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return session.Record{}, errors.New("redis down") }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unavailable, status.Code(err)) assert.Equal(t, "session cache is unavailable", status.Convert(err).Message()) assert.Zero(t, delegate.executeCalls) } func TestSubscribeEventsRejectsSessionCacheUnavailable(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return session.Record{}, errors.New("redis down") }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) require.Error(t, err) assert.Equal(t, codes.Unavailable, status.Code(err)) assert.Equal(t, "session cache is unavailable", status.Convert(err).Message()) assert.Zero(t, delegate.subscribeCalls) } func TestExecuteCommandAttachesResolvedSession(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{ executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { record, ok := resolvedSessionFromContext(ctx) require.True(t, ok) assert.Equal(t, newActiveSessionRecord(), record) return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil }, } server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.NoError(t, err) assert.Equal(t, "request-123", response.GetRequestId()) } func TestSubscribeEventsAttachesResolvedSession(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{ subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { record, ok := resolvedSessionFromContext(stream.Context()) require.True(t, ok) assert.Equal(t, newActiveSessionRecord(), record) return nil }, } server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest()) require.NoError(t, err) event := recvBootstrapEvent(t, stream) assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) _, err = stream.Recv() require.ErrorIs(t, err, io.EOF) } func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{ subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { binding, ok := authenticatedStreamBindingFromContext(stream.Context()) require.True(t, ok) assert.Equal(t, authenticatedStreamBinding{ UserID: "user-123", DeviceSessionID: "device-session-123", MessageType: "gateway.subscribe", RequestID: "request-123", TraceID: "trace-123", }, binding) return nil }, } server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }}, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest()) require.NoError(t, err) event := recvBootstrapEvent(t, stream) assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) _, err = stream.Recv() require.ErrorIs(t, err, io.EOF) } type staticSessionCache struct { lookupFunc func(context.Context, string) (session.Record, error) } func (c staticSessionCache) Lookup(ctx context.Context, deviceSessionID string) (session.Record, error) { return c.lookupFunc(ctx, deviceSessionID) } func (staticSessionCache) MarkRevoked(string) {} func (staticSessionCache) MarkAllRevokedForUser(string) {}