151 lines
4.4 KiB
Go
151 lines
4.4 KiB
Go
package session
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/config"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// RedisCache implements Cache with Redis GET lookups over strict JSON session
|
|
// records.
|
|
type RedisCache struct {
|
|
client *redis.Client
|
|
keyPrefix string
|
|
lookupTimeout time.Duration
|
|
}
|
|
|
|
type redisRecord struct {
|
|
DeviceSessionID string `json:"device_session_id"`
|
|
UserID string `json:"user_id"`
|
|
ClientPublicKey string `json:"client_public_key"`
|
|
Status Status `json:"status"`
|
|
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
|
}
|
|
|
|
// NewRedisCache constructs a Redis-backed SessionCache that uses client and
|
|
// applies the namespace and timeout settings from cfg. The cache does not own
|
|
// the client; the runtime supplies a shared *redis.Client.
|
|
func NewRedisCache(client *redis.Client, cfg config.SessionCacheRedisConfig) (*RedisCache, error) {
|
|
if client == nil {
|
|
return nil, errors.New("new redis session cache: nil redis client")
|
|
}
|
|
if strings.TrimSpace(cfg.KeyPrefix) == "" {
|
|
return nil, errors.New("new redis session cache: redis key prefix must not be empty")
|
|
}
|
|
if cfg.LookupTimeout <= 0 {
|
|
return nil, errors.New("new redis session cache: lookup timeout must be positive")
|
|
}
|
|
|
|
return &RedisCache{
|
|
client: client,
|
|
keyPrefix: cfg.KeyPrefix,
|
|
lookupTimeout: cfg.LookupTimeout,
|
|
}, nil
|
|
}
|
|
|
|
// Lookup resolves deviceSessionID from Redis, validates the cached JSON
|
|
// payload strictly, and returns the decoded session record.
|
|
func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
|
if c == nil || c.client == nil {
|
|
return Record{}, errors.New("lookup session from redis: nil cache")
|
|
}
|
|
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
|
return Record{}, errors.New("lookup session from redis: nil context")
|
|
}
|
|
if strings.TrimSpace(deviceSessionID) == "" {
|
|
return Record{}, errors.New("lookup session from redis: empty device session id")
|
|
}
|
|
|
|
lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
|
|
defer cancel()
|
|
|
|
payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes()
|
|
switch {
|
|
case errors.Is(err, redis.Nil):
|
|
return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound)
|
|
case err != nil:
|
|
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
|
}
|
|
|
|
record, err := decodeRedisRecord(deviceSessionID, payload)
|
|
if err != nil {
|
|
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
func (c *RedisCache) lookupKey(deviceSessionID string) string {
|
|
return c.keyPrefix + deviceSessionID
|
|
}
|
|
|
|
func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) {
|
|
decoder := json.NewDecoder(bytes.NewReader(payload))
|
|
decoder.DisallowUnknownFields()
|
|
|
|
var stored redisRecord
|
|
if err := decoder.Decode(&stored); err != nil {
|
|
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
|
}
|
|
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
|
if err == nil {
|
|
return Record{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
|
}
|
|
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
|
}
|
|
|
|
record := Record{
|
|
DeviceSessionID: stored.DeviceSessionID,
|
|
UserID: stored.UserID,
|
|
ClientPublicKey: stored.ClientPublicKey,
|
|
Status: stored.Status,
|
|
RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS),
|
|
}
|
|
|
|
if err := validateRecord(expectedDeviceSessionID, record); err != nil {
|
|
return Record{}, err
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
func validateRecord(expectedDeviceSessionID string, record Record) error {
|
|
if record.DeviceSessionID == "" {
|
|
return errors.New("session record device_session_id must not be empty")
|
|
}
|
|
if record.DeviceSessionID != expectedDeviceSessionID {
|
|
return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID)
|
|
}
|
|
if record.UserID == "" {
|
|
return errors.New("session record user_id must not be empty")
|
|
}
|
|
if record.ClientPublicKey == "" {
|
|
return errors.New("session record client_public_key must not be empty")
|
|
}
|
|
if !record.Status.IsKnown() {
|
|
return fmt.Errorf("session record status %q is unsupported", record.Status)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func cloneOptionalInt64(value *int64) *int64 {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
|
|
cloned := *value
|
|
return &cloned
|
|
}
|
|
|
|
var _ Cache = (*RedisCache)(nil)
|