// Package events subscribes to internal session lifecycle streams used to keep // the gateway hot-path session cache synchronized without per-request upstream // lookups. package events import ( "context" "crypto/tls" "errors" "fmt" "strconv" "strings" "sync" "time" "galaxy/gateway/internal/config" "galaxy/gateway/internal/session" "galaxy/gateway/internal/telemetry" "github.com/redis/go-redis/v9" "go.opentelemetry.io/otel/attribute" "go.uber.org/zap" ) const sessionEventReadCount int64 = 128 // SessionRevocationHandler reacts to a successfully applied revoked session // snapshot and may tear down active resources bound to that session. type SessionRevocationHandler interface { // RevokeDeviceSession tears down active resources bound to deviceSessionID. RevokeDeviceSession(deviceSessionID string) } // RedisSessionSubscriber consumes full session snapshots from one Redis Stream // and applies them to a process-local session snapshot store. type RedisSessionSubscriber struct { client *redis.Client stream string pingTimeout time.Duration readBlockTimeout time.Duration store session.SnapshotStore revocationHandler SessionRevocationHandler logger *zap.Logger metrics *telemetry.Runtime closeOnce sync.Once startedOnce sync.Once started chan struct{} } // NewRedisSessionSubscriber constructs a Redis Stream subscriber that reuses // the SessionCache Redis connection settings and applies updates to store. func NewRedisSessionSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) { return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, nil, nil, nil) } // NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream // subscriber that reuses the SessionCache Redis connection settings, applies // updates to store, and optionally tears down active resources for revoked // sessions. func NewRedisSessionSubscriberWithRevocationHandler(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) { return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, revocationHandler, nil, nil) } // NewRedisSessionSubscriberWithObservability constructs a Redis Stream // subscriber that also logs and counts malformed internal session events. func NewRedisSessionSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) { if strings.TrimSpace(sessionCfg.Addr) == "" { return nil, errors.New("new redis session subscriber: redis addr must not be empty") } if sessionCfg.DB < 0 { return nil, errors.New("new redis session subscriber: redis db must not be negative") } if sessionCfg.LookupTimeout <= 0 { return nil, errors.New("new redis session subscriber: lookup timeout must be positive") } if strings.TrimSpace(eventsCfg.Stream) == "" { return nil, errors.New("new redis session subscriber: stream must not be empty") } if eventsCfg.ReadBlockTimeout <= 0 { return nil, errors.New("new redis session subscriber: read block timeout must be positive") } if store == nil { return nil, errors.New("new redis session subscriber: nil session snapshot store") } options := &redis.Options{ Addr: sessionCfg.Addr, Username: sessionCfg.Username, Password: sessionCfg.Password, DB: sessionCfg.DB, Protocol: 2, DisableIdentity: true, } if sessionCfg.TLSEnabled { options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} } if logger == nil { logger = zap.NewNop() } return &RedisSessionSubscriber{ client: redis.NewClient(options), stream: eventsCfg.Stream, pingTimeout: sessionCfg.LookupTimeout, readBlockTimeout: eventsCfg.ReadBlockTimeout, store: store, revocationHandler: revocationHandler, logger: logger.Named("session_subscriber"), metrics: metrics, started: make(chan struct{}), }, nil } // Ping verifies that the Redis backend used for session lifecycle events is // reachable within the configured timeout budget. func (s *RedisSessionSubscriber) Ping(ctx context.Context) error { if s == nil || s.client == nil { return errors.New("ping redis session subscriber: nil subscriber") } if ctx == nil { return errors.New("ping redis session subscriber: nil context") } pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout) defer cancel() if err := s.client.Ping(pingCtx).Err(); err != nil { return fmt.Errorf("ping redis session subscriber: %w", err) } return nil } // Run consumes session lifecycle events until ctx is canceled or Redis returns // an unexpected error. func (s *RedisSessionSubscriber) Run(ctx context.Context) error { if s == nil || s.client == nil { return errors.New("run redis session subscriber: nil subscriber") } if ctx == nil { return errors.New("run redis session subscriber: nil context") } if err := ctx.Err(); err != nil { return err } lastID, err := s.resolveStartID(ctx) if err != nil { return err } s.signalStarted() for { streams, err := s.client.XRead(ctx, &redis.XReadArgs{ Streams: []string{s.stream, lastID}, Count: sessionEventReadCount, Block: s.readBlockTimeout, }).Result() switch { case err == nil: for _, stream := range streams { for _, message := range stream.Messages { s.applyMessage(message) lastID = message.ID } } continue case errors.Is(err, redis.Nil): continue case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)): return ctx.Err() case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed): return fmt.Errorf("run redis session subscriber: %w", err) default: return fmt.Errorf("run redis session subscriber: %w", err) } } } func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) { messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result() switch { case err == nil: case errors.Is(err, redis.Nil): return "0-0", nil default: return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err) } if len(messages) == 0 { return "0-0", nil } return messages[0].ID, nil } // Shutdown closes the Redis client so a blocking stream read can terminate // promptly during gateway shutdown. func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error { if ctx == nil { return errors.New("shutdown redis session subscriber: nil context") } return s.Close() } // Close releases the underlying Redis client resources. func (s *RedisSessionSubscriber) Close() error { if s == nil || s.client == nil { return nil } var err error s.closeOnce.Do(func() { err = s.client.Close() }) return err } func (s *RedisSessionSubscriber) signalStarted() { s.startedOnce.Do(func() { close(s.started) }) } func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) { record, err := decodeSessionRecordSnapshot(message.Values) if err != nil { s.logger.Warn("dropped malformed session event", zap.String("stream", s.stream), zap.String("message_id", message.ID), zap.Error(err), ) s.metrics.RecordInternalEventDrop(context.Background(), attribute.String("component", "session_subscriber"), attribute.String("reason", "malformed_event"), ) if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok { s.store.Delete(deviceSessionID) } return } if err := s.store.Upsert(record); err != nil { s.logger.Warn("dropped session snapshot after store failure", zap.String("stream", s.stream), zap.String("message_id", message.ID), zap.String("device_session_id", record.DeviceSessionID), zap.Error(err), ) s.metrics.RecordInternalEventDrop(context.Background(), attribute.String("component", "session_subscriber"), attribute.String("reason", "store_failure"), ) s.store.Delete(record.DeviceSessionID) return } if record.Status == session.StatusRevoked && s.revocationHandler != nil { s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID) } } func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) { requiredKeys := map[string]struct{}{ "device_session_id": {}, "user_id": {}, "client_public_key": {}, "status": {}, } optionalKeys := map[string]struct{}{ "revoked_at_ms": {}, } for key := range values { if _, ok := requiredKeys[key]; ok { continue } if _, ok := optionalKeys[key]; ok { continue } return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key) } deviceSessionID, err := requiredStringField(values, "device_session_id") if err != nil { return session.Record{}, err } userID, err := requiredStringField(values, "user_id") if err != nil { return session.Record{}, err } clientPublicKey, err := requiredStringField(values, "client_public_key") if err != nil { return session.Record{}, err } statusValue, err := requiredStringField(values, "status") if err != nil { return session.Record{}, err } record := session.Record{ DeviceSessionID: deviceSessionID, UserID: userID, ClientPublicKey: clientPublicKey, Status: session.Status(statusValue), } if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok { revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms") if err != nil { return session.Record{}, err } record.RevokedAtMS = &revokedAtMS } return record, nil } func extractDeviceSessionID(values map[string]any) (string, bool) { value, ok := values["device_session_id"] if !ok { return "", false } deviceSessionID, err := coerceString(value) if err != nil { return "", false } if strings.TrimSpace(deviceSessionID) == "" { return "", false } return deviceSessionID, true } func requiredStringField(values map[string]any, field string) (string, error) { value, ok := values[field] if !ok { return "", fmt.Errorf("decode session event: missing %s", field) } stringValue, err := coerceString(value) if err != nil { return "", fmt.Errorf("decode session event: %s: %w", field, err) } if strings.TrimSpace(stringValue) == "" { return "", fmt.Errorf("decode session event: %s must not be empty", field) } return stringValue, nil } func parseInt64Field(value any, field string) (int64, error) { stringValue, err := coerceString(value) if err != nil { return 0, fmt.Errorf("decode session event: %s: %w", field, err) } parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64) if err != nil { return 0, fmt.Errorf("decode session event: %s: %w", field, err) } return parsed, nil } func coerceString(value any) (string, error) { switch typed := value.(type) { case string: return typed, nil case []byte: return string(typed), nil case fmt.Stringer: return typed.String(), nil case int: return strconv.Itoa(typed), nil case int64: return strconv.FormatInt(typed, 10), nil case uint64: return strconv.FormatUint(typed, 10), nil default: return "", fmt.Errorf("unsupported value type %T", value) } }