feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,341 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const clientEventReadCount int64 = 128
|
||||
|
||||
// ClientEventPublisher accepts decoded client-facing events from the internal
|
||||
// event subscriber.
|
||||
type ClientEventPublisher interface {
|
||||
// Publish fans out event to the currently active push streams.
|
||||
Publish(event push.Event)
|
||||
}
|
||||
|
||||
// RedisClientEventSubscriber consumes client-facing events from one Redis
|
||||
// Stream and forwards them to the configured publisher.
|
||||
type RedisClientEventSubscriber struct {
|
||||
client *redis.Client
|
||||
stream string
|
||||
pingTimeout time.Duration
|
||||
readBlockTimeout time.Duration
|
||||
publisher ClientEventPublisher
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
closeOnce sync.Once
|
||||
startedOnce sync.Once
|
||||
started chan struct{}
|
||||
}
|
||||
|
||||
// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that
|
||||
// reuses the SessionCache Redis connection settings and forwards decoded
|
||||
// client-facing events to publisher.
|
||||
func NewRedisClientEventSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) {
|
||||
return NewRedisClientEventSubscriberWithObservability(sessionCfg, eventsCfg, publisher, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream
|
||||
// subscriber that also records malformed or dropped internal events.
|
||||
func NewRedisClientEventSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) {
|
||||
if strings.TrimSpace(sessionCfg.Addr) == "" {
|
||||
return nil, errors.New("new redis client event subscriber: redis addr must not be empty")
|
||||
}
|
||||
if sessionCfg.DB < 0 {
|
||||
return nil, errors.New("new redis client event subscriber: redis db must not be negative")
|
||||
}
|
||||
if sessionCfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis client event subscriber: lookup timeout must be positive")
|
||||
}
|
||||
if strings.TrimSpace(eventsCfg.Stream) == "" {
|
||||
return nil, errors.New("new redis client event subscriber: stream must not be empty")
|
||||
}
|
||||
if eventsCfg.ReadBlockTimeout <= 0 {
|
||||
return nil, errors.New("new redis client event subscriber: read block timeout must be positive")
|
||||
}
|
||||
if publisher == nil {
|
||||
return nil, errors.New("new redis client event subscriber: nil publisher")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: sessionCfg.Addr,
|
||||
Username: sessionCfg.Username,
|
||||
Password: sessionCfg.Password,
|
||||
DB: sessionCfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if sessionCfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &RedisClientEventSubscriber{
|
||||
client: redis.NewClient(options),
|
||||
stream: eventsCfg.Stream,
|
||||
pingTimeout: sessionCfg.LookupTimeout,
|
||||
readBlockTimeout: eventsCfg.ReadBlockTimeout,
|
||||
publisher: publisher,
|
||||
logger: logger.Named("client_event_subscriber"),
|
||||
metrics: metrics,
|
||||
started: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping verifies that the Redis backend used for client-facing event fan-out is
|
||||
// reachable within the configured timeout budget.
|
||||
func (s *RedisClientEventSubscriber) Ping(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("ping redis client event subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("ping redis client event subscriber: nil context")
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(pingCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis client event subscriber: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run consumes client-facing events until ctx is canceled or Redis returns an
|
||||
// unexpected error.
|
||||
func (s *RedisClientEventSubscriber) Run(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("run redis client event subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("run redis client event subscriber: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastID, err := s.resolveStartID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.signalStarted()
|
||||
|
||||
for {
|
||||
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
|
||||
Streams: []string{s.stream, lastID},
|
||||
Count: clientEventReadCount,
|
||||
Block: s.readBlockTimeout,
|
||||
}).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
s.publishMessage(message)
|
||||
lastID = message.ID
|
||||
}
|
||||
}
|
||||
continue
|
||||
case errors.Is(err, redis.Nil):
|
||||
continue
|
||||
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
|
||||
return ctx.Err()
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
|
||||
return fmt.Errorf("run redis client event subscriber: %w", err)
|
||||
default:
|
||||
return fmt.Errorf("run redis client event subscriber: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) {
|
||||
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "0-0", nil
|
||||
default:
|
||||
return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "0-0", nil
|
||||
}
|
||||
|
||||
return messages[0].ID, nil
|
||||
}
|
||||
|
||||
// Shutdown closes the Redis client so a blocking stream read can terminate
|
||||
// promptly during gateway shutdown.
|
||||
func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown redis client event subscriber: nil context")
|
||||
}
|
||||
|
||||
return s.Close()
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *RedisClientEventSubscriber) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.closeOnce.Do(func() {
|
||||
err = s.client.Close()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) signalStarted() {
|
||||
s.startedOnce.Do(func() {
|
||||
close(s.started)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) {
|
||||
event, err := decodeClientEvent(message.Values)
|
||||
if err != nil {
|
||||
s.logger.Warn("dropped malformed client event",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "client_event_subscriber"),
|
||||
attribute.String("reason", "malformed_event"),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
s.publisher.Publish(event)
|
||||
}
|
||||
|
||||
func decodeClientEvent(values map[string]any) (push.Event, error) {
|
||||
requiredKeys := map[string]struct{}{
|
||||
"user_id": {},
|
||||
"event_type": {},
|
||||
"event_id": {},
|
||||
"payload_bytes": {},
|
||||
}
|
||||
optionalKeys := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"request_id": {},
|
||||
"trace_id": {},
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if _, ok := requiredKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := optionalKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key)
|
||||
}
|
||||
|
||||
userID, err := requiredStringField(values, "user_id")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
eventType, err := requiredStringField(values, "event_type")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
eventID, err := requiredStringField(values, "event_id")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
payloadBytes, err := requiredBytesField(values, "payload_bytes")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
|
||||
event := push.Event{
|
||||
UserID: userID,
|
||||
EventType: eventType,
|
||||
EventID: eventID,
|
||||
PayloadBytes: payloadBytes,
|
||||
}
|
||||
|
||||
if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.DeviceSessionID = strings.TrimSpace(deviceSessionID)
|
||||
}
|
||||
|
||||
if requestID, ok, err := optionalStringField(values, "request_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.RequestID = requestID
|
||||
}
|
||||
|
||||
if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.TraceID = traceID
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func requiredBytesField(values map[string]any, field string) ([]byte, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("decode client event: missing %s", field)
|
||||
}
|
||||
|
||||
byteValue, err := coerceBytes(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode client event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return byteValue, nil
|
||||
}
|
||||
|
||||
func optionalStringField(values map[string]any, field string) (string, bool, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("decode client event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return stringValue, true, nil
|
||||
}
|
||||
|
||||
func coerceBytes(value any) ([]byte, error) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return []byte(typed), nil
|
||||
case []byte:
|
||||
return bytes.Clone(typed), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type %T", value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisClientEventSubscriberPublishesValidEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"device_session_id": "device-session-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-123",
|
||||
"payload_bytes": []byte("payload-123"),
|
||||
"request_id": "request-123",
|
||||
"trace_id": "trace-123",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(publisher.events()) == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
assert.Equal(t, []push.Event{{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
}}, publisher.events())
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberSkipsMalformedEventAndContinues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-bad",
|
||||
"payload_bytes": []byte("payload-bad"),
|
||||
"unexpected": "boom",
|
||||
})
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-good",
|
||||
"payload_bytes": []byte("payload-good"),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
events := publisher.events()
|
||||
return len(events) == 1 && events[0].EventID == "event-good"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberStartsFromCurrentTail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-old",
|
||||
"payload_bytes": []byte("payload-old"),
|
||||
})
|
||||
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(publisher.events()) > 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-new",
|
||||
"payload_bytes": []byte("payload-new"),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
events := publisher.events()
|
||||
return len(events) == 1 && events[0].EventID == "event-new"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
require.NoError(t, subscriber.Shutdown(context.Background()))
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberLogsAndCountsMalformedEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
logger, logBuffer := testutil.NewObservedLogger(t)
|
||||
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
||||
|
||||
subscriber, err := NewRedisClientEventSubscriberWithObservability(
|
||||
config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.ClientEventsRedisConfig{
|
||||
Stream: "gateway:client_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
publisher,
|
||||
logger,
|
||||
telemetryRuntime,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, subscriber.Close())
|
||||
})
|
||||
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-bad",
|
||||
"payload_bytes": []byte("payload-bad"),
|
||||
"unexpected": "boom",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return strings.Contains(logBuffer.String(), "dropped malformed client event")
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
|
||||
assert.Contains(t, metricsText, `gateway_internal_event_drops_total`)
|
||||
assert.Contains(t, metricsText, `component="client_event_subscriber"`)
|
||||
assert.Contains(t, metricsText, `reason="malformed_event"`)
|
||||
}
|
||||
|
||||
func newTestRedisClientEventSubscriber(t *testing.T, server *miniredis.Miniredis, publisher ClientEventPublisher) *RedisClientEventSubscriber {
|
||||
t.Helper()
|
||||
|
||||
subscriber, err := NewRedisClientEventSubscriber(
|
||||
config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.ClientEventsRedisConfig{
|
||||
Stream: "gateway:client_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
publisher,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, subscriber.Close())
|
||||
})
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
func addClientEvent(t *testing.T, server *miniredis.Miniredis, stream string, values map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: server.Addr(),
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
defer func() {
|
||||
assert.NoError(t, client.Close())
|
||||
}()
|
||||
|
||||
err := client.XAdd(context.Background(), &redis.XAddArgs{
|
||||
Stream: stream,
|
||||
Values: values,
|
||||
}).Err()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runningClientEventSubscriber struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func runTestClientEventSubscriber(t *testing.T, subscriber *RedisClientEventSubscriber) runningClientEventSubscriber {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
return runningClientEventSubscriber{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (r runningClientEventSubscriber) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r.cancel()
|
||||
|
||||
select {
|
||||
case err := <-r.resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop")
|
||||
}
|
||||
}
|
||||
|
||||
type recordingClientEventPublisher struct {
|
||||
mu sync.Mutex
|
||||
records []push.Event
|
||||
}
|
||||
|
||||
func (p *recordingClientEventPublisher) Publish(event push.Event) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.records = append(p.records, event)
|
||||
}
|
||||
|
||||
func (p *recordingClientEventPublisher) events() []push.Event {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
cloned := make([]push.Event, len(p.records))
|
||||
copy(cloned, p.records)
|
||||
return cloned
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/grpcapi"
|
||||
"galaxy/gateway/internal/replay"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
assert.Len(t, downstreamClient.commands(), 2)
|
||||
}
|
||||
|
||||
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": testClientPublicKeyBase64(),
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
||||
return lookupErr == nil && record.UserID == "user-456"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
commands := downstreamClient.commands()
|
||||
require.Len(t, commands, 2)
|
||||
assert.Equal(t, "user-456", commands[1].UserID)
|
||||
}
|
||||
|
||||
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": testClientPublicKeyBase64(),
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
||||
return lookupErr == nil && record.Status == session.StatusRevoked
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
}
|
||||
|
||||
type runningAuthenticatedGateway struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
|
||||
t.Helper()
|
||||
|
||||
addr := unusedTCPAddr(t)
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = addr
|
||||
grpcCfg.FreshnessWindow = 5 * time.Minute
|
||||
|
||||
router := downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": downstreamClient,
|
||||
})
|
||||
|
||||
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
||||
Router: router,
|
||||
ResponseSigner: newTestResponseSigner(t),
|
||||
SessionCache: sessionCache,
|
||||
ReplayStore: staticReplayStore{},
|
||||
Clock: fixedClock{now: testNow},
|
||||
})
|
||||
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
},
|
||||
gateway,
|
||||
subscriber,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "session subscriber did not start")
|
||||
}
|
||||
|
||||
return addr, runningAuthenticatedGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (g runningAuthenticatedGateway) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
g.cancel()
|
||||
|
||||
select {
|
||||
case err := <-g.resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
require.FailNow(t, "gateway did not stop after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
|
||||
payloadBytes := []byte("payload")
|
||||
payloadHash := sha256.Sum256(payloadBytes)
|
||||
|
||||
req := &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionId: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMs: testNow.UnixMilli(),
|
||||
RequestId: requestID,
|
||||
PayloadBytes: payloadBytes,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: "trace-123",
|
||||
}
|
||||
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
PayloadHash: req.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func newActiveSessionRecord(userID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: userID,
|
||||
ClientPublicKey: testClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func testClientPrivateKey() ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func testClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
|
||||
t.Helper()
|
||||
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
||||
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
|
||||
require.NoError(t, err)
|
||||
|
||||
return signer
|
||||
}
|
||||
|
||||
type fixedClock struct {
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (c fixedClock) Now() time.Time {
|
||||
return c.now
|
||||
}
|
||||
|
||||
var _ clock.Clock = fixedClock{}
|
||||
|
||||
type staticReplayStore struct{}
|
||||
|
||||
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ replay.Store = staticReplayStore{}
|
||||
|
||||
type countingSessionCache struct {
|
||||
mu sync.Mutex
|
||||
records map[string]session.Record
|
||||
lookupCount int
|
||||
}
|
||||
|
||||
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lookupCount++
|
||||
|
||||
record, ok := c.records["device-session-123"]
|
||||
if !ok {
|
||||
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *countingSessionCache) lookupCalls() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.lookupCount
|
||||
}
|
||||
|
||||
type recordingDownstreamClient struct {
|
||||
mu sync.Mutex
|
||||
captured []downstream.AuthenticatedCommand
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
c.mu.Lock()
|
||||
c.captured = append(c.captured, command)
|
||||
c.mu.Unlock()
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "ok",
|
||||
PayloadBytes: []byte("response"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
|
||||
copy(cloned, c.captured)
|
||||
return cloned
|
||||
}
|
||||
@@ -0,0 +1,416 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/grpcapi"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
targetOneCtx, cancelTargetOne := context.WithCancel(context.Background())
|
||||
defer cancelTargetOne()
|
||||
targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1")
|
||||
|
||||
targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background())
|
||||
defer cancelTargetTwo()
|
||||
targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2")
|
||||
|
||||
unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background())
|
||||
defer cancelUnrelated()
|
||||
unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3")
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-123",
|
||||
"payload_bytes": []byte("payload-123"),
|
||||
"request_id": "request-123",
|
||||
"trace_id": "trace-123",
|
||||
})
|
||||
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetOne), push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetTwo), push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertNoPushEvent(t, unrelated, cancelUnrelated)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
otherCtx, cancelOther := context.WithCancel(context.Background())
|
||||
defer cancelOther()
|
||||
otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1")
|
||||
|
||||
targetCtx, cancelTarget := context.WithCancel(context.Background())
|
||||
defer cancelTarget()
|
||||
targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2")
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"device_session_id": "device-session-2",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-456",
|
||||
"payload_bytes": []byte("payload-456"),
|
||||
})
|
||||
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-456",
|
||||
PayloadBytes: []byte("payload-456"),
|
||||
})
|
||||
assertNoPushEvent(t, otherStream, cancelOther)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
select {
|
||||
case <-sessionSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "session subscriber did not start")
|
||||
}
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
streamCtx, cancelStream := context.WithCancel(context.Background())
|
||||
defer cancelStream()
|
||||
|
||||
stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-1",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": pushClientPublicKeyBase64(),
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1")
|
||||
return lookupErr == nil && record.Status == session.StatusRevoked
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after revoke")
|
||||
}
|
||||
|
||||
reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2"))
|
||||
if err == nil {
|
||||
_, err = reopened.Recv()
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
running.cancel()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(recvErr))
|
||||
assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after gateway shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) {
|
||||
t.Helper()
|
||||
|
||||
addr := unusedTCPAddr(t)
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = addr
|
||||
grpcCfg.FreshnessWindow = 5 * time.Minute
|
||||
|
||||
responseSigner := newTestResponseSigner(t)
|
||||
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
||||
Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()),
|
||||
ResponseSigner: responseSigner,
|
||||
SessionCache: sessionCache,
|
||||
ReplayStore: staticReplayStore{},
|
||||
Clock: fixedClock{now: testNow},
|
||||
PushHub: pushHub,
|
||||
})
|
||||
|
||||
components := []app.Component{gateway, clientSubscriber}
|
||||
components = append(components, extraComponents...)
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
},
|
||||
components...,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-clientSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "client event subscriber did not start")
|
||||
}
|
||||
|
||||
return addr, runningAuthenticatedGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: pushClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
|
||||
payloadHash := sha256.Sum256(nil)
|
||||
traceID := "trace-" + deviceSessionID
|
||||
|
||||
req := &gatewayv1.SubscribeEventsRequest{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "gateway.subscribe",
|
||||
TimestampMs: testNow.UnixMilli(),
|
||||
RequestId: requestID,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: traceID,
|
||||
}
|
||||
req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
PayloadHash: req.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
t.Helper()
|
||||
|
||||
event, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
return event
|
||||
}
|
||||
|
||||
func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, "gateway.server_time", event.GetEventType())
|
||||
assert.Equal(t, wantRequestID, event.GetEventId())
|
||||
assert.Equal(t, wantRequestID, event.GetRequestId())
|
||||
assert.Equal(t, wantTraceID, event.GetTraceId())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, want.EventType, event.GetEventType())
|
||||
assert.Equal(t, want.EventID, event.GetEventId())
|
||||
assert.Equal(t, want.RequestID, event.GetRequestId())
|
||||
assert.Equal(t, want.TraceID, event.GetTraceId())
|
||||
assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) {
|
||||
t.Helper()
|
||||
|
||||
recvCh := make(chan *gatewayv1.GatewayEvent, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
recvCh <- event
|
||||
}()
|
||||
|
||||
select {
|
||||
case event := <-recvCh:
|
||||
require.FailNowf(t, "unexpected push event delivered", "%+v", event)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
cancel()
|
||||
case err := <-errCh:
|
||||
require.FailNowf(t, "stream closed unexpectedly", "%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func pushClientPrivateKey() ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-push-grpc-test-client"))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func pushClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func pushResponseSignerPublicKey() ed25519.PublicKey {
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
||||
return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey)
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
// Package events subscribes to internal session lifecycle streams used to keep
|
||||
// the gateway hot-path session cache synchronized without per-request upstream
|
||||
// lookups.
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const sessionEventReadCount int64 = 128
|
||||
|
||||
// SessionRevocationHandler reacts to a successfully applied revoked session
|
||||
// snapshot and may tear down active resources bound to that session.
|
||||
type SessionRevocationHandler interface {
|
||||
// RevokeDeviceSession tears down active resources bound to deviceSessionID.
|
||||
RevokeDeviceSession(deviceSessionID string)
|
||||
}
|
||||
|
||||
// RedisSessionSubscriber consumes full session snapshots from one Redis Stream
|
||||
// and applies them to a process-local session snapshot store.
|
||||
type RedisSessionSubscriber struct {
|
||||
client *redis.Client
|
||||
stream string
|
||||
pingTimeout time.Duration
|
||||
readBlockTimeout time.Duration
|
||||
store session.SnapshotStore
|
||||
revocationHandler SessionRevocationHandler
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
closeOnce sync.Once
|
||||
startedOnce sync.Once
|
||||
started chan struct{}
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriber constructs a Redis Stream subscriber that reuses
|
||||
// the SessionCache Redis connection settings and applies updates to store.
|
||||
func NewRedisSessionSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) {
|
||||
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, nil, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream
|
||||
// subscriber that reuses the SessionCache Redis connection settings, applies
|
||||
// updates to store, and optionally tears down active resources for revoked
|
||||
// sessions.
|
||||
func NewRedisSessionSubscriberWithRevocationHandler(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) {
|
||||
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, revocationHandler, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriberWithObservability constructs a Redis Stream
|
||||
// subscriber that also logs and counts malformed internal session events.
|
||||
func NewRedisSessionSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) {
|
||||
if strings.TrimSpace(sessionCfg.Addr) == "" {
|
||||
return nil, errors.New("new redis session subscriber: redis addr must not be empty")
|
||||
}
|
||||
if sessionCfg.DB < 0 {
|
||||
return nil, errors.New("new redis session subscriber: redis db must not be negative")
|
||||
}
|
||||
if sessionCfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis session subscriber: lookup timeout must be positive")
|
||||
}
|
||||
if strings.TrimSpace(eventsCfg.Stream) == "" {
|
||||
return nil, errors.New("new redis session subscriber: stream must not be empty")
|
||||
}
|
||||
if eventsCfg.ReadBlockTimeout <= 0 {
|
||||
return nil, errors.New("new redis session subscriber: read block timeout must be positive")
|
||||
}
|
||||
if store == nil {
|
||||
return nil, errors.New("new redis session subscriber: nil session snapshot store")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: sessionCfg.Addr,
|
||||
Username: sessionCfg.Username,
|
||||
Password: sessionCfg.Password,
|
||||
DB: sessionCfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if sessionCfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &RedisSessionSubscriber{
|
||||
client: redis.NewClient(options),
|
||||
stream: eventsCfg.Stream,
|
||||
pingTimeout: sessionCfg.LookupTimeout,
|
||||
readBlockTimeout: eventsCfg.ReadBlockTimeout,
|
||||
store: store,
|
||||
revocationHandler: revocationHandler,
|
||||
logger: logger.Named("session_subscriber"),
|
||||
metrics: metrics,
|
||||
started: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping verifies that the Redis backend used for session lifecycle events is
|
||||
// reachable within the configured timeout budget.
|
||||
func (s *RedisSessionSubscriber) Ping(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("ping redis session subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("ping redis session subscriber: nil context")
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(pingCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis session subscriber: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run consumes session lifecycle events until ctx is canceled or Redis returns
|
||||
// an unexpected error.
|
||||
func (s *RedisSessionSubscriber) Run(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("run redis session subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("run redis session subscriber: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastID, err := s.resolveStartID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.signalStarted()
|
||||
|
||||
for {
|
||||
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
|
||||
Streams: []string{s.stream, lastID},
|
||||
Count: sessionEventReadCount,
|
||||
Block: s.readBlockTimeout,
|
||||
}).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
s.applyMessage(message)
|
||||
lastID = message.ID
|
||||
}
|
||||
}
|
||||
continue
|
||||
case errors.Is(err, redis.Nil):
|
||||
continue
|
||||
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
|
||||
return ctx.Err()
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
|
||||
return fmt.Errorf("run redis session subscriber: %w", err)
|
||||
default:
|
||||
return fmt.Errorf("run redis session subscriber: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) {
|
||||
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "0-0", nil
|
||||
default:
|
||||
return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "0-0", nil
|
||||
}
|
||||
|
||||
return messages[0].ID, nil
|
||||
}
|
||||
|
||||
// Shutdown closes the Redis client so a blocking stream read can terminate
|
||||
// promptly during gateway shutdown.
|
||||
func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown redis session subscriber: nil context")
|
||||
}
|
||||
|
||||
return s.Close()
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *RedisSessionSubscriber) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.closeOnce.Do(func() {
|
||||
err = s.client.Close()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) signalStarted() {
|
||||
s.startedOnce.Do(func() {
|
||||
close(s.started)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) {
|
||||
record, err := decodeSessionRecordSnapshot(message.Values)
|
||||
if err != nil {
|
||||
s.logger.Warn("dropped malformed session event",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "session_subscriber"),
|
||||
attribute.String("reason", "malformed_event"),
|
||||
)
|
||||
if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok {
|
||||
s.store.Delete(deviceSessionID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Upsert(record); err != nil {
|
||||
s.logger.Warn("dropped session snapshot after store failure",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.String("device_session_id", record.DeviceSessionID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "session_subscriber"),
|
||||
attribute.String("reason", "store_failure"),
|
||||
)
|
||||
s.store.Delete(record.DeviceSessionID)
|
||||
return
|
||||
}
|
||||
|
||||
if record.Status == session.StatusRevoked && s.revocationHandler != nil {
|
||||
s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) {
|
||||
requiredKeys := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"user_id": {},
|
||||
"client_public_key": {},
|
||||
"status": {},
|
||||
}
|
||||
optionalKeys := map[string]struct{}{
|
||||
"revoked_at_ms": {},
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if _, ok := requiredKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := optionalKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key)
|
||||
}
|
||||
|
||||
deviceSessionID, err := requiredStringField(values, "device_session_id")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
userID, err := requiredStringField(values, "user_id")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
clientPublicKey, err := requiredStringField(values, "client_public_key")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
statusValue, err := requiredStringField(values, "status")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
|
||||
record := session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: session.Status(statusValue),
|
||||
}
|
||||
|
||||
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
|
||||
revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
record.RevokedAtMS = &revokedAtMS
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func extractDeviceSessionID(values map[string]any) (string, bool) {
|
||||
value, ok := values["device_session_id"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
deviceSessionID, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return deviceSessionID, true
|
||||
}
|
||||
|
||||
func requiredStringField(values map[string]any, field string) (string, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("decode session event: missing %s", field)
|
||||
}
|
||||
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
if strings.TrimSpace(stringValue) == "" {
|
||||
return "", fmt.Errorf("decode session event: %s must not be empty", field)
|
||||
}
|
||||
|
||||
return stringValue, nil
|
||||
}
|
||||
|
||||
func parseInt64Field(value any, field string) (int64, error) {
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func coerceString(value any) (string, error) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed, nil
|
||||
case []byte:
|
||||
return string(typed), nil
|
||||
case fmt.Stringer:
|
||||
return typed.String(), nil
|
||||
case int:
|
||||
return strconv.Itoa(typed), nil
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10), nil
|
||||
case uint64:
|
||||
return strconv.FormatUint(typed, 10), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported value type %T", value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-123" && record.Status == session.StatusActive
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
require.NoError(t, store.Upsert(session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: session.StatusActive,
|
||||
}))
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil || record.RevokedAtMS == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil || record.Status != session.StatusRevoked {
|
||||
return false
|
||||
}
|
||||
|
||||
return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations())
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(handler.revocations()) != 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(handler.revocations()) != 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberLaterEventWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": "public-key-456",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
require.NoError(t, store.Upsert(session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: session.StatusActive,
|
||||
}))
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": "paused",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := store.Lookup(context.Background(), "device-session-123")
|
||||
return err != nil
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": "public-key-456",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-456" && record.Status == session.StatusActive
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
require.NoError(t, subscriber.Shutdown(context.Background()))
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber {
|
||||
t.Helper()
|
||||
|
||||
return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil)
|
||||
}
|
||||
|
||||
func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber {
|
||||
t.Helper()
|
||||
|
||||
subscriber, err := NewRedisSessionSubscriberWithRevocationHandler(
|
||||
config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.SessionEventsRedisConfig{
|
||||
Stream: "gateway:session_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
store,
|
||||
revocationHandler,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, subscriber.Close())
|
||||
})
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
type recordingSessionRevocationHandler struct {
|
||||
mu sync.Mutex
|
||||
revokedIDs []string
|
||||
}
|
||||
|
||||
func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) {
|
||||
h.mu.Lock()
|
||||
h.revokedIDs = append(h.revokedIDs, deviceSessionID)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *recordingSessionRevocationHandler) revocations() []string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
return append([]string(nil), h.revokedIDs...)
|
||||
}
|
||||
|
||||
type failingSnapshotStore struct{}
|
||||
|
||||
func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
}
|
||||
|
||||
func (failingSnapshotStore) Upsert(session.Record) error {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
func (failingSnapshotStore) Delete(string) {}
|
||||
|
||||
func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
values := make([]string, 0, len(fields)*2)
|
||||
for key, value := range fields {
|
||||
values = append(values, key, value)
|
||||
}
|
||||
|
||||
_, err := server.XAdd(stream, "*", values)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runningSubscriber struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
stopOnce bool
|
||||
}
|
||||
|
||||
func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
return runningSubscriber{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (r runningSubscriber) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r.cancel()
|
||||
|
||||
select {
|
||||
case err := <-r.resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user