package grpcapi import ( "context" "crypto/tls" "errors" "net" "net/http" "testing" "time" "galaxy/gateway/internal/app" "galaxy/gateway/internal/config" "galaxy/gateway/internal/session" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "galaxy/gateway/proto/galaxy/gateway/v1/gatewayv1connect" "connectrpc.com/connect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/http2" ) func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(&gatewayv1.ExecuteCommandRequest{})) require.Error(t, err) assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err)) } func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{}) require.Error(t, err) assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err)) } func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(&gatewayv1.ExecuteCommandRequest{ ProtocolVersion: "v2", DeviceSessionId: "device-session-123", MessageType: "fleet.move", TimestampMs: 123456789, RequestId: "request-123", PayloadBytes: []byte("payload"), PayloadHash: []byte("hash"), Signature: []byte("signature"), })) require.Error(t, err) assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err)) assert.Equal(t, `unsupported protocol_version "v2"`, connectErrorMessage(t, err)) } func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest())) require.Error(t, err) assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err)) } func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest())) require.Error(t, err) assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err)) assert.Equal(t, "replay store is unavailable", connectErrorMessage(t, err)) } func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) stream, err := client.SubscribeEvents(ctx, connect.NewRequest(newValidSubscribeEventsRequest())) require.NoError(t, err) t.Cleanup(func() { _ = stream.Close() }) event := recvBootstrapEvent(t, stream) assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli()) recvResult := make(chan error, 1) go func() { if stream.Receive() { recvResult <- errors.New("stream produced unexpected event") return } recvResult <- stream.Err() }() require.Never(t, func() bool { select { case <-recvResult: return true default: return false } }, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation") cancel() var recvErr error require.Eventually(t, func() bool { select { case recvErr = <-recvResult: return true default: return false } }, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation") require.Error(t, recvErr) assert.Equal(t, connect.CodeCanceled, connect.CodeOf(recvErr)) } func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) require.Error(t, err) assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err)) assert.Equal(t, "replay store is unavailable", connectErrorMessage(t, err)) } func TestSubscribeEventsFailsClosedWhenResponseSignerUnavailable(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{ ResponseSigner: unavailableResponseSigner{}, SessionCache: staticSessionCache{ lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }, }, ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest()) require.Error(t, err) assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err)) assert.Equal(t, "response signer is unavailable", connectErrorMessage(t, err)) } func TestServerLifecycle(t *testing.T) { t.Parallel() server, runGateway := newTestGateway(t, ServerDependencies{}) addr := waitForListenAddr(t, server) // Probe the listener before shutdown so we know it accepted at // least one TCP connection. probe, err := net.DialTimeout("tcp", addr, time.Second) require.NoError(t, err) require.NoError(t, probe.Close()) runGateway.stop(t) // After shutdown the listener must refuse new TCP connections. dialCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() dialer := &net.Dialer{} closedConn, err := dialer.DialContext(dialCtx, "tcp", addr) if err == nil { _ = closedConn.Close() t.Fatalf("expected dial to %s to fail after shutdown", addr) } } type runningGateway struct { cancel context.CancelFunc resultCh chan error } func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) { t.Helper() grpcCfg := config.DefaultAuthenticatedGRPCConfig() grpcCfg.Addr = "127.0.0.1:0" grpcCfg.FreshnessWindow = testFreshnessWindow return newTestGatewayWithGRPCConfig(t, grpcCfg, deps) } func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) { t.Helper() cfg := config.Config{ ShutdownTimeout: time.Second, AuthenticatedGRPC: grpcCfg, } if deps.Clock == nil { deps.Clock = fixedClock{now: testCurrentTime} } if deps.ResponseSigner == nil { deps.ResponseSigner = newTestResponseSigner() } if deps.Router == nil && deps.Service != nil { deps.Router = executeCommandAdapterRouter{service: deps.Service} } server := NewServer(cfg.AuthenticatedGRPC, deps) application := app.New(cfg, server) ctx, cancel := context.WithCancel(context.Background()) resultCh := make(chan error, 1) go func() { resultCh <- application.Run(ctx) }() return server, runningGateway{ cancel: cancel, resultCh: resultCh, } } func (g runningGateway) stop(t *testing.T) { t.Helper() g.cancel() var err error require.Eventually(t, func() bool { select { case err = <-g.resultCh: return true default: return false } }, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation") require.NoError(t, err) } func waitForListenAddr(t *testing.T, server *Server) string { t.Helper() var addr string require.Eventually(t, func() bool { addr = server.listenAddr() return addr != "" }, time.Second, 10*time.Millisecond, "server did not start listening") return addr } // newEdgeClient returns a Connect client speaking HTTP/2 cleartext to the // authenticated edge listener. AllowHTTP forces the client to issue plain // HTTP/2 requests (h2c) instead of attempting TLS, which the gateway's // in-process test bootstrap does not configure. func newEdgeClient(t *testing.T, addr string) gatewayv1connect.EdgeGatewayClient { t.Helper() httpClient := &http.Client{ Transport: &http2.Transport{ AllowHTTP: true, DialTLSContext: func(ctx context.Context, network, target string, _ *tls.Config) (net.Conn, error) { return (&net.Dialer{}).DialContext(ctx, network, target) }, }, } return gatewayv1connect.NewEdgeGatewayClient(httpClient, "http://"+addr) } // connectErrorMessage extracts the *connect.Error message from err. It // fails the test if err is not a *connect.Error so the caller's expected // message comparison doesn't accidentally match the wrapped Go error // string instead of the protocol-level message. func connectErrorMessage(t require.TestingT, err error) string { if helper, ok := t.(interface{ Helper() }); ok { helper.Helper() } var connectErr *connect.Error if !errors.As(err, &connectErr) { require.FailNowf(t, "expected *connect.Error", "got %T: %v", err, err) } return connectErr.Message() }