Files
galaxy-game/gateway/internal/events/client_subscriber.go
T
2026-04-02 19:18:42 +02:00

342 lines
9.4 KiB
Go

package events
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const clientEventReadCount int64 = 128
// ClientEventPublisher accepts decoded client-facing events from the internal
// event subscriber.
type ClientEventPublisher interface {
// Publish fans out event to the currently active push streams.
Publish(event push.Event)
}
// RedisClientEventSubscriber consumes client-facing events from one Redis
// Stream and forwards them to the configured publisher.
type RedisClientEventSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
publisher ClientEventPublisher
logger *zap.Logger
metrics *telemetry.Runtime
closeOnce sync.Once
startedOnce sync.Once
started chan struct{}
}
// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that
// reuses the SessionCache Redis connection settings and forwards decoded
// client-facing events to publisher.
func NewRedisClientEventSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) {
return NewRedisClientEventSubscriberWithObservability(sessionCfg, eventsCfg, publisher, nil, nil)
}
// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream
// subscriber that also records malformed or dropped internal events.
func NewRedisClientEventSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) {
if strings.TrimSpace(sessionCfg.Addr) == "" {
return nil, errors.New("new redis client event subscriber: redis addr must not be empty")
}
if sessionCfg.DB < 0 {
return nil, errors.New("new redis client event subscriber: redis db must not be negative")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis client event subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: read block timeout must be positive")
}
if publisher == nil {
return nil, errors.New("new redis client event subscriber: nil publisher")
}
options := &redis.Options{
Addr: sessionCfg.Addr,
Username: sessionCfg.Username,
Password: sessionCfg.Password,
DB: sessionCfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if sessionCfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisClientEventSubscriber{
client: redis.NewClient(options),
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
publisher: publisher,
logger: logger.Named("client_event_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Ping verifies that the Redis backend used for client-facing event fan-out is
// reachable within the configured timeout budget.
func (s *RedisClientEventSubscriber) Ping(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("ping redis client event subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("ping redis client event subscriber: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
defer cancel()
if err := s.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis client event subscriber: %w", err)
}
return nil
}
// Run consumes client-facing events until ctx is canceled or Redis returns an
// unexpected error.
func (s *RedisClientEventSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis client event subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis client event subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: clientEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.publishMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis client event subscriber: %w", err)
default:
return fmt.Errorf("run redis client event subscriber: %w", err)
}
}
}
func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown closes the Redis client so a blocking stream read can terminate
// promptly during gateway shutdown.
func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis client event subscriber: nil context")
}
return s.Close()
}
// Close releases the underlying Redis client resources.
func (s *RedisClientEventSubscriber) Close() error {
if s == nil || s.client == nil {
return nil
}
var err error
s.closeOnce.Do(func() {
err = s.client.Close()
})
return err
}
func (s *RedisClientEventSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) {
event, err := decodeClientEvent(message.Values)
if err != nil {
s.logger.Warn("dropped malformed client event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "client_event_subscriber"),
attribute.String("reason", "malformed_event"),
)
return
}
s.publisher.Publish(event)
}
func decodeClientEvent(values map[string]any) (push.Event, error) {
requiredKeys := map[string]struct{}{
"user_id": {},
"event_type": {},
"event_id": {},
"payload_bytes": {},
}
optionalKeys := map[string]struct{}{
"device_session_id": {},
"request_id": {},
"trace_id": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key)
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return push.Event{}, err
}
eventType, err := requiredStringField(values, "event_type")
if err != nil {
return push.Event{}, err
}
eventID, err := requiredStringField(values, "event_id")
if err != nil {
return push.Event{}, err
}
payloadBytes, err := requiredBytesField(values, "payload_bytes")
if err != nil {
return push.Event{}, err
}
event := push.Event{
UserID: userID,
EventType: eventType,
EventID: eventID,
PayloadBytes: payloadBytes,
}
if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil {
return push.Event{}, err
} else if ok {
event.DeviceSessionID = strings.TrimSpace(deviceSessionID)
}
if requestID, ok, err := optionalStringField(values, "request_id"); err != nil {
return push.Event{}, err
} else if ok {
event.RequestID = requestID
}
if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil {
return push.Event{}, err
} else if ok {
event.TraceID = traceID
}
return event, nil
}
func requiredBytesField(values map[string]any, field string) ([]byte, error) {
value, ok := values[field]
if !ok {
return nil, fmt.Errorf("decode client event: missing %s", field)
}
byteValue, err := coerceBytes(value)
if err != nil {
return nil, fmt.Errorf("decode client event: %s: %w", field, err)
}
return byteValue, nil
}
func optionalStringField(values map[string]any, field string) (string, bool, error) {
value, ok := values[field]
if !ok {
return "", false, nil
}
stringValue, err := coerceString(value)
if err != nil {
return "", false, fmt.Errorf("decode client event: %s: %w", field, err)
}
return stringValue, true, nil
}
func coerceBytes(value any) ([]byte, error) {
switch typed := value.(type) {
case string:
return []byte(typed), nil
case []byte:
return bytes.Clone(typed), nil
default:
return nil, fmt.Errorf("unsupported type %T", value)
}
}