feat: backend service
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// BackendLookup describes the slice of `backendclient.RESTClient`
|
||||
// SessionCache depends on. The narrow interface keeps this package free
|
||||
// of any backendclient import.
|
||||
type BackendLookup interface {
|
||||
LookupSession(ctx context.Context, deviceSessionID string) (Record, error)
|
||||
}
|
||||
|
||||
// BackendCache resolves authenticated device sessions by issuing one
|
||||
// synchronous REST call to backend per request. The canonical implementation replaces the
|
||||
// previous Redis-backed projection with this thin wrapper; gateway no
|
||||
// longer keeps a process-local snapshot. See ARCHITECTURE.md §11
|
||||
// «backend (sync REST), no Redis projection».
|
||||
type BackendCache struct {
|
||||
backend BackendLookup
|
||||
}
|
||||
|
||||
// NewBackendCache constructs a Cache that delegates every Lookup to
|
||||
// backend over REST. backend must not be nil.
|
||||
func NewBackendCache(backend BackendLookup) (*BackendCache, error) {
|
||||
if backend == nil {
|
||||
return nil, errors.New("session.NewBackendCache: backend lookup must not be nil")
|
||||
}
|
||||
return &BackendCache{backend: backend}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID via backend. ErrNotFound is forwarded
|
||||
// unchanged so callers can keep using the existing equality check.
|
||||
func (c *BackendCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("session backend cache: nil cache")
|
||||
}
|
||||
if c.backend == nil {
|
||||
return Record{}, errors.New("session backend cache: nil backend lookup")
|
||||
}
|
||||
rec, err := c.backend.LookupSession(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Record{}, fmt.Errorf("session backend cache: %w", err)
|
||||
}
|
||||
return rec, nil
|
||||
}
|
||||
|
||||
var _ Cache = (*BackendCache)(nil)
|
||||
@@ -1,88 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryCache stores session record snapshots in process-local memory. It is
|
||||
// intended for the authenticated gateway hot path and deliberately keeps no
|
||||
// TTL or size-based eviction policy.
|
||||
type MemoryCache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]Record
|
||||
}
|
||||
|
||||
// NewMemoryCache constructs an empty process-local session snapshot store.
|
||||
func NewMemoryCache() *MemoryCache {
|
||||
return &MemoryCache{
|
||||
records: make(map[string]Record),
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from the process-local snapshot map.
|
||||
func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: empty device session id")
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
record, ok := c.records[deviceSessionID]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return Record{}, fmt.Errorf("lookup session from in-memory cache: %w", ErrNotFound)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Upsert stores record in the process-local snapshot map after validating the
|
||||
// same session invariants expected from the Redis-backed cache.
|
||||
func (c *MemoryCache) Upsert(record Record) error {
|
||||
if c == nil {
|
||||
return errors.New("upsert session into in-memory cache: nil cache")
|
||||
}
|
||||
if err := validateRecord(record.DeviceSessionID, record); err != nil {
|
||||
return fmt.Errorf("upsert session into in-memory cache: %w", err)
|
||||
}
|
||||
|
||||
cloned := cloneRecord(record)
|
||||
|
||||
c.mu.Lock()
|
||||
c.records[record.DeviceSessionID] = cloned
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when one exists.
|
||||
func (c *MemoryCache) Delete(deviceSessionID string) {
|
||||
if c == nil || strings.TrimSpace(deviceSessionID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
delete(c.records, deviceSessionID)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func cloneRecord(record Record) Record {
|
||||
cloned := record
|
||||
if record.RevokedAtMS != nil {
|
||||
value := *record.RevokedAtMS
|
||||
cloned.RevokedAtMS = &value
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
var _ SnapshotStore = (*MemoryCache)(nil)
|
||||
@@ -1,68 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ReadThroughCache resolves authenticated sessions from a process-local
|
||||
// SnapshotStore first and falls back to another Cache only on a local miss.
|
||||
type ReadThroughCache struct {
|
||||
local SnapshotStore
|
||||
fallback Cache
|
||||
}
|
||||
|
||||
// NewReadThroughCache constructs a hot-path cache that seeds local snapshots
|
||||
// from fallback on demand.
|
||||
func NewReadThroughCache(local SnapshotStore, fallback Cache) (*ReadThroughCache, error) {
|
||||
if local == nil {
|
||||
return nil, errors.New("new read-through session cache: nil local cache")
|
||||
}
|
||||
if fallback == nil {
|
||||
return nil, errors.New("new read-through session cache: nil fallback cache")
|
||||
}
|
||||
|
||||
return &ReadThroughCache{
|
||||
local: local,
|
||||
fallback: fallback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from local first, then performs one fallback
|
||||
// lookup on a local miss and seeds the local cache with the returned snapshot.
|
||||
func (c *ReadThroughCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from read-through cache: nil cache")
|
||||
}
|
||||
|
||||
record, err := c.local.Lookup(ctx, deviceSessionID)
|
||||
switch {
|
||||
case err == nil:
|
||||
return record, nil
|
||||
case !errors.Is(err, ErrNotFound):
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: %w", err)
|
||||
}
|
||||
|
||||
record, err = c.fallback.Lookup(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
if err := c.local.Upsert(record); err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: seed local cache: %w", err)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Local returns the mutable process-local snapshot store used by c.
|
||||
func (c *ReadThroughCache) Local() SnapshotStore {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.local
|
||||
}
|
||||
|
||||
var _ Cache = (*ReadThroughCache)(nil)
|
||||
@@ -1,176 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewMemoryCache()
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
require.NoError(t, cache.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}))
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
require.NoError(t, local.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}))
|
||||
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{}, errors.New("fallback should not be called")
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, record)
|
||||
assert.Equal(t, 0, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, revokedAtMS, *record.RevokedAtMS)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := local.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
type recordingCache struct {
|
||||
lookupCalls int
|
||||
lookupFunc func(context.Context, string) (Record, error)
|
||||
}
|
||||
|
||||
func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
c.lookupCalls++
|
||||
if c.lookupFunc != nil {
|
||||
return c.lookupFunc(ctx, deviceSessionID)
|
||||
}
|
||||
|
||||
return Record{}, errors.New("lookup is not implemented")
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisCache implements Cache with Redis GET lookups over strict JSON session
|
||||
// records.
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
lookupTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status Status `json:"status"`
|
||||
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
||||
}
|
||||
|
||||
// NewRedisCache constructs a Redis-backed SessionCache that uses client and
|
||||
// applies the namespace and timeout settings from cfg. The cache does not own
|
||||
// the client; the runtime supplies a shared *redis.Client.
|
||||
func NewRedisCache(client *redis.Client, cfg config.SessionCacheRedisConfig) (*RedisCache, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("new redis session cache: nil redis client")
|
||||
}
|
||||
if strings.TrimSpace(cfg.KeyPrefix) == "" {
|
||||
return nil, errors.New("new redis session cache: redis key prefix must not be empty")
|
||||
}
|
||||
if cfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis session cache: lookup timeout must be positive")
|
||||
}
|
||||
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
lookupTimeout: cfg.LookupTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from Redis, validates the cached JSON
|
||||
// payload strictly, and returns the decoded session record.
|
||||
func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil || c.client == nil {
|
||||
return Record{}, errors.New("lookup session from redis: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from redis: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from redis: empty device session id")
|
||||
}
|
||||
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound)
|
||||
case err != nil:
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
record, err := decodeRedisRecord(deviceSessionID, payload)
|
||||
if err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *RedisCache) lookupKey(deviceSessionID string) string {
|
||||
return c.keyPrefix + deviceSessionID
|
||||
}
|
||||
|
||||
func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return Record{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
||||
}
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
record := Record{
|
||||
DeviceSessionID: stored.DeviceSessionID,
|
||||
UserID: stored.UserID,
|
||||
ClientPublicKey: stored.ClientPublicKey,
|
||||
Status: stored.Status,
|
||||
RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS),
|
||||
}
|
||||
|
||||
if err := validateRecord(expectedDeviceSessionID, record); err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func validateRecord(expectedDeviceSessionID string, record Record) error {
|
||||
if record.DeviceSessionID == "" {
|
||||
return errors.New("session record device_session_id must not be empty")
|
||||
}
|
||||
if record.DeviceSessionID != expectedDeviceSessionID {
|
||||
return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID)
|
||||
}
|
||||
if record.UserID == "" {
|
||||
return errors.New("session record user_id must not be empty")
|
||||
}
|
||||
if record.ClientPublicKey == "" {
|
||||
return errors.New("session record client_public_key must not be empty")
|
||||
}
|
||||
if !record.Status.IsKnown() {
|
||||
return fmt.Errorf("session record status %q is unsupported", record.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneOptionalInt64(value *int64) *int64 {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
var _ Cache = (*RedisCache)(nil)
|
||||
@@ -1,317 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newRedisClient(t *testing.T, server *miniredis.Miniredis) *redis.Client {
|
||||
t.Helper()
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: server.Addr(),
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func TestNewRedisCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
client := newRedisClient(t, server)
|
||||
|
||||
validCfg := config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
client *redis.Client
|
||||
cfg config.SessionCacheRedisConfig
|
||||
wantErr string
|
||||
}{
|
||||
{name: "valid config", client: client, cfg: validCfg},
|
||||
{name: "nil client", client: nil, cfg: validCfg, wantErr: "nil redis client"},
|
||||
{
|
||||
name: "empty key prefix",
|
||||
client: client,
|
||||
cfg: config.SessionCacheRedisConfig{LookupTimeout: 250 * time.Millisecond},
|
||||
wantErr: "redis key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive lookup timeout",
|
||||
client: client,
|
||||
cfg: config.SessionCacheRedisConfig{KeyPrefix: "gateway:session:"},
|
||||
wantErr: "lookup timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache, err := NewRedisCache(tt.client, tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisCacheLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.SessionCacheRedisConfig
|
||||
requestID string
|
||||
seed func(*testing.T, *miniredis.Miniredis, config.SessionCacheRedisConfig)
|
||||
want Record
|
||||
wantErrIs error
|
||||
wantErrText string
|
||||
assertErrText string
|
||||
}{
|
||||
{
|
||||
name: "active cache hit",
|
||||
requestID: "device-session-123",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-123", redisRecord{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing session",
|
||||
requestID: "device-session-404",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
wantErrIs: ErrNotFound,
|
||||
assertErrText: "session cache record not found",
|
||||
},
|
||||
{
|
||||
name: "revoked session",
|
||||
requestID: "device-session-revoked",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-revoked", redisRecord{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
requestID: "device-session-bad-json",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
server.Set(cfg.KeyPrefix+"device-session-bad-json", "{")
|
||||
},
|
||||
wantErrText: "decode redis session record",
|
||||
},
|
||||
{
|
||||
name: "unknown status",
|
||||
requestID: "device-session-unknown-status",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-unknown-status", redisRecord{
|
||||
DeviceSessionID: "device-session-unknown-status",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: Status("paused"),
|
||||
})
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "missing required field",
|
||||
requestID: "device-session-missing-user",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-missing-user", redisRecord{
|
||||
DeviceSessionID: "device-session-missing-user",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: "user_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "device session id mismatch",
|
||||
requestID: "device-session-requested",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-requested", redisRecord{
|
||||
DeviceSessionID: "device-session-other",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: `does not match requested "device-session-requested"`,
|
||||
},
|
||||
{
|
||||
name: "key prefix is honored",
|
||||
requestID: "device-session-prefixed",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "custom:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
})
|
||||
setRedisSessionRecord(t, server, "gateway:session:device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "wrong-user",
|
||||
ClientPublicKey: "wrong-key",
|
||||
Status: StatusRevoked,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
cfg := tt.cfg
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
|
||||
if tt.seed != nil {
|
||||
tt.seed(t, server, cfg)
|
||||
}
|
||||
|
||||
cache := newTestRedisCache(t, server, cfg)
|
||||
record, err := cache.Lookup(context.Background(), tt.requestID)
|
||||
if tt.wantErrIs != nil || tt.wantErrText != "" {
|
||||
require.Error(t, err)
|
||||
if tt.wantErrIs != nil {
|
||||
assert.ErrorIs(t, err, tt.wantErrIs)
|
||||
}
|
||||
if tt.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
}
|
||||
if tt.assertErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.assertErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, record)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedisCache(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) *RedisCache {
|
||||
t.Helper()
|
||||
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "gateway:session:"
|
||||
}
|
||||
if cfg.LookupTimeout == 0 {
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
cache, err := NewRedisCache(newRedisClient(t, server), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func setRedisSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record redisRecord) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(record)
|
||||
require.NoError(t, err)
|
||||
|
||||
server.Set(key, string(payload))
|
||||
}
|
||||
|
||||
func TestRedisCacheLookupNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
|
||||
|
||||
_, err := cache.Lookup(context.TODO(), "device-session-123")
|
||||
require.Error(t, err)
|
||||
assert.False(t, errors.Is(err, ErrNotFound))
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
@@ -13,27 +13,16 @@ var (
|
||||
ErrNotFound = errors.New("session cache record not found")
|
||||
)
|
||||
|
||||
// Cache resolves authenticated device-session state from the gateway hot-path
|
||||
// cache.
|
||||
// Cache resolves authenticated device-session state from the gateway
|
||||
// hot path. The implementation dropped the previous Redis projection: the only
|
||||
// implementation is *BackendCache, which calls backend's
|
||||
// `/api/v1/internal/sessions/{id}` synchronously per request.
|
||||
type Cache interface {
|
||||
// Lookup returns the cached record for deviceSessionID. Implementations must
|
||||
// wrap ErrNotFound when the cache does not contain the requested record.
|
||||
Lookup(ctx context.Context, deviceSessionID string) (Record, error)
|
||||
}
|
||||
|
||||
// SnapshotStore stores mutable session record snapshots inside one gateway
|
||||
// process and exposes the same read contract as Cache for the hot path.
|
||||
type SnapshotStore interface {
|
||||
Cache
|
||||
|
||||
// Upsert stores record under record.DeviceSessionID, replacing any previous
|
||||
// snapshot for that session.
|
||||
Upsert(record Record) error
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when it exists.
|
||||
Delete(deviceSessionID string)
|
||||
}
|
||||
|
||||
// Status identifies the cached lifecycle state of a device session.
|
||||
type Status string
|
||||
|
||||
|
||||
Reference in New Issue
Block a user