Files
galaxy-game/gateway/internal/events/grpc_integration_test.go
T
2026-04-09 12:34:55 +02:00

397 lines
11 KiB
Go

package events
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"errors"
"net"
"sync"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/grpcapi"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
assert.Len(t, downstreamClient.commands(), 2)
}
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.UserID == "user-456"
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
commands := downstreamClient.commands()
require.Len(t, commands, 2)
assert.Equal(t, "user-456", commands[1].UserID)
}
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.Status == session.StatusRevoked
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Equal(t, 1, fallback.lookupCalls())
}
type runningAuthenticatedGateway struct {
cancel context.CancelFunc
resultCh chan error
}
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
t.Helper()
addr := unusedTCPAddr(t)
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = addr
grpcCfg.FreshnessWindow = 5 * time.Minute
router := downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": downstreamClient,
})
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
Router: router,
ResponseSigner: newTestResponseSigner(t),
SessionCache: sessionCache,
ReplayStore: staticReplayStore{},
Clock: fixedClock{now: testNow},
})
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
},
gateway,
subscriber,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "session subscriber did not start")
}
return addr, runningAuthenticatedGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func (g runningAuthenticatedGateway) stop(t *testing.T) {
t.Helper()
g.cancel()
select {
case err := <-g.resultCh:
require.NoError(t, err)
case <-time.After(2 * time.Second):
require.FailNow(t, "gateway did not stop after cancellation")
}
}
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
t.Helper()
var conn *grpc.ClientConn
require.Eventually(t, func() bool {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
candidate, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
if candidate != nil {
_ = candidate.Close()
}
return false
}
conn = candidate
return true
}, 2*time.Second, 10*time.Millisecond, "gateway did not accept gRPC connections")
return conn
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
payloadBytes := []byte("payload")
payloadHash := sha256.Sum256(payloadBytes)
req := &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: "v1",
DeviceSessionId: "device-session-123",
MessageType: "fleet.move",
TimestampMs: testNow.UnixMilli(),
RequestId: requestID,
PayloadBytes: payloadBytes,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
PayloadHash: req.GetPayloadHash(),
}))
return req
}
func newActiveSessionRecord(userID string) session.Record {
return session.Record{
DeviceSessionID: "device-session-123",
UserID: userID,
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func testClientPrivateKey() ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
return ed25519.NewKeyFromSeed(seed[:])
}
func testClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
}
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
t.Helper()
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
require.NoError(t, err)
return signer
}
type fixedClock struct {
now time.Time
}
func (c fixedClock) Now() time.Time {
return c.now
}
var _ clock.Clock = fixedClock{}
type staticReplayStore struct{}
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
return nil
}
var _ replay.Store = staticReplayStore{}
type countingSessionCache struct {
mu sync.Mutex
records map[string]session.Record
lookupCount int
}
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.lookupCount++
record, ok := c.records["device-session-123"]
if !ok {
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
}
return record, nil
}
func (c *countingSessionCache) lookupCalls() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.lookupCount
}
type recordingDownstreamClient struct {
mu sync.Mutex
captured []downstream.AuthenticatedCommand
}
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
c.mu.Lock()
c.captured = append(c.captured, command)
c.mu.Unlock()
return downstream.UnaryResult{
ResultCode: "ok",
PayloadBytes: []byte("response"),
}, nil
}
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
c.mu.Lock()
defer c.mu.Unlock()
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
copy(cloned, c.captured)
return cloned
}