118f7c17a2
Replace the native-gRPC server bootstrap with a single `connectrpc.com/connect` HTTP/h2c listener. Connect-Go natively serves Connect, gRPC, and gRPC-Web on the same port, so browsers can now reach the authenticated surface without giving up the gRPC framing native and desktop clients may use later. The decorator stack (envelope → session → payload-hash → signature → freshness/replay → rate-limit → routing/push) is reused unchanged behind a small Connect → gRPC adapter and a `grpc.ServerStream` shim around `*connect.ServerStream`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
328 lines
11 KiB
Go
328 lines
11 KiB
Go
package grpcapi
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/gateway/authn"
|
|
"galaxy/gateway/internal/config"
|
|
"galaxy/gateway/internal/downstream"
|
|
"galaxy/gateway/internal/testutil"
|
|
|
|
"connectrpc.com/connect"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
signer := newTestEd25519ResponseSigner()
|
|
moveClient := &recordingDownstreamClient{
|
|
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
assert.Equal(t, downstream.AuthenticatedCommand{
|
|
ProtocolVersion: "v1",
|
|
UserID: "user-123",
|
|
DeviceSessionID: "device-session-123",
|
|
MessageType: "fleet.move",
|
|
TimestampMS: testCurrentTime.UnixMilli(),
|
|
RequestID: "request-123",
|
|
TraceID: "trace-123",
|
|
PayloadBytes: []byte("payload"),
|
|
}, command)
|
|
|
|
return downstream.UnaryResult{
|
|
ResultCode: "accepted",
|
|
PayloadBytes: []byte("downstream-response"),
|
|
}, nil
|
|
},
|
|
}
|
|
renameClient := &recordingDownstreamClient{}
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": moveClient,
|
|
"fleet.rename": renameClient,
|
|
}),
|
|
ResponseSigner: signer,
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
response, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, "v1", response.Msg.GetProtocolVersion())
|
|
assert.Equal(t, "request-123", response.Msg.GetRequestId())
|
|
assert.Equal(t, testCurrentTime.UnixMilli(), response.Msg.GetTimestampMs())
|
|
assert.Equal(t, "accepted", response.Msg.GetResultCode())
|
|
assert.Equal(t, []byte("downstream-response"), response.Msg.GetPayloadBytes())
|
|
assert.Equal(t, 1, moveClient.executeCalls)
|
|
assert.Zero(t, renameClient.executeCalls)
|
|
|
|
wantHash := sha256.Sum256([]byte("downstream-response"))
|
|
assert.Equal(t, wantHash[:], response.Msg.GetPayloadHash())
|
|
require.NoError(t, authn.VerifyPayloadHash(response.Msg.GetPayloadBytes(), response.Msg.GetPayloadHash()))
|
|
require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.Msg.GetSignature(), authn.ResponseSigningFields{
|
|
ProtocolVersion: response.Msg.GetProtocolVersion(),
|
|
RequestID: response.Msg.GetRequestId(),
|
|
TimestampMS: response.Msg.GetTimestampMs(),
|
|
ResultCode: response.Msg.GetResultCode(),
|
|
PayloadHash: response.Msg.GetPayloadHash(),
|
|
}))
|
|
}
|
|
|
|
func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Router: downstream.NewStaticRouter(nil),
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
ResponseSigner: newTestResponseSigner(),
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
require.Error(t, err)
|
|
assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err))
|
|
assert.Equal(t, "message_type is not routed", connectErrorMessage(t, err))
|
|
}
|
|
|
|
func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
failingClient := &recordingDownstreamClient{
|
|
executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
return downstream.UnaryResult{}, fmt.Errorf("rpc transport failed: %w", downstream.ErrDownstreamUnavailable)
|
|
},
|
|
}
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": failingClient,
|
|
}),
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
ResponseSigner: newTestResponseSigner(),
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
require.Error(t, err)
|
|
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
|
assert.Equal(t, "downstream service is unavailable", connectErrorMessage(t, err))
|
|
assert.Equal(t, 1, failingClient.executeCalls)
|
|
}
|
|
|
|
func TestExecuteCommandMapsDownstreamTimeoutToUnavailable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
stallingClient := &recordingDownstreamClient{
|
|
executeFunc: func(ctx context.Context, _ downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
<-ctx.Done()
|
|
return downstream.UnaryResult{}, ctx.Err()
|
|
},
|
|
}
|
|
|
|
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
|
cfg.DownstreamTimeout = 50 * time.Millisecond
|
|
}), ServerDependencies{
|
|
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": stallingClient,
|
|
}),
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
ResponseSigner: newTestResponseSigner(),
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
require.Error(t, err)
|
|
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
|
assert.Equal(t, "downstream service is unavailable", connectErrorMessage(t, err))
|
|
assert.Equal(t, 1, stallingClient.executeCalls)
|
|
}
|
|
|
|
func TestExecuteCommandFailsClosedWhenResponseSignerUnavailable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
successClient := &recordingDownstreamClient{
|
|
executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
return downstream.UnaryResult{
|
|
ResultCode: "accepted",
|
|
PayloadBytes: []byte("downstream-response"),
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": successClient,
|
|
}),
|
|
ResponseSigner: unavailableResponseSigner{},
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
require.Error(t, err)
|
|
assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err))
|
|
assert.Equal(t, "response signer is unavailable", connectErrorMessage(t, err))
|
|
assert.Equal(t, 1, successClient.executeCalls)
|
|
}
|
|
|
|
func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := zap.NewNop()
|
|
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
|
|
|
var (
|
|
seenSpanContext trace.SpanContext
|
|
seenCommand downstream.AuthenticatedCommand
|
|
)
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": &recordingDownstreamClient{
|
|
executeFunc: func(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
seenSpanContext = trace.SpanContextFromContext(ctx)
|
|
seenCommand = command
|
|
|
|
return downstream.UnaryResult{
|
|
ResultCode: "accepted",
|
|
PayloadBytes: []byte("downstream-response"),
|
|
}, nil
|
|
},
|
|
},
|
|
}),
|
|
ResponseSigner: newTestResponseSigner(),
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
Logger: logger,
|
|
Telemetry: telemetryRuntime,
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
require.NoError(t, err)
|
|
|
|
assert.True(t, seenSpanContext.IsValid())
|
|
assert.Equal(t, "trace-123", seenCommand.TraceID)
|
|
}
|
|
|
|
func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
started := make(chan struct{})
|
|
release := make(chan struct{})
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": &recordingDownstreamClient{
|
|
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
close(started)
|
|
<-release
|
|
|
|
return downstream.UnaryResult{
|
|
ResultCode: "accepted",
|
|
PayloadBytes: []byte("downstream-response"),
|
|
}, nil
|
|
},
|
|
},
|
|
}),
|
|
ResponseSigner: newTestResponseSigner(),
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
resultCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
resultCh <- err
|
|
}()
|
|
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case <-started:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, time.Second, 10*time.Millisecond, "downstream execution did not start")
|
|
|
|
runGateway.cancel()
|
|
|
|
require.Never(t, func() bool {
|
|
select {
|
|
case <-resultCh:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, 100*time.Millisecond, 10*time.Millisecond, "unary request returned before downstream release")
|
|
|
|
close(release)
|
|
|
|
var err error
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case err = <-resultCh:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, time.Second, 10*time.Millisecond, "unary request did not drain before shutdown timeout")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger, logBuffer := testutil.NewObservedLogger(t)
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": &recordingDownstreamClient{},
|
|
}),
|
|
ResponseSigner: newTestResponseSigner(),
|
|
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
|
ReplayStore: staticReplayStore{},
|
|
Logger: logger,
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
client := newEdgeClient(t, addr)
|
|
_, err := client.ExecuteCommand(context.Background(), connect.NewRequest(newValidExecuteCommandRequest()))
|
|
require.NoError(t, err)
|
|
|
|
logOutput := logBuffer.String()
|
|
assert.NotContains(t, logOutput, "payload_hash")
|
|
assert.NotContains(t, logOutput, "signature")
|
|
assert.NotContains(t, logOutput, `"payload"`)
|
|
}
|