package grpcapi import ( "context" "crypto/ed25519" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/pem" "time" "galaxy/gateway/authn" "galaxy/gateway/internal/downstream" "galaxy/gateway/internal/session" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" gatewayfbs "galaxy/schema/fbs/gateway" flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" ) var ( testCurrentTime = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) testFreshnessWindow = 5 * time.Minute ) func newValidExecuteCommandRequest() *gatewayv1.ExecuteCommandRequest { return newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-123") } func newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.ExecuteCommandRequest { return newValidExecuteCommandRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli()) } func newValidExecuteCommandRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.ExecuteCommandRequest { payloadBytes := []byte("payload") payloadHash := sha256.Sum256(payloadBytes) req := &gatewayv1.ExecuteCommandRequest{ ProtocolVersion: supportedProtocolVersion, DeviceSessionId: deviceSessionID, MessageType: "fleet.move", TimestampMs: timestampMS, RequestId: requestID, PayloadBytes: payloadBytes, PayloadHash: payloadHash[:], TraceId: "trace-123", } req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash()) return req } func newValidSubscribeEventsRequest() *gatewayv1.SubscribeEventsRequest { return newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-123") } func newValidSubscribeEventsRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest { return newValidSubscribeEventsRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli()) } func newValidSubscribeEventsRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.SubscribeEventsRequest { payloadHash := sha256.Sum256(nil) req := &gatewayv1.SubscribeEventsRequest{ ProtocolVersion: supportedProtocolVersion, DeviceSessionId: deviceSessionID, MessageType: "gateway.subscribe", TimestampMs: timestampMS, RequestId: requestID, PayloadHash: payloadHash[:], TraceId: "trace-123", } req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash()) return req } func newActiveSessionRecord() session.Record { return newActiveSessionRecordWithSessionID("device-session-123") } func newActiveSessionRecordWithSessionID(deviceSessionID string) session.Record { return session.Record{ DeviceSessionID: deviceSessionID, UserID: "user-123", ClientPublicKey: testClientPublicKeyBase64(), Status: session.StatusActive, } } func newRevokedSessionRecord() session.Record { revokedAtMS := int64(123456789) return session.Record{ DeviceSessionID: "device-session-123", UserID: "user-123", ClientPublicKey: testClientPublicKeyBase64(), Status: session.StatusRevoked, RevokedAtMS: &revokedAtMS, } } func alternateTestClientPublicKeyBase64() string { return base64.StdEncoding.EncodeToString(newTestPrivateKey("alternate").Public().(ed25519.PublicKey)) } func testClientPublicKeyBase64() string { return base64.StdEncoding.EncodeToString(newTestPrivateKey("primary").Public().(ed25519.PublicKey)) } func signRequest(protocolVersion, deviceSessionID, messageType string, timestampMS int64, requestID string, payloadHash []byte) []byte { return ed25519.Sign(newTestPrivateKey("primary"), authn.BuildRequestSigningInput(authn.RequestSigningFields{ ProtocolVersion: protocolVersion, DeviceSessionID: deviceSessionID, MessageType: messageType, TimestampMS: timestampMS, RequestID: requestID, PayloadHash: payloadHash, })) } func newTestPrivateKey(label string) ed25519.PrivateKey { seed := sha256.Sum256([]byte("gateway-grpcapi-signature-test-" + label)) return ed25519.NewKeyFromSeed(seed[:]) } func newTestEd25519ResponseSigner() *authn.Ed25519ResponseSigner { pemBytes := pem.EncodeToMemory(&pem.Block{ Type: "PRIVATE KEY", Bytes: mustMarshalPKCS8PrivateKey(newTestPrivateKey("response-signer")), }) signer, err := authn.ParseEd25519ResponseSignerPEM(pemBytes) if err != nil { panic(err) } return signer } func newTestResponseSigner() authn.ResponseSigner { return newTestEd25519ResponseSigner() } func newTestResponseSignerPublicKey() ed25519.PublicKey { return newTestEd25519ResponseSigner().PublicKey() } func mustMarshalPKCS8PrivateKey(privateKey ed25519.PrivateKey) []byte { encoded, err := x509.MarshalPKCS8PrivateKey(privateKey) if err != nil { panic(err) } return encoded } type fixedClock struct { now time.Time } func (c fixedClock) Now() time.Time { return c.now } func recvBootstrapEvent(t interface { require.TestingT Helper() }, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent { t.Helper() event, err := stream.Recv() require.NoError(t, err) return event } func subscribeEventsError(t interface { require.TestingT Helper() }, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error { t.Helper() stream, err := client.SubscribeEvents(ctx, req) if err != nil { return err } _, err = stream.Recv() return err } func assertServerTimeBootstrapEvent(t interface { require.TestingT Helper() }, event *gatewayv1.GatewayEvent, publicKey ed25519.PublicKey, wantRequestID string, wantTraceID string, wantTimestampMS int64) { t.Helper() require.NotNil(t, event) assert.Equal(t, serverTimeEventType, event.GetEventType()) assert.Equal(t, wantRequestID, event.GetEventId()) assert.Equal(t, wantRequestID, event.GetRequestId()) assert.Equal(t, wantTraceID, event.GetTraceId()) assert.Equal(t, wantTimestampMS, event.GetTimestampMs()) require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash())) require.NoError(t, authn.VerifyEventSignature(publicKey, event.GetSignature(), authn.EventSigningFields{ EventType: event.GetEventType(), EventID: event.GetEventId(), TimestampMS: event.GetTimestampMs(), RequestID: event.GetRequestId(), TraceID: event.GetTraceId(), PayloadHash: event.GetPayloadHash(), })) payload := gatewayfbs.GetRootAsServerTimeEvent(event.GetPayloadBytes(), flatbuffers.UOffsetT(0)) assert.Equal(t, wantTimestampMS, payload.ServerTimeMs()) } type staticReplayStore struct { reserveFunc func(context.Context, string, string, time.Duration) error } func (s staticReplayStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error { if s.reserveFunc != nil { return s.reserveFunc(ctx, deviceSessionID, requestID, ttl) } return nil } type executeCommandAdapterRouter struct { service gatewayv1.EdgeGatewayServer } func (r executeCommandAdapterRouter) Route(string) (downstream.Client, error) { return executeCommandAdapterClient{service: r.service}, nil } type executeCommandAdapterClient struct { service gatewayv1.EdgeGatewayServer } func (c executeCommandAdapterClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { response, err := c.service.ExecuteCommand(ctx, &gatewayv1.ExecuteCommandRequest{ ProtocolVersion: command.ProtocolVersion, DeviceSessionId: command.DeviceSessionID, MessageType: command.MessageType, TimestampMs: command.TimestampMS, RequestId: command.RequestID, PayloadBytes: command.PayloadBytes, TraceId: command.TraceID, }) if err != nil { return downstream.UnaryResult{}, err } resultCode := response.GetResultCode() if resultCode == "" { resultCode = "ok" } return downstream.UnaryResult{ ResultCode: resultCode, PayloadBytes: response.GetPayloadBytes(), }, nil } type recordingDownstreamClient struct { executeCalls int commands []downstream.AuthenticatedCommand executeFunc func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) } func (c *recordingDownstreamClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { c.executeCalls++ c.commands = append(c.commands, downstream.AuthenticatedCommand{ ProtocolVersion: command.ProtocolVersion, UserID: command.UserID, DeviceSessionID: command.DeviceSessionID, MessageType: command.MessageType, TimestampMS: command.TimestampMS, RequestID: command.RequestID, TraceID: command.TraceID, PayloadBytes: append([]byte(nil), command.PayloadBytes...), }) if c.executeFunc != nil { return c.executeFunc(ctx, command) } return downstream.UnaryResult{ ResultCode: "ok", PayloadBytes: []byte("response"), }, nil }