package session import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "strings" "time" "galaxy/gateway/internal/config" "github.com/redis/go-redis/v9" ) // RedisCache implements Cache with Redis GET lookups over strict JSON session // records. type RedisCache struct { client *redis.Client keyPrefix string lookupTimeout time.Duration } type redisRecord struct { DeviceSessionID string `json:"device_session_id"` UserID string `json:"user_id"` ClientPublicKey string `json:"client_public_key"` Status Status `json:"status"` RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"` } // NewRedisCache constructs a Redis-backed SessionCache that uses client and // applies the namespace and timeout settings from cfg. The cache does not own // the client; the runtime supplies a shared *redis.Client. func NewRedisCache(client *redis.Client, cfg config.SessionCacheRedisConfig) (*RedisCache, error) { if client == nil { return nil, errors.New("new redis session cache: nil redis client") } if strings.TrimSpace(cfg.KeyPrefix) == "" { return nil, errors.New("new redis session cache: redis key prefix must not be empty") } if cfg.LookupTimeout <= 0 { return nil, errors.New("new redis session cache: lookup timeout must be positive") } return &RedisCache{ client: client, keyPrefix: cfg.KeyPrefix, lookupTimeout: cfg.LookupTimeout, }, nil } // Lookup resolves deviceSessionID from Redis, validates the cached JSON // payload strictly, and returns the decoded session record. func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) { if c == nil || c.client == nil { return Record{}, errors.New("lookup session from redis: nil cache") } if ctx == nil || fmt.Sprint(ctx) == "context.TODO" { return Record{}, errors.New("lookup session from redis: nil context") } if strings.TrimSpace(deviceSessionID) == "" { return Record{}, errors.New("lookup session from redis: empty device session id") } lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout) defer cancel() payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes() switch { case errors.Is(err, redis.Nil): return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound) case err != nil: return Record{}, fmt.Errorf("lookup session from redis: %w", err) } record, err := decodeRedisRecord(deviceSessionID, payload) if err != nil { return Record{}, fmt.Errorf("lookup session from redis: %w", err) } return record, nil } func (c *RedisCache) lookupKey(deviceSessionID string) string { return c.keyPrefix + deviceSessionID } func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) { decoder := json.NewDecoder(bytes.NewReader(payload)) decoder.DisallowUnknownFields() var stored redisRecord if err := decoder.Decode(&stored); err != nil { return Record{}, fmt.Errorf("decode redis session record: %w", err) } if err := decoder.Decode(&struct{}{}); err != io.EOF { if err == nil { return Record{}, errors.New("decode redis session record: unexpected trailing JSON input") } return Record{}, fmt.Errorf("decode redis session record: %w", err) } record := Record{ DeviceSessionID: stored.DeviceSessionID, UserID: stored.UserID, ClientPublicKey: stored.ClientPublicKey, Status: stored.Status, RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS), } if err := validateRecord(expectedDeviceSessionID, record); err != nil { return Record{}, err } return record, nil } func validateRecord(expectedDeviceSessionID string, record Record) error { if record.DeviceSessionID == "" { return errors.New("session record device_session_id must not be empty") } if record.DeviceSessionID != expectedDeviceSessionID { return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID) } if record.UserID == "" { return errors.New("session record user_id must not be empty") } if record.ClientPublicKey == "" { return errors.New("session record client_public_key must not be empty") } if !record.Status.IsKnown() { return fmt.Errorf("session record status %q is unsupported", record.Status) } return nil } func cloneOptionalInt64(value *int64) *int64 { if value == nil { return nil } cloned := *value return &cloned } var _ Cache = (*RedisCache)(nil)