feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,497 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
"galaxy/gateway/internal/restapi"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRateLimitsByIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsBySession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsByUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{
|
||||
"device-session-1": "user-shared",
|
||||
"device-session-2": "user-shared",
|
||||
"device-session-3": "user-other",
|
||||
}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{
|
||||
"device-session-1": "user-1",
|
||||
"device-session-2": "user-2",
|
||||
}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
policy := &recordingAuthenticatedRequestPolicy{}
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Policy: policy,
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, policy.requests, 1)
|
||||
assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod)
|
||||
assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP)
|
||||
assert.Equal(t, "fleet.move", policy.requests[0].MessageClass)
|
||||
assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID)
|
||||
assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID)
|
||||
assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID)
|
||||
assert.Equal(t, "user-123", policy.requests[0].Session.UserID)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error {
|
||||
return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied)
|
||||
}),
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rejected by edge policy", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sharedLimiter := ratelimit.NewInMemory()
|
||||
|
||||
publicCfg := config.DefaultPublicHTTPConfig()
|
||||
publicCfg.Addr = unusedTCPAddr(t)
|
||||
publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
|
||||
grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.Addr = unusedTCPAddr(t)
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
})
|
||||
|
||||
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
|
||||
AuthService: staticAuthServiceClient{},
|
||||
Limiter: publicLimiterAdapter{limiter: sharedLimiter},
|
||||
})
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
grpcServer := NewServer(grpcCfg, ServerDependencies{
|
||||
Service: delegate,
|
||||
Router: executeCommandAdapterRouter{service: delegate},
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Limiter: sharedLimiter,
|
||||
Clock: fixedClock{now: testCurrentTime},
|
||||
})
|
||||
|
||||
application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
runGateway := runningGateway{cancel: cancel, resultCh: resultCh}
|
||||
defer runGateway.stop(t)
|
||||
|
||||
waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz")
|
||||
addr := waitForListenAddr(t, grpcServer)
|
||||
|
||||
firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
|
||||
secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
|
||||
|
||||
assert.Equal(t, http.StatusOK, firstPublic.StatusCode)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode)
|
||||
require.NoError(t, firstPublic.Body.Close())
|
||||
require.NoError(t, secondPublic.Body.Close())
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig {
|
||||
cfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
cfg.Addr = "127.0.0.1:0"
|
||||
cfg.FreshnessWindow = testFreshnessWindow
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
|
||||
if mutate != nil {
|
||||
mutate(&cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest {
|
||||
req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID)
|
||||
req.MessageType = messageType
|
||||
req.Signature = signRequest(
|
||||
req.GetProtocolVersion(),
|
||||
req.GetDeviceSessionId(),
|
||||
req.GetMessageType(),
|
||||
req.GetTimestampMs(),
|
||||
req.GetRequestId(),
|
||||
req.GetPayloadHash(),
|
||||
)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func userMappedSessionCache(users map[string]string) staticSessionCache {
|
||||
return staticSessionCache{
|
||||
lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) {
|
||||
userID, ok := users[deviceSessionID]
|
||||
if !ok {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
}
|
||||
|
||||
record := newActiveSessionRecordWithSessionID(deviceSessionID)
|
||||
record.UserID = userID
|
||||
return record, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error
|
||||
|
||||
func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error {
|
||||
return f(ctx, request)
|
||||
}
|
||||
|
||||
type recordingAuthenticatedRequestPolicy struct {
|
||||
requests []AuthenticatedRequest
|
||||
}
|
||||
|
||||
func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error {
|
||||
p.requests = append(p.requests, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
type publicLimiterAdapter struct {
|
||||
limiter ratelimit.Limiter
|
||||
}
|
||||
|
||||
func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision {
|
||||
decision := a.limiter.Reserve(key, ratelimit.Policy{
|
||||
Requests: policy.Requests,
|
||||
Window: policy.Window,
|
||||
Burst: policy.Burst,
|
||||
})
|
||||
|
||||
return restapi.PublicRateLimitDecision{
|
||||
Allowed: decision.Allowed,
|
||||
RetryAfter: decision.RetryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
type staticAuthServiceClient struct{}
|
||||
|
||||
func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) {
|
||||
return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil
|
||||
}
|
||||
|
||||
func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) {
|
||||
return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil
|
||||
}
|
||||
|
||||
func waitForHTTPHealthz(t *testing.T, url string) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||
require.Eventually(t, func() bool {
|
||||
response, err := client.Get(url)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
require.NoError(t, response.Body.Close())
|
||||
|
||||
return response.StatusCode == http.StatusOK
|
||||
}, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url)
|
||||
}
|
||||
|
||||
func sendPublicAuthRequest(t *testing.T, url string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`))
|
||||
require.NoError(t, err)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := (&http.Client{Timeout: time.Second}).Do(request)
|
||||
require.NoError(t, err)
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
Reference in New Issue
Block a user