package grpcapi import ( "context" "fmt" "net" "net/http" "strings" "testing" "time" "galaxy/gateway/internal/app" "galaxy/gateway/internal/config" "galaxy/gateway/internal/ratelimit" "galaxy/gateway/internal/restapi" "galaxy/gateway/internal/session" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "connectrpc.com/connect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestExecuteCommandRateLimitsByIP(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ Requests: 1, Window: time.Hour, Burst: 1, } }), ServerDependencies{ Service: delegate, SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))) require.NoError(t, err) _, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))) require.Error(t, err) assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err)) assert.Equal(t, "authenticated request rate limit exceeded", connectErrorMessage(t, err)) assert.Equal(t, 1, delegate.executeCalls) } func TestExecuteCommandRateLimitsBySession(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{ Requests: 1, Window: time.Hour, Burst: 1, } }), ServerDependencies{ Service: delegate, SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))) require.NoError(t, err) _, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2"))) require.Error(t, err) assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err)) _, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3"))) require.NoError(t, err) assert.Equal(t, 2, delegate.executeCalls) } func TestExecuteCommandRateLimitsByUser(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{ Requests: 1, Window: time.Hour, Burst: 1, } }), ServerDependencies{ Service: delegate, SessionCache: userMappedSessionCache(map[string]string{ "device-session-1": "user-shared", "device-session-2": "user-shared", "device-session-3": "user-other", }), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))) require.NoError(t, err) _, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))) require.Error(t, err) assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err)) _, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3"))) require.NoError(t, err) assert.Equal(t, 2, delegate.executeCalls) } func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{ Requests: 1, Window: time.Hour, Burst: 1, } }), ServerDependencies{ Service: delegate, SessionCache: userMappedSessionCache(map[string]string{ "device-session-1": "user-1", "device-session-2": "user-2", }), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move"))) require.NoError(t, err) _, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move"))) require.Error(t, err) assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err)) _, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename"))) require.NoError(t, err) assert.Equal(t, 2, delegate.executeCalls) } func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) { t.Parallel() policy := &recordingAuthenticatedRequestPolicy{} delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, Policy: policy, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest())) require.NoError(t, err) require.Len(t, policy.requests, 1) assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod) assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP) assert.Equal(t, "fleet.move", policy.requests[0].MessageClass) assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID) assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID) assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID) assert.Equal(t, "user-123", policy.requests[0].Session.UserID) assert.Equal(t, 1, delegate.executeCalls) } func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGateway(t, ServerDependencies{ Service: delegate, SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error { return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied) }), }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest())) require.Error(t, err) assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) assert.Equal(t, "authenticated request rejected by edge policy", connectErrorMessage(t, err)) assert.Zero(t, delegate.executeCalls) } func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ Requests: 1, Window: time.Hour, Burst: 1, } }), ServerDependencies{ Service: delegate, SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}), ReplayStore: staticReplayStore{}, }) defer runGateway.stop(t) addr := waitForListenAddr(t, server) client := newEdgeClient(t, addr) stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1"))) require.NoError(t, err) event := recvBootstrapEvent(t, stream) assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli()) require.False(t, stream.Receive()) require.NoError(t, stream.Err()) err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2")) require.Error(t, err) assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err)) assert.Equal(t, "authenticated request rate limit exceeded", connectErrorMessage(t, err)) assert.Equal(t, 1, delegate.subscribeCalls) } func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) { t.Parallel() sharedLimiter := ratelimit.NewInMemory() publicCfg := config.DefaultPublicHTTPConfig() publicCfg.Addr = unusedTCPAddr(t) publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{ Requests: 1, Window: time.Hour, Burst: 1, } publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{ Requests: 100, Window: time.Hour, Burst: 100, } grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) { cfg.Addr = unusedTCPAddr(t) cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ Requests: 1, Window: time.Hour, Burst: 1, } }) restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{ AuthService: staticAuthServiceClient{}, Limiter: publicLimiterAdapter{limiter: sharedLimiter}, }) delegate := &recordingEdgeGatewayService{} grpcServer := NewServer(grpcCfg, ServerDependencies{ Service: delegate, Router: executeCommandAdapterRouter{service: delegate}, ResponseSigner: newTestResponseSigner(), SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}), ReplayStore: staticReplayStore{}, Limiter: sharedLimiter, Clock: fixedClock{now: testCurrentTime}, }) application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer) ctx, cancel := context.WithCancel(context.Background()) resultCh := make(chan error, 1) go func() { resultCh <- application.Run(ctx) }() runGateway := runningGateway{cancel: cancel, resultCh: resultCh} defer runGateway.stop(t) waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz") addr := waitForListenAddr(t, grpcServer) firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code") secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code") assert.Equal(t, http.StatusOK, firstPublic.StatusCode) assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode) require.NoError(t, firstPublic.Body.Close()) require.NoError(t, secondPublic.Body.Close()) client := newEdgeClient(t, addr) _, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest())) require.NoError(t, err) } func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig { cfg := config.DefaultAuthenticatedGRPCConfig() cfg.Addr = "127.0.0.1:0" cfg.FreshnessWindow = testFreshnessWindow cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{ Requests: 100, Window: time.Hour, Burst: 100, } cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{ Requests: 100, Window: time.Hour, Burst: 100, } cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{ Requests: 100, Window: time.Hour, Burst: 100, } cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{ Requests: 100, Window: time.Hour, Burst: 100, } if mutate != nil { mutate(&cfg) } return cfg } func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest { req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID) req.MessageType = messageType req.Signature = signRequest( req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash(), ) return req } func userMappedSessionCache(users map[string]string) staticSessionCache { return staticSessionCache{ lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) { userID, ok := users[deviceSessionID] if !ok { return session.Record{}, session.ErrNotFound } record := newActiveSessionRecordWithSessionID(deviceSessionID) record.UserID = userID return record, nil }, } } type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error { return f(ctx, request) } type recordingAuthenticatedRequestPolicy struct { requests []AuthenticatedRequest } func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error { p.requests = append(p.requests, request) return nil } type publicLimiterAdapter struct { limiter ratelimit.Limiter } func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision { decision := a.limiter.Reserve(key, ratelimit.Policy{ Requests: policy.Requests, Window: policy.Window, Burst: policy.Burst, }) return restapi.PublicRateLimitDecision{ Allowed: decision.Allowed, RetryAfter: decision.RetryAfter, } } type staticAuthServiceClient struct{} func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) { return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil } func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) { return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil } func waitForHTTPHealthz(t *testing.T, url string) { t.Helper() client := &http.Client{Timeout: 200 * time.Millisecond} require.Eventually(t, func() bool { response, err := client.Get(url) if err != nil { return false } require.NoError(t, response.Body.Close()) return response.StatusCode == http.StatusOK }, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url) } func sendPublicAuthRequest(t *testing.T, url string) *http.Response { t.Helper() request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`)) require.NoError(t, err) request.Header.Set("Content-Type", "application/json") response, err := (&http.Client{Timeout: time.Second}).Do(request) require.NoError(t, err) return response } func unusedTCPAddr(t *testing.T) string { t.Helper() listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) addr := listener.Addr().String() require.NoError(t, listener.Close()) return addr }