Files
galaxy-game/gateway/internal/grpcapi/rate_limit_integration_test.go
Ilia Denisov 118f7c17a2 phase 4: connectrpc on the gateway authenticated edge
Replace the native-gRPC server bootstrap with a single
`connectrpc.com/connect` HTTP/h2c listener. Connect-Go natively
serves Connect, gRPC, and gRPC-Web on the same port, so browsers can
now reach the authenticated surface without giving up the gRPC
framing native and desktop clients may use later. The decorator
stack (envelope → session → payload-hash → signature →
freshness/replay → rate-limit → routing/push) is reused unchanged
behind a small Connect → gRPC adapter and a `grpc.ServerStream`
shim around `*connect.ServerStream`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-07 11:49:28 +02:00

456 lines
16 KiB
Go

package grpcapi
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/ratelimit"
"galaxy/gateway/internal/restapi"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteCommandRateLimitsByIP(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
client := newEdgeClient(t, addr)
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2")))
require.Error(t, err)
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
assert.Equal(t, "authenticated request rate limit exceeded", connectErrorMessage(t, err))
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsBySession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
client := newEdgeClient(t, addr)
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2")))
require.Error(t, err)
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3")))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsByUser(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{
"device-session-1": "user-shared",
"device-session-2": "user-shared",
"device-session-3": "user-other",
}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
client := newEdgeClient(t, addr)
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2")))
require.Error(t, err)
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3")))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{
"device-session-1": "user-1",
"device-session-2": "user-2",
}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
client := newEdgeClient(t, addr)
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move")))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move")))
require.Error(t, err)
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename")))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) {
t.Parallel()
policy := &recordingAuthenticatedRequestPolicy{}
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Policy: policy,
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
client := newEdgeClient(t, addr)
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
require.NoError(t, err)
require.Len(t, policy.requests, 1)
assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod)
assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP)
assert.Equal(t, "fleet.move", policy.requests[0].MessageClass)
assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID)
assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID)
assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID)
assert.Equal(t, "user-123", policy.requests[0].Session.UserID)
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error {
return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied)
}),
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
client := newEdgeClient(t, addr)
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
require.Error(t, err)
assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err))
assert.Equal(t, "authenticated request rejected by edge policy", connectErrorMessage(t, err))
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
client := newEdgeClient(t, addr)
stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1")))
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli())
require.False(t, stream.Receive())
require.NoError(t, stream.Err())
err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2"))
require.Error(t, err)
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
assert.Equal(t, "authenticated request rate limit exceeded", connectErrorMessage(t, err))
assert.Equal(t, 1, delegate.subscribeCalls)
}
func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) {
t.Parallel()
sharedLimiter := ratelimit.NewInMemory()
publicCfg := config.DefaultPublicHTTPConfig()
publicCfg.Addr = unusedTCPAddr(t)
publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.Addr = unusedTCPAddr(t)
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
})
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
AuthService: staticAuthServiceClient{},
Limiter: publicLimiterAdapter{limiter: sharedLimiter},
})
delegate := &recordingEdgeGatewayService{}
grpcServer := NewServer(grpcCfg, ServerDependencies{
Service: delegate,
Router: executeCommandAdapterRouter{service: delegate},
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Limiter: sharedLimiter,
Clock: fixedClock{now: testCurrentTime},
})
application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
runGateway := runningGateway{cancel: cancel, resultCh: resultCh}
defer runGateway.stop(t)
waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz")
addr := waitForListenAddr(t, grpcServer)
firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
assert.Equal(t, http.StatusOK, firstPublic.StatusCode)
assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode)
require.NoError(t, firstPublic.Body.Close())
require.NoError(t, secondPublic.Body.Close())
client := newEdgeClient(t, addr)
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
require.NoError(t, err)
}
func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig {
cfg := config.DefaultAuthenticatedGRPCConfig()
cfg.Addr = "127.0.0.1:0"
cfg.FreshnessWindow = testFreshnessWindow
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
if mutate != nil {
mutate(&cfg)
}
return cfg
}
func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest {
req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID)
req.MessageType = messageType
req.Signature = signRequest(
req.GetProtocolVersion(),
req.GetDeviceSessionId(),
req.GetMessageType(),
req.GetTimestampMs(),
req.GetRequestId(),
req.GetPayloadHash(),
)
return req
}
func userMappedSessionCache(users map[string]string) staticSessionCache {
return staticSessionCache{
lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) {
userID, ok := users[deviceSessionID]
if !ok {
return session.Record{}, session.ErrNotFound
}
record := newActiveSessionRecordWithSessionID(deviceSessionID)
record.UserID = userID
return record, nil
},
}
}
type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error
func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error {
return f(ctx, request)
}
type recordingAuthenticatedRequestPolicy struct {
requests []AuthenticatedRequest
}
func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error {
p.requests = append(p.requests, request)
return nil
}
type publicLimiterAdapter struct {
limiter ratelimit.Limiter
}
func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision {
decision := a.limiter.Reserve(key, ratelimit.Policy{
Requests: policy.Requests,
Window: policy.Window,
Burst: policy.Burst,
})
return restapi.PublicRateLimitDecision{
Allowed: decision.Allowed,
RetryAfter: decision.RetryAfter,
}
}
type staticAuthServiceClient struct{}
func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) {
return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil
}
func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) {
return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil
}
func waitForHTTPHealthz(t *testing.T, url string) {
t.Helper()
client := &http.Client{Timeout: 200 * time.Millisecond}
require.Eventually(t, func() bool {
response, err := client.Get(url)
if err != nil {
return false
}
require.NoError(t, response.Body.Close())
return response.StatusCode == http.StatusOK
}, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url)
}
func sendPublicAuthRequest(t *testing.T, url string) *http.Response {
t.Helper()
request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`))
require.NoError(t, err)
request.Header.Set("Content-Type", "application/json")
response, err := (&http.Client{Timeout: time.Second}).Do(request)
require.NoError(t, err)
return response
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}