feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// commandRoutingService translates the verified authenticated request context
|
||||
// into an internal downstream command and signs successful unary responses.
|
||||
type commandRoutingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
subscribeDelegate gatewayv1.EdgeGatewayServer
|
||||
router downstream.Router
|
||||
responseSigner authn.ResponseSigner
|
||||
clock clock.Clock
|
||||
downstreamTimeout time.Duration
|
||||
}
|
||||
|
||||
// ExecuteCommand builds a verified downstream command, routes it by exact
|
||||
// message_type, executes it, and signs the resulting unary response.
|
||||
func (s commandRoutingService) ExecuteCommand(ctx context.Context, _ *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
command, err := authenticatedCommandFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := s.router.Route(command.MessageType)
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, downstream.ErrRouteNotFound):
|
||||
return nil, status.Error(codes.Unimplemented, "message_type is not routed")
|
||||
case errors.Is(err, downstream.ErrDownstreamUnavailable):
|
||||
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
|
||||
default:
|
||||
return nil, status.Error(codes.Internal, "downstream route resolution failed")
|
||||
}
|
||||
|
||||
downstreamCtx, cancel := context.WithTimeout(ctx, s.downstreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := client.ExecuteCommand(downstreamCtx, command)
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, downstream.ErrDownstreamUnavailable),
|
||||
errors.Is(err, context.DeadlineExceeded),
|
||||
errors.Is(err, context.Canceled):
|
||||
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
|
||||
default:
|
||||
return nil, status.Error(codes.Internal, "downstream execution failed")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(result.ResultCode) == "" {
|
||||
return nil, status.Error(codes.Internal, "downstream response is invalid")
|
||||
}
|
||||
|
||||
responseTimestampMS := s.clock.Now().UTC().UnixMilli()
|
||||
payloadHash := sha256.Sum256(result.PayloadBytes)
|
||||
signature, err := s.responseSigner.SignResponse(authn.ResponseSigningFields{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
RequestID: command.RequestID,
|
||||
TimestampMS: responseTimestampMS,
|
||||
ResultCode: result.ResultCode,
|
||||
PayloadHash: payloadHash[:],
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
|
||||
}
|
||||
|
||||
return &gatewayv1.ExecuteCommandResponse{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
RequestId: command.RequestID,
|
||||
TimestampMs: responseTimestampMS,
|
||||
ResultCode: result.ResultCode,
|
||||
PayloadBytes: bytes.Clone(result.PayloadBytes),
|
||||
PayloadHash: bytes.Clone(payloadHash[:]),
|
||||
Signature: signature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SubscribeEvents delegates to the authenticated streaming service
|
||||
// implementation selected during server construction.
|
||||
func (s commandRoutingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
return s.subscribeDelegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newCommandRoutingService constructs the final authenticated service that
|
||||
// owns verified unary routing while preserving the delegated streaming path.
|
||||
func newCommandRoutingService(subscribeDelegate gatewayv1.EdgeGatewayServer, router downstream.Router, responseSigner authn.ResponseSigner, clk clock.Clock, downstreamTimeout time.Duration) gatewayv1.EdgeGatewayServer {
|
||||
return commandRoutingService{
|
||||
subscribeDelegate: subscribeDelegate,
|
||||
router: router,
|
||||
responseSigner: responseSigner,
|
||||
clock: clk,
|
||||
downstreamTimeout: downstreamTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func authenticatedCommandFromContext(ctx context.Context) (downstream.AuthenticatedCommand, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
if !ok {
|
||||
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
return downstream.AuthenticatedCommand{
|
||||
ProtocolVersion: envelope.ProtocolVersion,
|
||||
UserID: record.UserID,
|
||||
DeviceSessionID: record.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
TimestampMS: envelope.TimestampMS,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
PayloadBytes: bytes.Clone(envelope.PayloadBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type unavailableResponseSigner struct{}
|
||||
|
||||
func (unavailableResponseSigner) SignResponse(authn.ResponseSigningFields) ([]byte, error) {
|
||||
return nil, errors.New("response signer is unavailable")
|
||||
}
|
||||
|
||||
func (unavailableResponseSigner) SignEvent(authn.EventSigningFields) ([]byte, error) {
|
||||
return nil, errors.New("response signer is unavailable")
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = commandRoutingService{}
|
||||
@@ -0,0 +1,296 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := newTestEd25519ResponseSigner()
|
||||
moveClient := &recordingDownstreamClient{
|
||||
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
assert.Equal(t, downstream.AuthenticatedCommand{
|
||||
ProtocolVersion: "v1",
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: testCurrentTime.UnixMilli(),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
PayloadBytes: []byte("payload"),
|
||||
}, command)
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "accepted",
|
||||
PayloadBytes: []byte("downstream-response"),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
renameClient := &recordingDownstreamClient{}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": moveClient,
|
||||
"fleet.rename": renameClient,
|
||||
}),
|
||||
ResponseSigner: signer,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v1", response.GetProtocolVersion())
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
assert.Equal(t, testCurrentTime.UnixMilli(), response.GetTimestampMs())
|
||||
assert.Equal(t, "accepted", response.GetResultCode())
|
||||
assert.Equal(t, []byte("downstream-response"), response.GetPayloadBytes())
|
||||
assert.Equal(t, 1, moveClient.executeCalls)
|
||||
assert.Zero(t, renameClient.executeCalls)
|
||||
|
||||
wantHash := sha256.Sum256([]byte("downstream-response"))
|
||||
assert.Equal(t, wantHash[:], response.GetPayloadHash())
|
||||
require.NoError(t, authn.VerifyPayloadHash(response.GetPayloadBytes(), response.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.GetSignature(), authn.ResponseSigningFields{
|
||||
ProtocolVersion: response.GetProtocolVersion(),
|
||||
RequestID: response.GetRequestId(),
|
||||
TimestampMS: response.GetTimestampMs(),
|
||||
ResultCode: response.GetResultCode(),
|
||||
PayloadHash: response.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(nil),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unimplemented, status.Code(err))
|
||||
assert.Equal(t, "message_type is not routed", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
failingClient := &recordingDownstreamClient{
|
||||
executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("rpc transport failed: %w", downstream.ErrDownstreamUnavailable)
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": failingClient,
|
||||
}),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, failingClient.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := zap.NewNop()
|
||||
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
||||
|
||||
var (
|
||||
seenSpanContext trace.SpanContext
|
||||
seenCommand downstream.AuthenticatedCommand
|
||||
)
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": &recordingDownstreamClient{
|
||||
executeFunc: func(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
seenSpanContext = trace.SpanContextFromContext(ctx)
|
||||
seenCommand = command
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "accepted",
|
||||
PayloadBytes: []byte("downstream-response"),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}),
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, seenSpanContext.IsValid())
|
||||
assert.Equal(t, "trace-123", seenCommand.TraceID)
|
||||
}
|
||||
|
||||
func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": &recordingDownstreamClient{
|
||||
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
close(started)
|
||||
<-release
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "accepted",
|
||||
PayloadBytes: []byte("downstream-response"),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}),
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
resultCh <- err
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-started:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond, "downstream execution did not start")
|
||||
|
||||
runGateway.cancel()
|
||||
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-resultCh:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond, "unary request returned before downstream release")
|
||||
|
||||
close(release)
|
||||
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case err = <-resultCh:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond, "unary request did not drain before shutdown timeout")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, logBuffer := testutil.NewObservedLogger(t)
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": &recordingDownstreamClient{},
|
||||
}),
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Logger: logger,
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
logOutput := logBuffer.String()
|
||||
assert.NotContains(t, logOutput, "payload_hash")
|
||||
assert.NotContains(t, logOutput, "signature")
|
||||
assert.NotContains(t, logOutput, `"payload"`)
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"buf.build/go/protovalidate"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const supportedProtocolVersion = "v1"
|
||||
|
||||
// parsedEnvelope captures the authenticated transport fields extracted from a
|
||||
// request envelope after validation succeeds. Later wrappers may enrich this
|
||||
// structure without changing the raw gRPC request types.
|
||||
type parsedEnvelope struct {
|
||||
ProtocolVersion string
|
||||
DeviceSessionID string
|
||||
MessageType string
|
||||
TimestampMS int64
|
||||
RequestID string
|
||||
TraceID string
|
||||
PayloadBytes []byte
|
||||
PayloadHash []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
// parsedEnvelopeFromContext returns the parsed envelope previously attached to
|
||||
// ctx by the envelope-validating gRPC service wrapper.
|
||||
func parsedEnvelopeFromContext(ctx context.Context) (parsedEnvelope, bool) {
|
||||
if ctx == nil {
|
||||
return parsedEnvelope{}, false
|
||||
}
|
||||
|
||||
envelope, ok := ctx.Value(parsedEnvelopeContextKey{}).(parsedEnvelope)
|
||||
if !ok {
|
||||
return parsedEnvelope{}, false
|
||||
}
|
||||
|
||||
return envelope, true
|
||||
}
|
||||
|
||||
// envelopeValidatingService applies envelope parsing and the protocol gate
|
||||
// before delegating to the configured service implementation.
|
||||
type envelopeValidatingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
// ExecuteCommand validates req and only then forwards it to the configured
|
||||
// delegate with the parsed envelope attached to ctx.
|
||||
func (s envelopeValidatingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
envelope, err := parseExecuteCommandRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(context.WithValue(ctx, parsedEnvelopeContextKey{}, envelope), req)
|
||||
}
|
||||
|
||||
// SubscribeEvents validates req and only then forwards it to the configured
|
||||
// delegate with the parsed envelope attached to the stream context.
|
||||
func (s envelopeValidatingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
envelope, err := parseSubscribeEventsRequest(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, envelopeContextStream{
|
||||
ServerStreamingServer: stream,
|
||||
ctx: context.WithValue(stream.Context(), parsedEnvelopeContextKey{}, envelope),
|
||||
})
|
||||
}
|
||||
|
||||
// parseExecuteCommandRequest validates req according to the request-envelope
|
||||
// rules and returns a cloned parsed envelope suitable for later auth steps.
|
||||
func parseExecuteCommandRequest(req *gatewayv1.ExecuteCommandRequest) (parsedEnvelope, error) {
|
||||
if req == nil {
|
||||
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
|
||||
}
|
||||
if err := protovalidate.Validate(req); err != nil {
|
||||
return parsedEnvelope{}, canonicalExecuteCommandValidationError(req)
|
||||
}
|
||||
if req.GetProtocolVersion() != supportedProtocolVersion {
|
||||
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
|
||||
}
|
||||
|
||||
return parsedEnvelope{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
TraceID: req.GetTraceId(),
|
||||
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
|
||||
PayloadHash: bytes.Clone(req.GetPayloadHash()),
|
||||
Signature: bytes.Clone(req.GetSignature()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseSubscribeEventsRequest validates req according to the request-envelope
|
||||
// rules and returns a cloned parsed envelope suitable for later auth steps.
|
||||
func parseSubscribeEventsRequest(req *gatewayv1.SubscribeEventsRequest) (parsedEnvelope, error) {
|
||||
if req == nil {
|
||||
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
|
||||
}
|
||||
if err := protovalidate.Validate(req); err != nil {
|
||||
return parsedEnvelope{}, canonicalSubscribeEventsValidationError(req)
|
||||
}
|
||||
if req.GetProtocolVersion() != supportedProtocolVersion {
|
||||
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
|
||||
}
|
||||
|
||||
return parsedEnvelope{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
TraceID: req.GetTraceId(),
|
||||
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
|
||||
PayloadHash: bytes.Clone(req.GetPayloadHash()),
|
||||
Signature: bytes.Clone(req.GetSignature()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newEnvelopeValidatingService wraps delegate with the envelope-validation
|
||||
// gate.
|
||||
func newEnvelopeValidatingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
|
||||
return envelopeValidatingService{delegate: delegate}
|
||||
}
|
||||
|
||||
// canonicalExecuteCommandValidationError maps any ExecuteCommand validation
|
||||
// failure into the stable canonical error chosen by field order.
|
||||
func canonicalExecuteCommandValidationError(req *gatewayv1.ExecuteCommandRequest) error {
|
||||
switch {
|
||||
case req.GetProtocolVersion() == "":
|
||||
return newMalformedEnvelopeError("protocol_version must not be empty")
|
||||
case req.GetDeviceSessionId() == "":
|
||||
return newMalformedEnvelopeError("device_session_id must not be empty")
|
||||
case req.GetMessageType() == "":
|
||||
return newMalformedEnvelopeError("message_type must not be empty")
|
||||
case req.GetTimestampMs() <= 0:
|
||||
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
|
||||
case req.GetRequestId() == "":
|
||||
return newMalformedEnvelopeError("request_id must not be empty")
|
||||
case len(req.GetPayloadBytes()) == 0:
|
||||
return newMalformedEnvelopeError("payload_bytes must not be empty")
|
||||
case len(req.GetPayloadHash()) == 0:
|
||||
return newMalformedEnvelopeError("payload_hash must not be empty")
|
||||
case len(req.GetSignature()) == 0:
|
||||
return newMalformedEnvelopeError("signature must not be empty")
|
||||
default:
|
||||
return newMalformedEnvelopeError("request envelope is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
// canonicalSubscribeEventsValidationError maps any SubscribeEvents validation
|
||||
// failure into the stable canonical error chosen by field order.
|
||||
func canonicalSubscribeEventsValidationError(req *gatewayv1.SubscribeEventsRequest) error {
|
||||
switch {
|
||||
case req.GetProtocolVersion() == "":
|
||||
return newMalformedEnvelopeError("protocol_version must not be empty")
|
||||
case req.GetDeviceSessionId() == "":
|
||||
return newMalformedEnvelopeError("device_session_id must not be empty")
|
||||
case req.GetMessageType() == "":
|
||||
return newMalformedEnvelopeError("message_type must not be empty")
|
||||
case req.GetTimestampMs() <= 0:
|
||||
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
|
||||
case req.GetRequestId() == "":
|
||||
return newMalformedEnvelopeError("request_id must not be empty")
|
||||
case len(req.GetPayloadHash()) == 0:
|
||||
return newMalformedEnvelopeError("payload_hash must not be empty")
|
||||
case len(req.GetSignature()) == 0:
|
||||
return newMalformedEnvelopeError("signature must not be empty")
|
||||
default:
|
||||
return newMalformedEnvelopeError("request envelope is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
// newMalformedEnvelopeError returns the stable malformed-envelope reject used
|
||||
// before the gateway performs any auth or routing work.
|
||||
func newMalformedEnvelopeError(message string) error {
|
||||
return status.Error(codes.InvalidArgument, message)
|
||||
}
|
||||
|
||||
// newUnsupportedProtocolVersionError returns the stable reject for a non-empty
|
||||
// but unsupported protocol_version literal.
|
||||
func newUnsupportedProtocolVersionError(version string) error {
|
||||
return status.Error(codes.FailedPrecondition, fmt.Sprintf("unsupported protocol_version %q", version))
|
||||
}
|
||||
|
||||
type parsedEnvelopeContextKey struct{}
|
||||
|
||||
type envelopeContextStream struct {
|
||||
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s envelopeContextStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = envelopeValidatingService{}
|
||||
@@ -0,0 +1,420 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestParseExecuteCommandRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*gatewayv1.ExecuteCommandRequest)
|
||||
wantCode codes.Code
|
||||
wantMessage string
|
||||
assertValid func(*testing.T, *gatewayv1.ExecuteCommandRequest, parsedEnvelope)
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request envelope must not be nil",
|
||||
},
|
||||
{
|
||||
name: "empty protocol version",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.ProtocolVersion = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "protocol_version must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty device session id",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.DeviceSessionId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "device_session_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty message type",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.MessageType = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "message_type must not be empty",
|
||||
},
|
||||
{
|
||||
name: "zero timestamp",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.TimestampMs = 0
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "timestamp_ms must be greater than zero",
|
||||
},
|
||||
{
|
||||
name: "empty request id",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.RequestId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty payload bytes",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.PayloadBytes = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "payload_bytes must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty payload hash",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.PayloadHash = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "payload_hash must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty signature",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.Signature = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "signature must not be empty",
|
||||
},
|
||||
{
|
||||
name: "unsupported protocol version",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.ProtocolVersion = "v2"
|
||||
},
|
||||
wantCode: codes.FailedPrecondition,
|
||||
wantMessage: `unsupported protocol_version "v2"`,
|
||||
},
|
||||
{
|
||||
name: "valid request",
|
||||
wantCode: codes.OK,
|
||||
assertValid: func(t *testing.T, req *gatewayv1.ExecuteCommandRequest, envelope parsedEnvelope) {
|
||||
t.Helper()
|
||||
|
||||
assert.Equal(t, supportedProtocolVersion, envelope.ProtocolVersion)
|
||||
assert.Equal(t, req.GetDeviceSessionId(), envelope.DeviceSessionID)
|
||||
assert.Equal(t, req.GetMessageType(), envelope.MessageType)
|
||||
assert.Equal(t, req.GetTimestampMs(), envelope.TimestampMS)
|
||||
assert.Equal(t, req.GetRequestId(), envelope.RequestID)
|
||||
assert.Equal(t, req.GetTraceId(), envelope.TraceID)
|
||||
assert.Equal(t, req.GetPayloadBytes(), envelope.PayloadBytes)
|
||||
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
|
||||
assert.Equal(t, req.GetSignature(), envelope.Signature)
|
||||
|
||||
originalPayloadBytes := append([]byte(nil), req.GetPayloadBytes()...)
|
||||
originalPayloadHash := append([]byte(nil), req.GetPayloadHash()...)
|
||||
originalSignature := append([]byte(nil), req.GetSignature()...)
|
||||
|
||||
envelope.PayloadBytes[0] = 'X'
|
||||
envelope.PayloadHash[0] = 'Y'
|
||||
envelope.Signature[0] = 'Z'
|
||||
|
||||
assert.Equal(t, originalPayloadBytes, req.GetPayloadBytes())
|
||||
assert.Equal(t, originalPayloadHash, req.GetPayloadHash())
|
||||
assert.Equal(t, originalSignature, req.GetSignature())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var req *gatewayv1.ExecuteCommandRequest
|
||||
if tt.name != "nil request" {
|
||||
req = newValidExecuteCommandRequest()
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(req)
|
||||
}
|
||||
}
|
||||
|
||||
envelope, err := parseExecuteCommandRequest(req)
|
||||
if tt.wantCode != codes.OK {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.wantCode, status.Code(err))
|
||||
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tt.assertValid)
|
||||
tt.assertValid(t, req, envelope)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSubscribeEventsRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*gatewayv1.SubscribeEventsRequest)
|
||||
wantCode codes.Code
|
||||
wantMessage string
|
||||
assertValid func(*testing.T, *gatewayv1.SubscribeEventsRequest, parsedEnvelope)
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request envelope must not be nil",
|
||||
},
|
||||
{
|
||||
name: "empty protocol version",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.ProtocolVersion = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "protocol_version must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty device session id",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.DeviceSessionId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "device_session_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty message type",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.MessageType = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "message_type must not be empty",
|
||||
},
|
||||
{
|
||||
name: "zero timestamp",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.TimestampMs = 0
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "timestamp_ms must be greater than zero",
|
||||
},
|
||||
{
|
||||
name: "empty request id",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.RequestId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty payload hash",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.PayloadHash = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "payload_hash must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty signature",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.Signature = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "signature must not be empty",
|
||||
},
|
||||
{
|
||||
name: "unsupported protocol version",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.ProtocolVersion = "v2"
|
||||
},
|
||||
wantCode: codes.FailedPrecondition,
|
||||
wantMessage: `unsupported protocol_version "v2"`,
|
||||
},
|
||||
{
|
||||
name: "valid request with empty payload bytes",
|
||||
wantCode: codes.OK,
|
||||
assertValid: func(t *testing.T, req *gatewayv1.SubscribeEventsRequest, envelope parsedEnvelope) {
|
||||
t.Helper()
|
||||
|
||||
assert.Empty(t, req.GetPayloadBytes())
|
||||
assert.Empty(t, envelope.PayloadBytes)
|
||||
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
|
||||
assert.Equal(t, req.GetSignature(), envelope.Signature)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var req *gatewayv1.SubscribeEventsRequest
|
||||
if tt.name != "nil request" {
|
||||
req = newValidSubscribeEventsRequest()
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(req)
|
||||
}
|
||||
}
|
||||
|
||||
envelope, err := parseSubscribeEventsRequest(req)
|
||||
if tt.wantCode != codes.OK {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.wantCode, status.Code(err))
|
||||
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tt.assertValid)
|
||||
tt.assertValid(t, req, envelope)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceExecuteCommandRejectsInvalidRequestBeforeDelegate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
_, err := service.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceSubscribeEventsRejectsInvalidRequestBeforeDelegate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
err := service.SubscribeEvents(&gatewayv1.SubscribeEventsRequest{}, stubGatewayEventStream{})
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceExecuteCommandAttachesParsedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := newValidExecuteCommandRequest()
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
|
||||
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
|
||||
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
|
||||
assert.Equal(t, want.GetPayloadBytes(), envelope.PayloadBytes)
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
response, err := service.ExecuteCommand(context.Background(), want)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want.GetRequestId(), response.GetRequestId())
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceSubscribeEventsAttachesParsedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := newValidSubscribeEventsRequest()
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(stream.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
|
||||
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
|
||||
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
|
||||
assert.Equal(t, want.GetPayloadHash(), envelope.PayloadHash)
|
||||
assert.Equal(t, want.GetSignature(), envelope.Signature)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
err := service.SubscribeEvents(want, stubGatewayEventStream{})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
type recordingEdgeGatewayService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
executeCalls int
|
||||
subscribeCalls int
|
||||
executeCommandFunc func(context.Context, *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error)
|
||||
subscribeEventsFunc func(*gatewayv1.SubscribeEventsRequest, grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error
|
||||
}
|
||||
|
||||
func (s *recordingEdgeGatewayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
s.executeCalls++
|
||||
if s.executeCommandFunc != nil {
|
||||
return s.executeCommandFunc(ctx, req)
|
||||
}
|
||||
|
||||
return &gatewayv1.ExecuteCommandResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *recordingEdgeGatewayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
s.subscribeCalls++
|
||||
if s.subscribeEventsFunc != nil {
|
||||
return s.subscribeEventsFunc(req, stream)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubGatewayEventStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) Send(*gatewayv1.GatewayEvent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SendHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SetTrailer(metadata.MD) {}
|
||||
|
||||
func (s stubGatewayEventStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SendMsg(any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) RecvMsg(any) error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/replay"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const minimumReplayReservationTTL = time.Millisecond
|
||||
|
||||
// freshnessAndReplayService applies freshness and anti-replay checks after
|
||||
// client-signature verification and before later policy or routing steps run.
|
||||
type freshnessAndReplayService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
clock clock.Clock
|
||||
replayStore replay.Store
|
||||
freshnessWindow time.Duration
|
||||
}
|
||||
|
||||
// ExecuteCommand verifies request freshness and replay protection before
|
||||
// delegating to the configured service implementation.
|
||||
func (s freshnessAndReplayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := s.verifyFreshnessAndReplay(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents verifies request freshness and replay protection before
|
||||
// delegating to the configured service implementation.
|
||||
func (s freshnessAndReplayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := s.verifyFreshnessAndReplay(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newFreshnessAndReplayService wraps delegate with the freshness and replay
|
||||
// gate.
|
||||
func newFreshnessAndReplayService(delegate gatewayv1.EdgeGatewayServer, clk clock.Clock, replayStore replay.Store, freshnessWindow time.Duration) gatewayv1.EdgeGatewayServer {
|
||||
return freshnessAndReplayService{
|
||||
delegate: delegate,
|
||||
clock: clk,
|
||||
replayStore: replayStore,
|
||||
freshnessWindow: freshnessWindow,
|
||||
}
|
||||
}
|
||||
|
||||
func (s freshnessAndReplayService) verifyFreshnessAndReplay(ctx context.Context) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
now := s.clock.Now().UTC()
|
||||
requestTime := time.UnixMilli(envelope.TimestampMS).UTC()
|
||||
if requestTime.Before(now.Add(-s.freshnessWindow)) || requestTime.After(now.Add(s.freshnessWindow)) {
|
||||
return status.Error(codes.FailedPrecondition, "request timestamp is outside the freshness window")
|
||||
}
|
||||
|
||||
ttl := requestTime.Add(s.freshnessWindow).Sub(now)
|
||||
if ttl < minimumReplayReservationTTL {
|
||||
ttl = minimumReplayReservationTTL
|
||||
}
|
||||
|
||||
err := s.replayStore.Reserve(ctx, envelope.DeviceSessionID, envelope.RequestID, ttl)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, replay.ErrDuplicate):
|
||||
return status.Error(codes.FailedPrecondition, "request replay detected")
|
||||
default:
|
||||
return status.Error(codes.Unavailable, "replay store is unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
type unavailableReplayStore struct{}
|
||||
|
||||
func (unavailableReplayStore) Reserve(context.Context, string, string, time.Duration) error {
|
||||
return errors.New("replay store is unavailable")
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = freshnessAndReplayService{}
|
||||
@@ -0,0 +1,509 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/replay"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsStaleTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timestampMS int64
|
||||
}{
|
||||
{
|
||||
name: "past window",
|
||||
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
|
||||
},
|
||||
{
|
||||
name: "future window",
|
||||
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsStaleTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timestampMS int64
|
||||
}{
|
||||
{
|
||||
name: "past window",
|
||||
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
|
||||
},
|
||||
{
|
||||
name: "future window",
|
||||
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsReplay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
req := newValidExecuteCommandRequest()
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request replay detected", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsReplay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
req := newValidSubscribeEventsRequest()
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
err = subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request replay detected", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
||||
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-456", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
||||
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
stream, err = client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-456", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
event = recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
assert.Equal(t, 2, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsReplayStoreUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(context.Context, string, string, time.Duration) error {
|
||||
return errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsReplayStoreUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(context.Context, string, string, time.Duration) error {
|
||||
return errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedDeviceSessionID string
|
||||
var reservedRequestID string
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
reservedDeviceSessionID = deviceSessionID
|
||||
reservedRequestID = requestID
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
assert.Equal(t, "device-session-123", reservedDeviceSessionID)
|
||||
assert.Equal(t, "request-123", reservedRequestID)
|
||||
assert.Equal(t, testFreshnessWindow, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
assert.Equal(t, "device-session-123", deviceSessionID)
|
||||
assert.Equal(t, "request-123", requestID)
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
assert.Equal(t, testFreshnessWindow, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandFutureSkewUsesExtendedReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(
|
||||
context.Background(),
|
||||
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(2*time.Minute).UnixMilli()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 7*time.Minute, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandBoundaryFreshnessUsesMinimumReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(
|
||||
context.Background(),
|
||||
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(-testFreshnessWindow).UnixMilli()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, minimumReplayReservationTTL, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func replayDuplicateBySessionAndRequest() func(context.Context, string, string, time.Duration) error {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seen = make(map[string]struct{})
|
||||
)
|
||||
|
||||
return func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
key := deviceSessionID + "\x00" + requestID
|
||||
if _, ok := seen[key]; ok {
|
||||
return replay.ErrDuplicate
|
||||
}
|
||||
|
||||
seen[key] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func observabilityUnaryInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.UnaryServerInterceptor {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
start := time.Now()
|
||||
resp, err := handler(ctx, req)
|
||||
|
||||
recordGRPCRequest(logger, metrics, ctx, info.FullMethod, req, resp, err, time.Since(start), "unary")
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func observabilityStreamInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.StreamServerInterceptor {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
start := time.Now()
|
||||
wrapped := &observabilityServerStream{ServerStream: stream}
|
||||
err := handler(srv, wrapped)
|
||||
|
||||
recordGRPCRequest(logger, metrics, stream.Context(), info.FullMethod, wrapped.request, nil, err, time.Since(start), "stream")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
type observabilityServerStream struct {
|
||||
grpc.ServerStream
|
||||
request any
|
||||
}
|
||||
|
||||
func (s *observabilityServerStream) RecvMsg(m any) error {
|
||||
err := s.ServerStream.RecvMsg(m)
|
||||
if err == nil && s.request == nil {
|
||||
s.request = m
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx context.Context, fullMethod string, req any, resp any, err error, duration time.Duration, streamKind string) {
|
||||
rpcMethod := path.Base(fullMethod)
|
||||
messageType, requestID, traceID := grpcEnvelopeFields(req)
|
||||
resultCode := grpcResultCode(resp)
|
||||
grpcCode, grpcMessage, outcome := grpcOutcome(err)
|
||||
rejectReason := telemetry.RejectReason(outcome)
|
||||
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("rpc_method", rpcMethod),
|
||||
attribute.String("message_type", messageType),
|
||||
attribute.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if resultCode != "" {
|
||||
attrs = append(attrs, attribute.String("result_code", resultCode))
|
||||
}
|
||||
if rejectReason != "" {
|
||||
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
|
||||
}
|
||||
metrics.RecordAuthenticatedGRPC(ctx, attrs, duration)
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "authenticated_grpc"),
|
||||
zap.String("transport", "grpc"),
|
||||
zap.String("stream_kind", streamKind),
|
||||
zap.String("rpc_method", rpcMethod),
|
||||
zap.String("message_type", messageType),
|
||||
zap.String("grpc_code", grpcCode.String()),
|
||||
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("peer_ip", peerIPFromContext(ctx)),
|
||||
zap.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if resultCode != "" {
|
||||
fields = append(fields, zap.String("result_code", resultCode))
|
||||
}
|
||||
if rejectReason != "" {
|
||||
fields = append(fields, zap.String("reject_reason", rejectReason))
|
||||
}
|
||||
if grpcMessage != "" {
|
||||
fields = append(fields, zap.String("grpc_message", grpcMessage))
|
||||
}
|
||||
fields = append(fields, logging.TraceFieldsFromContext(ctx)...)
|
||||
|
||||
switch outcome {
|
||||
case telemetry.EdgeOutcomeSuccess:
|
||||
logger.Info("authenticated gRPC request completed", fields...)
|
||||
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeDownstreamUnavailable, telemetry.EdgeOutcomeInternalError:
|
||||
logger.Error("authenticated gRPC request failed", fields...)
|
||||
default:
|
||||
logger.Warn("authenticated gRPC request rejected", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func grpcEnvelopeFields(req any) (messageType string, requestID string, traceID string) {
|
||||
switch typed := req.(type) {
|
||||
case *gatewayv1.ExecuteCommandRequest:
|
||||
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
|
||||
case *gatewayv1.SubscribeEventsRequest:
|
||||
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
|
||||
default:
|
||||
return "", "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func grpcResultCode(resp any) string {
|
||||
typed, ok := resp.(*gatewayv1.ExecuteCommandResponse)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return typed.GetResultCode()
|
||||
}
|
||||
|
||||
func grpcOutcome(err error) (codes.Code, string, telemetry.EdgeOutcome) {
|
||||
switch {
|
||||
case err == nil:
|
||||
return codes.OK, "", telemetry.EdgeOutcomeSuccess
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return codes.Canceled, err.Error(), telemetry.EdgeOutcomeSuccess
|
||||
default:
|
||||
grpcStatus := status.Convert(err)
|
||||
return grpcStatus.Code(), grpcStatus.Message(), telemetry.OutcomeFromGRPCStatus(grpcStatus.Code(), grpcStatus.Message())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// payloadHashVerifyingService applies payload-hash verification after session
|
||||
// lookup and before any later auth or routing step runs.
|
||||
type payloadHashVerifyingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
// ExecuteCommand verifies req payload integrity before delegating to the
|
||||
// configured service implementation.
|
||||
func (s payloadHashVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := verifyPayloadHash(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents verifies req payload integrity before delegating to the
|
||||
// configured service implementation.
|
||||
func (s payloadHashVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := verifyPayloadHash(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newPayloadHashVerifyingService wraps delegate with the payload-hash
|
||||
// verification gate.
|
||||
func newPayloadHashVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
|
||||
return payloadHashVerifyingService{delegate: delegate}
|
||||
}
|
||||
|
||||
func verifyPayloadHash(ctx context.Context) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
err := authn.VerifyPayloadHash(envelope.PayloadBytes, envelope.PayloadHash)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, authn.ErrInvalidPayloadHash), errors.Is(err, authn.ErrPayloadHashMismatch):
|
||||
return status.Error(codes.InvalidArgument, err.Error())
|
||||
default:
|
||||
return status.Error(codes.Internal, "payload hash verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = payloadHashVerifyingService{}
|
||||
@@ -0,0 +1,125 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsPayloadHashWithInvalidLength(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
req.PayloadHash = []byte("short")
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsPayloadHashMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
sum := sha256.Sum256([]byte("other"))
|
||||
req.PayloadHash = sum[:]
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsPayloadHashWithInvalidLength(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
req.PayloadHash = []byte("short")
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsPayloadHashMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
sum := sha256.Sum256([]byte("other"))
|
||||
req.PayloadHash = sum[:]
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// NewFanOutPushStreamService constructs the authenticated SubscribeEvents tail
|
||||
// service that registers active streams in hub and forwards client-facing
|
||||
// events after the bootstrap event has been sent.
|
||||
func NewFanOutPushStreamService(hub *push.Hub, responseSigner authn.ResponseSigner, clk clock.Clock, logger *zap.Logger) gatewayv1.EdgeGatewayServer {
|
||||
if responseSigner == nil {
|
||||
responseSigner = unavailableResponseSigner{}
|
||||
}
|
||||
if clk == nil {
|
||||
clk = clock.System{}
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return fanOutPushStreamService{
|
||||
hub: hub,
|
||||
responseSigner: responseSigner,
|
||||
clock: clk,
|
||||
logger: logger.Named("push_stream"),
|
||||
}
|
||||
}
|
||||
|
||||
// fanOutPushStreamService owns the post-bootstrap authenticated push-stream
|
||||
// lifecycle backed by the in-memory push hub.
|
||||
type fanOutPushStreamService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
hub *push.Hub
|
||||
responseSigner authn.ResponseSigner
|
||||
clock clock.Clock
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// SubscribeEvents registers the verified stream in the push hub and forwards
|
||||
// matching client-facing events until the stream ends.
|
||||
func (s fanOutPushStreamService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
if s.hub == nil {
|
||||
return status.Error(codes.Internal, "push hub is unavailable")
|
||||
}
|
||||
|
||||
subscription, err := s.hub.Register(push.StreamBinding{
|
||||
UserID: binding.UserID,
|
||||
DeviceSessionID: binding.DeviceSessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "push stream registration failed")
|
||||
}
|
||||
defer subscription.Close()
|
||||
|
||||
openFields := []zap.Field{
|
||||
zap.String("component", "authenticated_grpc"),
|
||||
zap.String("transport", "grpc"),
|
||||
zap.String("rpc_method", authenticatedRPCSubscribeEvents),
|
||||
zap.String("message_type", binding.MessageType),
|
||||
zap.String("request_id", binding.RequestID),
|
||||
zap.String("trace_id", binding.TraceID),
|
||||
zap.String("device_session_id", binding.DeviceSessionID),
|
||||
zap.String("user_id", binding.UserID),
|
||||
}
|
||||
openFields = append(openFields, logging.TraceFieldsFromContext(stream.Context())...)
|
||||
s.logger.Info("push stream opened", openFields...)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
s.logger.Info("push stream closed", append(openFields, zap.String("edge_outcome", string(mapSubscriptionOutcome(stream.Context().Err()))))...)
|
||||
return stream.Context().Err()
|
||||
case <-subscription.Done():
|
||||
subscriptionErr := subscription.Err()
|
||||
s.logger.Warn("push stream closed", append(openFields,
|
||||
zap.String("edge_outcome", string(mapSubscriptionOutcome(subscriptionErr))),
|
||||
zap.String("reject_reason", string(mapSubscriptionOutcome(subscriptionErr))),
|
||||
)...)
|
||||
return mapSubscriptionError(subscriptionErr)
|
||||
case event := <-subscription.Events():
|
||||
signedEvent, err := s.buildGatewayEvent(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := stream.Send(signedEvent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s fanOutPushStreamService) buildGatewayEvent(event push.Event) (*gatewayv1.GatewayEvent, error) {
|
||||
timestampMS := s.clock.Now().UTC().UnixMilli()
|
||||
payloadHash := sha256.Sum256(event.PayloadBytes)
|
||||
|
||||
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
|
||||
EventType: event.EventType,
|
||||
EventID: event.EventID,
|
||||
TimestampMS: timestampMS,
|
||||
RequestID: event.RequestID,
|
||||
TraceID: event.TraceID,
|
||||
PayloadHash: payloadHash[:],
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
|
||||
}
|
||||
|
||||
return &gatewayv1.GatewayEvent{
|
||||
EventType: event.EventType,
|
||||
EventId: event.EventID,
|
||||
TimestampMs: timestampMS,
|
||||
PayloadBytes: bytes.Clone(event.PayloadBytes),
|
||||
PayloadHash: bytes.Clone(payloadHash[:]),
|
||||
Signature: signature,
|
||||
RequestId: event.RequestID,
|
||||
TraceId: event.TraceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mapSubscriptionError(err error) error {
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, push.ErrSubscriptionRevoked):
|
||||
return status.Error(codes.FailedPrecondition, "device session is revoked")
|
||||
case errors.Is(err, push.ErrSubscriptionOverflow):
|
||||
return status.Error(codes.ResourceExhausted, "push stream overflowed")
|
||||
case errors.Is(err, push.ErrHubShuttingDown):
|
||||
return status.Error(codes.Unavailable, "gateway is shutting down")
|
||||
default:
|
||||
return status.Error(codes.Internal, "push stream closed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
func mapSubscriptionOutcome(err error) telemetry.EdgeOutcome {
|
||||
switch {
|
||||
case err == nil:
|
||||
return telemetry.EdgeOutcomeSuccess
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return telemetry.EdgeOutcomeSuccess
|
||||
case errors.Is(err, push.ErrSubscriptionRevoked):
|
||||
return telemetry.EdgeOutcomeRevokedSession
|
||||
case errors.Is(err, push.ErrSubscriptionOverflow):
|
||||
return telemetry.EdgeOutcomeRateLimited
|
||||
case errors.Is(err, push.ErrHubShuttingDown):
|
||||
return telemetry.EdgeOutcomeGatewayShuttingDown
|
||||
default:
|
||||
return telemetry.EdgeOutcomeInternalError
|
||||
}
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = fanOutPushStreamService{}
|
||||
@@ -0,0 +1,164 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
gatewayfbs "galaxy/schema/fbs/gateway"
|
||||
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const serverTimeEventType = "gateway.server_time"
|
||||
|
||||
// authenticatedStreamBinding captures the verified identity bound to one
|
||||
// authenticated SubscribeEvents stream after the full ingress pipeline
|
||||
// succeeds.
|
||||
type authenticatedStreamBinding struct {
|
||||
UserID string
|
||||
DeviceSessionID string
|
||||
MessageType string
|
||||
RequestID string
|
||||
TraceID string
|
||||
}
|
||||
|
||||
// authenticatedStreamBindingFromContext returns the verified stream binding
|
||||
// previously attached to ctx by the authenticated push-stream service.
|
||||
func authenticatedStreamBindingFromContext(ctx context.Context) (authenticatedStreamBinding, bool) {
|
||||
if ctx == nil {
|
||||
return authenticatedStreamBinding{}, false
|
||||
}
|
||||
|
||||
binding, ok := ctx.Value(authenticatedStreamBindingContextKey{}).(authenticatedStreamBinding)
|
||||
if !ok {
|
||||
return authenticatedStreamBinding{}, false
|
||||
}
|
||||
|
||||
return binding, true
|
||||
}
|
||||
|
||||
// authenticatedPushStreamService owns SubscribeEvents bootstrap behavior:
|
||||
// bind the authenticated stream, send the initial signed server-time event,
|
||||
// and then hand the stream lifecycle to the configured tail delegate.
|
||||
type authenticatedPushStreamService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
tailDelegate gatewayv1.EdgeGatewayServer
|
||||
responseSigner authn.ResponseSigner
|
||||
clock clock.Clock
|
||||
}
|
||||
|
||||
// SubscribeEvents binds the verified stream identity, sends the initial signed
|
||||
// server-time event, and then delegates the remaining lifecycle.
|
||||
func (s authenticatedPushStreamService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(stream.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(stream.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
binding := authenticatedStreamBinding{
|
||||
UserID: record.UserID,
|
||||
DeviceSessionID: record.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
}
|
||||
boundStream := authenticatedStreamContextStream{
|
||||
ServerStreamingServer: stream,
|
||||
ctx: context.WithValue(
|
||||
stream.Context(),
|
||||
authenticatedStreamBindingContextKey{},
|
||||
binding,
|
||||
),
|
||||
}
|
||||
|
||||
serverTimeMS := s.clock.Now().UTC().UnixMilli()
|
||||
payloadBytes := buildServerTimeEventPayload(serverTimeMS)
|
||||
payloadHash := sha256.Sum256(payloadBytes)
|
||||
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
|
||||
EventType: serverTimeEventType,
|
||||
EventID: envelope.RequestID,
|
||||
TimestampMS: serverTimeMS,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
PayloadHash: payloadHash[:],
|
||||
})
|
||||
if err != nil {
|
||||
return status.Error(codes.Unavailable, "response signer is unavailable")
|
||||
}
|
||||
|
||||
if err := boundStream.Send(&gatewayv1.GatewayEvent{
|
||||
EventType: serverTimeEventType,
|
||||
EventId: envelope.RequestID,
|
||||
TimestampMs: serverTimeMS,
|
||||
PayloadBytes: bytes.Clone(payloadBytes),
|
||||
PayloadHash: bytes.Clone(payloadHash[:]),
|
||||
Signature: signature,
|
||||
RequestId: envelope.RequestID,
|
||||
TraceId: envelope.TraceID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.tailDelegate.SubscribeEvents(req, boundStream)
|
||||
}
|
||||
|
||||
func newAuthenticatedPushStreamService(tailDelegate gatewayv1.EdgeGatewayServer, responseSigner authn.ResponseSigner, clk clock.Clock) gatewayv1.EdgeGatewayServer {
|
||||
if tailDelegate == nil {
|
||||
tailDelegate = holdOpenSubscribeEventsService{}
|
||||
}
|
||||
|
||||
return authenticatedPushStreamService{
|
||||
tailDelegate: tailDelegate,
|
||||
responseSigner: responseSigner,
|
||||
clock: clk,
|
||||
}
|
||||
}
|
||||
|
||||
func buildServerTimeEventPayload(serverTimeMS int64) []byte {
|
||||
builder := flatbuffers.NewBuilder(32)
|
||||
gatewayfbs.ServerTimeEventStart(builder)
|
||||
gatewayfbs.ServerTimeEventAddServerTimeMs(builder, serverTimeMS)
|
||||
eventOffset := gatewayfbs.ServerTimeEventEnd(builder)
|
||||
gatewayfbs.FinishServerTimeEventBuffer(builder, eventOffset)
|
||||
|
||||
return bytes.Clone(builder.FinishedBytes())
|
||||
}
|
||||
|
||||
type authenticatedStreamBindingContextKey struct{}
|
||||
|
||||
type authenticatedStreamContextStream struct {
|
||||
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s authenticatedStreamContextStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
type holdOpenSubscribeEventsService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
}
|
||||
|
||||
func (holdOpenSubscribeEventsService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
<-stream.Context().Done()
|
||||
return stream.Context().Err()
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = authenticatedPushStreamService{}
|
||||
@@ -0,0 +1,286 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
authenticatedGRPCBaseBucketKeyPrefix = "authenticated_grpc/"
|
||||
|
||||
authenticatedGRPCIPBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "ip="
|
||||
authenticatedGRPCSessionBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "session="
|
||||
authenticatedGRPCUserBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "user="
|
||||
authenticatedGRPCMessageClassBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "message_class="
|
||||
|
||||
unknownAuthenticatedPeerIP = "unknown"
|
||||
|
||||
authenticatedRPCExecuteCommand = "ExecuteCommand"
|
||||
authenticatedRPCSubscribeEvents = "SubscribeEvents"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrAuthenticatedPolicyDenied reports that the authenticated request was
|
||||
// rejected by later edge policy after transport authenticity succeeded.
|
||||
ErrAuthenticatedPolicyDenied = errors.New("authenticated request rejected by edge policy")
|
||||
|
||||
// ErrAuthenticatedPolicyUnavailable reports that authenticated policy could
|
||||
// not be evaluated because its backing dependency is unavailable.
|
||||
ErrAuthenticatedPolicyUnavailable = errors.New("authenticated request policy is unavailable")
|
||||
)
|
||||
|
||||
// AuthenticatedRequestLimiter applies authenticated gRPC rate-limit policy to
|
||||
// one concrete bucket key.
|
||||
type AuthenticatedRequestLimiter interface {
|
||||
// Reserve evaluates key under policy and reports whether the request may
|
||||
// proceed immediately.
|
||||
Reserve(key string, policy ratelimit.Policy) ratelimit.Decision
|
||||
}
|
||||
|
||||
// AuthenticatedRequest describes the authenticated request metadata exposed to
|
||||
// the edge-policy hook.
|
||||
type AuthenticatedRequest struct {
|
||||
// RPCMethod identifies the public gRPC method being processed.
|
||||
RPCMethod string
|
||||
|
||||
// PeerIP is the transport peer IP derived from the gRPC connection.
|
||||
PeerIP string
|
||||
|
||||
// MessageClass is the stable rate-limit and policy class. The gateway uses
|
||||
// the full message_type literal because the v1 transport does not yet define
|
||||
// a coarser authenticated class taxonomy.
|
||||
MessageClass string
|
||||
|
||||
// Envelope contains the verified transport envelope fields used by later
|
||||
// edge policy.
|
||||
Envelope AuthenticatedRequestEnvelope
|
||||
|
||||
// Session contains the authenticated identity resolved from SessionCache.
|
||||
Session session.Record
|
||||
}
|
||||
|
||||
// AuthenticatedRequestEnvelope describes the verified request envelope fields
|
||||
// exposed to the edge-policy hook.
|
||||
type AuthenticatedRequestEnvelope struct {
|
||||
// ProtocolVersion is the supported transport protocol version literal.
|
||||
ProtocolVersion string
|
||||
|
||||
// DeviceSessionID is the authenticated device-session identifier.
|
||||
DeviceSessionID string
|
||||
|
||||
// MessageType is the verified downstream routing key supplied by the client.
|
||||
MessageType string
|
||||
|
||||
// TimestampMS is the client timestamp that already passed freshness checks.
|
||||
TimestampMS int64
|
||||
|
||||
// RequestID is the authenticated transport request identifier.
|
||||
RequestID string
|
||||
|
||||
// TraceID is the optional client-supplied correlation identifier.
|
||||
TraceID string
|
||||
}
|
||||
|
||||
// AuthenticatedRequestPolicy evaluates later authenticated edge policy after
|
||||
// transport authenticity and rate-limit checks succeed.
|
||||
type AuthenticatedRequestPolicy interface {
|
||||
// Evaluate returns nil when the authenticated request may proceed. It should
|
||||
// wrap ErrAuthenticatedPolicyDenied for stable reject mapping and
|
||||
// ErrAuthenticatedPolicyUnavailable when its backing dependency is
|
||||
// temporarily unavailable.
|
||||
Evaluate(ctx context.Context, request AuthenticatedRequest) error
|
||||
}
|
||||
|
||||
type authenticatedRateLimitService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
limiter AuthenticatedRequestLimiter
|
||||
policy AuthenticatedRequestPolicy
|
||||
cfg config.AuthenticatedGRPCAntiAbuseConfig
|
||||
}
|
||||
|
||||
// ExecuteCommand applies authenticated rate limits and edge policy before
|
||||
// delegating to the configured service implementation.
|
||||
func (s authenticatedRateLimitService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := s.applyRateLimitsAndPolicy(ctx, authenticatedRPCExecuteCommand); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents applies authenticated rate limits and edge policy before
|
||||
// delegating to the configured service implementation.
|
||||
func (s authenticatedRateLimitService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := s.applyRateLimitsAndPolicy(stream.Context(), authenticatedRPCSubscribeEvents); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newAuthenticatedRateLimitService wraps delegate with the authenticated
|
||||
// rate-limit and edge-policy gate.
|
||||
func newAuthenticatedRateLimitService(delegate gatewayv1.EdgeGatewayServer, limiter AuthenticatedRequestLimiter, policy AuthenticatedRequestPolicy, cfg config.AuthenticatedGRPCAntiAbuseConfig) gatewayv1.EdgeGatewayServer {
|
||||
return authenticatedRateLimitService{
|
||||
delegate: delegate,
|
||||
limiter: limiter,
|
||||
policy: policy,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s authenticatedRateLimitService) applyRateLimitsAndPolicy(ctx context.Context, rpcMethod string) error {
|
||||
request, err := authenticatedRequestFromContext(ctx, rpcMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.applyRateLimits(request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.applyPolicy(ctx, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s authenticatedRateLimitService) applyRateLimits(request AuthenticatedRequest) error {
|
||||
checks := []struct {
|
||||
key string
|
||||
policy config.AuthenticatedRateLimitConfig
|
||||
}{
|
||||
{
|
||||
key: authenticatedGRPCIPBucketKey(request.PeerIP),
|
||||
policy: s.cfg.IP,
|
||||
},
|
||||
{
|
||||
key: authenticatedGRPCSessionBucketKey(request.Envelope.DeviceSessionID),
|
||||
policy: s.cfg.Session,
|
||||
},
|
||||
{
|
||||
key: authenticatedGRPCUserBucketKey(request.Session.UserID),
|
||||
policy: s.cfg.User,
|
||||
},
|
||||
{
|
||||
key: authenticatedGRPCMessageClassBucketKey(request.MessageClass),
|
||||
policy: s.cfg.MessageClass,
|
||||
},
|
||||
}
|
||||
|
||||
for _, check := range checks {
|
||||
decision := s.limiter.Reserve(check.key, ratelimit.Policy{
|
||||
Requests: check.policy.Requests,
|
||||
Window: check.policy.Window,
|
||||
Burst: check.policy.Burst,
|
||||
})
|
||||
if !decision.Allowed {
|
||||
return status.Error(codes.ResourceExhausted, "authenticated request rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s authenticatedRateLimitService) applyPolicy(ctx context.Context, request AuthenticatedRequest) error {
|
||||
err := s.policy.Evaluate(ctx, request)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, ErrAuthenticatedPolicyDenied):
|
||||
return status.Error(codes.PermissionDenied, "authenticated request rejected by edge policy")
|
||||
case errors.Is(err, ErrAuthenticatedPolicyUnavailable):
|
||||
return status.Error(codes.Unavailable, "authenticated request policy is unavailable")
|
||||
default:
|
||||
return status.Error(codes.Internal, "authenticated request policy evaluation failed")
|
||||
}
|
||||
}
|
||||
|
||||
func authenticatedRequestFromContext(ctx context.Context, rpcMethod string) (AuthenticatedRequest, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
if !ok {
|
||||
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
return AuthenticatedRequest{
|
||||
RPCMethod: rpcMethod,
|
||||
PeerIP: peerIPFromContext(ctx),
|
||||
MessageClass: authenticatedMessageClass(envelope.MessageType),
|
||||
Envelope: AuthenticatedRequestEnvelope{
|
||||
ProtocolVersion: envelope.ProtocolVersion,
|
||||
DeviceSessionID: envelope.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
TimestampMS: envelope.TimestampMS,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
},
|
||||
Session: record,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authenticatedGRPCIPBucketKey(peerIP string) string {
|
||||
return authenticatedGRPCIPBucketKeySegment + peerIP
|
||||
}
|
||||
|
||||
func authenticatedGRPCSessionBucketKey(deviceSessionID string) string {
|
||||
return authenticatedGRPCSessionBucketKeySegment + deviceSessionID
|
||||
}
|
||||
|
||||
func authenticatedGRPCUserBucketKey(userID string) string {
|
||||
return authenticatedGRPCUserBucketKeySegment + userID
|
||||
}
|
||||
|
||||
func authenticatedGRPCMessageClassBucketKey(messageClass string) string {
|
||||
return authenticatedGRPCMessageClassBucketKeySegment + messageClass
|
||||
}
|
||||
|
||||
func authenticatedMessageClass(messageType string) string {
|
||||
return messageType
|
||||
}
|
||||
|
||||
func peerIPFromContext(ctx context.Context) string {
|
||||
peerInfo, ok := peer.FromContext(ctx)
|
||||
if !ok || peerInfo.Addr == nil {
|
||||
return unknownAuthenticatedPeerIP
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(peerInfo.Addr.String())
|
||||
if value == "" {
|
||||
return unknownAuthenticatedPeerIP
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(value)
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
type noopAuthenticatedRequestPolicy struct{}
|
||||
|
||||
func (noopAuthenticatedRequestPolicy) Evaluate(context.Context, AuthenticatedRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = authenticatedRateLimitService{}
|
||||
@@ -0,0 +1,497 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
"galaxy/gateway/internal/restapi"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRateLimitsByIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsBySession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsByUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{
|
||||
"device-session-1": "user-shared",
|
||||
"device-session-2": "user-shared",
|
||||
"device-session-3": "user-other",
|
||||
}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{
|
||||
"device-session-1": "user-1",
|
||||
"device-session-2": "user-2",
|
||||
}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
policy := &recordingAuthenticatedRequestPolicy{}
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Policy: policy,
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, policy.requests, 1)
|
||||
assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod)
|
||||
assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP)
|
||||
assert.Equal(t, "fleet.move", policy.requests[0].MessageClass)
|
||||
assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID)
|
||||
assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID)
|
||||
assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID)
|
||||
assert.Equal(t, "user-123", policy.requests[0].Session.UserID)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error {
|
||||
return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied)
|
||||
}),
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rejected by edge policy", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sharedLimiter := ratelimit.NewInMemory()
|
||||
|
||||
publicCfg := config.DefaultPublicHTTPConfig()
|
||||
publicCfg.Addr = unusedTCPAddr(t)
|
||||
publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
|
||||
grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.Addr = unusedTCPAddr(t)
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
})
|
||||
|
||||
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
|
||||
AuthService: staticAuthServiceClient{},
|
||||
Limiter: publicLimiterAdapter{limiter: sharedLimiter},
|
||||
})
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
grpcServer := NewServer(grpcCfg, ServerDependencies{
|
||||
Service: delegate,
|
||||
Router: executeCommandAdapterRouter{service: delegate},
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Limiter: sharedLimiter,
|
||||
Clock: fixedClock{now: testCurrentTime},
|
||||
})
|
||||
|
||||
application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
runGateway := runningGateway{cancel: cancel, resultCh: resultCh}
|
||||
defer runGateway.stop(t)
|
||||
|
||||
waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz")
|
||||
addr := waitForListenAddr(t, grpcServer)
|
||||
|
||||
firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
|
||||
secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
|
||||
|
||||
assert.Equal(t, http.StatusOK, firstPublic.StatusCode)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode)
|
||||
require.NoError(t, firstPublic.Body.Close())
|
||||
require.NoError(t, secondPublic.Body.Close())
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig {
|
||||
cfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
cfg.Addr = "127.0.0.1:0"
|
||||
cfg.FreshnessWindow = testFreshnessWindow
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
|
||||
if mutate != nil {
|
||||
mutate(&cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest {
|
||||
req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID)
|
||||
req.MessageType = messageType
|
||||
req.Signature = signRequest(
|
||||
req.GetProtocolVersion(),
|
||||
req.GetDeviceSessionId(),
|
||||
req.GetMessageType(),
|
||||
req.GetTimestampMs(),
|
||||
req.GetRequestId(),
|
||||
req.GetPayloadHash(),
|
||||
)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func userMappedSessionCache(users map[string]string) staticSessionCache {
|
||||
return staticSessionCache{
|
||||
lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) {
|
||||
userID, ok := users[deviceSessionID]
|
||||
if !ok {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
}
|
||||
|
||||
record := newActiveSessionRecordWithSessionID(deviceSessionID)
|
||||
record.UserID = userID
|
||||
return record, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error
|
||||
|
||||
func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error {
|
||||
return f(ctx, request)
|
||||
}
|
||||
|
||||
type recordingAuthenticatedRequestPolicy struct {
|
||||
requests []AuthenticatedRequest
|
||||
}
|
||||
|
||||
func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error {
|
||||
p.requests = append(p.requests, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
type publicLimiterAdapter struct {
|
||||
limiter ratelimit.Limiter
|
||||
}
|
||||
|
||||
func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision {
|
||||
decision := a.limiter.Reserve(key, ratelimit.Policy{
|
||||
Requests: policy.Requests,
|
||||
Window: policy.Window,
|
||||
Burst: policy.Burst,
|
||||
})
|
||||
|
||||
return restapi.PublicRateLimitDecision{
|
||||
Allowed: decision.Allowed,
|
||||
RetryAfter: decision.RetryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
type staticAuthServiceClient struct{}
|
||||
|
||||
func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) {
|
||||
return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil
|
||||
}
|
||||
|
||||
func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) {
|
||||
return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil
|
||||
}
|
||||
|
||||
func waitForHTTPHealthz(t *testing.T, url string) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||
require.Eventually(t, func() bool {
|
||||
response, err := client.Get(url)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
require.NoError(t, response.Body.Close())
|
||||
|
||||
return response.StatusCode == http.StatusOK
|
||||
}, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url)
|
||||
}
|
||||
|
||||
func sendPublicAuthRequest(t *testing.T, url string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`))
|
||||
require.NoError(t, err)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := (&http.Client{Timeout: time.Second}).Do(request)
|
||||
require.NoError(t, err)
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
// Package grpcapi exposes the authenticated gRPC surface of the gateway.
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
"galaxy/gateway/internal/replay"
|
||||
"galaxy/gateway/internal/session"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// ServerDependencies describes the optional collaborators used by the
|
||||
// authenticated gRPC server. The zero value is valid and keeps the process
|
||||
// runnable with the built-in unimplemented service stub.
|
||||
type ServerDependencies struct {
|
||||
// Service optionally handles the post-bootstrap SubscribeEvents lifecycle
|
||||
// after the initial authenticated service event has been sent. When nil, the
|
||||
// gateway keeps authenticated SubscribeEvents streams open until the client
|
||||
// cancels them, the server shuts down, or a later stream send fails.
|
||||
Service gatewayv1.EdgeGatewayServer
|
||||
|
||||
// Router resolves the exact downstream unary client for the verified
|
||||
// message_type value. When nil, the authenticated unary surface uses an
|
||||
// empty exact-match router and returns UNIMPLEMENTED for unrouted commands.
|
||||
Router downstream.Router
|
||||
|
||||
// ResponseSigner signs authenticated unary responses after downstream
|
||||
// execution succeeds. When nil, the unary surface fails closed once it needs
|
||||
// to sign a routed response.
|
||||
ResponseSigner authn.ResponseSigner
|
||||
|
||||
// SessionCache resolves authenticated device sessions after the envelope
|
||||
// gate succeeds. When nil, the authenticated gRPC surface remains runnable
|
||||
// but valid envelopes fail closed as session-cache unavailable.
|
||||
SessionCache session.Cache
|
||||
|
||||
// Clock provides current server time for freshness checks. When nil, the
|
||||
// authenticated gRPC surface uses the system clock.
|
||||
Clock clock.Clock
|
||||
|
||||
// ReplayStore reserves authenticated request identifiers after signature
|
||||
// verification. When nil, valid requests fail closed as replay-store
|
||||
// unavailable.
|
||||
ReplayStore replay.Store
|
||||
|
||||
// Limiter applies authenticated rate limits after the request passes the
|
||||
// transport authenticity checks. When nil, the authenticated gRPC surface
|
||||
// uses a process-local in-memory limiter.
|
||||
Limiter AuthenticatedRequestLimiter
|
||||
|
||||
// Policy evaluates later authenticated edge policy after rate limits pass.
|
||||
// When nil, the authenticated gRPC surface applies a no-op allow policy.
|
||||
Policy AuthenticatedRequestPolicy
|
||||
|
||||
// Logger writes structured logs for authenticated gRPC traffic.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Telemetry records low-cardinality gRPC metrics.
|
||||
Telemetry *telemetry.Runtime
|
||||
|
||||
// PushHub is the active authenticated push-stream hub. When present, the
|
||||
// server closes active streams before GracefulStop during shutdown.
|
||||
PushHub *push.Hub
|
||||
}
|
||||
|
||||
// Server owns the authenticated gRPC listener exposed by the gateway.
|
||||
type Server struct {
|
||||
cfg config.AuthenticatedGRPCConfig
|
||||
service gatewayv1.EdgeGatewayServer
|
||||
logger *zap.Logger
|
||||
pushHub *push.Hub
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs an authenticated gRPC server for the supplied listener
|
||||
// configuration and dependency bundle. Nil dependencies are replaced with safe
|
||||
// defaults so the gateway can expose the documented transport surface with the
|
||||
// full auth pipeline wired from built-in fallbacks.
|
||||
func NewServer(cfg config.AuthenticatedGRPCConfig, deps ServerDependencies) *Server {
|
||||
deps = normalizeServerDependencies(deps)
|
||||
|
||||
finalService := newCommandRoutingService(
|
||||
newAuthenticatedPushStreamService(deps.Service, deps.ResponseSigner, deps.Clock),
|
||||
deps.Router,
|
||||
deps.ResponseSigner,
|
||||
deps.Clock,
|
||||
cfg.DownstreamTimeout,
|
||||
)
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
service: newEnvelopeValidatingService(
|
||||
newSessionLookupService(
|
||||
newPayloadHashVerifyingService(
|
||||
newSignatureVerifyingService(
|
||||
newFreshnessAndReplayService(
|
||||
newAuthenticatedRateLimitService(
|
||||
finalService,
|
||||
deps.Limiter,
|
||||
deps.Policy,
|
||||
cfg.AntiAbuse,
|
||||
),
|
||||
deps.Clock,
|
||||
deps.ReplayStore,
|
||||
cfg.FreshnessWindow,
|
||||
),
|
||||
),
|
||||
),
|
||||
deps.SessionCache,
|
||||
),
|
||||
),
|
||||
logger: deps.Logger.Named("authenticated_grpc"),
|
||||
pushHub: deps.PushHub,
|
||||
metrics: deps.Telemetry,
|
||||
}
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the authenticated gRPC surface
|
||||
// until Shutdown closes the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run authenticated gRPC server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run authenticated gRPC server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
grpc.ConnectionTimeout(s.cfg.ConnectionTimeout),
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
grpc.ChainUnaryInterceptor(observabilityUnaryInterceptor(s.logger, s.metrics)),
|
||||
grpc.ChainStreamInterceptor(observabilityStreamInterceptor(s.logger, s.metrics)),
|
||||
)
|
||||
gatewayv1.RegisterEdgeGatewayServer(grpcServer, s.service)
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = grpcServer
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("authenticated gRPC server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = grpcServer.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, grpc.ErrServerStopped):
|
||||
s.logger.Info("authenticated gRPC server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run authenticated gRPC server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the authenticated gRPC server within ctx. When the
|
||||
// graceful stop exceeds ctx, the server is force-stopped before returning the
|
||||
// timeout to the caller.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown authenticated gRPC server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.pushHub != nil {
|
||||
s.pushHub.Shutdown()
|
||||
}
|
||||
|
||||
stopped := make(chan struct{})
|
||||
go func() {
|
||||
server.GracefulStop()
|
||||
close(stopped)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopped:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
server.Stop()
|
||||
<-stopped
|
||||
return fmt.Errorf("shutdown authenticated gRPC server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenAddr() string {
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
|
||||
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
|
||||
if deps.Router == nil {
|
||||
deps.Router = downstream.NewStaticRouter(nil)
|
||||
}
|
||||
if deps.ResponseSigner == nil {
|
||||
deps.ResponseSigner = unavailableResponseSigner{}
|
||||
}
|
||||
if deps.SessionCache == nil {
|
||||
deps.SessionCache = unavailableSessionCache{}
|
||||
}
|
||||
if deps.Clock == nil {
|
||||
deps.Clock = clock.System{}
|
||||
}
|
||||
if deps.ReplayStore == nil {
|
||||
deps.ReplayStore = unavailableReplayStore{}
|
||||
}
|
||||
if deps.Limiter == nil {
|
||||
deps.Limiter = ratelimit.NewInMemory()
|
||||
}
|
||||
if deps.Policy == nil {
|
||||
deps.Policy = noopAuthenticatedRequestPolicy{}
|
||||
}
|
||||
if deps.Logger == nil {
|
||||
deps.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return deps
|
||||
}
|
||||
@@ -0,0 +1,332 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: "v2",
|
||||
DeviceSessionId: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMs: 123456789,
|
||||
RequestId: "request-123",
|
||||
PayloadBytes: []byte("payload"),
|
||||
PayloadHash: []byte("hash"),
|
||||
Signature: []byte("signature"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unimplemented, status.Code(err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
recvResult := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvResult <- recvErr
|
||||
}()
|
||||
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-recvResult:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation")
|
||||
|
||||
cancel()
|
||||
|
||||
var recvErr error
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case recvErr = <-recvResult:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation")
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.Canceled, status.Code(recvErr))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestServerLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
require.NoError(t, conn.Close())
|
||||
|
||||
runGateway.stop(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
type runningGateway struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) {
|
||||
t.Helper()
|
||||
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = "127.0.0.1:0"
|
||||
grpcCfg.FreshnessWindow = testFreshnessWindow
|
||||
|
||||
return newTestGatewayWithGRPCConfig(t, grpcCfg, deps)
|
||||
}
|
||||
|
||||
func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
}
|
||||
|
||||
if deps.Clock == nil {
|
||||
deps.Clock = fixedClock{now: testCurrentTime}
|
||||
}
|
||||
if deps.ResponseSigner == nil {
|
||||
deps.ResponseSigner = newTestResponseSigner()
|
||||
}
|
||||
if deps.Router == nil && deps.Service != nil {
|
||||
deps.Router = executeCommandAdapterRouter{service: deps.Service}
|
||||
}
|
||||
|
||||
server := NewServer(cfg.AuthenticatedGRPC, deps)
|
||||
application := app.New(cfg, server)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
return server, runningGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (g runningGateway) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
g.cancel()
|
||||
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case err = <-g.resultCh:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func waitForListenAddr(t *testing.T, server *Server) string {
|
||||
t.Helper()
|
||||
|
||||
var addr string
|
||||
require.Eventually(t, func() bool {
|
||||
addr = server.listenAddr()
|
||||
return addr != ""
|
||||
}, time.Second, 10*time.Millisecond, "server did not start listening")
|
||||
return addr
|
||||
}
|
||||
|
||||
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return conn
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// resolvedSessionFromContext returns the session record previously attached to
|
||||
// ctx by the session-lookup gateway wrapper.
|
||||
func resolvedSessionFromContext(ctx context.Context) (session.Record, bool) {
|
||||
if ctx == nil {
|
||||
return session.Record{}, false
|
||||
}
|
||||
|
||||
record, ok := ctx.Value(resolvedSessionContextKey{}).(session.Record)
|
||||
if !ok {
|
||||
return session.Record{}, false
|
||||
}
|
||||
|
||||
return cloneSessionRecord(record), true
|
||||
}
|
||||
|
||||
// sessionLookupService resolves the authenticated session from SessionCache
|
||||
// after envelope parsing succeeds and before later auth steps run.
|
||||
type sessionLookupService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
cache session.Cache
|
||||
}
|
||||
|
||||
// ExecuteCommand resolves the cached session for req and only then forwards it
|
||||
// to the configured delegate with the resolved session attached to ctx.
|
||||
func (s sessionLookupService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
record, err := s.lookupSession(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(context.WithValue(ctx, resolvedSessionContextKey{}, cloneSessionRecord(record)), req)
|
||||
}
|
||||
|
||||
// SubscribeEvents resolves the cached session for req and only then forwards it
|
||||
// to the configured delegate with the resolved session attached to the stream
|
||||
// context.
|
||||
func (s sessionLookupService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
record, err := s.lookupSession(stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, resolvedSessionContextStream{
|
||||
ServerStreamingServer: stream,
|
||||
ctx: context.WithValue(stream.Context(), resolvedSessionContextKey{}, cloneSessionRecord(record)),
|
||||
})
|
||||
}
|
||||
|
||||
// newSessionLookupService wraps delegate with the session-cache lookup gate.
|
||||
func newSessionLookupService(delegate gatewayv1.EdgeGatewayServer, cache session.Cache) gatewayv1.EdgeGatewayServer {
|
||||
return sessionLookupService{
|
||||
delegate: delegate,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (s sessionLookupService) lookupSession(ctx context.Context) (session.Record, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return session.Record{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, err := s.cache.Lookup(ctx, envelope.DeviceSessionID)
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, session.ErrNotFound):
|
||||
return session.Record{}, status.Error(codes.Unauthenticated, "unknown device session")
|
||||
default:
|
||||
return session.Record{}, status.Error(codes.Unavailable, "session cache is unavailable")
|
||||
}
|
||||
|
||||
if record.Status == session.StatusRevoked {
|
||||
return session.Record{}, status.Error(codes.FailedPrecondition, "device session is revoked")
|
||||
}
|
||||
|
||||
return cloneSessionRecord(record), nil
|
||||
}
|
||||
|
||||
func cloneSessionRecord(record session.Record) session.Record {
|
||||
cloned := record
|
||||
if record.RevokedAtMS != nil {
|
||||
value := *record.RevokedAtMS
|
||||
cloned.RevokedAtMS = &value
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
type resolvedSessionContextKey struct{}
|
||||
|
||||
type resolvedSessionContextStream struct {
|
||||
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s resolvedSessionContextStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
type unavailableSessionCache struct{}
|
||||
|
||||
func (unavailableSessionCache) Lookup(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, errors.New("session cache is unavailable")
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = sessionLookupService{}
|
||||
@@ -0,0 +1,294 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsUnknownSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "unknown device session", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsUnknownSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "unknown device session", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsRevokedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsRevokedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsSessionCacheUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsSessionCacheUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandAttachesResolvedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, newActiveSessionRecord(), record)
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAttachesResolvedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
record, ok := resolvedSessionFromContext(stream.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, newActiveSessionRecord(), record)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, authenticatedStreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "gateway.subscribe",
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
}, binding)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
type staticSessionCache struct {
|
||||
lookupFunc func(context.Context, string) (session.Record, error)
|
||||
}
|
||||
|
||||
func (c staticSessionCache) Lookup(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
||||
return c.lookupFunc(ctx, deviceSessionID)
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// signatureVerifyingService applies client-signature verification after
|
||||
// payload integrity checks and before later auth or routing steps run.
|
||||
type signatureVerifyingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
// ExecuteCommand verifies req client signature before delegating to the
|
||||
// configured service implementation.
|
||||
func (s signatureVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := verifyRequestSignature(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents verifies req client signature before delegating to the
|
||||
// configured service implementation.
|
||||
func (s signatureVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := verifyRequestSignature(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newSignatureVerifyingService wraps delegate with the client-signature
|
||||
// verification gate.
|
||||
func newSignatureVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
|
||||
return signatureVerifyingService{delegate: delegate}
|
||||
}
|
||||
|
||||
func verifyRequestSignature(ctx context.Context) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
err := authn.VerifyRequestSignature(record.ClientPublicKey, envelope.Signature, authn.RequestSigningFields{
|
||||
ProtocolVersion: envelope.ProtocolVersion,
|
||||
DeviceSessionID: envelope.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
TimestampMS: envelope.TimestampMS,
|
||||
RequestID: envelope.RequestID,
|
||||
PayloadHash: envelope.PayloadHash,
|
||||
})
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, authn.ErrInvalidClientPublicKey):
|
||||
return status.Error(codes.Unavailable, "session cache is unavailable")
|
||||
case errors.Is(err, authn.ErrInvalidRequestSignature):
|
||||
return status.Error(codes.Unauthenticated, "invalid request signature")
|
||||
default:
|
||||
return status.Error(codes.Internal, "request signature verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = signatureVerifyingService{}
|
||||
@@ -0,0 +1,188 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsInvalidSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
req.Signature[0] ^= 0xff
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsWrongKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsInvalidCachedPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = "%%%not-base64%%%"
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsInvalidSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
req.Signature[0] ^= 0xff
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsWrongKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsInvalidCachedPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = "%%%not-base64%%%"
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
gatewayfbs "galaxy/schema/fbs/gateway"
|
||||
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
testCurrentTime = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
|
||||
testFreshnessWindow = 5 * time.Minute
|
||||
)
|
||||
|
||||
func newValidExecuteCommandRequest() *gatewayv1.ExecuteCommandRequest {
|
||||
return newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-123")
|
||||
}
|
||||
|
||||
func newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.ExecuteCommandRequest {
|
||||
return newValidExecuteCommandRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
|
||||
}
|
||||
|
||||
func newValidExecuteCommandRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.ExecuteCommandRequest {
|
||||
payloadBytes := []byte("payload")
|
||||
payloadHash := sha256.Sum256(payloadBytes)
|
||||
|
||||
req := &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: supportedProtocolVersion,
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "fleet.move",
|
||||
TimestampMs: timestampMS,
|
||||
RequestId: requestID,
|
||||
PayloadBytes: payloadBytes,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: "trace-123",
|
||||
}
|
||||
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func newValidSubscribeEventsRequest() *gatewayv1.SubscribeEventsRequest {
|
||||
return newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-123")
|
||||
}
|
||||
|
||||
func newValidSubscribeEventsRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
|
||||
return newValidSubscribeEventsRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
|
||||
}
|
||||
|
||||
func newValidSubscribeEventsRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.SubscribeEventsRequest {
|
||||
payloadHash := sha256.Sum256(nil)
|
||||
|
||||
req := &gatewayv1.SubscribeEventsRequest{
|
||||
ProtocolVersion: supportedProtocolVersion,
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "gateway.subscribe",
|
||||
TimestampMs: timestampMS,
|
||||
RequestId: requestID,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: "trace-123",
|
||||
}
|
||||
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func newActiveSessionRecord() session.Record {
|
||||
return newActiveSessionRecordWithSessionID("device-session-123")
|
||||
}
|
||||
|
||||
func newActiveSessionRecordWithSessionID(deviceSessionID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: testClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func newRevokedSessionRecord() session.Record {
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
return session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: testClientPublicKeyBase64(),
|
||||
Status: session.StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}
|
||||
}
|
||||
|
||||
func alternateTestClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(newTestPrivateKey("alternate").Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func testClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(newTestPrivateKey("primary").Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func signRequest(protocolVersion, deviceSessionID, messageType string, timestampMS int64, requestID string, payloadHash []byte) []byte {
|
||||
return ed25519.Sign(newTestPrivateKey("primary"), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: protocolVersion,
|
||||
DeviceSessionID: deviceSessionID,
|
||||
MessageType: messageType,
|
||||
TimestampMS: timestampMS,
|
||||
RequestID: requestID,
|
||||
PayloadHash: payloadHash,
|
||||
}))
|
||||
}
|
||||
|
||||
func newTestPrivateKey(label string) ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-grpcapi-signature-test-" + label))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func newTestEd25519ResponseSigner() *authn.Ed25519ResponseSigner {
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: mustMarshalPKCS8PrivateKey(newTestPrivateKey("response-signer")),
|
||||
})
|
||||
|
||||
signer, err := authn.ParseEd25519ResponseSignerPEM(pemBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return signer
|
||||
}
|
||||
|
||||
func newTestResponseSigner() authn.ResponseSigner {
|
||||
return newTestEd25519ResponseSigner()
|
||||
}
|
||||
|
||||
func newTestResponseSignerPublicKey() ed25519.PublicKey {
|
||||
return newTestEd25519ResponseSigner().PublicKey()
|
||||
}
|
||||
|
||||
func mustMarshalPKCS8PrivateKey(privateKey ed25519.PrivateKey) []byte {
|
||||
encoded, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoded
|
||||
}
|
||||
|
||||
type fixedClock struct {
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (c fixedClock) Now() time.Time {
|
||||
return c.now
|
||||
}
|
||||
|
||||
func recvBootstrapEvent(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
t.Helper()
|
||||
|
||||
event, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
|
||||
return event
|
||||
}
|
||||
|
||||
func subscribeEventsError(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error {
|
||||
t.Helper()
|
||||
|
||||
stream, err := client.SubscribeEvents(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = stream.Recv()
|
||||
return err
|
||||
}
|
||||
|
||||
func assertServerTimeBootstrapEvent(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, event *gatewayv1.GatewayEvent, publicKey ed25519.PublicKey, wantRequestID string, wantTraceID string, wantTimestampMS int64) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, serverTimeEventType, event.GetEventType())
|
||||
assert.Equal(t, wantRequestID, event.GetEventId())
|
||||
assert.Equal(t, wantRequestID, event.GetRequestId())
|
||||
assert.Equal(t, wantTraceID, event.GetTraceId())
|
||||
assert.Equal(t, wantTimestampMS, event.GetTimestampMs())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(publicKey, event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
payload := gatewayfbs.GetRootAsServerTimeEvent(event.GetPayloadBytes(), flatbuffers.UOffsetT(0))
|
||||
assert.Equal(t, wantTimestampMS, payload.ServerTimeMs())
|
||||
}
|
||||
|
||||
type staticReplayStore struct {
|
||||
reserveFunc func(context.Context, string, string, time.Duration) error
|
||||
}
|
||||
|
||||
func (s staticReplayStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
if s.reserveFunc != nil {
|
||||
return s.reserveFunc(ctx, deviceSessionID, requestID, ttl)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type executeCommandAdapterRouter struct {
|
||||
service gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
func (r executeCommandAdapterRouter) Route(string) (downstream.Client, error) {
|
||||
return executeCommandAdapterClient{service: r.service}, nil
|
||||
}
|
||||
|
||||
type executeCommandAdapterClient struct {
|
||||
service gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
func (c executeCommandAdapterClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
response, err := c.service.ExecuteCommand(ctx, &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
DeviceSessionId: command.DeviceSessionID,
|
||||
MessageType: command.MessageType,
|
||||
TimestampMs: command.TimestampMS,
|
||||
RequestId: command.RequestID,
|
||||
PayloadBytes: command.PayloadBytes,
|
||||
TraceId: command.TraceID,
|
||||
})
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, err
|
||||
}
|
||||
|
||||
resultCode := response.GetResultCode()
|
||||
if resultCode == "" {
|
||||
resultCode = "ok"
|
||||
}
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: resultCode,
|
||||
PayloadBytes: response.GetPayloadBytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type recordingDownstreamClient struct {
|
||||
executeCalls int
|
||||
commands []downstream.AuthenticatedCommand
|
||||
executeFunc func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error)
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
c.executeCalls++
|
||||
c.commands = append(c.commands, downstream.AuthenticatedCommand{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
UserID: command.UserID,
|
||||
DeviceSessionID: command.DeviceSessionID,
|
||||
MessageType: command.MessageType,
|
||||
TimestampMS: command.TimestampMS,
|
||||
RequestID: command.RequestID,
|
||||
TraceID: command.TraceID,
|
||||
PayloadBytes: append([]byte(nil), command.PayloadBytes...),
|
||||
})
|
||||
if c.executeFunc != nil {
|
||||
return c.executeFunc(ctx, command)
|
||||
}
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "ok",
|
||||
PayloadBytes: []byte("response"),
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user