package session import ( "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "strings" "time" "galaxy/gateway/internal/config" "github.com/redis/go-redis/v9" ) // RedisCache implements Cache with Redis GET lookups over strict JSON session // records. type RedisCache struct { client *redis.Client keyPrefix string lookupTimeout time.Duration } type redisRecord struct { DeviceSessionID string `json:"device_session_id"` UserID string `json:"user_id"` ClientPublicKey string `json:"client_public_key"` Status Status `json:"status"` RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"` } // NewRedisCache constructs a Redis-backed SessionCache from cfg. The returned // cache is read-only from the gateway perspective and does not write or mutate // Redis state. func NewRedisCache(cfg config.SessionCacheRedisConfig) (*RedisCache, error) { if strings.TrimSpace(cfg.Addr) == "" { return nil, errors.New("new redis session cache: redis addr must not be empty") } if cfg.DB < 0 { return nil, errors.New("new redis session cache: redis db must not be negative") } if cfg.LookupTimeout <= 0 { return nil, errors.New("new redis session cache: lookup timeout must be positive") } options := &redis.Options{ Addr: cfg.Addr, Username: cfg.Username, Password: cfg.Password, DB: cfg.DB, Protocol: 2, DisableIdentity: true, } if cfg.TLSEnabled { options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} } return &RedisCache{ client: redis.NewClient(options), keyPrefix: cfg.KeyPrefix, lookupTimeout: cfg.LookupTimeout, }, nil } // Close releases the underlying Redis client resources. func (c *RedisCache) Close() error { if c == nil || c.client == nil { return nil } return c.client.Close() } // Ping verifies that the configured Redis backend is reachable within the // cache lookup timeout budget. func (c *RedisCache) Ping(ctx context.Context) error { if c == nil || c.client == nil { return errors.New("ping redis session cache: nil cache") } if ctx == nil { return errors.New("ping redis session cache: nil context") } pingCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout) defer cancel() if err := c.client.Ping(pingCtx).Err(); err != nil { return fmt.Errorf("ping redis session cache: %w", err) } return nil } // Lookup resolves deviceSessionID from Redis, validates the cached JSON // payload strictly, and returns the decoded session record. func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) { if c == nil || c.client == nil { return Record{}, errors.New("lookup session from redis: nil cache") } if ctx == nil || fmt.Sprint(ctx) == "context.TODO" { return Record{}, errors.New("lookup session from redis: nil context") } if strings.TrimSpace(deviceSessionID) == "" { return Record{}, errors.New("lookup session from redis: empty device session id") } lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout) defer cancel() payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes() switch { case errors.Is(err, redis.Nil): return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound) case err != nil: return Record{}, fmt.Errorf("lookup session from redis: %w", err) } record, err := decodeRedisRecord(deviceSessionID, payload) if err != nil { return Record{}, fmt.Errorf("lookup session from redis: %w", err) } return record, nil } func (c *RedisCache) lookupKey(deviceSessionID string) string { return c.keyPrefix + deviceSessionID } func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) { decoder := json.NewDecoder(bytes.NewReader(payload)) decoder.DisallowUnknownFields() var stored redisRecord if err := decoder.Decode(&stored); err != nil { return Record{}, fmt.Errorf("decode redis session record: %w", err) } if err := decoder.Decode(&struct{}{}); err != io.EOF { if err == nil { return Record{}, errors.New("decode redis session record: unexpected trailing JSON input") } return Record{}, fmt.Errorf("decode redis session record: %w", err) } record := Record{ DeviceSessionID: stored.DeviceSessionID, UserID: stored.UserID, ClientPublicKey: stored.ClientPublicKey, Status: stored.Status, RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS), } if err := validateRecord(expectedDeviceSessionID, record); err != nil { return Record{}, err } return record, nil } func validateRecord(expectedDeviceSessionID string, record Record) error { if record.DeviceSessionID == "" { return errors.New("session record device_session_id must not be empty") } if record.DeviceSessionID != expectedDeviceSessionID { return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID) } if record.UserID == "" { return errors.New("session record user_id must not be empty") } if record.ClientPublicKey == "" { return errors.New("session record client_public_key must not be empty") } if !record.Status.IsKnown() { return fmt.Errorf("session record status %q is unsupported", record.Status) } return nil } func cloneOptionalInt64(value *int64) *int64 { if value == nil { return nil } cloned := *value return &cloned } var _ Cache = (*RedisCache)(nil)