239 lines
6.4 KiB
Go
239 lines
6.4 KiB
Go
package session
|
|
|
|
import (
|
|
"container/list"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// DefaultMaxEntries is the LRU bound applied when MemoryCacheOptions
|
|
// does not supply a positive MaxEntries. Holds well below the per-process
|
|
// memory budget for the documented MVP scale (≤10K active accounts,
|
|
// ≤100K device sessions).
|
|
const DefaultMaxEntries = 50_000
|
|
|
|
// DefaultTTL is the safety-net freshness window applied when
|
|
// MemoryCacheOptions does not supply a positive TTL. Push events drive
|
|
// invalidation in the steady state; the TTL guards against missed
|
|
// events (cursor aged out, gateway restart) by forcing a fresh backend
|
|
// lookup at most once per window.
|
|
const DefaultTTL = 10 * time.Minute
|
|
|
|
// MemoryCache is the canonical Cache implementation. Hot-path Lookup
|
|
// reads serve from a process-local LRU + TTL map; misses delegate to
|
|
// BackendLookup and seed the cache. session_invalidation push events
|
|
// flip cached records to a revoked status without a backend
|
|
// roundtrip, after which Lookup returns the revoked record straight
|
|
// from memory and gateway rejects the request.
|
|
//
|
|
// MemoryCache is safe for concurrent use.
|
|
type MemoryCache struct {
|
|
mu sync.Mutex
|
|
entries map[string]*list.Element
|
|
byUser map[string]map[string]struct{}
|
|
order *list.List
|
|
max int
|
|
ttl time.Duration
|
|
backend BackendLookup
|
|
now func() time.Time
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// memoryEntry is the value stored inside the LRU list. The key
|
|
// duplication keeps Element.Value self-describing for eviction.
|
|
type memoryEntry struct {
|
|
key string
|
|
record Record
|
|
expiresAt time.Time
|
|
}
|
|
|
|
// MemoryCacheOptions tunes the cache.
|
|
type MemoryCacheOptions struct {
|
|
// MaxEntries bounds the number of cached records. Zero or
|
|
// negative values default to DefaultMaxEntries.
|
|
MaxEntries int
|
|
// TTL bounds how long a cached entry serves the hot path before
|
|
// a fresh backend lookup. Zero or negative values default to
|
|
// DefaultTTL.
|
|
TTL time.Duration
|
|
// Now overrides time.Now for tests.
|
|
Now func() time.Time
|
|
// Logger is named "session.cache". A nil value uses zap.NewNop.
|
|
Logger *zap.Logger
|
|
}
|
|
|
|
// NewMemoryCache constructs a MemoryCache. backend must not be nil.
|
|
func NewMemoryCache(backend BackendLookup, opts MemoryCacheOptions) (*MemoryCache, error) {
|
|
if backend == nil {
|
|
return nil, errors.New("session.NewMemoryCache: backend lookup must not be nil")
|
|
}
|
|
max := opts.MaxEntries
|
|
if max <= 0 {
|
|
max = DefaultMaxEntries
|
|
}
|
|
ttl := opts.TTL
|
|
if ttl <= 0 {
|
|
ttl = DefaultTTL
|
|
}
|
|
now := opts.Now
|
|
if now == nil {
|
|
now = time.Now
|
|
}
|
|
logger := opts.Logger
|
|
if logger == nil {
|
|
logger = zap.NewNop()
|
|
}
|
|
return &MemoryCache{
|
|
entries: make(map[string]*list.Element, max),
|
|
byUser: make(map[string]map[string]struct{}),
|
|
order: list.New(),
|
|
max: max,
|
|
ttl: ttl,
|
|
backend: backend,
|
|
now: now,
|
|
logger: logger.Named("session.cache"),
|
|
}, nil
|
|
}
|
|
|
|
// Lookup serves deviceSessionID from the cache. A miss (or an entry
|
|
// past its TTL) triggers a backend lookup and seeds the cache before
|
|
// returning. Concurrent Lookups for the same key are not coalesced —
|
|
// that level of optimisation is not needed at the documented MVP
|
|
// scale.
|
|
func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
|
if c == nil {
|
|
return Record{}, errors.New("session memory cache: nil cache")
|
|
}
|
|
if deviceSessionID == "" {
|
|
return Record{}, ErrNotFound
|
|
}
|
|
now := c.now()
|
|
c.mu.Lock()
|
|
if elem, ok := c.entries[deviceSessionID]; ok {
|
|
entry := elem.Value.(*memoryEntry)
|
|
if entry.expiresAt.After(now) {
|
|
c.order.MoveToFront(elem)
|
|
rec := entry.record
|
|
c.mu.Unlock()
|
|
return rec, nil
|
|
}
|
|
// Expired — evict and fall through to backend.
|
|
c.evictLocked(elem)
|
|
}
|
|
c.mu.Unlock()
|
|
|
|
rec, err := c.backend.LookupSession(ctx, deviceSessionID)
|
|
if err != nil {
|
|
return Record{}, fmt.Errorf("session memory cache: %w", err)
|
|
}
|
|
c.mu.Lock()
|
|
c.insertLocked(deviceSessionID, rec, now.Add(c.ttl))
|
|
c.mu.Unlock()
|
|
return rec, nil
|
|
}
|
|
|
|
// MarkRevoked flips the cached record for deviceSessionID to a
|
|
// revoked status. Calling on a missing entry is a no-op.
|
|
func (c *MemoryCache) MarkRevoked(deviceSessionID string) {
|
|
if c == nil || deviceSessionID == "" {
|
|
return
|
|
}
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
elem, ok := c.entries[deviceSessionID]
|
|
if !ok {
|
|
return
|
|
}
|
|
entry := elem.Value.(*memoryEntry)
|
|
entry.record.Status = StatusRevoked
|
|
}
|
|
|
|
// MarkAllRevokedForUser flips every cached record whose UserID is
|
|
// userID to revoked. The user index is updated in O(n) over the
|
|
// user's session set, not the whole cache.
|
|
func (c *MemoryCache) MarkAllRevokedForUser(userID string) {
|
|
if c == nil || userID == "" {
|
|
return
|
|
}
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
set, ok := c.byUser[userID]
|
|
if !ok {
|
|
return
|
|
}
|
|
for id := range set {
|
|
if elem, ok := c.entries[id]; ok {
|
|
elem.Value.(*memoryEntry).record.Status = StatusRevoked
|
|
}
|
|
}
|
|
}
|
|
|
|
// Len returns the current number of cached entries. Useful for
|
|
// metrics and tests.
|
|
func (c *MemoryCache) Len() int {
|
|
if c == nil {
|
|
return 0
|
|
}
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.order.Len()
|
|
}
|
|
|
|
// insertLocked stores rec under deviceSessionID. The caller holds c.mu.
|
|
func (c *MemoryCache) insertLocked(deviceSessionID string, rec Record, expiresAt time.Time) {
|
|
if existing, ok := c.entries[deviceSessionID]; ok {
|
|
existing.Value.(*memoryEntry).record = rec
|
|
existing.Value.(*memoryEntry).expiresAt = expiresAt
|
|
c.order.MoveToFront(existing)
|
|
c.indexUserLocked(deviceSessionID, rec.UserID)
|
|
return
|
|
}
|
|
elem := c.order.PushFront(&memoryEntry{
|
|
key: deviceSessionID,
|
|
record: rec,
|
|
expiresAt: expiresAt,
|
|
})
|
|
c.entries[deviceSessionID] = elem
|
|
c.indexUserLocked(deviceSessionID, rec.UserID)
|
|
if c.order.Len() > c.max {
|
|
oldest := c.order.Back()
|
|
if oldest != nil {
|
|
c.evictLocked(oldest)
|
|
}
|
|
}
|
|
}
|
|
|
|
// evictLocked removes elem from every internal index. The caller holds c.mu.
|
|
func (c *MemoryCache) evictLocked(elem *list.Element) {
|
|
entry := elem.Value.(*memoryEntry)
|
|
delete(c.entries, entry.key)
|
|
if set := c.byUser[entry.record.UserID]; set != nil {
|
|
delete(set, entry.key)
|
|
if len(set) == 0 {
|
|
delete(c.byUser, entry.record.UserID)
|
|
}
|
|
}
|
|
c.order.Remove(elem)
|
|
}
|
|
|
|
// indexUserLocked associates deviceSessionID with userID in byUser.
|
|
// The caller holds c.mu.
|
|
func (c *MemoryCache) indexUserLocked(deviceSessionID, userID string) {
|
|
if userID == "" {
|
|
return
|
|
}
|
|
set, ok := c.byUser[userID]
|
|
if !ok {
|
|
set = make(map[string]struct{})
|
|
c.byUser[userID] = set
|
|
}
|
|
set[deviceSessionID] = struct{}{}
|
|
}
|
|
|
|
var _ Cache = (*MemoryCache)(nil)
|