feat: edge gateway service

This commit is contained in:
Ilia Denisov
2026-04-02 19:18:42 +02:00
committed by GitHub
parent 8cde99936c
commit 436c97a38b
95 changed files with 20504 additions and 57 deletions
@@ -0,0 +1,341 @@
package events
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const clientEventReadCount int64 = 128
// ClientEventPublisher accepts decoded client-facing events from the internal
// event subscriber.
type ClientEventPublisher interface {
// Publish fans out event to the currently active push streams.
Publish(event push.Event)
}
// RedisClientEventSubscriber consumes client-facing events from one Redis
// Stream and forwards them to the configured publisher.
type RedisClientEventSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
publisher ClientEventPublisher
logger *zap.Logger
metrics *telemetry.Runtime
closeOnce sync.Once
startedOnce sync.Once
started chan struct{}
}
// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that
// reuses the SessionCache Redis connection settings and forwards decoded
// client-facing events to publisher.
func NewRedisClientEventSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) {
return NewRedisClientEventSubscriberWithObservability(sessionCfg, eventsCfg, publisher, nil, nil)
}
// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream
// subscriber that also records malformed or dropped internal events.
func NewRedisClientEventSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) {
if strings.TrimSpace(sessionCfg.Addr) == "" {
return nil, errors.New("new redis client event subscriber: redis addr must not be empty")
}
if sessionCfg.DB < 0 {
return nil, errors.New("new redis client event subscriber: redis db must not be negative")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis client event subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: read block timeout must be positive")
}
if publisher == nil {
return nil, errors.New("new redis client event subscriber: nil publisher")
}
options := &redis.Options{
Addr: sessionCfg.Addr,
Username: sessionCfg.Username,
Password: sessionCfg.Password,
DB: sessionCfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if sessionCfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisClientEventSubscriber{
client: redis.NewClient(options),
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
publisher: publisher,
logger: logger.Named("client_event_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Ping verifies that the Redis backend used for client-facing event fan-out is
// reachable within the configured timeout budget.
func (s *RedisClientEventSubscriber) Ping(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("ping redis client event subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("ping redis client event subscriber: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
defer cancel()
if err := s.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis client event subscriber: %w", err)
}
return nil
}
// Run consumes client-facing events until ctx is canceled or Redis returns an
// unexpected error.
func (s *RedisClientEventSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis client event subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis client event subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: clientEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.publishMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis client event subscriber: %w", err)
default:
return fmt.Errorf("run redis client event subscriber: %w", err)
}
}
}
func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown closes the Redis client so a blocking stream read can terminate
// promptly during gateway shutdown.
func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis client event subscriber: nil context")
}
return s.Close()
}
// Close releases the underlying Redis client resources.
func (s *RedisClientEventSubscriber) Close() error {
if s == nil || s.client == nil {
return nil
}
var err error
s.closeOnce.Do(func() {
err = s.client.Close()
})
return err
}
func (s *RedisClientEventSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) {
event, err := decodeClientEvent(message.Values)
if err != nil {
s.logger.Warn("dropped malformed client event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "client_event_subscriber"),
attribute.String("reason", "malformed_event"),
)
return
}
s.publisher.Publish(event)
}
func decodeClientEvent(values map[string]any) (push.Event, error) {
requiredKeys := map[string]struct{}{
"user_id": {},
"event_type": {},
"event_id": {},
"payload_bytes": {},
}
optionalKeys := map[string]struct{}{
"device_session_id": {},
"request_id": {},
"trace_id": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key)
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return push.Event{}, err
}
eventType, err := requiredStringField(values, "event_type")
if err != nil {
return push.Event{}, err
}
eventID, err := requiredStringField(values, "event_id")
if err != nil {
return push.Event{}, err
}
payloadBytes, err := requiredBytesField(values, "payload_bytes")
if err != nil {
return push.Event{}, err
}
event := push.Event{
UserID: userID,
EventType: eventType,
EventID: eventID,
PayloadBytes: payloadBytes,
}
if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil {
return push.Event{}, err
} else if ok {
event.DeviceSessionID = strings.TrimSpace(deviceSessionID)
}
if requestID, ok, err := optionalStringField(values, "request_id"); err != nil {
return push.Event{}, err
} else if ok {
event.RequestID = requestID
}
if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil {
return push.Event{}, err
} else if ok {
event.TraceID = traceID
}
return event, nil
}
func requiredBytesField(values map[string]any, field string) ([]byte, error) {
value, ok := values[field]
if !ok {
return nil, fmt.Errorf("decode client event: missing %s", field)
}
byteValue, err := coerceBytes(value)
if err != nil {
return nil, fmt.Errorf("decode client event: %s: %w", field, err)
}
return byteValue, nil
}
func optionalStringField(values map[string]any, field string) (string, bool, error) {
value, ok := values[field]
if !ok {
return "", false, nil
}
stringValue, err := coerceString(value)
if err != nil {
return "", false, fmt.Errorf("decode client event: %s: %w", field, err)
}
return stringValue, true, nil
}
func coerceBytes(value any) ([]byte, error) {
switch typed := value.(type) {
case string:
return []byte(typed), nil
case []byte:
return bytes.Clone(typed), nil
default:
return nil, fmt.Errorf("unsupported type %T", value)
}
}
@@ -0,0 +1,294 @@
package events
import (
"context"
"strings"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/testutil"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedisClientEventSubscriberPublishesValidEvent(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"device_session_id": "device-session-123",
"event_type": "fleet.updated",
"event_id": "event-123",
"payload_bytes": []byte("payload-123"),
"request_id": "request-123",
"trace_id": "trace-123",
})
require.Eventually(t, func() bool {
return len(publisher.events()) == 1
}, time.Second, 10*time.Millisecond)
assert.Equal(t, []push.Event{{
UserID: "user-123",
DeviceSessionID: "device-session-123",
EventType: "fleet.updated",
EventID: "event-123",
PayloadBytes: []byte("payload-123"),
RequestID: "request-123",
TraceID: "trace-123",
}}, publisher.events())
}
func TestRedisClientEventSubscriberSkipsMalformedEventAndContinues(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-bad",
"payload_bytes": []byte("payload-bad"),
"unexpected": "boom",
})
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-good",
"payload_bytes": []byte("payload-good"),
})
require.Eventually(t, func() bool {
events := publisher.events()
return len(events) == 1 && events[0].EventID == "event-good"
}, time.Second, 10*time.Millisecond)
}
func TestRedisClientEventSubscriberStartsFromCurrentTail(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-old",
"payload_bytes": []byte("payload-old"),
})
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
assert.Never(t, func() bool {
return len(publisher.events()) > 0
}, 100*time.Millisecond, 10*time.Millisecond)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-new",
"payload_bytes": []byte("payload-new"),
})
require.Eventually(t, func() bool {
events := publisher.events()
return len(events) == 1 && events[0].EventID == "event-new"
}, time.Second, 10*time.Millisecond)
}
func TestRedisClientEventSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
cancel()
require.NoError(t, subscriber.Shutdown(context.Background()))
select {
case err := <-resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop after shutdown")
}
}
func TestRedisClientEventSubscriberLogsAndCountsMalformedEvents(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
logger, logBuffer := testutil.NewObservedLogger(t)
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
subscriber, err := NewRedisClientEventSubscriberWithObservability(
config.SessionCacheRedisConfig{
Addr: server.Addr(),
LookupTimeout: 250 * time.Millisecond,
},
config.ClientEventsRedisConfig{
Stream: "gateway:client_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
publisher,
logger,
telemetryRuntime,
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, subscriber.Close())
})
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-bad",
"payload_bytes": []byte("payload-bad"),
"unexpected": "boom",
})
require.Eventually(t, func() bool {
return strings.Contains(logBuffer.String(), "dropped malformed client event")
}, time.Second, 10*time.Millisecond)
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
assert.Contains(t, metricsText, `gateway_internal_event_drops_total`)
assert.Contains(t, metricsText, `component="client_event_subscriber"`)
assert.Contains(t, metricsText, `reason="malformed_event"`)
}
func newTestRedisClientEventSubscriber(t *testing.T, server *miniredis.Miniredis, publisher ClientEventPublisher) *RedisClientEventSubscriber {
t.Helper()
subscriber, err := NewRedisClientEventSubscriber(
config.SessionCacheRedisConfig{
Addr: server.Addr(),
LookupTimeout: 250 * time.Millisecond,
},
config.ClientEventsRedisConfig{
Stream: "gateway:client_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
publisher,
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, subscriber.Close())
})
return subscriber
}
func addClientEvent(t *testing.T, server *miniredis.Miniredis, stream string, values map[string]any) {
t.Helper()
client := redis.NewClient(&redis.Options{
Addr: server.Addr(),
Protocol: 2,
DisableIdentity: true,
})
defer func() {
assert.NoError(t, client.Close())
}()
err := client.XAdd(context.Background(), &redis.XAddArgs{
Stream: stream,
Values: values,
}).Err()
require.NoError(t, err)
}
type runningClientEventSubscriber struct {
cancel context.CancelFunc
resultCh chan error
}
func runTestClientEventSubscriber(t *testing.T, subscriber *RedisClientEventSubscriber) runningClientEventSubscriber {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
return runningClientEventSubscriber{
cancel: cancel,
resultCh: resultCh,
}
}
func (r runningClientEventSubscriber) stop(t *testing.T) {
t.Helper()
r.cancel()
select {
case err := <-r.resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop")
}
}
type recordingClientEventPublisher struct {
mu sync.Mutex
records []push.Event
}
func (p *recordingClientEventPublisher) Publish(event push.Event) {
p.mu.Lock()
defer p.mu.Unlock()
p.records = append(p.records, event)
}
func (p *recordingClientEventPublisher) events() []push.Event {
p.mu.Lock()
defer p.mu.Unlock()
cloned := make([]push.Event, len(p.records))
copy(cloned, p.records)
return cloned
}
@@ -0,0 +1,385 @@
package events
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"errors"
"net"
"sync"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/grpcapi"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
assert.Len(t, downstreamClient.commands(), 2)
}
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.UserID == "user-456"
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
commands := downstreamClient.commands()
require.Len(t, commands, 2)
assert.Equal(t, "user-456", commands[1].UserID)
}
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.Status == session.StatusRevoked
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Equal(t, 1, fallback.lookupCalls())
}
type runningAuthenticatedGateway struct {
cancel context.CancelFunc
resultCh chan error
}
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
t.Helper()
addr := unusedTCPAddr(t)
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = addr
grpcCfg.FreshnessWindow = 5 * time.Minute
router := downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": downstreamClient,
})
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
Router: router,
ResponseSigner: newTestResponseSigner(t),
SessionCache: sessionCache,
ReplayStore: staticReplayStore{},
Clock: fixedClock{now: testNow},
})
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
},
gateway,
subscriber,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "session subscriber did not start")
}
return addr, runningAuthenticatedGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func (g runningAuthenticatedGateway) stop(t *testing.T) {
t.Helper()
g.cancel()
select {
case err := <-g.resultCh:
require.NoError(t, err)
case <-time.After(2 * time.Second):
require.FailNow(t, "gateway did not stop after cancellation")
}
}
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
return conn
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
payloadBytes := []byte("payload")
payloadHash := sha256.Sum256(payloadBytes)
req := &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: "v1",
DeviceSessionId: "device-session-123",
MessageType: "fleet.move",
TimestampMs: testNow.UnixMilli(),
RequestId: requestID,
PayloadBytes: payloadBytes,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
PayloadHash: req.GetPayloadHash(),
}))
return req
}
func newActiveSessionRecord(userID string) session.Record {
return session.Record{
DeviceSessionID: "device-session-123",
UserID: userID,
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func testClientPrivateKey() ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
return ed25519.NewKeyFromSeed(seed[:])
}
func testClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
}
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
t.Helper()
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
require.NoError(t, err)
return signer
}
type fixedClock struct {
now time.Time
}
func (c fixedClock) Now() time.Time {
return c.now
}
var _ clock.Clock = fixedClock{}
type staticReplayStore struct{}
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
return nil
}
var _ replay.Store = staticReplayStore{}
type countingSessionCache struct {
mu sync.Mutex
records map[string]session.Record
lookupCount int
}
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.lookupCount++
record, ok := c.records["device-session-123"]
if !ok {
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
}
return record, nil
}
func (c *countingSessionCache) lookupCalls() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.lookupCount
}
type recordingDownstreamClient struct {
mu sync.Mutex
captured []downstream.AuthenticatedCommand
}
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
c.mu.Lock()
c.captured = append(c.captured, command)
c.mu.Unlock()
return downstream.UnaryResult{
ResultCode: "ok",
PayloadBytes: []byte("response"),
}, nil
}
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
c.mu.Lock()
defer c.mu.Unlock()
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
copy(cloned, c.captured)
return cloned
}
@@ -0,0 +1,416 @@
package events
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/grpcapi"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
targetOneCtx, cancelTargetOne := context.WithCancel(context.Background())
defer cancelTargetOne()
targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1")
targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background())
defer cancelTargetTwo()
targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2")
unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background())
defer cancelUnrelated()
unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3")
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-123",
"payload_bytes": []byte("payload-123"),
"request_id": "request-123",
"trace_id": "trace-123",
})
assertSignedPushEvent(t, recvPushEvent(t, targetOne), push.Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-123",
PayloadBytes: []byte("payload-123"),
RequestID: "request-123",
TraceID: "trace-123",
})
assertSignedPushEvent(t, recvPushEvent(t, targetTwo), push.Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-123",
PayloadBytes: []byte("payload-123"),
RequestID: "request-123",
TraceID: "trace-123",
})
assertNoPushEvent(t, unrelated, cancelUnrelated)
}
func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
otherCtx, cancelOther := context.WithCancel(context.Background())
defer cancelOther()
otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1")
targetCtx, cancelTarget := context.WithCancel(context.Background())
defer cancelTarget()
targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2")
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"device_session_id": "device-session-2",
"event_type": "fleet.updated",
"event_id": "event-456",
"payload_bytes": []byte("payload-456"),
})
assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{
UserID: "user-123",
DeviceSessionID: "device-session-2",
EventType: "fleet.updated",
EventID: "event-456",
PayloadBytes: []byte("payload-456"),
})
assertNoPushEvent(t, otherStream, cancelOther)
}
func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber)
defer running.stop(t)
select {
case <-sessionSubscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "session subscriber did not start")
}
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
streamCtx, cancelStream := context.WithCancel(context.Background())
defer cancelStream()
stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-1",
"user_id": "user-123",
"client_public_key": pushClientPublicKeyBase64(),
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1")
return lookupErr == nil && record.Status == session.StatusRevoked
}, time.Second, 10*time.Millisecond)
recvErrCh := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvErrCh <- recvErr
}()
select {
case recvErr := <-recvErrCh:
require.Error(t, recvErr)
assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr))
assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message())
case <-time.After(time.Second):
require.FailNow(t, "stream did not close after revoke")
}
reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2"))
if err == nil {
_, err = reopened.Recv()
}
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
}
func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
recvErrCh := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvErrCh <- recvErr
}()
running.cancel()
select {
case recvErr := <-recvErrCh:
require.Error(t, recvErr)
assert.Equal(t, codes.Unavailable, status.Code(recvErr))
assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message())
case <-time.After(time.Second):
require.FailNow(t, "stream did not close after gateway shutdown")
}
}
func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) {
t.Helper()
addr := unusedTCPAddr(t)
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = addr
grpcCfg.FreshnessWindow = 5 * time.Minute
responseSigner := newTestResponseSigner(t)
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()),
ResponseSigner: responseSigner,
SessionCache: sessionCache,
ReplayStore: staticReplayStore{},
Clock: fixedClock{now: testNow},
PushHub: pushHub,
})
components := []app.Component{gateway, clientSubscriber}
components = append(components, extraComponents...)
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
},
components...,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case <-clientSubscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "client event subscriber did not start")
}
return addr, runningAuthenticatedGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record {
return session.Record{
DeviceSessionID: deviceSessionID,
UserID: userID,
ClientPublicKey: pushClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
payloadHash := sha256.Sum256(nil)
traceID := "trace-" + deviceSessionID
req := &gatewayv1.SubscribeEventsRequest{
ProtocolVersion: "v1",
DeviceSessionId: deviceSessionID,
MessageType: "gateway.subscribe",
TimestampMs: testNow.UnixMilli(),
RequestId: requestID,
PayloadHash: payloadHash[:],
TraceId: traceID,
}
req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
PayloadHash: req.GetPayloadHash(),
}))
return req
}
func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
t.Helper()
event, err := stream.Recv()
require.NoError(t, err)
return event
}
func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, "gateway.server_time", event.GetEventType())
assert.Equal(t, wantRequestID, event.GetEventId())
assert.Equal(t, wantRequestID, event.GetRequestId())
assert.Equal(t, wantTraceID, event.GetTraceId())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
}
func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, want.EventType, event.GetEventType())
assert.Equal(t, want.EventID, event.GetEventId())
assert.Equal(t, want.RequestID, event.GetRequestId())
assert.Equal(t, want.TraceID, event.GetTraceId())
assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
}
func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) {
t.Helper()
recvCh := make(chan *gatewayv1.GatewayEvent, 1)
errCh := make(chan error, 1)
go func() {
event, err := stream.Recv()
if err != nil {
errCh <- err
return
}
recvCh <- event
}()
select {
case event := <-recvCh:
require.FailNowf(t, "unexpected push event delivered", "%+v", event)
case <-time.After(100 * time.Millisecond):
cancel()
case err := <-errCh:
require.FailNowf(t, "stream closed unexpectedly", "%v", err)
}
}
func pushClientPrivateKey() ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-push-grpc-test-client"))
return ed25519.NewKeyFromSeed(seed[:])
}
func pushClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey))
}
func pushResponseSignerPublicKey() ed25519.PublicKey {
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey)
}
+389
View File
@@ -0,0 +1,389 @@
// Package events subscribes to internal session lifecycle streams used to keep
// the gateway hot-path session cache synchronized without per-request upstream
// lookups.
package events
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const sessionEventReadCount int64 = 128
// SessionRevocationHandler reacts to a successfully applied revoked session
// snapshot and may tear down active resources bound to that session.
type SessionRevocationHandler interface {
// RevokeDeviceSession tears down active resources bound to deviceSessionID.
RevokeDeviceSession(deviceSessionID string)
}
// RedisSessionSubscriber consumes full session snapshots from one Redis Stream
// and applies them to a process-local session snapshot store.
type RedisSessionSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
store session.SnapshotStore
revocationHandler SessionRevocationHandler
logger *zap.Logger
metrics *telemetry.Runtime
closeOnce sync.Once
startedOnce sync.Once
started chan struct{}
}
// NewRedisSessionSubscriber constructs a Redis Stream subscriber that reuses
// the SessionCache Redis connection settings and applies updates to store.
func NewRedisSessionSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, nil, nil, nil)
}
// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream
// subscriber that reuses the SessionCache Redis connection settings, applies
// updates to store, and optionally tears down active resources for revoked
// sessions.
func NewRedisSessionSubscriberWithRevocationHandler(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, revocationHandler, nil, nil)
}
// NewRedisSessionSubscriberWithObservability constructs a Redis Stream
// subscriber that also logs and counts malformed internal session events.
func NewRedisSessionSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) {
if strings.TrimSpace(sessionCfg.Addr) == "" {
return nil, errors.New("new redis session subscriber: redis addr must not be empty")
}
if sessionCfg.DB < 0 {
return nil, errors.New("new redis session subscriber: redis db must not be negative")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis session subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis session subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis session subscriber: read block timeout must be positive")
}
if store == nil {
return nil, errors.New("new redis session subscriber: nil session snapshot store")
}
options := &redis.Options{
Addr: sessionCfg.Addr,
Username: sessionCfg.Username,
Password: sessionCfg.Password,
DB: sessionCfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if sessionCfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisSessionSubscriber{
client: redis.NewClient(options),
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
store: store,
revocationHandler: revocationHandler,
logger: logger.Named("session_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Ping verifies that the Redis backend used for session lifecycle events is
// reachable within the configured timeout budget.
func (s *RedisSessionSubscriber) Ping(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("ping redis session subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("ping redis session subscriber: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
defer cancel()
if err := s.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis session subscriber: %w", err)
}
return nil
}
// Run consumes session lifecycle events until ctx is canceled or Redis returns
// an unexpected error.
func (s *RedisSessionSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis session subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis session subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: sessionEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.applyMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis session subscriber: %w", err)
default:
return fmt.Errorf("run redis session subscriber: %w", err)
}
}
}
func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown closes the Redis client so a blocking stream read can terminate
// promptly during gateway shutdown.
func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis session subscriber: nil context")
}
return s.Close()
}
// Close releases the underlying Redis client resources.
func (s *RedisSessionSubscriber) Close() error {
if s == nil || s.client == nil {
return nil
}
var err error
s.closeOnce.Do(func() {
err = s.client.Close()
})
return err
}
func (s *RedisSessionSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) {
record, err := decodeSessionRecordSnapshot(message.Values)
if err != nil {
s.logger.Warn("dropped malformed session event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "malformed_event"),
)
if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok {
s.store.Delete(deviceSessionID)
}
return
}
if err := s.store.Upsert(record); err != nil {
s.logger.Warn("dropped session snapshot after store failure",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.String("device_session_id", record.DeviceSessionID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "store_failure"),
)
s.store.Delete(record.DeviceSessionID)
return
}
if record.Status == session.StatusRevoked && s.revocationHandler != nil {
s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID)
}
}
func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) {
requiredKeys := map[string]struct{}{
"device_session_id": {},
"user_id": {},
"client_public_key": {},
"status": {},
}
optionalKeys := map[string]struct{}{
"revoked_at_ms": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key)
}
deviceSessionID, err := requiredStringField(values, "device_session_id")
if err != nil {
return session.Record{}, err
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return session.Record{}, err
}
clientPublicKey, err := requiredStringField(values, "client_public_key")
if err != nil {
return session.Record{}, err
}
statusValue, err := requiredStringField(values, "status")
if err != nil {
return session.Record{}, err
}
record := session.Record{
DeviceSessionID: deviceSessionID,
UserID: userID,
ClientPublicKey: clientPublicKey,
Status: session.Status(statusValue),
}
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms")
if err != nil {
return session.Record{}, err
}
record.RevokedAtMS = &revokedAtMS
}
return record, nil
}
func extractDeviceSessionID(values map[string]any) (string, bool) {
value, ok := values["device_session_id"]
if !ok {
return "", false
}
deviceSessionID, err := coerceString(value)
if err != nil {
return "", false
}
if strings.TrimSpace(deviceSessionID) == "" {
return "", false
}
return deviceSessionID, true
}
func requiredStringField(values map[string]any, field string) (string, error) {
value, ok := values[field]
if !ok {
return "", fmt.Errorf("decode session event: missing %s", field)
}
stringValue, err := coerceString(value)
if err != nil {
return "", fmt.Errorf("decode session event: %s: %w", field, err)
}
if strings.TrimSpace(stringValue) == "" {
return "", fmt.Errorf("decode session event: %s must not be empty", field)
}
return stringValue, nil
}
func parseInt64Field(value any, field string) (int64, error) {
stringValue, err := coerceString(value)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
return parsed, nil
}
func coerceString(value any) (string, error) {
switch typed := value.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
case fmt.Stringer:
return typed.String(), nil
case int:
return strconv.Itoa(typed), nil
case int64:
return strconv.FormatInt(typed, 10), nil
case uint64:
return strconv.FormatUint(typed, 10), nil
default:
return "", fmt.Errorf("unsupported value type %T", value)
}
}
+366
View File
@@ -0,0 +1,366 @@
package events
import (
"context"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-123" && record.Status == session.StatusActive
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
require.NoError(t, store.Upsert(session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: session.StatusActive,
}))
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil || record.RevokedAtMS == nil {
return false
}
return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil || record.Status != session.StatusRevoked {
return false
}
return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations())
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
assert.Never(t, func() bool {
return len(handler.revocations()) != 0
}, 100*time.Millisecond, 10*time.Millisecond)
}
func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
assert.Never(t, func() bool {
return len(handler.revocations()) != 0
}, 100*time.Millisecond, 10*time.Millisecond)
}
func TestRedisSessionSubscriberLaterEventWins(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": "public-key-456",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456"
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
require.NoError(t, store.Upsert(session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: session.StatusActive,
}))
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": "paused",
})
require.Eventually(t, func() bool {
_, err := store.Lookup(context.Background(), "device-session-123")
return err != nil
}, time.Second, 10*time.Millisecond)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": "public-key-456",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-456" && record.Status == session.StatusActive
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
cancel()
require.NoError(t, subscriber.Shutdown(context.Background()))
select {
case err := <-resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop after shutdown")
}
}
func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber {
t.Helper()
return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil)
}
func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber {
t.Helper()
subscriber, err := NewRedisSessionSubscriberWithRevocationHandler(
config.SessionCacheRedisConfig{
Addr: server.Addr(),
LookupTimeout: 250 * time.Millisecond,
},
config.SessionEventsRedisConfig{
Stream: "gateway:session_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
store,
revocationHandler,
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, subscriber.Close())
})
return subscriber
}
type recordingSessionRevocationHandler struct {
mu sync.Mutex
revokedIDs []string
}
func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) {
h.mu.Lock()
h.revokedIDs = append(h.revokedIDs, deviceSessionID)
h.mu.Unlock()
}
func (h *recordingSessionRevocationHandler) revocations() []string {
h.mu.Lock()
defer h.mu.Unlock()
return append([]string(nil), h.revokedIDs...)
}
type failingSnapshotStore struct{}
func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) {
return session.Record{}, session.ErrNotFound
}
func (failingSnapshotStore) Upsert(session.Record) error {
return context.DeadlineExceeded
}
func (failingSnapshotStore) Delete(string) {}
func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) {
t.Helper()
values := make([]string, 0, len(fields)*2)
for key, value := range fields {
values = append(values, key, value)
}
_, err := server.XAdd(stream, "*", values)
require.NoError(t, err)
}
type runningSubscriber struct {
cancel context.CancelFunc
resultCh chan error
stopOnce bool
}
func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
return runningSubscriber{
cancel: cancel,
resultCh: resultCh,
}
}
func (r runningSubscriber) stop(t *testing.T) {
t.Helper()
r.cancel()
select {
case err := <-r.resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop")
}
}