package events import ( "context" "crypto/ed25519" "crypto/sha256" "encoding/base64" "errors" "net" "sync" "testing" "time" "galaxy/gateway/internal/app" "galaxy/gateway/internal/authn" "galaxy/gateway/internal/clock" "galaxy/gateway/internal/config" "galaxy/gateway/internal/downstream" "galaxy/gateway/internal/grpcapi" "galaxy/gateway/internal/replay" "galaxy/gateway/internal/session" gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1" "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" ) var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) { t.Parallel() server := miniredis.RunT(t) local := session.NewMemoryCache() fallback := &countingSessionCache{ records: map[string]session.Record{ "device-session-123": newActiveSessionRecord("user-123"), }, } readThrough, err := session.NewReadThroughCache(local, fallback) require.NoError(t, err) subscriber := newTestRedisSessionSubscriber(t, server, local) downstreamClient := &recordingDownstreamClient{} addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient) defer running.stop(t) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1")) require.NoError(t, err) assert.Equal(t, 1, fallback.lookupCalls()) _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2")) require.NoError(t, err) assert.Equal(t, 1, fallback.lookupCalls()) assert.Len(t, downstreamClient.commands(), 2) } func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) { t.Parallel() server := miniredis.RunT(t) local := session.NewMemoryCache() fallback := &countingSessionCache{ records: map[string]session.Record{ "device-session-123": newActiveSessionRecord("user-123"), }, } readThrough, err := session.NewReadThroughCache(local, fallback) require.NoError(t, err) subscriber := newTestRedisSessionSubscriber(t, server, local) downstreamClient := &recordingDownstreamClient{} addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient) defer running.stop(t) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1")) require.NoError(t, err) assert.Equal(t, 1, fallback.lookupCalls()) addSessionEvent(t, server, "gateway:session_events", map[string]string{ "device_session_id": "device-session-123", "user_id": "user-456", "client_public_key": testClientPublicKeyBase64(), "status": string(session.StatusActive), }) require.Eventually(t, func() bool { record, lookupErr := local.Lookup(context.Background(), "device-session-123") return lookupErr == nil && record.UserID == "user-456" }, time.Second, 10*time.Millisecond) _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2")) require.NoError(t, err) assert.Equal(t, 1, fallback.lookupCalls()) commands := downstreamClient.commands() require.Len(t, commands, 2) assert.Equal(t, "user-456", commands[1].UserID) } func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) { t.Parallel() server := miniredis.RunT(t) local := session.NewMemoryCache() fallback := &countingSessionCache{ records: map[string]session.Record{ "device-session-123": newActiveSessionRecord("user-123"), }, } readThrough, err := session.NewReadThroughCache(local, fallback) require.NoError(t, err) subscriber := newTestRedisSessionSubscriber(t, server, local) downstreamClient := &recordingDownstreamClient{} addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient) defer running.stop(t) conn := dialGatewayClient(t, addr) defer func() { require.NoError(t, conn.Close()) }() client := gatewayv1.NewEdgeGatewayClient(conn) _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1")) require.NoError(t, err) assert.Equal(t, 1, fallback.lookupCalls()) addSessionEvent(t, server, "gateway:session_events", map[string]string{ "device_session_id": "device-session-123", "user_id": "user-123", "client_public_key": testClientPublicKeyBase64(), "status": string(session.StatusRevoked), "revoked_at_ms": "123456789", }) require.Eventually(t, func() bool { record, lookupErr := local.Lookup(context.Background(), "device-session-123") return lookupErr == nil && record.Status == session.StatusRevoked }, time.Second, 10*time.Millisecond) _, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2")) require.Error(t, err) assert.Equal(t, codes.FailedPrecondition, status.Code(err)) assert.Equal(t, "device session is revoked", status.Convert(err).Message()) assert.Equal(t, 1, fallback.lookupCalls()) } type runningAuthenticatedGateway struct { cancel context.CancelFunc resultCh chan error } func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) { t.Helper() addr := unusedTCPAddr(t) grpcCfg := config.DefaultAuthenticatedGRPCConfig() grpcCfg.Addr = addr grpcCfg.FreshnessWindow = 5 * time.Minute router := downstream.NewStaticRouter(map[string]downstream.Client{ "fleet.move": downstreamClient, }) gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{ Router: router, ResponseSigner: newTestResponseSigner(t), SessionCache: sessionCache, ReplayStore: staticReplayStore{}, Clock: fixedClock{now: testNow}, }) application := app.New( config.Config{ ShutdownTimeout: time.Second, AuthenticatedGRPC: grpcCfg, }, gateway, subscriber, ) ctx, cancel := context.WithCancel(context.Background()) resultCh := make(chan error, 1) go func() { resultCh <- application.Run(ctx) }() select { case <-subscriber.started: case <-time.After(time.Second): require.FailNow(t, "session subscriber did not start") } return addr, runningAuthenticatedGateway{ cancel: cancel, resultCh: resultCh, } } func (g runningAuthenticatedGateway) stop(t *testing.T) { t.Helper() g.cancel() select { case err := <-g.resultCh: require.NoError(t, err) case <-time.After(2 * time.Second): require.FailNow(t, "gateway did not stop after cancellation") } } func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := grpc.DialContext( ctx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), ) require.NoError(t, err) return conn } func unusedTCPAddr(t *testing.T) string { t.Helper() listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) addr := listener.Addr().String() require.NoError(t, listener.Close()) return addr } func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest { payloadBytes := []byte("payload") payloadHash := sha256.Sum256(payloadBytes) req := &gatewayv1.ExecuteCommandRequest{ ProtocolVersion: "v1", DeviceSessionId: "device-session-123", MessageType: "fleet.move", TimestampMs: testNow.UnixMilli(), RequestId: requestID, PayloadBytes: payloadBytes, PayloadHash: payloadHash[:], TraceId: "trace-123", } req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{ ProtocolVersion: req.GetProtocolVersion(), DeviceSessionID: req.GetDeviceSessionId(), MessageType: req.GetMessageType(), TimestampMS: req.GetTimestampMs(), RequestID: req.GetRequestId(), PayloadHash: req.GetPayloadHash(), })) return req } func newActiveSessionRecord(userID string) session.Record { return session.Record{ DeviceSessionID: "device-session-123", UserID: userID, ClientPublicKey: testClientPublicKeyBase64(), Status: session.StatusActive, } } func testClientPrivateKey() ed25519.PrivateKey { seed := sha256.Sum256([]byte("gateway-events-grpc-test-client")) return ed25519.NewKeyFromSeed(seed[:]) } func testClientPublicKeyBase64() string { return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey)) } func newTestResponseSigner(t *testing.T) authn.ResponseSigner { t.Helper() seed := sha256.Sum256([]byte("gateway-events-grpc-test-response")) signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:])) require.NoError(t, err) return signer } type fixedClock struct { now time.Time } func (c fixedClock) Now() time.Time { return c.now } var _ clock.Clock = fixedClock{} type staticReplayStore struct{} func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error { return nil } var _ replay.Store = staticReplayStore{} type countingSessionCache struct { mu sync.Mutex records map[string]session.Record lookupCount int } func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) { c.mu.Lock() defer c.mu.Unlock() c.lookupCount++ record, ok := c.records["device-session-123"] if !ok { return session.Record{}, errors.New("lookup session from counting cache: session cache record not found") } return record, nil } func (c *countingSessionCache) lookupCalls() int { c.mu.Lock() defer c.mu.Unlock() return c.lookupCount } type recordingDownstreamClient struct { mu sync.Mutex captured []downstream.AuthenticatedCommand } func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) { c.mu.Lock() c.captured = append(c.captured, command) c.mu.Unlock() return downstream.UnaryResult{ ResultCode: "ok", PayloadBytes: []byte("response"), }, nil } func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand { c.mu.Lock() defer c.mu.Unlock() cloned := make([]downstream.AuthenticatedCommand, len(c.captured)) copy(cloned, c.captured) return cloned }