package grpcapi import ( "context" "testing" "time" "galaxy/gateway/internal/app" "galaxy/gateway/internal/config" "galaxy/gateway/internal/session" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" ) func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{}) require.Error(t, err) assert.Equal(t, codes.InvalidArgument, status.Code(err)) } func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{}) require.Error(t, err) assert.Equal(t, codes.InvalidArgument, status.Code(err)) } func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{ ProtocolVersion: "v2", DeviceSessionId: "device-session-123", MessageType: "fleet.move", TimestampMs: 123456789, RequestId: "request-123", PayloadBytes: []byte("payload"), PayloadHash: []byte("hash"), Signature: []byte("signature"), }) require.Error(t, err) assert.Equal(t, codes.FailedPrecondition, status.Code(err)) assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message()) } func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unimplemented, status.Code(err)) } func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unavailable, status.Code(err)) assert.Equal(t, "replay store is unavailable", status.Convert(err).Message()) } func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest()) require.NoError(t, err) event := recvBootstrapEvent(t, stream) assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) recvResult := make(chan error, 1) go func() { _, recvErr := stream.Recv() recvResult <- recvErr }() require.Never(t, func() bool { select { case <-recvResult: return true default: return false } }, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation") cancel() var recvErr error require.Eventually(t, func() bool { select { case recvErr = <-recvResult: return true default: return false } }, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation") require.Error(t, recvErr) assert.Equal(t, codes.Canceled, status.Code(recvErr)) } func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) require.Error(t, err) assert.Equal(t, codes.Unavailable, status.Code(err)) assert.Equal(t, "replay store is unavailable", status.Convert(err).Message()) } func TestServerLifecycle(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) require.NoError(t, conn.Close()) runGateway.stop(t) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err := grpc.DialContext( ctx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), ) require.Error(t, err) } type runningGateway struct { cancel context.CancelFunc resultCh chan error } func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) { t.Helper() grpcCfg := config.DefaultAuthenticatedGRPCConfig() grpcCfg.Addr = "127.0.0.1:0" grpcCfg.FreshnessWindow = testFreshnessWindow return newTestGatewayWithGRPCConfig(t, grpcCfg, deps) } func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) { t.Helper() cfg := config.Config{ ShutdownTimeout: time.Second, AuthenticatedGRPC: grpcCfg, } if deps.Clock == nil { deps.Clock = fixedClock{now: testCurrentTime} } if deps.ResponseSigner == nil { deps.ResponseSigner = newTestResponseSigner() } if deps.Router == nil && deps.Service != nil { deps.Router = executeCommandAdapterRouter{service: deps.Service} } server := NewServer(cfg.AuthenticatedGRPC, deps) application := app.New(cfg, server) ctx, cancel := context.WithCancel(context.Background()) resultCh := make(chan error, 1) go func() { resultCh <- application.Run(ctx) }() return server, runningGateway{ cancel: cancel, resultCh: resultCh, } } func (g runningGateway) stop(t *testing.T) { t.Helper() g.cancel() var err error require.Eventually(t, func() bool { select { case err = <-g.resultCh: return true default: return false } }, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation") require.NoError(t, err) } func waitForListenAddr(t *testing.T, server *Server) string { t.Helper() var addr string require.Eventually(t, func() bool { addr = server.listenAddr() return addr != "" }, time.Second, 10*time.Millisecond, "server did not start listening") return addr } func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := grpc.DialContext( ctx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), ) require.NoError(t, err) return conn }