360 lines
9.5 KiB
Go
360 lines
9.5 KiB
Go
package grpcapi
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/app"
|
|
"galaxy/gateway/internal/config"
|
|
"galaxy/gateway/internal/session"
|
|
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
|
}
|
|
|
|
func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{})
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
|
}
|
|
|
|
func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{
|
|
ProtocolVersion: "v2",
|
|
DeviceSessionId: "device-session-123",
|
|
MessageType: "fleet.move",
|
|
TimestampMs: 123456789,
|
|
RequestId: "request-123",
|
|
PayloadBytes: []byte("payload"),
|
|
PayloadHash: []byte("hash"),
|
|
Signature: []byte("signature"),
|
|
})
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
|
assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message())
|
|
}
|
|
|
|
func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
return newActiveSessionRecord(), nil
|
|
},
|
|
},
|
|
ReplayStore: staticReplayStore{},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unimplemented, status.Code(err))
|
|
}
|
|
|
|
func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
return newActiveSessionRecord(), nil
|
|
},
|
|
},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unavailable, status.Code(err))
|
|
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
|
}
|
|
|
|
func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
return newActiveSessionRecord(), nil
|
|
},
|
|
},
|
|
ReplayStore: staticReplayStore{},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest())
|
|
require.NoError(t, err)
|
|
|
|
event := recvBootstrapEvent(t, stream)
|
|
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
|
|
|
recvResult := make(chan error, 1)
|
|
go func() {
|
|
_, recvErr := stream.Recv()
|
|
recvResult <- recvErr
|
|
}()
|
|
|
|
require.Never(t, func() bool {
|
|
select {
|
|
case <-recvResult:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation")
|
|
|
|
cancel()
|
|
|
|
var recvErr error
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case recvErr = <-recvResult:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation")
|
|
require.Error(t, recvErr)
|
|
assert.Equal(t, codes.Canceled, status.Code(recvErr))
|
|
}
|
|
|
|
func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
return newActiveSessionRecord(), nil
|
|
},
|
|
},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unavailable, status.Code(err))
|
|
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
|
}
|
|
|
|
func TestSubscribeEventsFailsClosedWhenResponseSignerUnavailable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
ResponseSigner: unavailableResponseSigner{},
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
return newActiveSessionRecord(), nil
|
|
},
|
|
},
|
|
ReplayStore: staticReplayStore{},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unavailable, status.Code(err))
|
|
assert.Equal(t, "response signer is unavailable", status.Convert(err).Message())
|
|
}
|
|
|
|
func TestServerLifecycle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, runGateway := newTestGateway(t, ServerDependencies{})
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
require.NoError(t, conn.Close())
|
|
|
|
runGateway.stop(t)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err := grpc.DialContext(
|
|
ctx,
|
|
addr,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
type runningGateway struct {
|
|
cancel context.CancelFunc
|
|
resultCh chan error
|
|
}
|
|
|
|
func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) {
|
|
t.Helper()
|
|
|
|
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
|
grpcCfg.Addr = "127.0.0.1:0"
|
|
grpcCfg.FreshnessWindow = testFreshnessWindow
|
|
|
|
return newTestGatewayWithGRPCConfig(t, grpcCfg, deps)
|
|
}
|
|
|
|
func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) {
|
|
t.Helper()
|
|
|
|
cfg := config.Config{
|
|
ShutdownTimeout: time.Second,
|
|
AuthenticatedGRPC: grpcCfg,
|
|
}
|
|
|
|
if deps.Clock == nil {
|
|
deps.Clock = fixedClock{now: testCurrentTime}
|
|
}
|
|
if deps.ResponseSigner == nil {
|
|
deps.ResponseSigner = newTestResponseSigner()
|
|
}
|
|
if deps.Router == nil && deps.Service != nil {
|
|
deps.Router = executeCommandAdapterRouter{service: deps.Service}
|
|
}
|
|
|
|
server := NewServer(cfg.AuthenticatedGRPC, deps)
|
|
application := app.New(cfg, server)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
resultCh := make(chan error, 1)
|
|
go func() {
|
|
resultCh <- application.Run(ctx)
|
|
}()
|
|
|
|
return server, runningGateway{
|
|
cancel: cancel,
|
|
resultCh: resultCh,
|
|
}
|
|
}
|
|
|
|
func (g runningGateway) stop(t *testing.T) {
|
|
t.Helper()
|
|
|
|
g.cancel()
|
|
|
|
var err error
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case err = <-g.resultCh:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func waitForListenAddr(t *testing.T, server *Server) string {
|
|
t.Helper()
|
|
|
|
var addr string
|
|
require.Eventually(t, func() bool {
|
|
addr = server.listenAddr()
|
|
return addr != ""
|
|
}, time.Second, 10*time.Millisecond, "server did not start listening")
|
|
return addr
|
|
}
|
|
|
|
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
|
t.Helper()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := grpc.DialContext(
|
|
ctx,
|
|
addr,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
return conn
|
|
}
|