package grpcapi import ( "context" "testing" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) func TestParseExecuteCommandRequest(t *testing.T) { t.Parallel() tests := []struct { name string mutate func(*gatewayv1.ExecuteCommandRequest) wantCode codes.Code wantMessage string assertValid func(*testing.T, *gatewayv1.ExecuteCommandRequest, parsedEnvelope) }{ { name: "nil request", wantCode: codes.InvalidArgument, wantMessage: "request envelope must not be nil", }, { name: "empty protocol version", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.ProtocolVersion = "" }, wantCode: codes.InvalidArgument, wantMessage: "protocol_version must not be empty", }, { name: "empty device session id", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.DeviceSessionId = "" }, wantCode: codes.InvalidArgument, wantMessage: "device_session_id must not be empty", }, { name: "empty message type", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.MessageType = "" }, wantCode: codes.InvalidArgument, wantMessage: "message_type must not be empty", }, { name: "zero timestamp", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.TimestampMs = 0 }, wantCode: codes.InvalidArgument, wantMessage: "timestamp_ms must be greater than zero", }, { name: "empty request id", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.RequestId = "" }, wantCode: codes.InvalidArgument, wantMessage: "request_id must not be empty", }, { name: "empty payload bytes", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.PayloadBytes = nil }, wantCode: codes.InvalidArgument, wantMessage: "payload_bytes must not be empty", }, { name: "empty payload hash", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.PayloadHash = nil }, wantCode: codes.InvalidArgument, wantMessage: "payload_hash must not be empty", }, { name: "empty signature", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.Signature = nil }, wantCode: codes.InvalidArgument, wantMessage: "signature must not be empty", }, { name: "unsupported protocol version", mutate: func(req *gatewayv1.ExecuteCommandRequest) { req.ProtocolVersion = "v2" }, wantCode: codes.FailedPrecondition, wantMessage: `unsupported protocol_version "v2"`, }, { name: "valid request", wantCode: codes.OK, assertValid: func(t *testing.T, req *gatewayv1.ExecuteCommandRequest, envelope parsedEnvelope) { t.Helper() assert.Equal(t, supportedProtocolVersion, envelope.ProtocolVersion) assert.Equal(t, req.GetDeviceSessionId(), envelope.DeviceSessionID) assert.Equal(t, req.GetMessageType(), envelope.MessageType) assert.Equal(t, req.GetTimestampMs(), envelope.TimestampMS) assert.Equal(t, req.GetRequestId(), envelope.RequestID) assert.Equal(t, req.GetTraceId(), envelope.TraceID) assert.Equal(t, req.GetPayloadBytes(), envelope.PayloadBytes) assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash) assert.Equal(t, req.GetSignature(), envelope.Signature) originalPayloadBytes := append([]byte(nil), req.GetPayloadBytes()...) originalPayloadHash := append([]byte(nil), req.GetPayloadHash()...) originalSignature := append([]byte(nil), req.GetSignature()...) envelope.PayloadBytes[0] = 'X' envelope.PayloadHash[0] = 'Y' envelope.Signature[0] = 'Z' assert.Equal(t, originalPayloadBytes, req.GetPayloadBytes()) assert.Equal(t, originalPayloadHash, req.GetPayloadHash()) assert.Equal(t, originalSignature, req.GetSignature()) }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() var req *gatewayv1.ExecuteCommandRequest if tt.name != "nil request" { req = newValidExecuteCommandRequest() if tt.mutate != nil { tt.mutate(req) } } envelope, err := parseExecuteCommandRequest(req) if tt.wantCode != codes.OK { require.Error(t, err) assert.Equal(t, tt.wantCode, status.Code(err)) assert.Equal(t, tt.wantMessage, status.Convert(err).Message()) return } require.NoError(t, err) require.NotNil(t, tt.assertValid) tt.assertValid(t, req, envelope) }) } } func TestParseSubscribeEventsRequest(t *testing.T) { t.Parallel() tests := []struct { name string mutate func(*gatewayv1.SubscribeEventsRequest) wantCode codes.Code wantMessage string assertValid func(*testing.T, *gatewayv1.SubscribeEventsRequest, parsedEnvelope) }{ { name: "nil request", wantCode: codes.InvalidArgument, wantMessage: "request envelope must not be nil", }, { name: "empty protocol version", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.ProtocolVersion = "" }, wantCode: codes.InvalidArgument, wantMessage: "protocol_version must not be empty", }, { name: "empty device session id", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.DeviceSessionId = "" }, wantCode: codes.InvalidArgument, wantMessage: "device_session_id must not be empty", }, { name: "empty message type", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.MessageType = "" }, wantCode: codes.InvalidArgument, wantMessage: "message_type must not be empty", }, { name: "zero timestamp", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.TimestampMs = 0 }, wantCode: codes.InvalidArgument, wantMessage: "timestamp_ms must be greater than zero", }, { name: "empty request id", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.RequestId = "" }, wantCode: codes.InvalidArgument, wantMessage: "request_id must not be empty", }, { name: "empty payload hash", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.PayloadHash = nil }, wantCode: codes.InvalidArgument, wantMessage: "payload_hash must not be empty", }, { name: "empty signature", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.Signature = nil }, wantCode: codes.InvalidArgument, wantMessage: "signature must not be empty", }, { name: "unsupported protocol version", mutate: func(req *gatewayv1.SubscribeEventsRequest) { req.ProtocolVersion = "v2" }, wantCode: codes.FailedPrecondition, wantMessage: `unsupported protocol_version "v2"`, }, { name: "valid request with empty payload bytes", wantCode: codes.OK, assertValid: func(t *testing.T, req *gatewayv1.SubscribeEventsRequest, envelope parsedEnvelope) { t.Helper() assert.Empty(t, req.GetPayloadBytes()) assert.Empty(t, envelope.PayloadBytes) assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash) assert.Equal(t, req.GetSignature(), envelope.Signature) }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() var req *gatewayv1.SubscribeEventsRequest if tt.name != "nil request" { req = newValidSubscribeEventsRequest() if tt.mutate != nil { tt.mutate(req) } } envelope, err := parseSubscribeEventsRequest(req) if tt.wantCode != codes.OK { require.Error(t, err) assert.Equal(t, tt.wantCode, status.Code(err)) assert.Equal(t, tt.wantMessage, status.Convert(err).Message()) return } require.NoError(t, err) require.NotNil(t, tt.assertValid) tt.assertValid(t, req, envelope) }) } } func TestEnvelopeValidatingServiceExecuteCommandRejectsInvalidRequestBeforeDelegate(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} service := newEnvelopeValidatingService(delegate) _, err := service.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{}) require.Error(t, err) assert.Equal(t, codes.InvalidArgument, status.Code(err)) assert.Zero(t, delegate.executeCalls) } func TestEnvelopeValidatingServiceSubscribeEventsRejectsInvalidRequestBeforeDelegate(t *testing.T) { t.Parallel() delegate := &recordingEdgeGatewayService{} service := newEnvelopeValidatingService(delegate) err := service.SubscribeEvents(&gatewayv1.SubscribeEventsRequest{}, stubGatewayEventStream{}) require.Error(t, err) assert.Equal(t, codes.InvalidArgument, status.Code(err)) assert.Zero(t, delegate.subscribeCalls) } func TestEnvelopeValidatingServiceExecuteCommandAttachesParsedEnvelope(t *testing.T) { t.Parallel() want := newValidExecuteCommandRequest() delegate := &recordingEdgeGatewayService{ executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { envelope, ok := parsedEnvelopeFromContext(ctx) require.True(t, ok) assert.Equal(t, want.GetRequestId(), envelope.RequestID) assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID) assert.Equal(t, want.GetMessageType(), envelope.MessageType) assert.Equal(t, want.GetPayloadBytes(), envelope.PayloadBytes) return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil }, } service := newEnvelopeValidatingService(delegate) response, err := service.ExecuteCommand(context.Background(), want) require.NoError(t, err) assert.Equal(t, want.GetRequestId(), response.GetRequestId()) assert.Equal(t, 1, delegate.executeCalls) } func TestEnvelopeValidatingServiceSubscribeEventsAttachesParsedEnvelope(t *testing.T) { t.Parallel() want := newValidSubscribeEventsRequest() delegate := &recordingEdgeGatewayService{ subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { envelope, ok := parsedEnvelopeFromContext(stream.Context()) require.True(t, ok) assert.Equal(t, want.GetRequestId(), envelope.RequestID) assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID) assert.Equal(t, want.GetMessageType(), envelope.MessageType) assert.Equal(t, want.GetPayloadHash(), envelope.PayloadHash) assert.Equal(t, want.GetSignature(), envelope.Signature) return nil }, } service := newEnvelopeValidatingService(delegate) err := service.SubscribeEvents(want, stubGatewayEventStream{}) require.NoError(t, err) assert.Equal(t, 1, delegate.subscribeCalls) } type recordingEdgeGatewayService struct { gatewayv1.UnimplementedEdgeGatewayServer executeCalls int subscribeCalls int executeCommandFunc func(context.Context, *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) subscribeEventsFunc func(*gatewayv1.SubscribeEventsRequest, grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error } func (s *recordingEdgeGatewayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) { s.executeCalls++ if s.executeCommandFunc != nil { return s.executeCommandFunc(ctx, req) } return &gatewayv1.ExecuteCommandResponse{}, nil } func (s *recordingEdgeGatewayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error { s.subscribeCalls++ if s.subscribeEventsFunc != nil { return s.subscribeEventsFunc(req, stream) } return nil } type stubGatewayEventStream struct { grpc.ServerStream ctx context.Context } func (s stubGatewayEventStream) Send(*gatewayv1.GatewayEvent) error { return nil } func (s stubGatewayEventStream) SetHeader(metadata.MD) error { return nil } func (s stubGatewayEventStream) SendHeader(metadata.MD) error { return nil } func (s stubGatewayEventStream) SetTrailer(metadata.MD) {} func (s stubGatewayEventStream) Context() context.Context { if s.ctx == nil { return context.Background() } return s.ctx } func (s stubGatewayEventStream) SendMsg(any) error { return nil } func (s stubGatewayEventStream) RecvMsg(any) error { return nil }