Files
galaxy-game/gateway/internal/events/subscriber.go
T
2026-04-02 19:18:42 +02:00

390 lines
11 KiB
Go

// Package events subscribes to internal session lifecycle streams used to keep
// the gateway hot-path session cache synchronized without per-request upstream
// lookups.
package events
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const sessionEventReadCount int64 = 128
// SessionRevocationHandler reacts to a successfully applied revoked session
// snapshot and may tear down active resources bound to that session.
type SessionRevocationHandler interface {
// RevokeDeviceSession tears down active resources bound to deviceSessionID.
RevokeDeviceSession(deviceSessionID string)
}
// RedisSessionSubscriber consumes full session snapshots from one Redis Stream
// and applies them to a process-local session snapshot store.
type RedisSessionSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
store session.SnapshotStore
revocationHandler SessionRevocationHandler
logger *zap.Logger
metrics *telemetry.Runtime
closeOnce sync.Once
startedOnce sync.Once
started chan struct{}
}
// NewRedisSessionSubscriber constructs a Redis Stream subscriber that reuses
// the SessionCache Redis connection settings and applies updates to store.
func NewRedisSessionSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, nil, nil, nil)
}
// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream
// subscriber that reuses the SessionCache Redis connection settings, applies
// updates to store, and optionally tears down active resources for revoked
// sessions.
func NewRedisSessionSubscriberWithRevocationHandler(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, revocationHandler, nil, nil)
}
// NewRedisSessionSubscriberWithObservability constructs a Redis Stream
// subscriber that also logs and counts malformed internal session events.
func NewRedisSessionSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) {
if strings.TrimSpace(sessionCfg.Addr) == "" {
return nil, errors.New("new redis session subscriber: redis addr must not be empty")
}
if sessionCfg.DB < 0 {
return nil, errors.New("new redis session subscriber: redis db must not be negative")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis session subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis session subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis session subscriber: read block timeout must be positive")
}
if store == nil {
return nil, errors.New("new redis session subscriber: nil session snapshot store")
}
options := &redis.Options{
Addr: sessionCfg.Addr,
Username: sessionCfg.Username,
Password: sessionCfg.Password,
DB: sessionCfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if sessionCfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisSessionSubscriber{
client: redis.NewClient(options),
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
store: store,
revocationHandler: revocationHandler,
logger: logger.Named("session_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Ping verifies that the Redis backend used for session lifecycle events is
// reachable within the configured timeout budget.
func (s *RedisSessionSubscriber) Ping(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("ping redis session subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("ping redis session subscriber: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
defer cancel()
if err := s.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis session subscriber: %w", err)
}
return nil
}
// Run consumes session lifecycle events until ctx is canceled or Redis returns
// an unexpected error.
func (s *RedisSessionSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis session subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis session subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: sessionEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.applyMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis session subscriber: %w", err)
default:
return fmt.Errorf("run redis session subscriber: %w", err)
}
}
}
func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown closes the Redis client so a blocking stream read can terminate
// promptly during gateway shutdown.
func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis session subscriber: nil context")
}
return s.Close()
}
// Close releases the underlying Redis client resources.
func (s *RedisSessionSubscriber) Close() error {
if s == nil || s.client == nil {
return nil
}
var err error
s.closeOnce.Do(func() {
err = s.client.Close()
})
return err
}
func (s *RedisSessionSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) {
record, err := decodeSessionRecordSnapshot(message.Values)
if err != nil {
s.logger.Warn("dropped malformed session event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "malformed_event"),
)
if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok {
s.store.Delete(deviceSessionID)
}
return
}
if err := s.store.Upsert(record); err != nil {
s.logger.Warn("dropped session snapshot after store failure",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.String("device_session_id", record.DeviceSessionID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "store_failure"),
)
s.store.Delete(record.DeviceSessionID)
return
}
if record.Status == session.StatusRevoked && s.revocationHandler != nil {
s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID)
}
}
func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) {
requiredKeys := map[string]struct{}{
"device_session_id": {},
"user_id": {},
"client_public_key": {},
"status": {},
}
optionalKeys := map[string]struct{}{
"revoked_at_ms": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key)
}
deviceSessionID, err := requiredStringField(values, "device_session_id")
if err != nil {
return session.Record{}, err
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return session.Record{}, err
}
clientPublicKey, err := requiredStringField(values, "client_public_key")
if err != nil {
return session.Record{}, err
}
statusValue, err := requiredStringField(values, "status")
if err != nil {
return session.Record{}, err
}
record := session.Record{
DeviceSessionID: deviceSessionID,
UserID: userID,
ClientPublicKey: clientPublicKey,
Status: session.Status(statusValue),
}
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms")
if err != nil {
return session.Record{}, err
}
record.RevokedAtMS = &revokedAtMS
}
return record, nil
}
func extractDeviceSessionID(values map[string]any) (string, bool) {
value, ok := values["device_session_id"]
if !ok {
return "", false
}
deviceSessionID, err := coerceString(value)
if err != nil {
return "", false
}
if strings.TrimSpace(deviceSessionID) == "" {
return "", false
}
return deviceSessionID, true
}
func requiredStringField(values map[string]any, field string) (string, error) {
value, ok := values[field]
if !ok {
return "", fmt.Errorf("decode session event: missing %s", field)
}
stringValue, err := coerceString(value)
if err != nil {
return "", fmt.Errorf("decode session event: %s: %w", field, err)
}
if strings.TrimSpace(stringValue) == "" {
return "", fmt.Errorf("decode session event: %s must not be empty", field)
}
return stringValue, nil
}
func parseInt64Field(value any, field string) (int64, error) {
stringValue, err := coerceString(value)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
return parsed, nil
}
func coerceString(value any) (string, error) {
switch typed := value.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
case fmt.Stringer:
return typed.String(), nil
case int:
return strconv.Itoa(typed), nil
case int64:
return strconv.FormatInt(typed, 10), nil
case uint64:
return strconv.FormatUint(typed, 10), nil
default:
return "", fmt.Errorf("unsupported value type %T", value)
}
}