feat: edge gateway service

This commit is contained in:
Ilia Denisov
2026-04-02 19:18:42 +02:00
committed by GitHub
parent 8cde99936c
commit 436c97a38b
95 changed files with 20504 additions and 57 deletions
+145
View File
@@ -0,0 +1,145 @@
package grpcapi
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"strings"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/downstream"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// commandRoutingService translates the verified authenticated request context
// into an internal downstream command and signs successful unary responses.
type commandRoutingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
subscribeDelegate gatewayv1.EdgeGatewayServer
router downstream.Router
responseSigner authn.ResponseSigner
clock clock.Clock
downstreamTimeout time.Duration
}
// ExecuteCommand builds a verified downstream command, routes it by exact
// message_type, executes it, and signs the resulting unary response.
func (s commandRoutingService) ExecuteCommand(ctx context.Context, _ *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
command, err := authenticatedCommandFromContext(ctx)
if err != nil {
return nil, err
}
client, err := s.router.Route(command.MessageType)
switch {
case err == nil:
case errors.Is(err, downstream.ErrRouteNotFound):
return nil, status.Error(codes.Unimplemented, "message_type is not routed")
case errors.Is(err, downstream.ErrDownstreamUnavailable):
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
default:
return nil, status.Error(codes.Internal, "downstream route resolution failed")
}
downstreamCtx, cancel := context.WithTimeout(ctx, s.downstreamTimeout)
defer cancel()
result, err := client.ExecuteCommand(downstreamCtx, command)
switch {
case err == nil:
case errors.Is(err, downstream.ErrDownstreamUnavailable),
errors.Is(err, context.DeadlineExceeded),
errors.Is(err, context.Canceled):
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
default:
return nil, status.Error(codes.Internal, "downstream execution failed")
}
if strings.TrimSpace(result.ResultCode) == "" {
return nil, status.Error(codes.Internal, "downstream response is invalid")
}
responseTimestampMS := s.clock.Now().UTC().UnixMilli()
payloadHash := sha256.Sum256(result.PayloadBytes)
signature, err := s.responseSigner.SignResponse(authn.ResponseSigningFields{
ProtocolVersion: command.ProtocolVersion,
RequestID: command.RequestID,
TimestampMS: responseTimestampMS,
ResultCode: result.ResultCode,
PayloadHash: payloadHash[:],
})
if err != nil {
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
}
return &gatewayv1.ExecuteCommandResponse{
ProtocolVersion: command.ProtocolVersion,
RequestId: command.RequestID,
TimestampMs: responseTimestampMS,
ResultCode: result.ResultCode,
PayloadBytes: bytes.Clone(result.PayloadBytes),
PayloadHash: bytes.Clone(payloadHash[:]),
Signature: signature,
}, nil
}
// SubscribeEvents delegates to the authenticated streaming service
// implementation selected during server construction.
func (s commandRoutingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
return s.subscribeDelegate.SubscribeEvents(req, stream)
}
// newCommandRoutingService constructs the final authenticated service that
// owns verified unary routing while preserving the delegated streaming path.
func newCommandRoutingService(subscribeDelegate gatewayv1.EdgeGatewayServer, router downstream.Router, responseSigner authn.ResponseSigner, clk clock.Clock, downstreamTimeout time.Duration) gatewayv1.EdgeGatewayServer {
return commandRoutingService{
subscribeDelegate: subscribeDelegate,
router: router,
responseSigner: responseSigner,
clock: clk,
downstreamTimeout: downstreamTimeout,
}
}
func authenticatedCommandFromContext(ctx context.Context) (downstream.AuthenticatedCommand, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(ctx)
if !ok {
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
return downstream.AuthenticatedCommand{
ProtocolVersion: envelope.ProtocolVersion,
UserID: record.UserID,
DeviceSessionID: record.DeviceSessionID,
MessageType: envelope.MessageType,
TimestampMS: envelope.TimestampMS,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
PayloadBytes: bytes.Clone(envelope.PayloadBytes),
}, nil
}
type unavailableResponseSigner struct{}
func (unavailableResponseSigner) SignResponse(authn.ResponseSigningFields) ([]byte, error) {
return nil, errors.New("response signer is unavailable")
}
func (unavailableResponseSigner) SignEvent(authn.EventSigningFields) ([]byte, error) {
return nil, errors.New("response signer is unavailable")
}
var _ gatewayv1.EdgeGatewayServer = commandRoutingService{}
@@ -0,0 +1,296 @@
package grpcapi
import (
"context"
"crypto/sha256"
"fmt"
"testing"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/testutil"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) {
t.Parallel()
signer := newTestEd25519ResponseSigner()
moveClient := &recordingDownstreamClient{
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
assert.Equal(t, downstream.AuthenticatedCommand{
ProtocolVersion: "v1",
UserID: "user-123",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: testCurrentTime.UnixMilli(),
RequestID: "request-123",
TraceID: "trace-123",
PayloadBytes: []byte("payload"),
}, command)
return downstream.UnaryResult{
ResultCode: "accepted",
PayloadBytes: []byte("downstream-response"),
}, nil
},
}
renameClient := &recordingDownstreamClient{}
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": moveClient,
"fleet.rename": renameClient,
}),
ResponseSigner: signer,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.Equal(t, "v1", response.GetProtocolVersion())
assert.Equal(t, "request-123", response.GetRequestId())
assert.Equal(t, testCurrentTime.UnixMilli(), response.GetTimestampMs())
assert.Equal(t, "accepted", response.GetResultCode())
assert.Equal(t, []byte("downstream-response"), response.GetPayloadBytes())
assert.Equal(t, 1, moveClient.executeCalls)
assert.Zero(t, renameClient.executeCalls)
wantHash := sha256.Sum256([]byte("downstream-response"))
assert.Equal(t, wantHash[:], response.GetPayloadHash())
require.NoError(t, authn.VerifyPayloadHash(response.GetPayloadBytes(), response.GetPayloadHash()))
require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.GetSignature(), authn.ResponseSigningFields{
ProtocolVersion: response.GetProtocolVersion(),
RequestID: response.GetRequestId(),
TimestampMS: response.GetTimestampMs(),
ResultCode: response.GetResultCode(),
PayloadHash: response.GetPayloadHash(),
}))
}
func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(nil),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
ResponseSigner: newTestResponseSigner(),
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unimplemented, status.Code(err))
assert.Equal(t, "message_type is not routed", status.Convert(err).Message())
}
func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) {
t.Parallel()
failingClient := &recordingDownstreamClient{
executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
return downstream.UnaryResult{}, fmt.Errorf("rpc transport failed: %w", downstream.ErrDownstreamUnavailable)
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": failingClient,
}),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
ResponseSigner: newTestResponseSigner(),
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message())
assert.Equal(t, 1, failingClient.executeCalls)
}
func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) {
t.Parallel()
logger := zap.NewNop()
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
var (
seenSpanContext trace.SpanContext
seenCommand downstream.AuthenticatedCommand
)
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": &recordingDownstreamClient{
executeFunc: func(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
seenSpanContext = trace.SpanContextFromContext(ctx)
seenCommand = command
return downstream.UnaryResult{
ResultCode: "accepted",
PayloadBytes: []byte("downstream-response"),
}, nil
},
},
}),
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Logger: logger,
Telemetry: telemetryRuntime,
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.True(t, seenSpanContext.IsValid())
assert.Equal(t, "trace-123", seenCommand.TraceID)
}
func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) {
t.Parallel()
started := make(chan struct{})
release := make(chan struct{})
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": &recordingDownstreamClient{
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
close(started)
<-release
return downstream.UnaryResult{
ResultCode: "accepted",
PayloadBytes: []byte("downstream-response"),
}, nil
},
},
}),
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
resultCh := make(chan error, 1)
go func() {
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
resultCh <- err
}()
require.Eventually(t, func() bool {
select {
case <-started:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond, "downstream execution did not start")
runGateway.cancel()
require.Never(t, func() bool {
select {
case <-resultCh:
return true
default:
return false
}
}, 100*time.Millisecond, 10*time.Millisecond, "unary request returned before downstream release")
close(release)
var err error
require.Eventually(t, func() bool {
select {
case err = <-resultCh:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond, "unary request did not drain before shutdown timeout")
require.NoError(t, err)
}
func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T) {
t.Parallel()
logger, logBuffer := testutil.NewObservedLogger(t)
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": &recordingDownstreamClient{},
}),
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Logger: logger,
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
logOutput := logBuffer.String()
assert.NotContains(t, logOutput, "payload_hash")
assert.NotContains(t, logOutput, "signature")
assert.NotContains(t, logOutput, `"payload"`)
}
+214
View File
@@ -0,0 +1,214 @@
package grpcapi
import (
"bytes"
"context"
"fmt"
"galaxy/gateway/proto/galaxy/gateway/v1"
"buf.build/go/protovalidate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const supportedProtocolVersion = "v1"
// parsedEnvelope captures the authenticated transport fields extracted from a
// request envelope after validation succeeds. Later wrappers may enrich this
// structure without changing the raw gRPC request types.
type parsedEnvelope struct {
ProtocolVersion string
DeviceSessionID string
MessageType string
TimestampMS int64
RequestID string
TraceID string
PayloadBytes []byte
PayloadHash []byte
Signature []byte
}
// parsedEnvelopeFromContext returns the parsed envelope previously attached to
// ctx by the envelope-validating gRPC service wrapper.
func parsedEnvelopeFromContext(ctx context.Context) (parsedEnvelope, bool) {
if ctx == nil {
return parsedEnvelope{}, false
}
envelope, ok := ctx.Value(parsedEnvelopeContextKey{}).(parsedEnvelope)
if !ok {
return parsedEnvelope{}, false
}
return envelope, true
}
// envelopeValidatingService applies envelope parsing and the protocol gate
// before delegating to the configured service implementation.
type envelopeValidatingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
}
// ExecuteCommand validates req and only then forwards it to the configured
// delegate with the parsed envelope attached to ctx.
func (s envelopeValidatingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
envelope, err := parseExecuteCommandRequest(req)
if err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(context.WithValue(ctx, parsedEnvelopeContextKey{}, envelope), req)
}
// SubscribeEvents validates req and only then forwards it to the configured
// delegate with the parsed envelope attached to the stream context.
func (s envelopeValidatingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
envelope, err := parseSubscribeEventsRequest(req)
if err != nil {
return err
}
return s.delegate.SubscribeEvents(req, envelopeContextStream{
ServerStreamingServer: stream,
ctx: context.WithValue(stream.Context(), parsedEnvelopeContextKey{}, envelope),
})
}
// parseExecuteCommandRequest validates req according to the request-envelope
// rules and returns a cloned parsed envelope suitable for later auth steps.
func parseExecuteCommandRequest(req *gatewayv1.ExecuteCommandRequest) (parsedEnvelope, error) {
if req == nil {
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
}
if err := protovalidate.Validate(req); err != nil {
return parsedEnvelope{}, canonicalExecuteCommandValidationError(req)
}
if req.GetProtocolVersion() != supportedProtocolVersion {
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
}
return parsedEnvelope{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
TraceID: req.GetTraceId(),
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
PayloadHash: bytes.Clone(req.GetPayloadHash()),
Signature: bytes.Clone(req.GetSignature()),
}, nil
}
// parseSubscribeEventsRequest validates req according to the request-envelope
// rules and returns a cloned parsed envelope suitable for later auth steps.
func parseSubscribeEventsRequest(req *gatewayv1.SubscribeEventsRequest) (parsedEnvelope, error) {
if req == nil {
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
}
if err := protovalidate.Validate(req); err != nil {
return parsedEnvelope{}, canonicalSubscribeEventsValidationError(req)
}
if req.GetProtocolVersion() != supportedProtocolVersion {
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
}
return parsedEnvelope{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
TraceID: req.GetTraceId(),
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
PayloadHash: bytes.Clone(req.GetPayloadHash()),
Signature: bytes.Clone(req.GetSignature()),
}, nil
}
// newEnvelopeValidatingService wraps delegate with the envelope-validation
// gate.
func newEnvelopeValidatingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
return envelopeValidatingService{delegate: delegate}
}
// canonicalExecuteCommandValidationError maps any ExecuteCommand validation
// failure into the stable canonical error chosen by field order.
func canonicalExecuteCommandValidationError(req *gatewayv1.ExecuteCommandRequest) error {
switch {
case req.GetProtocolVersion() == "":
return newMalformedEnvelopeError("protocol_version must not be empty")
case req.GetDeviceSessionId() == "":
return newMalformedEnvelopeError("device_session_id must not be empty")
case req.GetMessageType() == "":
return newMalformedEnvelopeError("message_type must not be empty")
case req.GetTimestampMs() <= 0:
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
case req.GetRequestId() == "":
return newMalformedEnvelopeError("request_id must not be empty")
case len(req.GetPayloadBytes()) == 0:
return newMalformedEnvelopeError("payload_bytes must not be empty")
case len(req.GetPayloadHash()) == 0:
return newMalformedEnvelopeError("payload_hash must not be empty")
case len(req.GetSignature()) == 0:
return newMalformedEnvelopeError("signature must not be empty")
default:
return newMalformedEnvelopeError("request envelope is invalid")
}
}
// canonicalSubscribeEventsValidationError maps any SubscribeEvents validation
// failure into the stable canonical error chosen by field order.
func canonicalSubscribeEventsValidationError(req *gatewayv1.SubscribeEventsRequest) error {
switch {
case req.GetProtocolVersion() == "":
return newMalformedEnvelopeError("protocol_version must not be empty")
case req.GetDeviceSessionId() == "":
return newMalformedEnvelopeError("device_session_id must not be empty")
case req.GetMessageType() == "":
return newMalformedEnvelopeError("message_type must not be empty")
case req.GetTimestampMs() <= 0:
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
case req.GetRequestId() == "":
return newMalformedEnvelopeError("request_id must not be empty")
case len(req.GetPayloadHash()) == 0:
return newMalformedEnvelopeError("payload_hash must not be empty")
case len(req.GetSignature()) == 0:
return newMalformedEnvelopeError("signature must not be empty")
default:
return newMalformedEnvelopeError("request envelope is invalid")
}
}
// newMalformedEnvelopeError returns the stable malformed-envelope reject used
// before the gateway performs any auth or routing work.
func newMalformedEnvelopeError(message string) error {
return status.Error(codes.InvalidArgument, message)
}
// newUnsupportedProtocolVersionError returns the stable reject for a non-empty
// but unsupported protocol_version literal.
func newUnsupportedProtocolVersionError(version string) error {
return status.Error(codes.FailedPrecondition, fmt.Sprintf("unsupported protocol_version %q", version))
}
type parsedEnvelopeContextKey struct{}
type envelopeContextStream struct {
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
ctx context.Context
}
func (s envelopeContextStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
var _ gatewayv1.EdgeGatewayServer = envelopeValidatingService{}
+420
View File
@@ -0,0 +1,420 @@
package grpcapi
import (
"context"
"testing"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
func TestParseExecuteCommandRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*gatewayv1.ExecuteCommandRequest)
wantCode codes.Code
wantMessage string
assertValid func(*testing.T, *gatewayv1.ExecuteCommandRequest, parsedEnvelope)
}{
{
name: "nil request",
wantCode: codes.InvalidArgument,
wantMessage: "request envelope must not be nil",
},
{
name: "empty protocol version",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.ProtocolVersion = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "protocol_version must not be empty",
},
{
name: "empty device session id",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.DeviceSessionId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "device_session_id must not be empty",
},
{
name: "empty message type",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.MessageType = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "message_type must not be empty",
},
{
name: "zero timestamp",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.TimestampMs = 0
},
wantCode: codes.InvalidArgument,
wantMessage: "timestamp_ms must be greater than zero",
},
{
name: "empty request id",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.RequestId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "request_id must not be empty",
},
{
name: "empty payload bytes",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.PayloadBytes = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_bytes must not be empty",
},
{
name: "empty payload hash",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.PayloadHash = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_hash must not be empty",
},
{
name: "empty signature",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.Signature = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "signature must not be empty",
},
{
name: "unsupported protocol version",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.ProtocolVersion = "v2"
},
wantCode: codes.FailedPrecondition,
wantMessage: `unsupported protocol_version "v2"`,
},
{
name: "valid request",
wantCode: codes.OK,
assertValid: func(t *testing.T, req *gatewayv1.ExecuteCommandRequest, envelope parsedEnvelope) {
t.Helper()
assert.Equal(t, supportedProtocolVersion, envelope.ProtocolVersion)
assert.Equal(t, req.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, req.GetMessageType(), envelope.MessageType)
assert.Equal(t, req.GetTimestampMs(), envelope.TimestampMS)
assert.Equal(t, req.GetRequestId(), envelope.RequestID)
assert.Equal(t, req.GetTraceId(), envelope.TraceID)
assert.Equal(t, req.GetPayloadBytes(), envelope.PayloadBytes)
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, req.GetSignature(), envelope.Signature)
originalPayloadBytes := append([]byte(nil), req.GetPayloadBytes()...)
originalPayloadHash := append([]byte(nil), req.GetPayloadHash()...)
originalSignature := append([]byte(nil), req.GetSignature()...)
envelope.PayloadBytes[0] = 'X'
envelope.PayloadHash[0] = 'Y'
envelope.Signature[0] = 'Z'
assert.Equal(t, originalPayloadBytes, req.GetPayloadBytes())
assert.Equal(t, originalPayloadHash, req.GetPayloadHash())
assert.Equal(t, originalSignature, req.GetSignature())
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var req *gatewayv1.ExecuteCommandRequest
if tt.name != "nil request" {
req = newValidExecuteCommandRequest()
if tt.mutate != nil {
tt.mutate(req)
}
}
envelope, err := parseExecuteCommandRequest(req)
if tt.wantCode != codes.OK {
require.Error(t, err)
assert.Equal(t, tt.wantCode, status.Code(err))
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
return
}
require.NoError(t, err)
require.NotNil(t, tt.assertValid)
tt.assertValid(t, req, envelope)
})
}
}
func TestParseSubscribeEventsRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*gatewayv1.SubscribeEventsRequest)
wantCode codes.Code
wantMessage string
assertValid func(*testing.T, *gatewayv1.SubscribeEventsRequest, parsedEnvelope)
}{
{
name: "nil request",
wantCode: codes.InvalidArgument,
wantMessage: "request envelope must not be nil",
},
{
name: "empty protocol version",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.ProtocolVersion = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "protocol_version must not be empty",
},
{
name: "empty device session id",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.DeviceSessionId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "device_session_id must not be empty",
},
{
name: "empty message type",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.MessageType = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "message_type must not be empty",
},
{
name: "zero timestamp",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.TimestampMs = 0
},
wantCode: codes.InvalidArgument,
wantMessage: "timestamp_ms must be greater than zero",
},
{
name: "empty request id",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.RequestId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "request_id must not be empty",
},
{
name: "empty payload hash",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.PayloadHash = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_hash must not be empty",
},
{
name: "empty signature",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.Signature = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "signature must not be empty",
},
{
name: "unsupported protocol version",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.ProtocolVersion = "v2"
},
wantCode: codes.FailedPrecondition,
wantMessage: `unsupported protocol_version "v2"`,
},
{
name: "valid request with empty payload bytes",
wantCode: codes.OK,
assertValid: func(t *testing.T, req *gatewayv1.SubscribeEventsRequest, envelope parsedEnvelope) {
t.Helper()
assert.Empty(t, req.GetPayloadBytes())
assert.Empty(t, envelope.PayloadBytes)
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, req.GetSignature(), envelope.Signature)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var req *gatewayv1.SubscribeEventsRequest
if tt.name != "nil request" {
req = newValidSubscribeEventsRequest()
if tt.mutate != nil {
tt.mutate(req)
}
}
envelope, err := parseSubscribeEventsRequest(req)
if tt.wantCode != codes.OK {
require.Error(t, err)
assert.Equal(t, tt.wantCode, status.Code(err))
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
return
}
require.NoError(t, err)
require.NotNil(t, tt.assertValid)
tt.assertValid(t, req, envelope)
})
}
}
func TestEnvelopeValidatingServiceExecuteCommandRejectsInvalidRequestBeforeDelegate(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
service := newEnvelopeValidatingService(delegate)
_, err := service.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Zero(t, delegate.executeCalls)
}
func TestEnvelopeValidatingServiceSubscribeEventsRejectsInvalidRequestBeforeDelegate(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
service := newEnvelopeValidatingService(delegate)
err := service.SubscribeEvents(&gatewayv1.SubscribeEventsRequest{}, stubGatewayEventStream{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Zero(t, delegate.subscribeCalls)
}
func TestEnvelopeValidatingServiceExecuteCommandAttachesParsedEnvelope(t *testing.T) {
t.Parallel()
want := newValidExecuteCommandRequest()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
require.True(t, ok)
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
assert.Equal(t, want.GetPayloadBytes(), envelope.PayloadBytes)
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
service := newEnvelopeValidatingService(delegate)
response, err := service.ExecuteCommand(context.Background(), want)
require.NoError(t, err)
assert.Equal(t, want.GetRequestId(), response.GetRequestId())
assert.Equal(t, 1, delegate.executeCalls)
}
func TestEnvelopeValidatingServiceSubscribeEventsAttachesParsedEnvelope(t *testing.T) {
t.Parallel()
want := newValidSubscribeEventsRequest()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
envelope, ok := parsedEnvelopeFromContext(stream.Context())
require.True(t, ok)
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
assert.Equal(t, want.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, want.GetSignature(), envelope.Signature)
return nil
},
}
service := newEnvelopeValidatingService(delegate)
err := service.SubscribeEvents(want, stubGatewayEventStream{})
require.NoError(t, err)
assert.Equal(t, 1, delegate.subscribeCalls)
}
type recordingEdgeGatewayService struct {
gatewayv1.UnimplementedEdgeGatewayServer
executeCalls int
subscribeCalls int
executeCommandFunc func(context.Context, *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error)
subscribeEventsFunc func(*gatewayv1.SubscribeEventsRequest, grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error
}
func (s *recordingEdgeGatewayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
s.executeCalls++
if s.executeCommandFunc != nil {
return s.executeCommandFunc(ctx, req)
}
return &gatewayv1.ExecuteCommandResponse{}, nil
}
func (s *recordingEdgeGatewayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
s.subscribeCalls++
if s.subscribeEventsFunc != nil {
return s.subscribeEventsFunc(req, stream)
}
return nil
}
type stubGatewayEventStream struct {
grpc.ServerStream
ctx context.Context
}
func (s stubGatewayEventStream) Send(*gatewayv1.GatewayEvent) error {
return nil
}
func (s stubGatewayEventStream) SetHeader(metadata.MD) error {
return nil
}
func (s stubGatewayEventStream) SendHeader(metadata.MD) error {
return nil
}
func (s stubGatewayEventStream) SetTrailer(metadata.MD) {}
func (s stubGatewayEventStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
func (s stubGatewayEventStream) SendMsg(any) error {
return nil
}
func (s stubGatewayEventStream) RecvMsg(any) error {
return nil
}
@@ -0,0 +1,95 @@
package grpcapi
import (
"context"
"errors"
"time"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/replay"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const minimumReplayReservationTTL = time.Millisecond
// freshnessAndReplayService applies freshness and anti-replay checks after
// client-signature verification and before later policy or routing steps run.
type freshnessAndReplayService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
clock clock.Clock
replayStore replay.Store
freshnessWindow time.Duration
}
// ExecuteCommand verifies request freshness and replay protection before
// delegating to the configured service implementation.
func (s freshnessAndReplayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := s.verifyFreshnessAndReplay(ctx); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents verifies request freshness and replay protection before
// delegating to the configured service implementation.
func (s freshnessAndReplayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := s.verifyFreshnessAndReplay(stream.Context()); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newFreshnessAndReplayService wraps delegate with the freshness and replay
// gate.
func newFreshnessAndReplayService(delegate gatewayv1.EdgeGatewayServer, clk clock.Clock, replayStore replay.Store, freshnessWindow time.Duration) gatewayv1.EdgeGatewayServer {
return freshnessAndReplayService{
delegate: delegate,
clock: clk,
replayStore: replayStore,
freshnessWindow: freshnessWindow,
}
}
func (s freshnessAndReplayService) verifyFreshnessAndReplay(ctx context.Context) error {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
now := s.clock.Now().UTC()
requestTime := time.UnixMilli(envelope.TimestampMS).UTC()
if requestTime.Before(now.Add(-s.freshnessWindow)) || requestTime.After(now.Add(s.freshnessWindow)) {
return status.Error(codes.FailedPrecondition, "request timestamp is outside the freshness window")
}
ttl := requestTime.Add(s.freshnessWindow).Sub(now)
if ttl < minimumReplayReservationTTL {
ttl = minimumReplayReservationTTL
}
err := s.replayStore.Reserve(ctx, envelope.DeviceSessionID, envelope.RequestID, ttl)
switch {
case err == nil:
return nil
case errors.Is(err, replay.ErrDuplicate):
return status.Error(codes.FailedPrecondition, "request replay detected")
default:
return status.Error(codes.Unavailable, "replay store is unavailable")
}
}
type unavailableReplayStore struct{}
func (unavailableReplayStore) Reserve(context.Context, string, string, time.Duration) error {
return errors.New("replay store is unavailable")
}
var _ gatewayv1.EdgeGatewayServer = freshnessAndReplayService{}
@@ -0,0 +1,509 @@
package grpcapi
import (
"context"
"errors"
"io"
"sync"
"testing"
"time"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsStaleTimestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
timestampMS int64
}{
{
name: "past window",
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
},
{
name: "future window",
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
})
}
}
func TestSubscribeEventsRejectsStaleTimestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
timestampMS int64
}{
{
name: "past window",
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
},
{
name: "future window",
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
})
}
}
func TestExecuteCommandRejectsReplay(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
req := newValidExecuteCommandRequest()
_, err := client.ExecuteCommand(context.Background(), req)
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request replay detected", status.Convert(err).Message())
assert.Equal(t, 1, delegate.executeCalls)
}
func TestSubscribeEventsRejectsReplay(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
req := newValidSubscribeEventsRequest()
stream, err := client.SubscribeEvents(context.Background(), req)
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
err = subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request replay detected", status.Convert(err).Message())
assert.Equal(t, 1, delegate.subscribeCalls)
}
func TestExecuteCommandAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
},
},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-shared"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-456", "request-shared"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestSubscribeEventsAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
return nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
},
},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-shared"))
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
stream, err = client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-456", "request-shared"))
require.NoError(t, err)
event = recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
assert.Equal(t, 2, delegate.subscribeCalls)
}
func TestExecuteCommandRejectsReplayStoreUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(context.Context, string, string, time.Duration) error {
return errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsReplayStoreUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(context.Context, string, string, time.Duration) error {
return errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
var reservedDeviceSessionID string
var reservedRequestID string
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
reservedDeviceSessionID = deviceSessionID
reservedRequestID = requestID
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.Equal(t, "request-123", response.GetRequestId())
assert.Equal(t, "device-session-123", reservedDeviceSessionID)
assert.Equal(t, "request-123", reservedRequestID)
assert.Equal(t, testFreshnessWindow, reservedTTL)
assert.Equal(t, 1, delegate.executeCalls)
}
func TestSubscribeEventsFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
return nil
},
}
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
assert.Equal(t, "device-session-123", deviceSessionID)
assert.Equal(t, "request-123", requestID)
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
assert.Equal(t, testFreshnessWindow, reservedTTL)
assert.Equal(t, 1, delegate.subscribeCalls)
}
func TestExecuteCommandFutureSkewUsesExtendedReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(
context.Background(),
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(2*time.Minute).UnixMilli()),
)
require.NoError(t, err)
assert.Equal(t, 7*time.Minute, reservedTTL)
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandBoundaryFreshnessUsesMinimumReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(
context.Background(),
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(-testFreshnessWindow).UnixMilli()),
)
require.NoError(t, err)
assert.Equal(t, minimumReplayReservationTTL, reservedTTL)
assert.Equal(t, 1, delegate.executeCalls)
}
func replayDuplicateBySessionAndRequest() func(context.Context, string, string, time.Duration) error {
var (
mu sync.Mutex
seen = make(map[string]struct{})
)
return func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
mu.Lock()
defer mu.Unlock()
key := deviceSessionID + "\x00" + requestID
if _, ok := seen[key]; ok {
return replay.ErrDuplicate
}
seen[key] = struct{}{}
return nil
}
}
+147
View File
@@ -0,0 +1,147 @@
package grpcapi
import (
"context"
"errors"
"path"
"time"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/telemetry"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func observabilityUnaryInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.UnaryServerInterceptor {
if logger == nil {
logger = zap.NewNop()
}
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
start := time.Now()
resp, err := handler(ctx, req)
recordGRPCRequest(logger, metrics, ctx, info.FullMethod, req, resp, err, time.Since(start), "unary")
return resp, err
}
}
func observabilityStreamInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.StreamServerInterceptor {
if logger == nil {
logger = zap.NewNop()
}
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
start := time.Now()
wrapped := &observabilityServerStream{ServerStream: stream}
err := handler(srv, wrapped)
recordGRPCRequest(logger, metrics, stream.Context(), info.FullMethod, wrapped.request, nil, err, time.Since(start), "stream")
return err
}
}
type observabilityServerStream struct {
grpc.ServerStream
request any
}
func (s *observabilityServerStream) RecvMsg(m any) error {
err := s.ServerStream.RecvMsg(m)
if err == nil && s.request == nil {
s.request = m
}
return err
}
func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx context.Context, fullMethod string, req any, resp any, err error, duration time.Duration, streamKind string) {
rpcMethod := path.Base(fullMethod)
messageType, requestID, traceID := grpcEnvelopeFields(req)
resultCode := grpcResultCode(resp)
grpcCode, grpcMessage, outcome := grpcOutcome(err)
rejectReason := telemetry.RejectReason(outcome)
attrs := []attribute.KeyValue{
attribute.String("rpc_method", rpcMethod),
attribute.String("message_type", messageType),
attribute.String("edge_outcome", string(outcome)),
}
if resultCode != "" {
attrs = append(attrs, attribute.String("result_code", resultCode))
}
if rejectReason != "" {
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
}
metrics.RecordAuthenticatedGRPC(ctx, attrs, duration)
fields := []zap.Field{
zap.String("component", "authenticated_grpc"),
zap.String("transport", "grpc"),
zap.String("stream_kind", streamKind),
zap.String("rpc_method", rpcMethod),
zap.String("message_type", messageType),
zap.String("grpc_code", grpcCode.String()),
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
zap.String("request_id", requestID),
zap.String("trace_id", traceID),
zap.String("peer_ip", peerIPFromContext(ctx)),
zap.String("edge_outcome", string(outcome)),
}
if resultCode != "" {
fields = append(fields, zap.String("result_code", resultCode))
}
if rejectReason != "" {
fields = append(fields, zap.String("reject_reason", rejectReason))
}
if grpcMessage != "" {
fields = append(fields, zap.String("grpc_message", grpcMessage))
}
fields = append(fields, logging.TraceFieldsFromContext(ctx)...)
switch outcome {
case telemetry.EdgeOutcomeSuccess:
logger.Info("authenticated gRPC request completed", fields...)
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeDownstreamUnavailable, telemetry.EdgeOutcomeInternalError:
logger.Error("authenticated gRPC request failed", fields...)
default:
logger.Warn("authenticated gRPC request rejected", fields...)
}
}
func grpcEnvelopeFields(req any) (messageType string, requestID string, traceID string) {
switch typed := req.(type) {
case *gatewayv1.ExecuteCommandRequest:
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
case *gatewayv1.SubscribeEventsRequest:
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
default:
return "", "", ""
}
}
func grpcResultCode(resp any) string {
typed, ok := resp.(*gatewayv1.ExecuteCommandResponse)
if !ok {
return ""
}
return typed.GetResultCode()
}
func grpcOutcome(err error) (codes.Code, string, telemetry.EdgeOutcome) {
switch {
case err == nil:
return codes.OK, "", telemetry.EdgeOutcomeSuccess
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return codes.Canceled, err.Error(), telemetry.EdgeOutcomeSuccess
default:
grpcStatus := status.Convert(err)
return grpcStatus.Code(), grpcStatus.Message(), telemetry.OutcomeFromGRPCStatus(grpcStatus.Code(), grpcStatus.Message())
}
}
+66
View File
@@ -0,0 +1,66 @@
package grpcapi
import (
"context"
"errors"
"galaxy/gateway/internal/authn"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// payloadHashVerifyingService applies payload-hash verification after session
// lookup and before any later auth or routing step runs.
type payloadHashVerifyingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
}
// ExecuteCommand verifies req payload integrity before delegating to the
// configured service implementation.
func (s payloadHashVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := verifyPayloadHash(ctx); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents verifies req payload integrity before delegating to the
// configured service implementation.
func (s payloadHashVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := verifyPayloadHash(stream.Context()); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newPayloadHashVerifyingService wraps delegate with the payload-hash
// verification gate.
func newPayloadHashVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
return payloadHashVerifyingService{delegate: delegate}
}
func verifyPayloadHash(ctx context.Context) error {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
err := authn.VerifyPayloadHash(envelope.PayloadBytes, envelope.PayloadHash)
switch {
case err == nil:
return nil
case errors.Is(err, authn.ErrInvalidPayloadHash), errors.Is(err, authn.ErrPayloadHashMismatch):
return status.Error(codes.InvalidArgument, err.Error())
default:
return status.Error(codes.Internal, "payload hash verification failed")
}
}
var _ gatewayv1.EdgeGatewayServer = payloadHashVerifyingService{}
@@ -0,0 +1,125 @@
package grpcapi
import (
"context"
"crypto/sha256"
"testing"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsPayloadHashWithInvalidLength(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidExecuteCommandRequest()
req.PayloadHash = []byte("short")
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestExecuteCommandRejectsPayloadHashMismatch(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidExecuteCommandRequest()
sum := sha256.Sum256([]byte("other"))
req.PayloadHash = sum[:]
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsPayloadHashWithInvalidLength(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidSubscribeEventsRequest()
req.PayloadHash = []byte("short")
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestSubscribeEventsRejectsPayloadHashMismatch(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidSubscribeEventsRequest()
sum := sha256.Sum256([]byte("other"))
req.PayloadHash = sum[:]
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
+172
View File
@@ -0,0 +1,172 @@
package grpcapi
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// NewFanOutPushStreamService constructs the authenticated SubscribeEvents tail
// service that registers active streams in hub and forwards client-facing
// events after the bootstrap event has been sent.
func NewFanOutPushStreamService(hub *push.Hub, responseSigner authn.ResponseSigner, clk clock.Clock, logger *zap.Logger) gatewayv1.EdgeGatewayServer {
if responseSigner == nil {
responseSigner = unavailableResponseSigner{}
}
if clk == nil {
clk = clock.System{}
}
if logger == nil {
logger = zap.NewNop()
}
return fanOutPushStreamService{
hub: hub,
responseSigner: responseSigner,
clock: clk,
logger: logger.Named("push_stream"),
}
}
// fanOutPushStreamService owns the post-bootstrap authenticated push-stream
// lifecycle backed by the in-memory push hub.
type fanOutPushStreamService struct {
gatewayv1.UnimplementedEdgeGatewayServer
hub *push.Hub
responseSigner authn.ResponseSigner
clock clock.Clock
logger *zap.Logger
}
// SubscribeEvents registers the verified stream in the push hub and forwards
// matching client-facing events until the stream ends.
func (s fanOutPushStreamService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
if s.hub == nil {
return status.Error(codes.Internal, "push hub is unavailable")
}
subscription, err := s.hub.Register(push.StreamBinding{
UserID: binding.UserID,
DeviceSessionID: binding.DeviceSessionID,
})
if err != nil {
return status.Error(codes.Internal, "push stream registration failed")
}
defer subscription.Close()
openFields := []zap.Field{
zap.String("component", "authenticated_grpc"),
zap.String("transport", "grpc"),
zap.String("rpc_method", authenticatedRPCSubscribeEvents),
zap.String("message_type", binding.MessageType),
zap.String("request_id", binding.RequestID),
zap.String("trace_id", binding.TraceID),
zap.String("device_session_id", binding.DeviceSessionID),
zap.String("user_id", binding.UserID),
}
openFields = append(openFields, logging.TraceFieldsFromContext(stream.Context())...)
s.logger.Info("push stream opened", openFields...)
for {
select {
case <-stream.Context().Done():
s.logger.Info("push stream closed", append(openFields, zap.String("edge_outcome", string(mapSubscriptionOutcome(stream.Context().Err()))))...)
return stream.Context().Err()
case <-subscription.Done():
subscriptionErr := subscription.Err()
s.logger.Warn("push stream closed", append(openFields,
zap.String("edge_outcome", string(mapSubscriptionOutcome(subscriptionErr))),
zap.String("reject_reason", string(mapSubscriptionOutcome(subscriptionErr))),
)...)
return mapSubscriptionError(subscriptionErr)
case event := <-subscription.Events():
signedEvent, err := s.buildGatewayEvent(event)
if err != nil {
return err
}
if err := stream.Send(signedEvent); err != nil {
return err
}
}
}
}
func (s fanOutPushStreamService) buildGatewayEvent(event push.Event) (*gatewayv1.GatewayEvent, error) {
timestampMS := s.clock.Now().UTC().UnixMilli()
payloadHash := sha256.Sum256(event.PayloadBytes)
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
EventType: event.EventType,
EventID: event.EventID,
TimestampMS: timestampMS,
RequestID: event.RequestID,
TraceID: event.TraceID,
PayloadHash: payloadHash[:],
})
if err != nil {
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
}
return &gatewayv1.GatewayEvent{
EventType: event.EventType,
EventId: event.EventID,
TimestampMs: timestampMS,
PayloadBytes: bytes.Clone(event.PayloadBytes),
PayloadHash: bytes.Clone(payloadHash[:]),
Signature: signature,
RequestId: event.RequestID,
TraceId: event.TraceID,
}, nil
}
func mapSubscriptionError(err error) error {
switch {
case err == nil:
return nil
case errors.Is(err, push.ErrSubscriptionRevoked):
return status.Error(codes.FailedPrecondition, "device session is revoked")
case errors.Is(err, push.ErrSubscriptionOverflow):
return status.Error(codes.ResourceExhausted, "push stream overflowed")
case errors.Is(err, push.ErrHubShuttingDown):
return status.Error(codes.Unavailable, "gateway is shutting down")
default:
return status.Error(codes.Internal, "push stream closed unexpectedly")
}
}
func mapSubscriptionOutcome(err error) telemetry.EdgeOutcome {
switch {
case err == nil:
return telemetry.EdgeOutcomeSuccess
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return telemetry.EdgeOutcomeSuccess
case errors.Is(err, push.ErrSubscriptionRevoked):
return telemetry.EdgeOutcomeRevokedSession
case errors.Is(err, push.ErrSubscriptionOverflow):
return telemetry.EdgeOutcomeRateLimited
case errors.Is(err, push.ErrHubShuttingDown):
return telemetry.EdgeOutcomeGatewayShuttingDown
default:
return telemetry.EdgeOutcomeInternalError
}
}
var _ gatewayv1.EdgeGatewayServer = fanOutPushStreamService{}
+164
View File
@@ -0,0 +1,164 @@
package grpcapi
import (
"bytes"
"context"
"crypto/sha256"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
gatewayfbs "galaxy/schema/fbs/gateway"
flatbuffers "github.com/google/flatbuffers/go"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const serverTimeEventType = "gateway.server_time"
// authenticatedStreamBinding captures the verified identity bound to one
// authenticated SubscribeEvents stream after the full ingress pipeline
// succeeds.
type authenticatedStreamBinding struct {
UserID string
DeviceSessionID string
MessageType string
RequestID string
TraceID string
}
// authenticatedStreamBindingFromContext returns the verified stream binding
// previously attached to ctx by the authenticated push-stream service.
func authenticatedStreamBindingFromContext(ctx context.Context) (authenticatedStreamBinding, bool) {
if ctx == nil {
return authenticatedStreamBinding{}, false
}
binding, ok := ctx.Value(authenticatedStreamBindingContextKey{}).(authenticatedStreamBinding)
if !ok {
return authenticatedStreamBinding{}, false
}
return binding, true
}
// authenticatedPushStreamService owns SubscribeEvents bootstrap behavior:
// bind the authenticated stream, send the initial signed server-time event,
// and then hand the stream lifecycle to the configured tail delegate.
type authenticatedPushStreamService struct {
gatewayv1.UnimplementedEdgeGatewayServer
tailDelegate gatewayv1.EdgeGatewayServer
responseSigner authn.ResponseSigner
clock clock.Clock
}
// SubscribeEvents binds the verified stream identity, sends the initial signed
// server-time event, and then delegates the remaining lifecycle.
func (s authenticatedPushStreamService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
envelope, ok := parsedEnvelopeFromContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
binding := authenticatedStreamBinding{
UserID: record.UserID,
DeviceSessionID: record.DeviceSessionID,
MessageType: envelope.MessageType,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
}
boundStream := authenticatedStreamContextStream{
ServerStreamingServer: stream,
ctx: context.WithValue(
stream.Context(),
authenticatedStreamBindingContextKey{},
binding,
),
}
serverTimeMS := s.clock.Now().UTC().UnixMilli()
payloadBytes := buildServerTimeEventPayload(serverTimeMS)
payloadHash := sha256.Sum256(payloadBytes)
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
EventType: serverTimeEventType,
EventID: envelope.RequestID,
TimestampMS: serverTimeMS,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
PayloadHash: payloadHash[:],
})
if err != nil {
return status.Error(codes.Unavailable, "response signer is unavailable")
}
if err := boundStream.Send(&gatewayv1.GatewayEvent{
EventType: serverTimeEventType,
EventId: envelope.RequestID,
TimestampMs: serverTimeMS,
PayloadBytes: bytes.Clone(payloadBytes),
PayloadHash: bytes.Clone(payloadHash[:]),
Signature: signature,
RequestId: envelope.RequestID,
TraceId: envelope.TraceID,
}); err != nil {
return err
}
return s.tailDelegate.SubscribeEvents(req, boundStream)
}
func newAuthenticatedPushStreamService(tailDelegate gatewayv1.EdgeGatewayServer, responseSigner authn.ResponseSigner, clk clock.Clock) gatewayv1.EdgeGatewayServer {
if tailDelegate == nil {
tailDelegate = holdOpenSubscribeEventsService{}
}
return authenticatedPushStreamService{
tailDelegate: tailDelegate,
responseSigner: responseSigner,
clock: clk,
}
}
func buildServerTimeEventPayload(serverTimeMS int64) []byte {
builder := flatbuffers.NewBuilder(32)
gatewayfbs.ServerTimeEventStart(builder)
gatewayfbs.ServerTimeEventAddServerTimeMs(builder, serverTimeMS)
eventOffset := gatewayfbs.ServerTimeEventEnd(builder)
gatewayfbs.FinishServerTimeEventBuffer(builder, eventOffset)
return bytes.Clone(builder.FinishedBytes())
}
type authenticatedStreamBindingContextKey struct{}
type authenticatedStreamContextStream struct {
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
ctx context.Context
}
func (s authenticatedStreamContextStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
type holdOpenSubscribeEventsService struct {
gatewayv1.UnimplementedEdgeGatewayServer
}
func (holdOpenSubscribeEventsService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
<-stream.Context().Done()
return stream.Context().Err()
}
var _ gatewayv1.EdgeGatewayServer = authenticatedPushStreamService{}
+286
View File
@@ -0,0 +1,286 @@
package grpcapi
import (
"context"
"errors"
"net"
"strings"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/ratelimit"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
const (
authenticatedGRPCBaseBucketKeyPrefix = "authenticated_grpc/"
authenticatedGRPCIPBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "ip="
authenticatedGRPCSessionBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "session="
authenticatedGRPCUserBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "user="
authenticatedGRPCMessageClassBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "message_class="
unknownAuthenticatedPeerIP = "unknown"
authenticatedRPCExecuteCommand = "ExecuteCommand"
authenticatedRPCSubscribeEvents = "SubscribeEvents"
)
var (
// ErrAuthenticatedPolicyDenied reports that the authenticated request was
// rejected by later edge policy after transport authenticity succeeded.
ErrAuthenticatedPolicyDenied = errors.New("authenticated request rejected by edge policy")
// ErrAuthenticatedPolicyUnavailable reports that authenticated policy could
// not be evaluated because its backing dependency is unavailable.
ErrAuthenticatedPolicyUnavailable = errors.New("authenticated request policy is unavailable")
)
// AuthenticatedRequestLimiter applies authenticated gRPC rate-limit policy to
// one concrete bucket key.
type AuthenticatedRequestLimiter interface {
// Reserve evaluates key under policy and reports whether the request may
// proceed immediately.
Reserve(key string, policy ratelimit.Policy) ratelimit.Decision
}
// AuthenticatedRequest describes the authenticated request metadata exposed to
// the edge-policy hook.
type AuthenticatedRequest struct {
// RPCMethod identifies the public gRPC method being processed.
RPCMethod string
// PeerIP is the transport peer IP derived from the gRPC connection.
PeerIP string
// MessageClass is the stable rate-limit and policy class. The gateway uses
// the full message_type literal because the v1 transport does not yet define
// a coarser authenticated class taxonomy.
MessageClass string
// Envelope contains the verified transport envelope fields used by later
// edge policy.
Envelope AuthenticatedRequestEnvelope
// Session contains the authenticated identity resolved from SessionCache.
Session session.Record
}
// AuthenticatedRequestEnvelope describes the verified request envelope fields
// exposed to the edge-policy hook.
type AuthenticatedRequestEnvelope struct {
// ProtocolVersion is the supported transport protocol version literal.
ProtocolVersion string
// DeviceSessionID is the authenticated device-session identifier.
DeviceSessionID string
// MessageType is the verified downstream routing key supplied by the client.
MessageType string
// TimestampMS is the client timestamp that already passed freshness checks.
TimestampMS int64
// RequestID is the authenticated transport request identifier.
RequestID string
// TraceID is the optional client-supplied correlation identifier.
TraceID string
}
// AuthenticatedRequestPolicy evaluates later authenticated edge policy after
// transport authenticity and rate-limit checks succeed.
type AuthenticatedRequestPolicy interface {
// Evaluate returns nil when the authenticated request may proceed. It should
// wrap ErrAuthenticatedPolicyDenied for stable reject mapping and
// ErrAuthenticatedPolicyUnavailable when its backing dependency is
// temporarily unavailable.
Evaluate(ctx context.Context, request AuthenticatedRequest) error
}
type authenticatedRateLimitService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
limiter AuthenticatedRequestLimiter
policy AuthenticatedRequestPolicy
cfg config.AuthenticatedGRPCAntiAbuseConfig
}
// ExecuteCommand applies authenticated rate limits and edge policy before
// delegating to the configured service implementation.
func (s authenticatedRateLimitService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := s.applyRateLimitsAndPolicy(ctx, authenticatedRPCExecuteCommand); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents applies authenticated rate limits and edge policy before
// delegating to the configured service implementation.
func (s authenticatedRateLimitService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := s.applyRateLimitsAndPolicy(stream.Context(), authenticatedRPCSubscribeEvents); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newAuthenticatedRateLimitService wraps delegate with the authenticated
// rate-limit and edge-policy gate.
func newAuthenticatedRateLimitService(delegate gatewayv1.EdgeGatewayServer, limiter AuthenticatedRequestLimiter, policy AuthenticatedRequestPolicy, cfg config.AuthenticatedGRPCAntiAbuseConfig) gatewayv1.EdgeGatewayServer {
return authenticatedRateLimitService{
delegate: delegate,
limiter: limiter,
policy: policy,
cfg: cfg,
}
}
func (s authenticatedRateLimitService) applyRateLimitsAndPolicy(ctx context.Context, rpcMethod string) error {
request, err := authenticatedRequestFromContext(ctx, rpcMethod)
if err != nil {
return err
}
if err := s.applyRateLimits(request); err != nil {
return err
}
if err := s.applyPolicy(ctx, request); err != nil {
return err
}
return nil
}
func (s authenticatedRateLimitService) applyRateLimits(request AuthenticatedRequest) error {
checks := []struct {
key string
policy config.AuthenticatedRateLimitConfig
}{
{
key: authenticatedGRPCIPBucketKey(request.PeerIP),
policy: s.cfg.IP,
},
{
key: authenticatedGRPCSessionBucketKey(request.Envelope.DeviceSessionID),
policy: s.cfg.Session,
},
{
key: authenticatedGRPCUserBucketKey(request.Session.UserID),
policy: s.cfg.User,
},
{
key: authenticatedGRPCMessageClassBucketKey(request.MessageClass),
policy: s.cfg.MessageClass,
},
}
for _, check := range checks {
decision := s.limiter.Reserve(check.key, ratelimit.Policy{
Requests: check.policy.Requests,
Window: check.policy.Window,
Burst: check.policy.Burst,
})
if !decision.Allowed {
return status.Error(codes.ResourceExhausted, "authenticated request rate limit exceeded")
}
}
return nil
}
func (s authenticatedRateLimitService) applyPolicy(ctx context.Context, request AuthenticatedRequest) error {
err := s.policy.Evaluate(ctx, request)
switch {
case err == nil:
return nil
case errors.Is(err, ErrAuthenticatedPolicyDenied):
return status.Error(codes.PermissionDenied, "authenticated request rejected by edge policy")
case errors.Is(err, ErrAuthenticatedPolicyUnavailable):
return status.Error(codes.Unavailable, "authenticated request policy is unavailable")
default:
return status.Error(codes.Internal, "authenticated request policy evaluation failed")
}
}
func authenticatedRequestFromContext(ctx context.Context, rpcMethod string) (AuthenticatedRequest, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(ctx)
if !ok {
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
return AuthenticatedRequest{
RPCMethod: rpcMethod,
PeerIP: peerIPFromContext(ctx),
MessageClass: authenticatedMessageClass(envelope.MessageType),
Envelope: AuthenticatedRequestEnvelope{
ProtocolVersion: envelope.ProtocolVersion,
DeviceSessionID: envelope.DeviceSessionID,
MessageType: envelope.MessageType,
TimestampMS: envelope.TimestampMS,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
},
Session: record,
}, nil
}
func authenticatedGRPCIPBucketKey(peerIP string) string {
return authenticatedGRPCIPBucketKeySegment + peerIP
}
func authenticatedGRPCSessionBucketKey(deviceSessionID string) string {
return authenticatedGRPCSessionBucketKeySegment + deviceSessionID
}
func authenticatedGRPCUserBucketKey(userID string) string {
return authenticatedGRPCUserBucketKeySegment + userID
}
func authenticatedGRPCMessageClassBucketKey(messageClass string) string {
return authenticatedGRPCMessageClassBucketKeySegment + messageClass
}
func authenticatedMessageClass(messageType string) string {
return messageType
}
func peerIPFromContext(ctx context.Context) string {
peerInfo, ok := peer.FromContext(ctx)
if !ok || peerInfo.Addr == nil {
return unknownAuthenticatedPeerIP
}
value := strings.TrimSpace(peerInfo.Addr.String())
if value == "" {
return unknownAuthenticatedPeerIP
}
host, _, err := net.SplitHostPort(value)
if err == nil && host != "" {
return host
}
return value
}
type noopAuthenticatedRequestPolicy struct{}
func (noopAuthenticatedRequestPolicy) Evaluate(context.Context, AuthenticatedRequest) error {
return nil
}
var _ gatewayv1.EdgeGatewayServer = authenticatedRateLimitService{}
@@ -0,0 +1,497 @@
package grpcapi
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/ratelimit"
"galaxy/gateway/internal/restapi"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRateLimitsByIP(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsBySession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsByUser(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{
"device-session-1": "user-shared",
"device-session-2": "user-shared",
"device-session-3": "user-other",
}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{
"device-session-1": "user-1",
"device-session-2": "user-2",
}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) {
t.Parallel()
policy := &recordingAuthenticatedRequestPolicy{}
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Policy: policy,
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
require.Len(t, policy.requests, 1)
assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod)
assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP)
assert.Equal(t, "fleet.move", policy.requests[0].MessageClass)
assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID)
assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID)
assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID)
assert.Equal(t, "user-123", policy.requests[0].Session.UserID)
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error {
return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied)
}),
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
assert.Equal(t, "authenticated request rejected by edge policy", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
assert.Equal(t, 1, delegate.subscribeCalls)
}
func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) {
t.Parallel()
sharedLimiter := ratelimit.NewInMemory()
publicCfg := config.DefaultPublicHTTPConfig()
publicCfg.Addr = unusedTCPAddr(t)
publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.Addr = unusedTCPAddr(t)
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
})
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
AuthService: staticAuthServiceClient{},
Limiter: publicLimiterAdapter{limiter: sharedLimiter},
})
delegate := &recordingEdgeGatewayService{}
grpcServer := NewServer(grpcCfg, ServerDependencies{
Service: delegate,
Router: executeCommandAdapterRouter{service: delegate},
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Limiter: sharedLimiter,
Clock: fixedClock{now: testCurrentTime},
})
application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
runGateway := runningGateway{cancel: cancel, resultCh: resultCh}
defer runGateway.stop(t)
waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz")
addr := waitForListenAddr(t, grpcServer)
firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
assert.Equal(t, http.StatusOK, firstPublic.StatusCode)
assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode)
require.NoError(t, firstPublic.Body.Close())
require.NoError(t, secondPublic.Body.Close())
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
}
func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig {
cfg := config.DefaultAuthenticatedGRPCConfig()
cfg.Addr = "127.0.0.1:0"
cfg.FreshnessWindow = testFreshnessWindow
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
if mutate != nil {
mutate(&cfg)
}
return cfg
}
func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest {
req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID)
req.MessageType = messageType
req.Signature = signRequest(
req.GetProtocolVersion(),
req.GetDeviceSessionId(),
req.GetMessageType(),
req.GetTimestampMs(),
req.GetRequestId(),
req.GetPayloadHash(),
)
return req
}
func userMappedSessionCache(users map[string]string) staticSessionCache {
return staticSessionCache{
lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) {
userID, ok := users[deviceSessionID]
if !ok {
return session.Record{}, session.ErrNotFound
}
record := newActiveSessionRecordWithSessionID(deviceSessionID)
record.UserID = userID
return record, nil
},
}
}
type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error
func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error {
return f(ctx, request)
}
type recordingAuthenticatedRequestPolicy struct {
requests []AuthenticatedRequest
}
func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error {
p.requests = append(p.requests, request)
return nil
}
type publicLimiterAdapter struct {
limiter ratelimit.Limiter
}
func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision {
decision := a.limiter.Reserve(key, ratelimit.Policy{
Requests: policy.Requests,
Window: policy.Window,
Burst: policy.Burst,
})
return restapi.PublicRateLimitDecision{
Allowed: decision.Allowed,
RetryAfter: decision.RetryAfter,
}
}
type staticAuthServiceClient struct{}
func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) {
return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil
}
func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) {
return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil
}
func waitForHTTPHealthz(t *testing.T, url string) {
t.Helper()
client := &http.Client{Timeout: 200 * time.Millisecond}
require.Eventually(t, func() bool {
response, err := client.Get(url)
if err != nil {
return false
}
require.NoError(t, response.Body.Close())
return response.StatusCode == http.StatusOK
}, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url)
}
func sendPublicAuthRequest(t *testing.T, url string) *http.Response {
t.Helper()
request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`))
require.NoError(t, err)
request.Header.Set("Content-Type", "application/json")
response, err := (&http.Client{Timeout: time.Second}).Do(request)
require.NoError(t, err)
return response
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
+260
View File
@@ -0,0 +1,260 @@
// Package grpcapi exposes the authenticated gRPC surface of the gateway.
package grpcapi
import (
"context"
"errors"
"fmt"
"net"
"sync"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/ratelimit"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
"galaxy/gateway/internal/telemetry"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/zap"
"google.golang.org/grpc"
)
// ServerDependencies describes the optional collaborators used by the
// authenticated gRPC server. The zero value is valid and keeps the process
// runnable with the built-in unimplemented service stub.
type ServerDependencies struct {
// Service optionally handles the post-bootstrap SubscribeEvents lifecycle
// after the initial authenticated service event has been sent. When nil, the
// gateway keeps authenticated SubscribeEvents streams open until the client
// cancels them, the server shuts down, or a later stream send fails.
Service gatewayv1.EdgeGatewayServer
// Router resolves the exact downstream unary client for the verified
// message_type value. When nil, the authenticated unary surface uses an
// empty exact-match router and returns UNIMPLEMENTED for unrouted commands.
Router downstream.Router
// ResponseSigner signs authenticated unary responses after downstream
// execution succeeds. When nil, the unary surface fails closed once it needs
// to sign a routed response.
ResponseSigner authn.ResponseSigner
// SessionCache resolves authenticated device sessions after the envelope
// gate succeeds. When nil, the authenticated gRPC surface remains runnable
// but valid envelopes fail closed as session-cache unavailable.
SessionCache session.Cache
// Clock provides current server time for freshness checks. When nil, the
// authenticated gRPC surface uses the system clock.
Clock clock.Clock
// ReplayStore reserves authenticated request identifiers after signature
// verification. When nil, valid requests fail closed as replay-store
// unavailable.
ReplayStore replay.Store
// Limiter applies authenticated rate limits after the request passes the
// transport authenticity checks. When nil, the authenticated gRPC surface
// uses a process-local in-memory limiter.
Limiter AuthenticatedRequestLimiter
// Policy evaluates later authenticated edge policy after rate limits pass.
// When nil, the authenticated gRPC surface applies a no-op allow policy.
Policy AuthenticatedRequestPolicy
// Logger writes structured logs for authenticated gRPC traffic.
Logger *zap.Logger
// Telemetry records low-cardinality gRPC metrics.
Telemetry *telemetry.Runtime
// PushHub is the active authenticated push-stream hub. When present, the
// server closes active streams before GracefulStop during shutdown.
PushHub *push.Hub
}
// Server owns the authenticated gRPC listener exposed by the gateway.
type Server struct {
cfg config.AuthenticatedGRPCConfig
service gatewayv1.EdgeGatewayServer
logger *zap.Logger
pushHub *push.Hub
metrics *telemetry.Runtime
stateMu sync.RWMutex
server *grpc.Server
listener net.Listener
}
// NewServer constructs an authenticated gRPC server for the supplied listener
// configuration and dependency bundle. Nil dependencies are replaced with safe
// defaults so the gateway can expose the documented transport surface with the
// full auth pipeline wired from built-in fallbacks.
func NewServer(cfg config.AuthenticatedGRPCConfig, deps ServerDependencies) *Server {
deps = normalizeServerDependencies(deps)
finalService := newCommandRoutingService(
newAuthenticatedPushStreamService(deps.Service, deps.ResponseSigner, deps.Clock),
deps.Router,
deps.ResponseSigner,
deps.Clock,
cfg.DownstreamTimeout,
)
return &Server{
cfg: cfg,
service: newEnvelopeValidatingService(
newSessionLookupService(
newPayloadHashVerifyingService(
newSignatureVerifyingService(
newFreshnessAndReplayService(
newAuthenticatedRateLimitService(
finalService,
deps.Limiter,
deps.Policy,
cfg.AntiAbuse,
),
deps.Clock,
deps.ReplayStore,
cfg.FreshnessWindow,
),
),
),
deps.SessionCache,
),
),
logger: deps.Logger.Named("authenticated_grpc"),
pushHub: deps.PushHub,
metrics: deps.Telemetry,
}
}
// Run binds the configured listener and serves the authenticated gRPC surface
// until Shutdown closes the server.
func (s *Server) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run authenticated gRPC server: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
listener, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("run authenticated gRPC server: listen on %q: %w", s.cfg.Addr, err)
}
grpcServer := grpc.NewServer(
grpc.ConnectionTimeout(s.cfg.ConnectionTimeout),
grpc.StatsHandler(otelgrpc.NewServerHandler()),
grpc.ChainUnaryInterceptor(observabilityUnaryInterceptor(s.logger, s.metrics)),
grpc.ChainStreamInterceptor(observabilityStreamInterceptor(s.logger, s.metrics)),
)
gatewayv1.RegisterEdgeGatewayServer(grpcServer, s.service)
s.stateMu.Lock()
s.server = grpcServer
s.listener = listener
s.stateMu.Unlock()
s.logger.Info("authenticated gRPC server started", zap.String("addr", listener.Addr().String()))
defer func() {
s.stateMu.Lock()
s.server = nil
s.listener = nil
s.stateMu.Unlock()
}()
err = grpcServer.Serve(listener)
switch {
case err == nil:
return nil
case errors.Is(err, grpc.ErrServerStopped):
s.logger.Info("authenticated gRPC server stopped")
return nil
default:
return fmt.Errorf("run authenticated gRPC server: serve on %q: %w", s.cfg.Addr, err)
}
}
// Shutdown gracefully stops the authenticated gRPC server within ctx. When the
// graceful stop exceeds ctx, the server is force-stopped before returning the
// timeout to the caller.
func (s *Server) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown authenticated gRPC server: nil context")
}
s.stateMu.RLock()
server := s.server
s.stateMu.RUnlock()
if server == nil {
return nil
}
if s.pushHub != nil {
s.pushHub.Shutdown()
}
stopped := make(chan struct{})
go func() {
server.GracefulStop()
close(stopped)
}()
select {
case <-stopped:
return nil
case <-ctx.Done():
server.Stop()
<-stopped
return fmt.Errorf("shutdown authenticated gRPC server: %w", ctx.Err())
}
}
func (s *Server) listenAddr() string {
s.stateMu.RLock()
defer s.stateMu.RUnlock()
if s.listener == nil {
return ""
}
return s.listener.Addr().String()
}
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
if deps.Router == nil {
deps.Router = downstream.NewStaticRouter(nil)
}
if deps.ResponseSigner == nil {
deps.ResponseSigner = unavailableResponseSigner{}
}
if deps.SessionCache == nil {
deps.SessionCache = unavailableSessionCache{}
}
if deps.Clock == nil {
deps.Clock = clock.System{}
}
if deps.ReplayStore == nil {
deps.ReplayStore = unavailableReplayStore{}
}
if deps.Limiter == nil {
deps.Limiter = ratelimit.NewInMemory()
}
if deps.Policy == nil {
deps.Policy = noopAuthenticatedRequestPolicy{}
}
if deps.Logger == nil {
deps.Logger = zap.NewNop()
}
return deps
}
+332
View File
@@ -0,0 +1,332 @@
package grpcapi
import (
"context"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
}
func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
}
func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: "v2",
DeviceSessionId: "device-session-123",
MessageType: "fleet.move",
TimestampMs: 123456789,
RequestId: "request-123",
PayloadBytes: []byte("payload"),
PayloadHash: []byte("hash"),
Signature: []byte("signature"),
})
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message())
}
func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unimplemented, status.Code(err))
}
func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
}
func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
recvResult := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvResult <- recvErr
}()
require.Never(t, func() bool {
select {
case <-recvResult:
return true
default:
return false
}
}, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation")
cancel()
var recvErr error
require.Eventually(t, func() bool {
select {
case recvErr = <-recvResult:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation")
require.Error(t, recvErr)
assert.Equal(t, codes.Canceled, status.Code(recvErr))
}
func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
}
func TestServerLifecycle(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
require.NoError(t, conn.Close())
runGateway.stop(t)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.Error(t, err)
}
type runningGateway struct {
cancel context.CancelFunc
resultCh chan error
}
func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) {
t.Helper()
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = "127.0.0.1:0"
grpcCfg.FreshnessWindow = testFreshnessWindow
return newTestGatewayWithGRPCConfig(t, grpcCfg, deps)
}
func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) {
t.Helper()
cfg := config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
}
if deps.Clock == nil {
deps.Clock = fixedClock{now: testCurrentTime}
}
if deps.ResponseSigner == nil {
deps.ResponseSigner = newTestResponseSigner()
}
if deps.Router == nil && deps.Service != nil {
deps.Router = executeCommandAdapterRouter{service: deps.Service}
}
server := NewServer(cfg.AuthenticatedGRPC, deps)
application := app.New(cfg, server)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
return server, runningGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func (g runningGateway) stop(t *testing.T) {
t.Helper()
g.cancel()
var err error
require.Eventually(t, func() bool {
select {
case err = <-g.resultCh:
return true
default:
return false
}
}, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation")
require.NoError(t, err)
}
func waitForListenAddr(t *testing.T, server *Server) string {
t.Helper()
var addr string
require.Eventually(t, func() bool {
addr = server.listenAddr()
return addr != ""
}, time.Second, 10*time.Millisecond, "server did not start listening")
return addr
}
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
return conn
}
+126
View File
@@ -0,0 +1,126 @@
package grpcapi
import (
"context"
"errors"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// resolvedSessionFromContext returns the session record previously attached to
// ctx by the session-lookup gateway wrapper.
func resolvedSessionFromContext(ctx context.Context) (session.Record, bool) {
if ctx == nil {
return session.Record{}, false
}
record, ok := ctx.Value(resolvedSessionContextKey{}).(session.Record)
if !ok {
return session.Record{}, false
}
return cloneSessionRecord(record), true
}
// sessionLookupService resolves the authenticated session from SessionCache
// after envelope parsing succeeds and before later auth steps run.
type sessionLookupService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
cache session.Cache
}
// ExecuteCommand resolves the cached session for req and only then forwards it
// to the configured delegate with the resolved session attached to ctx.
func (s sessionLookupService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
record, err := s.lookupSession(ctx)
if err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(context.WithValue(ctx, resolvedSessionContextKey{}, cloneSessionRecord(record)), req)
}
// SubscribeEvents resolves the cached session for req and only then forwards it
// to the configured delegate with the resolved session attached to the stream
// context.
func (s sessionLookupService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
record, err := s.lookupSession(stream.Context())
if err != nil {
return err
}
return s.delegate.SubscribeEvents(req, resolvedSessionContextStream{
ServerStreamingServer: stream,
ctx: context.WithValue(stream.Context(), resolvedSessionContextKey{}, cloneSessionRecord(record)),
})
}
// newSessionLookupService wraps delegate with the session-cache lookup gate.
func newSessionLookupService(delegate gatewayv1.EdgeGatewayServer, cache session.Cache) gatewayv1.EdgeGatewayServer {
return sessionLookupService{
delegate: delegate,
cache: cache,
}
}
func (s sessionLookupService) lookupSession(ctx context.Context) (session.Record, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return session.Record{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, err := s.cache.Lookup(ctx, envelope.DeviceSessionID)
switch {
case err == nil:
case errors.Is(err, session.ErrNotFound):
return session.Record{}, status.Error(codes.Unauthenticated, "unknown device session")
default:
return session.Record{}, status.Error(codes.Unavailable, "session cache is unavailable")
}
if record.Status == session.StatusRevoked {
return session.Record{}, status.Error(codes.FailedPrecondition, "device session is revoked")
}
return cloneSessionRecord(record), nil
}
func cloneSessionRecord(record session.Record) session.Record {
cloned := record
if record.RevokedAtMS != nil {
value := *record.RevokedAtMS
cloned.RevokedAtMS = &value
}
return cloned
}
type resolvedSessionContextKey struct{}
type resolvedSessionContextStream struct {
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
ctx context.Context
}
func (s resolvedSessionContextStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
type unavailableSessionCache struct{}
func (unavailableSessionCache) Lookup(context.Context, string) (session.Record, error) {
return session.Record{}, errors.New("session cache is unavailable")
}
var _ gatewayv1.EdgeGatewayServer = sessionLookupService{}
@@ -0,0 +1,294 @@
package grpcapi
import (
"context"
"errors"
"io"
"testing"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsUnknownSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, session.ErrNotFound
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "unknown device session", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsUnknownSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, session.ErrNotFound
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "unknown device session", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandRejectsRevokedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsRevokedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandRejectsSessionCacheUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsSessionCacheUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandAttachesResolvedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
record, ok := resolvedSessionFromContext(ctx)
require.True(t, ok)
assert.Equal(t, newActiveSessionRecord(), record)
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.Equal(t, "request-123", response.GetRequestId())
}
func TestSubscribeEventsAttachesResolvedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
record, ok := resolvedSessionFromContext(stream.Context())
require.True(t, ok)
assert.Equal(t, newActiveSessionRecord(), record)
return nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
}
func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
require.True(t, ok)
assert.Equal(t, authenticatedStreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-123",
MessageType: "gateway.subscribe",
RequestID: "request-123",
TraceID: "trace-123",
}, binding)
return nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
}
type staticSessionCache struct {
lookupFunc func(context.Context, string) (session.Record, error)
}
func (c staticSessionCache) Lookup(ctx context.Context, deviceSessionID string) (session.Record, error) {
return c.lookupFunc(ctx, deviceSessionID)
}
+80
View File
@@ -0,0 +1,80 @@
package grpcapi
import (
"context"
"errors"
"galaxy/gateway/internal/authn"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// signatureVerifyingService applies client-signature verification after
// payload integrity checks and before later auth or routing steps run.
type signatureVerifyingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
}
// ExecuteCommand verifies req client signature before delegating to the
// configured service implementation.
func (s signatureVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := verifyRequestSignature(ctx); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents verifies req client signature before delegating to the
// configured service implementation.
func (s signatureVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := verifyRequestSignature(stream.Context()); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newSignatureVerifyingService wraps delegate with the client-signature
// verification gate.
func newSignatureVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
return signatureVerifyingService{delegate: delegate}
}
func verifyRequestSignature(ctx context.Context) error {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
err := authn.VerifyRequestSignature(record.ClientPublicKey, envelope.Signature, authn.RequestSigningFields{
ProtocolVersion: envelope.ProtocolVersion,
DeviceSessionID: envelope.DeviceSessionID,
MessageType: envelope.MessageType,
TimestampMS: envelope.TimestampMS,
RequestID: envelope.RequestID,
PayloadHash: envelope.PayloadHash,
})
switch {
case err == nil:
return nil
case errors.Is(err, authn.ErrInvalidClientPublicKey):
return status.Error(codes.Unavailable, "session cache is unavailable")
case errors.Is(err, authn.ErrInvalidRequestSignature):
return status.Error(codes.Unauthenticated, "invalid request signature")
default:
return status.Error(codes.Internal, "request signature verification failed")
}
}
var _ gatewayv1.EdgeGatewayServer = signatureVerifyingService{}
@@ -0,0 +1,188 @@
package grpcapi
import (
"context"
"testing"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsInvalidSignature(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidExecuteCommandRequest()
req.Signature[0] ^= 0xff
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestExecuteCommandRejectsWrongKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestExecuteCommandRejectsInvalidCachedPublicKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = "%%%not-base64%%%"
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsInvalidSignature(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidSubscribeEventsRequest()
req.Signature[0] ^= 0xff
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestSubscribeEventsRejectsWrongKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestSubscribeEventsRejectsInvalidCachedPublicKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = "%%%not-base64%%%"
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
@@ -0,0 +1,298 @@
package grpcapi
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
gatewayfbs "galaxy/schema/fbs/gateway"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
var (
testCurrentTime = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
testFreshnessWindow = 5 * time.Minute
)
func newValidExecuteCommandRequest() *gatewayv1.ExecuteCommandRequest {
return newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-123")
}
func newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.ExecuteCommandRequest {
return newValidExecuteCommandRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
}
func newValidExecuteCommandRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.ExecuteCommandRequest {
payloadBytes := []byte("payload")
payloadHash := sha256.Sum256(payloadBytes)
req := &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: supportedProtocolVersion,
DeviceSessionId: deviceSessionID,
MessageType: "fleet.move",
TimestampMs: timestampMS,
RequestId: requestID,
PayloadBytes: payloadBytes,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
return req
}
func newValidSubscribeEventsRequest() *gatewayv1.SubscribeEventsRequest {
return newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-123")
}
func newValidSubscribeEventsRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
return newValidSubscribeEventsRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
}
func newValidSubscribeEventsRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.SubscribeEventsRequest {
payloadHash := sha256.Sum256(nil)
req := &gatewayv1.SubscribeEventsRequest{
ProtocolVersion: supportedProtocolVersion,
DeviceSessionId: deviceSessionID,
MessageType: "gateway.subscribe",
TimestampMs: timestampMS,
RequestId: requestID,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
return req
}
func newActiveSessionRecord() session.Record {
return newActiveSessionRecordWithSessionID("device-session-123")
}
func newActiveSessionRecordWithSessionID(deviceSessionID string) session.Record {
return session.Record{
DeviceSessionID: deviceSessionID,
UserID: "user-123",
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func newRevokedSessionRecord() session.Record {
revokedAtMS := int64(123456789)
return session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusRevoked,
RevokedAtMS: &revokedAtMS,
}
}
func alternateTestClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(newTestPrivateKey("alternate").Public().(ed25519.PublicKey))
}
func testClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(newTestPrivateKey("primary").Public().(ed25519.PublicKey))
}
func signRequest(protocolVersion, deviceSessionID, messageType string, timestampMS int64, requestID string, payloadHash []byte) []byte {
return ed25519.Sign(newTestPrivateKey("primary"), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: protocolVersion,
DeviceSessionID: deviceSessionID,
MessageType: messageType,
TimestampMS: timestampMS,
RequestID: requestID,
PayloadHash: payloadHash,
}))
}
func newTestPrivateKey(label string) ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-grpcapi-signature-test-" + label))
return ed25519.NewKeyFromSeed(seed[:])
}
func newTestEd25519ResponseSigner() *authn.Ed25519ResponseSigner {
pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: mustMarshalPKCS8PrivateKey(newTestPrivateKey("response-signer")),
})
signer, err := authn.ParseEd25519ResponseSignerPEM(pemBytes)
if err != nil {
panic(err)
}
return signer
}
func newTestResponseSigner() authn.ResponseSigner {
return newTestEd25519ResponseSigner()
}
func newTestResponseSignerPublicKey() ed25519.PublicKey {
return newTestEd25519ResponseSigner().PublicKey()
}
func mustMarshalPKCS8PrivateKey(privateKey ed25519.PrivateKey) []byte {
encoded, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
panic(err)
}
return encoded
}
type fixedClock struct {
now time.Time
}
func (c fixedClock) Now() time.Time {
return c.now
}
func recvBootstrapEvent(t interface {
require.TestingT
Helper()
}, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
t.Helper()
event, err := stream.Recv()
require.NoError(t, err)
return event
}
func subscribeEventsError(t interface {
require.TestingT
Helper()
}, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error {
t.Helper()
stream, err := client.SubscribeEvents(ctx, req)
if err != nil {
return err
}
_, err = stream.Recv()
return err
}
func assertServerTimeBootstrapEvent(t interface {
require.TestingT
Helper()
}, event *gatewayv1.GatewayEvent, publicKey ed25519.PublicKey, wantRequestID string, wantTraceID string, wantTimestampMS int64) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, serverTimeEventType, event.GetEventType())
assert.Equal(t, wantRequestID, event.GetEventId())
assert.Equal(t, wantRequestID, event.GetRequestId())
assert.Equal(t, wantTraceID, event.GetTraceId())
assert.Equal(t, wantTimestampMS, event.GetTimestampMs())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(publicKey, event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
payload := gatewayfbs.GetRootAsServerTimeEvent(event.GetPayloadBytes(), flatbuffers.UOffsetT(0))
assert.Equal(t, wantTimestampMS, payload.ServerTimeMs())
}
type staticReplayStore struct {
reserveFunc func(context.Context, string, string, time.Duration) error
}
func (s staticReplayStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
if s.reserveFunc != nil {
return s.reserveFunc(ctx, deviceSessionID, requestID, ttl)
}
return nil
}
type executeCommandAdapterRouter struct {
service gatewayv1.EdgeGatewayServer
}
func (r executeCommandAdapterRouter) Route(string) (downstream.Client, error) {
return executeCommandAdapterClient{service: r.service}, nil
}
type executeCommandAdapterClient struct {
service gatewayv1.EdgeGatewayServer
}
func (c executeCommandAdapterClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
response, err := c.service.ExecuteCommand(ctx, &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: command.ProtocolVersion,
DeviceSessionId: command.DeviceSessionID,
MessageType: command.MessageType,
TimestampMs: command.TimestampMS,
RequestId: command.RequestID,
PayloadBytes: command.PayloadBytes,
TraceId: command.TraceID,
})
if err != nil {
return downstream.UnaryResult{}, err
}
resultCode := response.GetResultCode()
if resultCode == "" {
resultCode = "ok"
}
return downstream.UnaryResult{
ResultCode: resultCode,
PayloadBytes: response.GetPayloadBytes(),
}, nil
}
type recordingDownstreamClient struct {
executeCalls int
commands []downstream.AuthenticatedCommand
executeFunc func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error)
}
func (c *recordingDownstreamClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
c.executeCalls++
c.commands = append(c.commands, downstream.AuthenticatedCommand{
ProtocolVersion: command.ProtocolVersion,
UserID: command.UserID,
DeviceSessionID: command.DeviceSessionID,
MessageType: command.MessageType,
TimestampMS: command.TimestampMS,
RequestID: command.RequestID,
TraceID: command.TraceID,
PayloadBytes: append([]byte(nil), command.PayloadBytes...),
})
if c.executeFunc != nil {
return c.executeFunc(ctx, command)
}
return downstream.UnaryResult{
ResultCode: "ok",
PayloadBytes: []byte("response"),
}, nil
}