299 lines
9.1 KiB
Go
299 lines
9.1 KiB
Go
package grpcapi
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/authn"
|
|
"galaxy/gateway/internal/downstream"
|
|
"galaxy/gateway/internal/session"
|
|
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
|
|
|
gatewayfbs "galaxy/schema/fbs/gateway"
|
|
|
|
flatbuffers "github.com/google/flatbuffers/go"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
var (
|
|
testCurrentTime = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
|
|
testFreshnessWindow = 5 * time.Minute
|
|
)
|
|
|
|
func newValidExecuteCommandRequest() *gatewayv1.ExecuteCommandRequest {
|
|
return newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-123")
|
|
}
|
|
|
|
func newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.ExecuteCommandRequest {
|
|
return newValidExecuteCommandRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
|
|
}
|
|
|
|
func newValidExecuteCommandRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.ExecuteCommandRequest {
|
|
payloadBytes := []byte("payload")
|
|
payloadHash := sha256.Sum256(payloadBytes)
|
|
|
|
req := &gatewayv1.ExecuteCommandRequest{
|
|
ProtocolVersion: supportedProtocolVersion,
|
|
DeviceSessionId: deviceSessionID,
|
|
MessageType: "fleet.move",
|
|
TimestampMs: timestampMS,
|
|
RequestId: requestID,
|
|
PayloadBytes: payloadBytes,
|
|
PayloadHash: payloadHash[:],
|
|
TraceId: "trace-123",
|
|
}
|
|
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
|
|
|
|
return req
|
|
}
|
|
|
|
func newValidSubscribeEventsRequest() *gatewayv1.SubscribeEventsRequest {
|
|
return newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-123")
|
|
}
|
|
|
|
func newValidSubscribeEventsRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
|
|
return newValidSubscribeEventsRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
|
|
}
|
|
|
|
func newValidSubscribeEventsRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.SubscribeEventsRequest {
|
|
payloadHash := sha256.Sum256(nil)
|
|
|
|
req := &gatewayv1.SubscribeEventsRequest{
|
|
ProtocolVersion: supportedProtocolVersion,
|
|
DeviceSessionId: deviceSessionID,
|
|
MessageType: "gateway.subscribe",
|
|
TimestampMs: timestampMS,
|
|
RequestId: requestID,
|
|
PayloadHash: payloadHash[:],
|
|
TraceId: "trace-123",
|
|
}
|
|
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
|
|
|
|
return req
|
|
}
|
|
|
|
func newActiveSessionRecord() session.Record {
|
|
return newActiveSessionRecordWithSessionID("device-session-123")
|
|
}
|
|
|
|
func newActiveSessionRecordWithSessionID(deviceSessionID string) session.Record {
|
|
return session.Record{
|
|
DeviceSessionID: deviceSessionID,
|
|
UserID: "user-123",
|
|
ClientPublicKey: testClientPublicKeyBase64(),
|
|
Status: session.StatusActive,
|
|
}
|
|
}
|
|
|
|
func newRevokedSessionRecord() session.Record {
|
|
revokedAtMS := int64(123456789)
|
|
|
|
return session.Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: testClientPublicKeyBase64(),
|
|
Status: session.StatusRevoked,
|
|
RevokedAtMS: &revokedAtMS,
|
|
}
|
|
}
|
|
|
|
func alternateTestClientPublicKeyBase64() string {
|
|
return base64.StdEncoding.EncodeToString(newTestPrivateKey("alternate").Public().(ed25519.PublicKey))
|
|
}
|
|
|
|
func testClientPublicKeyBase64() string {
|
|
return base64.StdEncoding.EncodeToString(newTestPrivateKey("primary").Public().(ed25519.PublicKey))
|
|
}
|
|
|
|
func signRequest(protocolVersion, deviceSessionID, messageType string, timestampMS int64, requestID string, payloadHash []byte) []byte {
|
|
return ed25519.Sign(newTestPrivateKey("primary"), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
|
ProtocolVersion: protocolVersion,
|
|
DeviceSessionID: deviceSessionID,
|
|
MessageType: messageType,
|
|
TimestampMS: timestampMS,
|
|
RequestID: requestID,
|
|
PayloadHash: payloadHash,
|
|
}))
|
|
}
|
|
|
|
func newTestPrivateKey(label string) ed25519.PrivateKey {
|
|
seed := sha256.Sum256([]byte("gateway-grpcapi-signature-test-" + label))
|
|
return ed25519.NewKeyFromSeed(seed[:])
|
|
}
|
|
|
|
func newTestEd25519ResponseSigner() *authn.Ed25519ResponseSigner {
|
|
pemBytes := pem.EncodeToMemory(&pem.Block{
|
|
Type: "PRIVATE KEY",
|
|
Bytes: mustMarshalPKCS8PrivateKey(newTestPrivateKey("response-signer")),
|
|
})
|
|
|
|
signer, err := authn.ParseEd25519ResponseSignerPEM(pemBytes)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return signer
|
|
}
|
|
|
|
func newTestResponseSigner() authn.ResponseSigner {
|
|
return newTestEd25519ResponseSigner()
|
|
}
|
|
|
|
func newTestResponseSignerPublicKey() ed25519.PublicKey {
|
|
return newTestEd25519ResponseSigner().PublicKey()
|
|
}
|
|
|
|
func mustMarshalPKCS8PrivateKey(privateKey ed25519.PrivateKey) []byte {
|
|
encoded, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return encoded
|
|
}
|
|
|
|
type fixedClock struct {
|
|
now time.Time
|
|
}
|
|
|
|
func (c fixedClock) Now() time.Time {
|
|
return c.now
|
|
}
|
|
|
|
func recvBootstrapEvent(t interface {
|
|
require.TestingT
|
|
Helper()
|
|
}, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
|
t.Helper()
|
|
|
|
event, err := stream.Recv()
|
|
require.NoError(t, err)
|
|
|
|
return event
|
|
}
|
|
|
|
func subscribeEventsError(t interface {
|
|
require.TestingT
|
|
Helper()
|
|
}, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error {
|
|
t.Helper()
|
|
|
|
stream, err := client.SubscribeEvents(ctx, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = stream.Recv()
|
|
return err
|
|
}
|
|
|
|
func assertServerTimeBootstrapEvent(t interface {
|
|
require.TestingT
|
|
Helper()
|
|
}, event *gatewayv1.GatewayEvent, publicKey ed25519.PublicKey, wantRequestID string, wantTraceID string, wantTimestampMS int64) {
|
|
t.Helper()
|
|
|
|
require.NotNil(t, event)
|
|
assert.Equal(t, serverTimeEventType, event.GetEventType())
|
|
assert.Equal(t, wantRequestID, event.GetEventId())
|
|
assert.Equal(t, wantRequestID, event.GetRequestId())
|
|
assert.Equal(t, wantTraceID, event.GetTraceId())
|
|
assert.Equal(t, wantTimestampMS, event.GetTimestampMs())
|
|
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
|
require.NoError(t, authn.VerifyEventSignature(publicKey, event.GetSignature(), authn.EventSigningFields{
|
|
EventType: event.GetEventType(),
|
|
EventID: event.GetEventId(),
|
|
TimestampMS: event.GetTimestampMs(),
|
|
RequestID: event.GetRequestId(),
|
|
TraceID: event.GetTraceId(),
|
|
PayloadHash: event.GetPayloadHash(),
|
|
}))
|
|
|
|
payload := gatewayfbs.GetRootAsServerTimeEvent(event.GetPayloadBytes(), flatbuffers.UOffsetT(0))
|
|
assert.Equal(t, wantTimestampMS, payload.ServerTimeMs())
|
|
}
|
|
|
|
type staticReplayStore struct {
|
|
reserveFunc func(context.Context, string, string, time.Duration) error
|
|
}
|
|
|
|
func (s staticReplayStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
|
if s.reserveFunc != nil {
|
|
return s.reserveFunc(ctx, deviceSessionID, requestID, ttl)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type executeCommandAdapterRouter struct {
|
|
service gatewayv1.EdgeGatewayServer
|
|
}
|
|
|
|
func (r executeCommandAdapterRouter) Route(string) (downstream.Client, error) {
|
|
return executeCommandAdapterClient{service: r.service}, nil
|
|
}
|
|
|
|
type executeCommandAdapterClient struct {
|
|
service gatewayv1.EdgeGatewayServer
|
|
}
|
|
|
|
func (c executeCommandAdapterClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
response, err := c.service.ExecuteCommand(ctx, &gatewayv1.ExecuteCommandRequest{
|
|
ProtocolVersion: command.ProtocolVersion,
|
|
DeviceSessionId: command.DeviceSessionID,
|
|
MessageType: command.MessageType,
|
|
TimestampMs: command.TimestampMS,
|
|
RequestId: command.RequestID,
|
|
PayloadBytes: command.PayloadBytes,
|
|
TraceId: command.TraceID,
|
|
})
|
|
if err != nil {
|
|
return downstream.UnaryResult{}, err
|
|
}
|
|
|
|
resultCode := response.GetResultCode()
|
|
if resultCode == "" {
|
|
resultCode = "ok"
|
|
}
|
|
|
|
return downstream.UnaryResult{
|
|
ResultCode: resultCode,
|
|
PayloadBytes: response.GetPayloadBytes(),
|
|
}, nil
|
|
}
|
|
|
|
type recordingDownstreamClient struct {
|
|
executeCalls int
|
|
commands []downstream.AuthenticatedCommand
|
|
executeFunc func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error)
|
|
}
|
|
|
|
func (c *recordingDownstreamClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
c.executeCalls++
|
|
c.commands = append(c.commands, downstream.AuthenticatedCommand{
|
|
ProtocolVersion: command.ProtocolVersion,
|
|
UserID: command.UserID,
|
|
DeviceSessionID: command.DeviceSessionID,
|
|
MessageType: command.MessageType,
|
|
TimestampMS: command.TimestampMS,
|
|
RequestID: command.RequestID,
|
|
TraceID: command.TraceID,
|
|
PayloadBytes: append([]byte(nil), command.PayloadBytes...),
|
|
})
|
|
if c.executeFunc != nil {
|
|
return c.executeFunc(ctx, command)
|
|
}
|
|
|
|
return downstream.UnaryResult{
|
|
ResultCode: "ok",
|
|
PayloadBytes: []byte("response"),
|
|
}, nil
|
|
}
|