package grpcapi import ( "context" "crypto/sha256" "fmt" "testing" "time" "galaxy/gateway/internal/authn" "galaxy/gateway/internal/config" "galaxy/gateway/internal/downstream" "galaxy/gateway/internal/testutil" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) { t.Parallel() signer := newTestEd25519ResponseSigner() moveClient := &recordingDownstreamClient{ executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { assert.Equal(t, downstream.AuthenticatedCommand{ ProtocolVersion: "v1", UserID: "user-123", DeviceSessionID: "device-session-123", MessageType: "fleet.move", TimestampMS: testCurrentTime.UnixMilli(), RequestID: "request-123", TraceID: "trace-123", PayloadBytes: []byte("payload"), }, command) return downstream.UnaryResult{ ResultCode: "accepted", PayloadBytes: []byte("downstream-response"), }, nil }, } renameClient := &recordingDownstreamClient{} server, runGateway := newTestGateway(t, ServerDependencies{ Router: downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": moveClient, "fleet.rename": renameClient, }), ResponseSigner: signer, SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.NoError(t, err) assert.Equal(t, "v1", response.GetProtocolVersion()) assert.Equal(t, "request-123", response.GetRequestId()) assert.Equal(t, testCurrentTime.UnixMilli(), response.GetTimestampMs()) assert.Equal(t, "accepted", response.GetResultCode()) assert.Equal(t, []byte("downstream-response"), response.GetPayloadBytes()) assert.Equal(t, 1, moveClient.executeCalls) assert.Zero(t, renameClient.executeCalls) wantHash := sha256.Sum256([]byte("downstream-response")) assert.Equal(t, wantHash[:], response.GetPayloadHash()) require.NoError(t, authn.VerifyPayloadHash(response.GetPayloadBytes(), response.GetPayloadHash())) require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.GetSignature(), authn.ResponseSigningFields{ ProtocolVersion: response.GetProtocolVersion(), RequestID: response.GetRequestId(), TimestampMS: response.GetTimestampMs(), ResultCode: response.GetResultCode(), PayloadHash: response.GetPayloadHash(), })) } func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ Router: downstream.NewStaticRouter(nil), SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, ResponseSigner: newTestResponseSigner(), }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unimplemented, status.Code(err)) assert.Equal(t, "message_type is not routed", status.Convert(err).Message()) } func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) { t.Parallel() failingClient := &recordingDownstreamClient{ executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { return downstream.UnaryResult{}, fmt.Errorf("rpc transport failed: %w", downstream.ErrDownstreamUnavailable) }, } server, runGateway := newTestGateway(t, ServerDependencies{ Router: downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": failingClient, }), SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, ResponseSigner: newTestResponseSigner(), }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unavailable, status.Code(err)) assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message()) assert.Equal(t, 1, failingClient.executeCalls) } func TestExecuteCommandMapsDownstreamTimeoutToUnavailable(t *testing.T) { t.Parallel() stallingClient := &recordingDownstreamClient{ executeFunc: func(ctx context.Context, _ downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { <-ctx.Done() return downstream.UnaryResult{}, ctx.Err() }, } server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { cfg.DownstreamTimeout = 50 * time.Millisecond }), ServerDependencies{ Router: downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": stallingClient, }), SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, ResponseSigner: newTestResponseSigner(), }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unavailable, status.Code(err)) assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message()) assert.Equal(t, 1, stallingClient.executeCalls) } func TestExecuteCommandFailsClosedWhenResponseSignerUnavailable(t *testing.T) { t.Parallel() successClient := &recordingDownstreamClient{ executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { return downstream.UnaryResult{ ResultCode: "accepted", PayloadBytes: []byte("downstream-response"), }, nil }, } server, runGateway := newTestGateway(t, ServerDependencies{ Router: downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": successClient, }), ResponseSigner: unavailableResponseSigner{}, SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.Error(t, err) assert.Equal(t, codes.Unavailable, status.Code(err)) assert.Equal(t, "response signer is unavailable", status.Convert(err).Message()) assert.Equal(t, 1, successClient.executeCalls) } func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) { t.Parallel() logger := zap.NewNop() telemetryRuntime := testutil.NewTelemetryRuntime(t, logger) var ( seenSpanContext trace.SpanContext seenCommand downstream.AuthenticatedCommand ) server, runGateway := newTestGateway(t, ServerDependencies{ Router: downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": &recordingDownstreamClient{ executeFunc: func(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { seenSpanContext = trace.SpanContextFromContext(ctx) seenCommand = command return downstream.UnaryResult{ ResultCode: "accepted", PayloadBytes: []byte("downstream-response"), }, nil }, }, }), ResponseSigner: newTestResponseSigner(), SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, Logger: logger, Telemetry: telemetryRuntime, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.NoError(t, err) assert.True(t, seenSpanContext.IsValid()) assert.Equal(t, "trace-123", seenCommand.TraceID) } func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) { t.Parallel() started := make(chan struct{}) release := make(chan struct{}) server, runGateway := newTestGateway(t, ServerDependencies{ Router: downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": &recordingDownstreamClient{ executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { close(started) <-release return downstream.UnaryResult{ ResultCode: "accepted", PayloadBytes: []byte("downstream-response"), }, nil }, }, }), ResponseSigner: newTestResponseSigner(), SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) resultCh := make(chan error, 1) go func() { _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) resultCh <- err }() require.Eventually(t, func() bool { select { case <-started: return true default: return false } }, time.Second, 10*time.Millisecond, "downstream execution did not start") runGateway.cancel() require.Never(t, func() bool { select { case <-resultCh: return true default: return false } }, 100*time.Millisecond, 10*time.Millisecond, "unary request returned before downstream release") close(release) var err error require.Eventually(t, func() bool { select { case err = <-resultCh: return true default: return false } }, time.Second, 10*time.Millisecond, "unary request did not drain before shutdown timeout") require.NoError(t, err) } func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T) { t.Parallel() logger, logBuffer := testutil.NewObservedLogger(t) server, runGateway := newTestGateway(t, ServerDependencies{ Router: downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": &recordingDownstreamClient{}, }), ResponseSigner: newTestResponseSigner(), SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, Logger: logger, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest()) require.NoError(t, err) logOutput := logBuffer.String() assert.NotContains(t, logOutput, "payload_hash") assert.NotContains(t, logOutput, "signature") assert.NotContains(t, logOutput, `"payload"`) }