phase 4: connectrpc on the gateway authenticated edge
Replace the native-gRPC server bootstrap with a single `connectrpc.com/connect` HTTP/h2c listener. Connect-Go natively serves Connect, gRPC, and gRPC-Web on the same port, so browsers can now reach the authenticated surface without giving up the gRPC framing native and desktop clients may use later. The decorator stack (envelope → session → payload-hash → signature → freshness/replay → rate-limit → routing/push) is reused unchanged behind a small Connect → gRPC adapter and a `grpc.ServerStream` shim around `*connect.ServerStream`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -11,14 +11,12 @@ import (
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) {
|
||||
@@ -58,32 +56,27 @@ func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
response, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v1", response.GetProtocolVersion())
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
assert.Equal(t, testCurrentTime.UnixMilli(), response.GetTimestampMs())
|
||||
assert.Equal(t, "accepted", response.GetResultCode())
|
||||
assert.Equal(t, []byte("downstream-response"), response.GetPayloadBytes())
|
||||
assert.Equal(t, "v1", response.Msg.GetProtocolVersion())
|
||||
assert.Equal(t, "request-123", response.Msg.GetRequestId())
|
||||
assert.Equal(t, testCurrentTime.UnixMilli(), response.Msg.GetTimestampMs())
|
||||
assert.Equal(t, "accepted", response.Msg.GetResultCode())
|
||||
assert.Equal(t, []byte("downstream-response"), response.Msg.GetPayloadBytes())
|
||||
assert.Equal(t, 1, moveClient.executeCalls)
|
||||
assert.Zero(t, renameClient.executeCalls)
|
||||
|
||||
wantHash := sha256.Sum256([]byte("downstream-response"))
|
||||
assert.Equal(t, wantHash[:], response.GetPayloadHash())
|
||||
require.NoError(t, authn.VerifyPayloadHash(response.GetPayloadBytes(), response.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.GetSignature(), authn.ResponseSigningFields{
|
||||
ProtocolVersion: response.GetProtocolVersion(),
|
||||
RequestID: response.GetRequestId(),
|
||||
TimestampMS: response.GetTimestampMs(),
|
||||
ResultCode: response.GetResultCode(),
|
||||
PayloadHash: response.GetPayloadHash(),
|
||||
assert.Equal(t, wantHash[:], response.Msg.GetPayloadHash())
|
||||
require.NoError(t, authn.VerifyPayloadHash(response.Msg.GetPayloadBytes(), response.Msg.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.Msg.GetSignature(), authn.ResponseSigningFields{
|
||||
ProtocolVersion: response.Msg.GetProtocolVersion(),
|
||||
RequestID: response.Msg.GetRequestId(),
|
||||
TimestampMS: response.Msg.GetTimestampMs(),
|
||||
ResultCode: response.Msg.GetResultCode(),
|
||||
PayloadHash: response.Msg.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -99,16 +92,11 @@ func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unimplemented, status.Code(err))
|
||||
assert.Equal(t, "message_type is not routed", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err))
|
||||
assert.Equal(t, "message_type is not routed", connectErrorMessage(t, err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) {
|
||||
@@ -131,16 +119,11 @@ func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "downstream service is unavailable", connectErrorMessage(t, err))
|
||||
assert.Equal(t, 1, failingClient.executeCalls)
|
||||
}
|
||||
|
||||
@@ -167,16 +150,11 @@ func TestExecuteCommandMapsDownstreamTimeoutToUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "downstream service is unavailable", connectErrorMessage(t, err))
|
||||
assert.Equal(t, 1, stallingClient.executeCalls)
|
||||
}
|
||||
|
||||
@@ -203,16 +181,11 @@ func TestExecuteCommandFailsClosedWhenResponseSignerUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "response signer is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "response signer is unavailable", connectErrorMessage(t, err))
|
||||
assert.Equal(t, 1, successClient.executeCalls)
|
||||
}
|
||||
|
||||
@@ -250,13 +223,8 @@ func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, seenSpanContext.IsValid())
|
||||
@@ -290,15 +258,10 @@ func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
resultCh <- err
|
||||
}()
|
||||
|
||||
@@ -353,13 +316,8 @@ func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T)
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.NoError(t, err)
|
||||
|
||||
logOutput := logBuffer.String()
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
"galaxy/gateway/proto/galaxy/gateway/v1/gatewayv1connect"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// connectEdgeAdapter exposes the existing gRPC-shaped authenticated edge
|
||||
// service decorator stack (envelope → session → payload-hash → signature →
|
||||
// freshness/replay → rate-limit → routing/push) through the
|
||||
// gatewayv1connect.EdgeGatewayHandler interface. It owns no logic of its
|
||||
// own; the underlying decorator stack carries the full ingress contract
|
||||
// unchanged.
|
||||
type connectEdgeAdapter struct {
|
||||
impl gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
// newConnectEdgeAdapter wraps impl as a Connect handler.
|
||||
func newConnectEdgeAdapter(impl gatewayv1.EdgeGatewayServer) gatewayv1connect.EdgeGatewayHandler {
|
||||
return &connectEdgeAdapter{impl: impl}
|
||||
}
|
||||
|
||||
// ExecuteCommand unwraps the typed Connect request, calls the underlying
|
||||
// service, and wraps the typed response. gRPC `status.Error` values
|
||||
// returned by the decorator stack are translated to *connect.Error so
|
||||
// the Connect client receives the matching code and message.
|
||||
func (a *connectEdgeAdapter) ExecuteCommand(ctx context.Context, req *connect.Request[gatewayv1.ExecuteCommandRequest]) (*connect.Response[gatewayv1.ExecuteCommandResponse], error) {
|
||||
resp, err := a.impl.ExecuteCommand(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, translateGRPCStatusError(err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
// SubscribeEvents adapts the Connect server stream to the
|
||||
// grpc.ServerStreamingServer contract expected by the existing decorator
|
||||
// stack. The decorator stack only ever calls Send and Context on the
|
||||
// stream; the remaining grpc.ServerStream surface is satisfied by no-op
|
||||
// shims so the interface contract is met without panicking. Errors
|
||||
// returned by the decorator stack are translated to *connect.Error.
|
||||
func (a *connectEdgeAdapter) SubscribeEvents(ctx context.Context, req *connect.Request[gatewayv1.SubscribeEventsRequest], stream *connect.ServerStream[gatewayv1.GatewayEvent]) error {
|
||||
wrapped := &connectEdgeStream{ctx: ctx, stream: stream}
|
||||
if err := a.impl.SubscribeEvents(req.Msg, wrapped); err != nil {
|
||||
return translateGRPCStatusError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// translateGRPCStatusError maps gRPC status.Error values returned by the
|
||||
// decorator stack into *connect.Error with the equivalent code and message.
|
||||
// Errors that are already *connect.Error pass through unchanged. Errors
|
||||
// without a recognisable gRPC status are returned verbatim — connect-go
|
||||
// renders those as CodeUnknown.
|
||||
func translateGRPCStatusError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var connectErr *connect.Error
|
||||
if errors.As(err, &connectErr) {
|
||||
return err
|
||||
}
|
||||
|
||||
grpcStatus, ok := grpcstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if grpcStatus.Code() == codes.OK {
|
||||
return nil
|
||||
}
|
||||
|
||||
return connect.NewError(connect.Code(grpcStatus.Code()), errors.New(grpcStatus.Message()))
|
||||
}
|
||||
|
||||
// connectEdgeStream satisfies grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
||||
// on top of *connect.ServerStream. The decorator stack reads the request
|
||||
// context and pushes outbound events through Send; the rest of the
|
||||
// grpc.ServerStream surface is not exercised in the gateway, so the no-op
|
||||
// implementations preserve the type contract without surprising behaviour.
|
||||
type connectEdgeStream struct {
|
||||
ctx context.Context
|
||||
stream *connect.ServerStream[gatewayv1.GatewayEvent]
|
||||
}
|
||||
|
||||
// Send forwards a typed gateway event through the underlying Connect server
|
||||
// stream.
|
||||
func (s *connectEdgeStream) Send(event *gatewayv1.GatewayEvent) error {
|
||||
return s.stream.Send(event)
|
||||
}
|
||||
|
||||
// Context returns the request context handed to the Connect handler.
|
||||
func (s *connectEdgeStream) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
// SetHeader is part of grpc.ServerStream. The Connect transport exposes
|
||||
// response headers through ResponseHeader() at construction time; metadata
|
||||
// supplied here is intentionally ignored because no decorator in the
|
||||
// gateway exercises the gRPC-only metadata path.
|
||||
func (s *connectEdgeStream) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendHeader is part of grpc.ServerStream. Connect-served streams flush
|
||||
// headers automatically on the first Send; manual header dispatch is not
|
||||
// modelled.
|
||||
func (s *connectEdgeStream) SendHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTrailer is part of grpc.ServerStream. Trailer metadata has no
|
||||
// corresponding Connect concept on server-streaming responses.
|
||||
func (s *connectEdgeStream) SetTrailer(metadata.MD) {}
|
||||
|
||||
// SendMsg is part of grpc.ServerStream. The decorator stack never calls
|
||||
// SendMsg directly; if a future caller does, the typed Send path is used
|
||||
// when the message is a GatewayEvent.
|
||||
func (s *connectEdgeStream) SendMsg(m any) error {
|
||||
event, ok := m.(*gatewayv1.GatewayEvent)
|
||||
if !ok {
|
||||
return fmt.Errorf("connectEdgeStream.SendMsg: unsupported message type %T", m)
|
||||
}
|
||||
|
||||
return s.stream.Send(event)
|
||||
}
|
||||
|
||||
// RecvMsg is part of grpc.ServerStream. Server-streaming server handlers
|
||||
// have no client messages to receive after the initial request, so this
|
||||
// method is intentionally an error path.
|
||||
func (s *connectEdgeStream) RecvMsg(any) error {
|
||||
return errors.New("connectEdgeStream.RecvMsg: server-streaming has no client messages")
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// observabilityConnectInterceptor returns a Connect interceptor that records
|
||||
// the same structured log entry and authenticated edge metric pair as the
|
||||
// gRPC instrumentation it replaced. It also injects the parsed peer IP into
|
||||
// the request context so the rate-limit decorator can attribute requests
|
||||
// without depending on the gRPC `peer` package.
|
||||
func observabilityConnectInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) connect.Interceptor {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &connectObservability{logger: logger, metrics: metrics}
|
||||
}
|
||||
|
||||
type connectObservability struct {
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
}
|
||||
|
||||
// WrapUnary records timing and outcome for a single unary edge call.
|
||||
func (o *connectObservability) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
ctx = contextWithPeerIP(ctx, hostFromConnectPeerAddr(req.Peer().Addr))
|
||||
|
||||
start := time.Now()
|
||||
resp, err := next(ctx, req)
|
||||
|
||||
var respValue any
|
||||
if resp != nil {
|
||||
respValue = resp.Any()
|
||||
}
|
||||
recordEdgeRequest(o.logger, o.metrics, ctx, "connect", req.Spec().Procedure, req.Any(), respValue, err, time.Since(start), "unary")
|
||||
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
// WrapStreamingClient is the client-side hook required by the
|
||||
// connect.Interceptor contract. The gateway only acts as a Connect server,
|
||||
// so this hook is a pass-through.
|
||||
func (o *connectObservability) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
||||
return next
|
||||
}
|
||||
|
||||
// WrapStreamingHandler records timing and outcome for one server-streaming
|
||||
// edge call. The wrapped conn captures the first received request so the
|
||||
// log/metric pair carries the same envelope fields the gRPC instrumentation
|
||||
// emitted before.
|
||||
func (o *connectObservability) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
||||
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
|
||||
ctx = contextWithPeerIP(ctx, hostFromConnectPeerAddr(conn.Peer().Addr))
|
||||
|
||||
start := time.Now()
|
||||
wrapped := &observabilityStreamingConn{StreamingHandlerConn: conn}
|
||||
err := next(ctx, wrapped)
|
||||
|
||||
recordEdgeRequest(o.logger, o.metrics, ctx, "connect", conn.Spec().Procedure, wrapped.firstRequest, nil, err, time.Since(start), "stream")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// observabilityStreamingConn captures the first received request so the
|
||||
// streaming-handler interceptor can derive the envelope log fields after
|
||||
// the handler returns.
|
||||
type observabilityStreamingConn struct {
|
||||
connect.StreamingHandlerConn
|
||||
|
||||
firstRequest any
|
||||
}
|
||||
|
||||
// Receive forwards to the underlying conn and stores the first successful
|
||||
// message, so envelopeFieldsFromRequest can read message_type, request_id,
|
||||
// and trace_id from it.
|
||||
func (c *observabilityStreamingConn) Receive(msg any) error {
|
||||
err := c.StreamingHandlerConn.Receive(msg)
|
||||
if err == nil && c.firstRequest == nil {
|
||||
c.firstRequest = msg
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// hostFromConnectPeerAddr returns the host part of a "host:port" peer
|
||||
// address, or the address verbatim when it cannot be split. Empty input
|
||||
// yields an empty string so peerIPFromContext falls back to the canonical
|
||||
// `unknown` bucket.
|
||||
func hostFromConnectPeerAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
@@ -4,8 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"buf.build/go/protovalidate"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
@@ -3,7 +3,6 @@ package grpcapi
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,11 +11,10 @@ import (
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsStaleTimestamp(t *testing.T) {
|
||||
@@ -51,16 +49,11 @@ func TestExecuteCommandRejectsStaleTimestamp(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS)))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
})
|
||||
}
|
||||
@@ -98,16 +91,11 @@ func TestSubscribeEventsRejectsStaleTimestamp(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
})
|
||||
}
|
||||
@@ -127,21 +115,16 @@ func TestExecuteCommandRejectsReplay(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
req := newValidExecuteCommandRequest()
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(req))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), req)
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(req))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request replay detected", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err))
|
||||
assert.Equal(t, "request replay detected", connectErrorMessage(t, err))
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -159,25 +142,20 @@ func TestSubscribeEventsRejectsReplay(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
req := newValidSubscribeEventsRequest()
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), req)
|
||||
stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(req))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.False(t, stream.Receive())
|
||||
require.NoError(t, stream.Err())
|
||||
|
||||
err = subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request replay detected", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err))
|
||||
assert.Equal(t, "request replay detected", connectErrorMessage(t, err))
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -204,17 +182,12 @@ func TestExecuteCommandAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-shared"))
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-shared")))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-456", "request-shared"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-456", "request-shared")))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
@@ -243,26 +216,21 @@ func TestSubscribeEventsAllowsSameRequestIDAcrossDistinctSessions(t *testing.T)
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-shared"))
|
||||
stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-shared")))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.False(t, stream.Receive())
|
||||
require.NoError(t, stream.Err())
|
||||
|
||||
stream, err = client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-456", "request-shared"))
|
||||
stream, err = client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-456", "request-shared")))
|
||||
require.NoError(t, err)
|
||||
event = recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.False(t, stream.Receive())
|
||||
require.NoError(t, stream.Err())
|
||||
|
||||
assert.Equal(t, 2, delegate.subscribeCalls)
|
||||
}
|
||||
@@ -283,16 +251,11 @@ func TestExecuteCommandRejectsReplayStoreUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "replay store is unavailable", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -312,16 +275,11 @@ func TestSubscribeEventsRejectsReplayStoreUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "replay store is unavailable", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -353,15 +311,10 @@ func TestExecuteCommandFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *tes
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
response, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
assert.Equal(t, "request-123", response.Msg.GetRequestId())
|
||||
assert.Equal(t, "device-session-123", reservedDeviceSessionID)
|
||||
assert.Equal(t, "request-123", reservedRequestID)
|
||||
assert.Equal(t, testFreshnessWindow, reservedTTL)
|
||||
@@ -394,18 +347,13 @@ func TestSubscribeEventsFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *te
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequest()))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.False(t, stream.Receive())
|
||||
require.NoError(t, stream.Err())
|
||||
assert.Equal(t, testFreshnessWindow, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
@@ -434,15 +382,10 @@ func TestExecuteCommandFutureSkewUsesExtendedReplayTTL(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(
|
||||
context.Background(),
|
||||
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(2*time.Minute).UnixMilli()),
|
||||
connect.NewRequest(newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(2*time.Minute).UnixMilli())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 7*time.Minute, reservedTTL)
|
||||
@@ -473,15 +416,10 @@ func TestExecuteCommandBoundaryFreshnessUsesMinimumReplayTTL(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(
|
||||
context.Background(),
|
||||
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(-testFreshnessWindow).UnixMilli()),
|
||||
connect.NewRequest(newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(-testFreshnessWindow).UnixMilli())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, minimumReplayReservationTTL, reservedTTL)
|
||||
|
||||
@@ -12,59 +12,21 @@ import (
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func observabilityUnaryInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.UnaryServerInterceptor {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
start := time.Now()
|
||||
resp, err := handler(ctx, req)
|
||||
|
||||
recordGRPCRequest(logger, metrics, ctx, info.FullMethod, req, resp, err, time.Since(start), "unary")
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func observabilityStreamInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.StreamServerInterceptor {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
start := time.Now()
|
||||
wrapped := &observabilityServerStream{ServerStream: stream}
|
||||
err := handler(srv, wrapped)
|
||||
|
||||
recordGRPCRequest(logger, metrics, stream.Context(), info.FullMethod, wrapped.request, nil, err, time.Since(start), "stream")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
type observabilityServerStream struct {
|
||||
grpc.ServerStream
|
||||
request any
|
||||
}
|
||||
|
||||
func (s *observabilityServerStream) RecvMsg(m any) error {
|
||||
err := s.ServerStream.RecvMsg(m)
|
||||
if err == nil && s.request == nil {
|
||||
s.request = m
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx context.Context, fullMethod string, req any, resp any, err error, duration time.Duration, streamKind string) {
|
||||
// recordEdgeRequest emits the structured log entry and the
|
||||
// `gateway.authenticated_grpc.*` metric pair for one authenticated edge
|
||||
// request or stream outcome. The transport parameter labels the wire
|
||||
// protocol the request travelled over (`connect`, `grpc`, or `grpc-web`),
|
||||
// preserving stable observability semantics across the unified Connect-go
|
||||
// listener.
|
||||
func recordEdgeRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx context.Context, transport string, fullMethod string, req any, resp any, err error, duration time.Duration, streamKind string) {
|
||||
rpcMethod := path.Base(fullMethod)
|
||||
messageType, requestID, traceID := grpcEnvelopeFields(req)
|
||||
resultCode := grpcResultCode(resp)
|
||||
grpcCode, grpcMessage, outcome := grpcOutcome(err)
|
||||
messageType, requestID, traceID := envelopeFieldsFromRequest(req)
|
||||
resultCode := resultCodeFromResponse(resp)
|
||||
grpcCode, grpcMessage, outcome := outcomeFromError(err)
|
||||
rejectReason := telemetry.RejectReason(outcome)
|
||||
|
||||
attrs := []attribute.KeyValue{
|
||||
@@ -82,7 +44,7 @@ func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx conte
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "authenticated_grpc"),
|
||||
zap.String("transport", "grpc"),
|
||||
zap.String("transport", transport),
|
||||
zap.String("stream_kind", streamKind),
|
||||
zap.String("rpc_method", rpcMethod),
|
||||
zap.String("message_type", messageType),
|
||||
@@ -106,15 +68,15 @@ func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx conte
|
||||
|
||||
switch outcome {
|
||||
case telemetry.EdgeOutcomeSuccess:
|
||||
logger.Info("authenticated gRPC request completed", fields...)
|
||||
logger.Info("authenticated edge request completed", fields...)
|
||||
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeDownstreamUnavailable, telemetry.EdgeOutcomeInternalError:
|
||||
logger.Error("authenticated gRPC request failed", fields...)
|
||||
logger.Error("authenticated edge request failed", fields...)
|
||||
default:
|
||||
logger.Warn("authenticated gRPC request rejected", fields...)
|
||||
logger.Warn("authenticated edge request rejected", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func grpcEnvelopeFields(req any) (messageType string, requestID string, traceID string) {
|
||||
func envelopeFieldsFromRequest(req any) (messageType string, requestID string, traceID string) {
|
||||
switch typed := req.(type) {
|
||||
case *gatewayv1.ExecuteCommandRequest:
|
||||
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
|
||||
@@ -125,7 +87,7 @@ func grpcEnvelopeFields(req any) (messageType string, requestID string, traceID
|
||||
}
|
||||
}
|
||||
|
||||
func grpcResultCode(resp any) string {
|
||||
func resultCodeFromResponse(resp any) string {
|
||||
typed, ok := resp.(*gatewayv1.ExecuteCommandResponse)
|
||||
if !ok {
|
||||
return ""
|
||||
@@ -134,7 +96,7 @@ func grpcResultCode(resp any) string {
|
||||
return typed.GetResultCode()
|
||||
}
|
||||
|
||||
func grpcOutcome(err error) (codes.Code, string, telemetry.EdgeOutcome) {
|
||||
func outcomeFromError(err error) (codes.Code, string, telemetry.EdgeOutcome) {
|
||||
switch {
|
||||
case err == nil:
|
||||
return codes.OK, "", telemetry.EdgeOutcomeSuccess
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsPayloadHashWithInvalidLength(t *testing.T) {
|
||||
@@ -25,19 +23,15 @@ func TestExecuteCommandRejectsPayloadHashWithInvalidLength(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
req.PayloadHash = []byte("short")
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(req))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -52,20 +46,16 @@ func TestExecuteCommandRejectsPayloadHashMismatch(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
sum := sha256.Sum256([]byte("other"))
|
||||
req.PayloadHash = sum[:]
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(req))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -80,19 +70,15 @@ func TestSubscribeEventsRejectsPayloadHashWithInvalidLength(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
req.PayloadHash = []byte("short")
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -107,19 +93,15 @@ func TestSubscribeEventsRejectsPayloadHashMismatch(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
sum := sha256.Sum256([]byte("other"))
|
||||
req.PayloadHash = sum[:]
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package grpcapi
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
@@ -13,7 +11,6 @@ import (
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
@@ -41,7 +38,7 @@ var (
|
||||
ErrAuthenticatedPolicyUnavailable = errors.New("authenticated request policy is unavailable")
|
||||
)
|
||||
|
||||
// AuthenticatedRequestLimiter applies authenticated gRPC rate-limit policy to
|
||||
// AuthenticatedRequestLimiter applies authenticated edge rate-limit policy to
|
||||
// one concrete bucket key.
|
||||
type AuthenticatedRequestLimiter interface {
|
||||
// Reserve evaluates key under policy and reports whether the request may
|
||||
@@ -52,10 +49,11 @@ type AuthenticatedRequestLimiter interface {
|
||||
// AuthenticatedRequest describes the authenticated request metadata exposed to
|
||||
// the edge-policy hook.
|
||||
type AuthenticatedRequest struct {
|
||||
// RPCMethod identifies the public gRPC method being processed.
|
||||
// RPCMethod identifies the public RPC method being processed.
|
||||
RPCMethod string
|
||||
|
||||
// PeerIP is the transport peer IP derived from the gRPC connection.
|
||||
// PeerIP is the transport peer IP host part derived from the
|
||||
// authenticated edge HTTP listener peer address.
|
||||
PeerIP string
|
||||
|
||||
// MessageClass is the stable rate-limit and policy class. The gateway uses
|
||||
@@ -258,23 +256,21 @@ func authenticatedMessageClass(messageType string) string {
|
||||
return messageType
|
||||
}
|
||||
|
||||
type peerIPContextKey struct{}
|
||||
|
||||
// contextWithPeerIP attaches the authenticated edge transport peer IP to ctx.
|
||||
// It is set by the transport interceptor before the service decorator stack
|
||||
// runs, and read back via peerIPFromContext.
|
||||
func contextWithPeerIP(ctx context.Context, ip string) context.Context {
|
||||
return context.WithValue(ctx, peerIPContextKey{}, ip)
|
||||
}
|
||||
|
||||
func peerIPFromContext(ctx context.Context) string {
|
||||
peerInfo, ok := peer.FromContext(ctx)
|
||||
if !ok || peerInfo.Addr == nil {
|
||||
return unknownAuthenticatedPeerIP
|
||||
if ip, ok := ctx.Value(peerIPContextKey{}).(string); ok && ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(peerInfo.Addr.String())
|
||||
if value == "" {
|
||||
return unknownAuthenticatedPeerIP
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(value)
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
|
||||
return value
|
||||
return unknownAuthenticatedPeerIP
|
||||
}
|
||||
|
||||
type noopAuthenticatedRequestPolicy struct{}
|
||||
|
||||
@@ -3,7 +3,6 @@ package grpcapi
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -17,10 +16,9 @@ import (
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRateLimitsByIP(t *testing.T) {
|
||||
@@ -41,20 +39,15 @@ func TestExecuteCommandRateLimitsByIP(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2")))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", connectErrorMessage(t, err))
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -76,21 +69,16 @@ func TestExecuteCommandRateLimitsBySession(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2")))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3")))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
@@ -118,21 +106,16 @@ func TestExecuteCommandRateLimitsByUser(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1")))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2")))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3")))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
@@ -159,21 +142,16 @@ func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move"))
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move")))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move")))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename"))
|
||||
_, err = client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename")))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
@@ -193,13 +171,8 @@ func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, policy.requests, 1)
|
||||
@@ -228,16 +201,11 @@ func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rejected by edge policy", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err))
|
||||
assert.Equal(t, "authenticated request rejected by edge policy", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -259,24 +227,19 @@ func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1")))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.False(t, stream.Receive())
|
||||
require.NoError(t, stream.Err())
|
||||
|
||||
err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", connectErrorMessage(t, err))
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -342,13 +305,8 @@ func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) {
|
||||
require.NoError(t, firstPublic.Body.Close())
|
||||
require.NoError(t, secondPublic.Body.Close())
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
// Package grpcapi exposes the authenticated gRPC surface of the gateway.
|
||||
// Package grpcapi exposes the authenticated edge transport surface of the
|
||||
// gateway. Despite the historical package name, the listener is built on
|
||||
// `connectrpc.com/connect` and natively serves the Connect, gRPC, and
|
||||
// gRPC-Web protocols on a single HTTP/h2c listener. The configured Go
|
||||
// types and environment variable names retain the `gRPC` infix for
|
||||
// operational stability — they describe the authenticated edge tier, not
|
||||
// the wire protocol.
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
@@ -6,6 +12,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/authn"
|
||||
@@ -18,14 +25,17 @@ import (
|
||||
"galaxy/gateway/internal/session"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
"galaxy/gateway/proto/galaxy/gateway/v1/gatewayv1connect"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"connectrpc.com/connect"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
// ServerDependencies describes the optional collaborators used by the
|
||||
// authenticated gRPC server. The zero value is valid and keeps the process
|
||||
// authenticated edge server. The zero value is valid and keeps the process
|
||||
// runnable with the built-in unimplemented service stub.
|
||||
type ServerDependencies struct {
|
||||
// Service optionally handles the post-bootstrap SubscribeEvents lifecycle
|
||||
@@ -45,12 +55,12 @@ type ServerDependencies struct {
|
||||
ResponseSigner authn.ResponseSigner
|
||||
|
||||
// SessionCache resolves authenticated device sessions after the envelope
|
||||
// gate succeeds. When nil, the authenticated gRPC surface remains runnable
|
||||
// gate succeeds. When nil, the authenticated edge surface remains runnable
|
||||
// but valid envelopes fail closed as session-cache unavailable.
|
||||
SessionCache session.Cache
|
||||
|
||||
// Clock provides current server time for freshness checks. When nil, the
|
||||
// authenticated gRPC surface uses the system clock.
|
||||
// authenticated edge surface uses the system clock.
|
||||
Clock clock.Clock
|
||||
|
||||
// ReplayStore reserves authenticated request identifiers after signature
|
||||
@@ -59,26 +69,28 @@ type ServerDependencies struct {
|
||||
ReplayStore replay.Store
|
||||
|
||||
// Limiter applies authenticated rate limits after the request passes the
|
||||
// transport authenticity checks. When nil, the authenticated gRPC surface
|
||||
// transport authenticity checks. When nil, the authenticated edge surface
|
||||
// uses a process-local in-memory limiter.
|
||||
Limiter AuthenticatedRequestLimiter
|
||||
|
||||
// Policy evaluates later authenticated edge policy after rate limits pass.
|
||||
// When nil, the authenticated gRPC surface applies a no-op allow policy.
|
||||
// When nil, the authenticated edge surface applies a no-op allow policy.
|
||||
Policy AuthenticatedRequestPolicy
|
||||
|
||||
// Logger writes structured logs for authenticated gRPC traffic.
|
||||
// Logger writes structured logs for authenticated edge traffic.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Telemetry records low-cardinality gRPC metrics.
|
||||
// Telemetry records low-cardinality edge metrics.
|
||||
Telemetry *telemetry.Runtime
|
||||
|
||||
// PushHub is the active authenticated push-stream hub. When present, the
|
||||
// server closes active streams before GracefulStop during shutdown.
|
||||
// server closes active streams before HTTP graceful shutdown.
|
||||
PushHub *push.Hub
|
||||
}
|
||||
|
||||
// Server owns the authenticated gRPC listener exposed by the gateway.
|
||||
// Server owns the authenticated edge HTTP/h2c listener exposed by the
|
||||
// gateway. It serves the Connect, gRPC, and gRPC-Web protocols from a
|
||||
// single net/http listener.
|
||||
type Server struct {
|
||||
cfg config.AuthenticatedGRPCConfig
|
||||
service gatewayv1.EdgeGatewayServer
|
||||
@@ -87,11 +99,11 @@ type Server struct {
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *grpc.Server
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs an authenticated gRPC server for the supplied listener
|
||||
// NewServer constructs an authenticated edge server for the supplied listener
|
||||
// configuration and dependency bundle. Nil dependencies are replaced with safe
|
||||
// defaults so the gateway can expose the documented transport surface with the
|
||||
// full auth pipeline wired from built-in fallbacks.
|
||||
@@ -128,17 +140,17 @@ func NewServer(cfg config.AuthenticatedGRPCConfig, deps ServerDependencies) *Ser
|
||||
deps.SessionCache,
|
||||
),
|
||||
),
|
||||
logger: deps.Logger.Named("authenticated_grpc"),
|
||||
logger: deps.Logger.Named("authenticated_edge"),
|
||||
pushHub: deps.PushHub,
|
||||
metrics: deps.Telemetry,
|
||||
}
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the authenticated gRPC surface
|
||||
// until Shutdown closes the server.
|
||||
// Run binds the configured listener and serves the authenticated edge
|
||||
// surface until Shutdown closes the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run authenticated gRPC server: nil context")
|
||||
return errors.New("run authenticated edge server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
@@ -146,23 +158,30 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run authenticated gRPC server: listen on %q: %w", s.cfg.Addr, err)
|
||||
return fmt.Errorf("run authenticated edge server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
grpc.ConnectionTimeout(s.cfg.ConnectionTimeout),
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
grpc.ChainUnaryInterceptor(observabilityUnaryInterceptor(s.logger, s.metrics)),
|
||||
grpc.ChainStreamInterceptor(observabilityStreamInterceptor(s.logger, s.metrics)),
|
||||
mux := http.NewServeMux()
|
||||
connectHandler := newConnectEdgeAdapter(s.service)
|
||||
path, handler := gatewayv1connect.NewEdgeGatewayHandler(
|
||||
connectHandler,
|
||||
connect.WithInterceptors(observabilityConnectInterceptor(s.logger, s.metrics)),
|
||||
)
|
||||
gatewayv1.RegisterEdgeGatewayServer(grpcServer, s.service)
|
||||
mux.Handle(path, handler)
|
||||
|
||||
tracedHandler := otelhttp.NewHandler(mux, "authenticated_edge")
|
||||
http2Server := &http2.Server{IdleTimeout: s.cfg.ConnectionTimeout}
|
||||
httpServer := &http.Server{
|
||||
Handler: h2c.NewHandler(tracedHandler, http2Server),
|
||||
ReadHeaderTimeout: s.cfg.ConnectionTimeout,
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = grpcServer
|
||||
s.server = httpServer
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("authenticated gRPC server started", zap.String("addr", listener.Addr().String()))
|
||||
s.logger.Info("authenticated edge server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
@@ -171,24 +190,22 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = grpcServer.Serve(listener)
|
||||
err = httpServer.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, grpc.ErrServerStopped):
|
||||
s.logger.Info("authenticated gRPC server stopped")
|
||||
case err == nil, errors.Is(err, http.ErrServerClosed):
|
||||
s.logger.Info("authenticated edge server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run authenticated gRPC server: serve on %q: %w", s.cfg.Addr, err)
|
||||
return fmt.Errorf("run authenticated edge server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the authenticated gRPC server within ctx. When the
|
||||
// graceful stop exceeds ctx, the server is force-stopped before returning the
|
||||
// Shutdown gracefully stops the authenticated edge server within ctx. When the
|
||||
// graceful stop exceeds ctx, the server is force-closed before returning the
|
||||
// timeout to the caller.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown authenticated gRPC server: nil context")
|
||||
return errors.New("shutdown authenticated edge server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
@@ -203,20 +220,16 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||
s.pushHub.Shutdown()
|
||||
}
|
||||
|
||||
stopped := make(chan struct{})
|
||||
go func() {
|
||||
server.GracefulStop()
|
||||
close(stopped)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopped:
|
||||
err := server.Shutdown(ctx)
|
||||
if err == nil {
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
server.Stop()
|
||||
<-stopped
|
||||
return fmt.Errorf("shutdown authenticated gRPC server: %w", ctx.Err())
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
_ = server.Close()
|
||||
return fmt.Errorf("shutdown authenticated edge server: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("shutdown authenticated edge server: %w", err)
|
||||
}
|
||||
|
||||
func (s *Server) listenAddr() string {
|
||||
|
||||
@@ -2,6 +2,10 @@ package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -9,13 +13,12 @@ import (
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
"galaxy/gateway/proto/galaxy/gateway/v1/gatewayv1connect"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) {
|
||||
@@ -25,15 +28,11 @@ func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(&gatewayv1.ExecuteCommandRequest{}))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) {
|
||||
@@ -43,15 +42,11 @@ func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
|
||||
@@ -61,13 +56,9 @@ func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(&gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: "v2",
|
||||
DeviceSessionId: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
@@ -76,10 +67,10 @@ func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
|
||||
PayloadBytes: []byte("payload"),
|
||||
PayloadHash: []byte("hash"),
|
||||
Signature: []byte("signature"),
|
||||
})
|
||||
}))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err))
|
||||
assert.Equal(t, `unsupported protocol_version "v2"`, connectErrorMessage(t, err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) {
|
||||
@@ -96,15 +87,11 @@ func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unimplemented, status.Code(err))
|
||||
assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
@@ -120,16 +107,12 @@ func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "replay store is unavailable", connectErrorMessage(t, err))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) {
|
||||
@@ -149,22 +132,22 @@ func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest())
|
||||
stream, err := client.SubscribeEvents(ctx, connect.NewRequest(newValidSubscribeEventsRequest()))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = stream.Close() })
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
recvResult := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvResult <- recvErr
|
||||
if stream.Receive() {
|
||||
recvResult <- errors.New("stream produced unexpected event")
|
||||
return
|
||||
}
|
||||
recvResult <- stream.Err()
|
||||
}()
|
||||
|
||||
require.Never(t, func() bool {
|
||||
@@ -188,7 +171,7 @@ func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation")
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.Canceled, status.Code(recvErr))
|
||||
assert.Equal(t, connect.CodeCanceled, connect.CodeOf(recvErr))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
@@ -204,16 +187,12 @@ func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "replay store is unavailable", connectErrorMessage(t, err))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsFailsClosedWhenResponseSignerUnavailable(t *testing.T) {
|
||||
@@ -231,16 +210,12 @@ func TestSubscribeEventsFailsClosedWhenResponseSignerUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "response signer is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "response signer is unavailable", connectErrorMessage(t, err))
|
||||
}
|
||||
|
||||
func TestServerLifecycle(t *testing.T) {
|
||||
@@ -248,21 +223,23 @@ func TestServerLifecycle(t *testing.T) {
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
require.NoError(t, conn.Close())
|
||||
// Probe the listener before shutdown so we know it accepted at
|
||||
// least one TCP connection.
|
||||
probe, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, probe.Close())
|
||||
|
||||
runGateway.stop(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
// After shutdown the listener must refuse new TCP connections.
|
||||
dialCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.Error(t, err)
|
||||
dialer := &net.Dialer{}
|
||||
closedConn, err := dialer.DialContext(dialCtx, "tcp", addr)
|
||||
if err == nil {
|
||||
_ = closedConn.Close()
|
||||
t.Fatalf("expected dial to %s to fail after shutdown", addr)
|
||||
}
|
||||
}
|
||||
|
||||
type runningGateway struct {
|
||||
@@ -341,19 +318,36 @@ func waitForListenAddr(t *testing.T, server *Server) string {
|
||||
return addr
|
||||
}
|
||||
|
||||
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
||||
// newEdgeClient returns a Connect client speaking HTTP/2 cleartext to the
|
||||
// authenticated edge listener. AllowHTTP forces the client to issue plain
|
||||
// HTTP/2 requests (h2c) instead of attempting TLS, which the gateway's
|
||||
// in-process test bootstrap does not configure.
|
||||
func newEdgeClient(t *testing.T, addr string) gatewayv1connect.EdgeGatewayClient {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return conn
|
||||
httpClient := &http.Client{
|
||||
Transport: &http2.Transport{
|
||||
AllowHTTP: true,
|
||||
DialTLSContext: func(ctx context.Context, network, target string, _ *tls.Config) (net.Conn, error) {
|
||||
return (&net.Dialer{}).DialContext(ctx, network, target)
|
||||
},
|
||||
},
|
||||
}
|
||||
return gatewayv1connect.NewEdgeGatewayClient(httpClient, "http://"+addr)
|
||||
}
|
||||
|
||||
// connectErrorMessage extracts the *connect.Error message from err. It
|
||||
// fails the test if err is not a *connect.Error so the caller's expected
|
||||
// message comparison doesn't accidentally match the wrapped Go error
|
||||
// string instead of the protocol-level message.
|
||||
func connectErrorMessage(t require.TestingT, err error) string {
|
||||
if helper, ok := t.(interface{ Helper() }); ok {
|
||||
helper.Helper()
|
||||
}
|
||||
|
||||
var connectErr *connect.Error
|
||||
if !errors.As(err, &connectErr) {
|
||||
require.FailNowf(t, "expected *connect.Error", "got %T: %v", err, err)
|
||||
}
|
||||
return connectErr.Message()
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (unavailableSessionCache) Lookup(context.Context, string) (session.Record,
|
||||
return session.Record{}, errors.New("session cache is unavailable")
|
||||
}
|
||||
|
||||
func (unavailableSessionCache) MarkRevoked(string) {}
|
||||
func (unavailableSessionCache) MarkRevoked(string) {}
|
||||
func (unavailableSessionCache) MarkAllRevokedForUser(string) {}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = sessionLookupService{}
|
||||
|
||||
@@ -3,17 +3,15 @@ package grpcapi
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsUnknownSession(t *testing.T) {
|
||||
@@ -31,16 +29,11 @@ func TestExecuteCommandRejectsUnknownSession(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "unknown device session", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
|
||||
assert.Equal(t, "unknown device session", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -59,16 +52,11 @@ func TestSubscribeEventsRejectsUnknownSession(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "unknown device session", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
|
||||
assert.Equal(t, "unknown device session", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -83,16 +71,11 @@ func TestExecuteCommandRejectsRevokedSession(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err))
|
||||
assert.Equal(t, "device session is revoked", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -107,16 +90,11 @@ func TestSubscribeEventsRejectsRevokedSession(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err))
|
||||
assert.Equal(t, "device session is revoked", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -135,16 +113,11 @@ func TestExecuteCommandRejectsSessionCacheUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "session cache is unavailable", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -163,16 +136,11 @@ func TestSubscribeEventsRejectsSessionCacheUnavailable(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "session cache is unavailable", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -196,15 +164,10 @@ func TestExecuteCommandAttachesResolvedSession(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
response, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
assert.Equal(t, "request-123", response.Msg.GetRequestId())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAttachesResolvedSession(t *testing.T) {
|
||||
@@ -227,20 +190,15 @@ func TestSubscribeEventsAttachesResolvedSession(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequest()))
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.False(t, stream.Receive())
|
||||
require.NoError(t, stream.Err())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) {
|
||||
@@ -269,20 +227,15 @@ func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
stream, err := client.SubscribeEvents(context.Background(), connect.NewRequest(newValidSubscribeEventsRequest()))
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.False(t, stream.Receive())
|
||||
require.NoError(t, stream.Err())
|
||||
}
|
||||
|
||||
type staticSessionCache struct {
|
||||
@@ -293,5 +246,5 @@ func (c staticSessionCache) Lookup(ctx context.Context, deviceSessionID string)
|
||||
return c.lookupFunc(ctx, deviceSessionID)
|
||||
}
|
||||
|
||||
func (staticSessionCache) MarkRevoked(string) {}
|
||||
func (staticSessionCache) MarkRevoked(string) {}
|
||||
func (staticSessionCache) MarkAllRevokedForUser(string) {}
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsInvalidSignature(t *testing.T) {
|
||||
@@ -24,19 +22,15 @@ func TestExecuteCommandRejectsInvalidSignature(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
req.Signature[0] ^= 0xff
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(req))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
|
||||
assert.Equal(t, "invalid request signature", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -57,16 +51,11 @@ func TestExecuteCommandRejectsWrongKey(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
|
||||
assert.Equal(t, "invalid request signature", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -87,16 +76,11 @@ func TestExecuteCommandRejectsInvalidCachedPublicKey(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
client := newEdgeClient(t, addr)
|
||||
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "session cache is unavailable", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
@@ -111,19 +95,15 @@ func TestSubscribeEventsRejectsInvalidSignature(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := newEdgeClient(t, addr)
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
req.Signature[0] ^= 0xff
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
|
||||
assert.Equal(t, "invalid request signature", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -144,16 +124,11 @@ func TestSubscribeEventsRejectsWrongKey(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
|
||||
assert.Equal(t, "invalid request signature", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -174,15 +149,10 @@ func TestSubscribeEventsRejectsInvalidCachedPublicKey(t *testing.T) {
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
client := newEdgeClient(t, addr)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
||||
assert.Equal(t, "session cache is unavailable", connectErrorMessage(t, err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
@@ -7,19 +7,21 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/authn"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
"galaxy/gateway/proto/galaxy/gateway/v1/gatewayv1connect"
|
||||
|
||||
gatewayfbs "galaxy/schema/fbs/gateway"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -170,28 +172,37 @@ func (c fixedClock) Now() time.Time {
|
||||
func recvBootstrapEvent(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
}, stream *connect.ServerStreamForClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
t.Helper()
|
||||
|
||||
event, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
if !stream.Receive() {
|
||||
err := stream.Err()
|
||||
if err == nil {
|
||||
err = errors.New("stream closed before bootstrap event")
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return event
|
||||
return stream.Msg()
|
||||
}
|
||||
|
||||
func subscribeEventsError(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error {
|
||||
}, ctx context.Context, client gatewayv1connect.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error {
|
||||
t.Helper()
|
||||
|
||||
stream, err := client.SubscribeEvents(ctx, req)
|
||||
stream, err := client.SubscribeEvents(ctx, connect.NewRequest(req))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = stream.Close() }()
|
||||
|
||||
_, err = stream.Recv()
|
||||
return err
|
||||
if stream.Receive() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return stream.Err()
|
||||
}
|
||||
|
||||
func assertServerTimeBootstrapEvent(t interface {
|
||||
|
||||
Reference in New Issue
Block a user