Files
2026-04-02 19:18:42 +02:00

421 lines
12 KiB
Go

package grpcapi
import (
"context"
"testing"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
func TestParseExecuteCommandRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*gatewayv1.ExecuteCommandRequest)
wantCode codes.Code
wantMessage string
assertValid func(*testing.T, *gatewayv1.ExecuteCommandRequest, parsedEnvelope)
}{
{
name: "nil request",
wantCode: codes.InvalidArgument,
wantMessage: "request envelope must not be nil",
},
{
name: "empty protocol version",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.ProtocolVersion = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "protocol_version must not be empty",
},
{
name: "empty device session id",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.DeviceSessionId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "device_session_id must not be empty",
},
{
name: "empty message type",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.MessageType = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "message_type must not be empty",
},
{
name: "zero timestamp",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.TimestampMs = 0
},
wantCode: codes.InvalidArgument,
wantMessage: "timestamp_ms must be greater than zero",
},
{
name: "empty request id",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.RequestId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "request_id must not be empty",
},
{
name: "empty payload bytes",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.PayloadBytes = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_bytes must not be empty",
},
{
name: "empty payload hash",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.PayloadHash = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_hash must not be empty",
},
{
name: "empty signature",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.Signature = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "signature must not be empty",
},
{
name: "unsupported protocol version",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.ProtocolVersion = "v2"
},
wantCode: codes.FailedPrecondition,
wantMessage: `unsupported protocol_version "v2"`,
},
{
name: "valid request",
wantCode: codes.OK,
assertValid: func(t *testing.T, req *gatewayv1.ExecuteCommandRequest, envelope parsedEnvelope) {
t.Helper()
assert.Equal(t, supportedProtocolVersion, envelope.ProtocolVersion)
assert.Equal(t, req.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, req.GetMessageType(), envelope.MessageType)
assert.Equal(t, req.GetTimestampMs(), envelope.TimestampMS)
assert.Equal(t, req.GetRequestId(), envelope.RequestID)
assert.Equal(t, req.GetTraceId(), envelope.TraceID)
assert.Equal(t, req.GetPayloadBytes(), envelope.PayloadBytes)
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, req.GetSignature(), envelope.Signature)
originalPayloadBytes := append([]byte(nil), req.GetPayloadBytes()...)
originalPayloadHash := append([]byte(nil), req.GetPayloadHash()...)
originalSignature := append([]byte(nil), req.GetSignature()...)
envelope.PayloadBytes[0] = 'X'
envelope.PayloadHash[0] = 'Y'
envelope.Signature[0] = 'Z'
assert.Equal(t, originalPayloadBytes, req.GetPayloadBytes())
assert.Equal(t, originalPayloadHash, req.GetPayloadHash())
assert.Equal(t, originalSignature, req.GetSignature())
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var req *gatewayv1.ExecuteCommandRequest
if tt.name != "nil request" {
req = newValidExecuteCommandRequest()
if tt.mutate != nil {
tt.mutate(req)
}
}
envelope, err := parseExecuteCommandRequest(req)
if tt.wantCode != codes.OK {
require.Error(t, err)
assert.Equal(t, tt.wantCode, status.Code(err))
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
return
}
require.NoError(t, err)
require.NotNil(t, tt.assertValid)
tt.assertValid(t, req, envelope)
})
}
}
func TestParseSubscribeEventsRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*gatewayv1.SubscribeEventsRequest)
wantCode codes.Code
wantMessage string
assertValid func(*testing.T, *gatewayv1.SubscribeEventsRequest, parsedEnvelope)
}{
{
name: "nil request",
wantCode: codes.InvalidArgument,
wantMessage: "request envelope must not be nil",
},
{
name: "empty protocol version",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.ProtocolVersion = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "protocol_version must not be empty",
},
{
name: "empty device session id",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.DeviceSessionId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "device_session_id must not be empty",
},
{
name: "empty message type",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.MessageType = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "message_type must not be empty",
},
{
name: "zero timestamp",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.TimestampMs = 0
},
wantCode: codes.InvalidArgument,
wantMessage: "timestamp_ms must be greater than zero",
},
{
name: "empty request id",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.RequestId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "request_id must not be empty",
},
{
name: "empty payload hash",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.PayloadHash = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_hash must not be empty",
},
{
name: "empty signature",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.Signature = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "signature must not be empty",
},
{
name: "unsupported protocol version",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.ProtocolVersion = "v2"
},
wantCode: codes.FailedPrecondition,
wantMessage: `unsupported protocol_version "v2"`,
},
{
name: "valid request with empty payload bytes",
wantCode: codes.OK,
assertValid: func(t *testing.T, req *gatewayv1.SubscribeEventsRequest, envelope parsedEnvelope) {
t.Helper()
assert.Empty(t, req.GetPayloadBytes())
assert.Empty(t, envelope.PayloadBytes)
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, req.GetSignature(), envelope.Signature)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var req *gatewayv1.SubscribeEventsRequest
if tt.name != "nil request" {
req = newValidSubscribeEventsRequest()
if tt.mutate != nil {
tt.mutate(req)
}
}
envelope, err := parseSubscribeEventsRequest(req)
if tt.wantCode != codes.OK {
require.Error(t, err)
assert.Equal(t, tt.wantCode, status.Code(err))
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
return
}
require.NoError(t, err)
require.NotNil(t, tt.assertValid)
tt.assertValid(t, req, envelope)
})
}
}
func TestEnvelopeValidatingServiceExecuteCommandRejectsInvalidRequestBeforeDelegate(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
service := newEnvelopeValidatingService(delegate)
_, err := service.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Zero(t, delegate.executeCalls)
}
func TestEnvelopeValidatingServiceSubscribeEventsRejectsInvalidRequestBeforeDelegate(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
service := newEnvelopeValidatingService(delegate)
err := service.SubscribeEvents(&gatewayv1.SubscribeEventsRequest{}, stubGatewayEventStream{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Zero(t, delegate.subscribeCalls)
}
func TestEnvelopeValidatingServiceExecuteCommandAttachesParsedEnvelope(t *testing.T) {
t.Parallel()
want := newValidExecuteCommandRequest()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
require.True(t, ok)
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
assert.Equal(t, want.GetPayloadBytes(), envelope.PayloadBytes)
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
service := newEnvelopeValidatingService(delegate)
response, err := service.ExecuteCommand(context.Background(), want)
require.NoError(t, err)
assert.Equal(t, want.GetRequestId(), response.GetRequestId())
assert.Equal(t, 1, delegate.executeCalls)
}
func TestEnvelopeValidatingServiceSubscribeEventsAttachesParsedEnvelope(t *testing.T) {
t.Parallel()
want := newValidSubscribeEventsRequest()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
envelope, ok := parsedEnvelopeFromContext(stream.Context())
require.True(t, ok)
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
assert.Equal(t, want.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, want.GetSignature(), envelope.Signature)
return nil
},
}
service := newEnvelopeValidatingService(delegate)
err := service.SubscribeEvents(want, stubGatewayEventStream{})
require.NoError(t, err)
assert.Equal(t, 1, delegate.subscribeCalls)
}
type recordingEdgeGatewayService struct {
gatewayv1.UnimplementedEdgeGatewayServer
executeCalls int
subscribeCalls int
executeCommandFunc func(context.Context, *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error)
subscribeEventsFunc func(*gatewayv1.SubscribeEventsRequest, grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error
}
func (s *recordingEdgeGatewayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
s.executeCalls++
if s.executeCommandFunc != nil {
return s.executeCommandFunc(ctx, req)
}
return &gatewayv1.ExecuteCommandResponse{}, nil
}
func (s *recordingEdgeGatewayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
s.subscribeCalls++
if s.subscribeEventsFunc != nil {
return s.subscribeEventsFunc(req, stream)
}
return nil
}
type stubGatewayEventStream struct {
grpc.ServerStream
ctx context.Context
}
func (s stubGatewayEventStream) Send(*gatewayv1.GatewayEvent) error {
return nil
}
func (s stubGatewayEventStream) SetHeader(metadata.MD) error {
return nil
}
func (s stubGatewayEventStream) SendHeader(metadata.MD) error {
return nil
}
func (s stubGatewayEventStream) SetTrailer(metadata.MD) {}
func (s stubGatewayEventStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
func (s stubGatewayEventStream) SendMsg(any) error {
return nil
}
func (s stubGatewayEventStream) RecvMsg(any) error {
return nil
}