feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,88 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryCache stores session record snapshots in process-local memory. It is
|
||||
// intended for the authenticated gateway hot path and deliberately keeps no
|
||||
// TTL or size-based eviction policy.
|
||||
type MemoryCache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]Record
|
||||
}
|
||||
|
||||
// NewMemoryCache constructs an empty process-local session snapshot store.
|
||||
func NewMemoryCache() *MemoryCache {
|
||||
return &MemoryCache{
|
||||
records: make(map[string]Record),
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from the process-local snapshot map.
|
||||
func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: empty device session id")
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
record, ok := c.records[deviceSessionID]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return Record{}, fmt.Errorf("lookup session from in-memory cache: %w", ErrNotFound)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Upsert stores record in the process-local snapshot map after validating the
|
||||
// same session invariants expected from the Redis-backed cache.
|
||||
func (c *MemoryCache) Upsert(record Record) error {
|
||||
if c == nil {
|
||||
return errors.New("upsert session into in-memory cache: nil cache")
|
||||
}
|
||||
if err := validateRecord(record.DeviceSessionID, record); err != nil {
|
||||
return fmt.Errorf("upsert session into in-memory cache: %w", err)
|
||||
}
|
||||
|
||||
cloned := cloneRecord(record)
|
||||
|
||||
c.mu.Lock()
|
||||
c.records[record.DeviceSessionID] = cloned
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when one exists.
|
||||
func (c *MemoryCache) Delete(deviceSessionID string) {
|
||||
if c == nil || strings.TrimSpace(deviceSessionID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
delete(c.records, deviceSessionID)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func cloneRecord(record Record) Record {
|
||||
cloned := record
|
||||
if record.RevokedAtMS != nil {
|
||||
value := *record.RevokedAtMS
|
||||
cloned.RevokedAtMS = &value
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
var _ SnapshotStore = (*MemoryCache)(nil)
|
||||
@@ -0,0 +1,68 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ReadThroughCache resolves authenticated sessions from a process-local
|
||||
// SnapshotStore first and falls back to another Cache only on a local miss.
|
||||
type ReadThroughCache struct {
|
||||
local SnapshotStore
|
||||
fallback Cache
|
||||
}
|
||||
|
||||
// NewReadThroughCache constructs a hot-path cache that seeds local snapshots
|
||||
// from fallback on demand.
|
||||
func NewReadThroughCache(local SnapshotStore, fallback Cache) (*ReadThroughCache, error) {
|
||||
if local == nil {
|
||||
return nil, errors.New("new read-through session cache: nil local cache")
|
||||
}
|
||||
if fallback == nil {
|
||||
return nil, errors.New("new read-through session cache: nil fallback cache")
|
||||
}
|
||||
|
||||
return &ReadThroughCache{
|
||||
local: local,
|
||||
fallback: fallback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from local first, then performs one fallback
|
||||
// lookup on a local miss and seeds the local cache with the returned snapshot.
|
||||
func (c *ReadThroughCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from read-through cache: nil cache")
|
||||
}
|
||||
|
||||
record, err := c.local.Lookup(ctx, deviceSessionID)
|
||||
switch {
|
||||
case err == nil:
|
||||
return record, nil
|
||||
case !errors.Is(err, ErrNotFound):
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: %w", err)
|
||||
}
|
||||
|
||||
record, err = c.fallback.Lookup(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
if err := c.local.Upsert(record); err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: seed local cache: %w", err)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Local returns the mutable process-local snapshot store used by c.
|
||||
func (c *ReadThroughCache) Local() SnapshotStore {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.local
|
||||
}
|
||||
|
||||
var _ Cache = (*ReadThroughCache)(nil)
|
||||
@@ -0,0 +1,176 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewMemoryCache()
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
require.NoError(t, cache.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}))
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
require.NoError(t, local.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}))
|
||||
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{}, errors.New("fallback should not be called")
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, record)
|
||||
assert.Equal(t, 0, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, revokedAtMS, *record.RevokedAtMS)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := local.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
type recordingCache struct {
|
||||
lookupCalls int
|
||||
lookupFunc func(context.Context, string) (Record, error)
|
||||
}
|
||||
|
||||
func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
c.lookupCalls++
|
||||
if c.lookupFunc != nil {
|
||||
return c.lookupFunc(ctx, deviceSessionID)
|
||||
}
|
||||
|
||||
return Record{}, errors.New("lookup is not implemented")
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisCache implements Cache with Redis GET lookups over strict JSON session
|
||||
// records.
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
lookupTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status Status `json:"status"`
|
||||
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
||||
}
|
||||
|
||||
// NewRedisCache constructs a Redis-backed SessionCache from cfg. The returned
|
||||
// cache is read-only from the gateway perspective and does not write or mutate
|
||||
// Redis state.
|
||||
func NewRedisCache(cfg config.SessionCacheRedisConfig) (*RedisCache, error) {
|
||||
if strings.TrimSpace(cfg.Addr) == "" {
|
||||
return nil, errors.New("new redis session cache: redis addr must not be empty")
|
||||
}
|
||||
if cfg.DB < 0 {
|
||||
return nil, errors.New("new redis session cache: redis db must not be negative")
|
||||
}
|
||||
if cfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis session cache: lookup timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &RedisCache{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
lookupTimeout: cfg.LookupTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (c *RedisCache) Close() error {
|
||||
if c == nil || c.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// cache lookup timeout budget.
|
||||
func (c *RedisCache) Ping(ctx context.Context) error {
|
||||
if c == nil || c.client == nil {
|
||||
return errors.New("ping redis session cache: nil cache")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("ping redis session cache: nil context")
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := c.client.Ping(pingCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis session cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from Redis, validates the cached JSON
|
||||
// payload strictly, and returns the decoded session record.
|
||||
func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil || c.client == nil {
|
||||
return Record{}, errors.New("lookup session from redis: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from redis: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from redis: empty device session id")
|
||||
}
|
||||
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound)
|
||||
case err != nil:
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
record, err := decodeRedisRecord(deviceSessionID, payload)
|
||||
if err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *RedisCache) lookupKey(deviceSessionID string) string {
|
||||
return c.keyPrefix + deviceSessionID
|
||||
}
|
||||
|
||||
func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return Record{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
||||
}
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
record := Record{
|
||||
DeviceSessionID: stored.DeviceSessionID,
|
||||
UserID: stored.UserID,
|
||||
ClientPublicKey: stored.ClientPublicKey,
|
||||
Status: stored.Status,
|
||||
RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS),
|
||||
}
|
||||
|
||||
if err := validateRecord(expectedDeviceSessionID, record); err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func validateRecord(expectedDeviceSessionID string, record Record) error {
|
||||
if record.DeviceSessionID == "" {
|
||||
return errors.New("session record device_session_id must not be empty")
|
||||
}
|
||||
if record.DeviceSessionID != expectedDeviceSessionID {
|
||||
return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID)
|
||||
}
|
||||
if record.UserID == "" {
|
||||
return errors.New("session record user_id must not be empty")
|
||||
}
|
||||
if record.ClientPublicKey == "" {
|
||||
return errors.New("session record client_public_key must not be empty")
|
||||
}
|
||||
if !record.Status.IsKnown() {
|
||||
return fmt.Errorf("session record status %q is unsupported", record.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneOptionalInt64(value *int64) *int64 {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
var _ Cache = (*RedisCache)(nil)
|
||||
@@ -0,0 +1,331 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRedisCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.SessionCacheRedisConfig
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
KeyPrefix: "gateway:session:",
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "non-positive lookup timeout",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
},
|
||||
wantErr: "lookup timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache, err := NewRedisCache(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cache.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisCachePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
|
||||
|
||||
require.NoError(t, cache.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestRedisCacheLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.SessionCacheRedisConfig
|
||||
requestID string
|
||||
seed func(*testing.T, *miniredis.Miniredis, config.SessionCacheRedisConfig)
|
||||
want Record
|
||||
wantErrIs error
|
||||
wantErrText string
|
||||
assertErrText string
|
||||
}{
|
||||
{
|
||||
name: "active cache hit",
|
||||
requestID: "device-session-123",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-123", redisRecord{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing session",
|
||||
requestID: "device-session-404",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
wantErrIs: ErrNotFound,
|
||||
assertErrText: "session cache record not found",
|
||||
},
|
||||
{
|
||||
name: "revoked session",
|
||||
requestID: "device-session-revoked",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-revoked", redisRecord{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
requestID: "device-session-bad-json",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
server.Set(cfg.KeyPrefix+"device-session-bad-json", "{")
|
||||
},
|
||||
wantErrText: "decode redis session record",
|
||||
},
|
||||
{
|
||||
name: "unknown status",
|
||||
requestID: "device-session-unknown-status",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-unknown-status", redisRecord{
|
||||
DeviceSessionID: "device-session-unknown-status",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: Status("paused"),
|
||||
})
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "missing required field",
|
||||
requestID: "device-session-missing-user",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-missing-user", redisRecord{
|
||||
DeviceSessionID: "device-session-missing-user",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: "user_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "device session id mismatch",
|
||||
requestID: "device-session-requested",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-requested", redisRecord{
|
||||
DeviceSessionID: "device-session-other",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: `does not match requested "device-session-requested"`,
|
||||
},
|
||||
{
|
||||
name: "key prefix is honored",
|
||||
requestID: "device-session-prefixed",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "custom:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
})
|
||||
setRedisSessionRecord(t, server, "gateway:session:device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "wrong-user",
|
||||
ClientPublicKey: "wrong-key",
|
||||
Status: StatusRevoked,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
cfg := tt.cfg
|
||||
cfg.Addr = server.Addr()
|
||||
cfg.DB = 0
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
|
||||
if tt.seed != nil {
|
||||
tt.seed(t, server, cfg)
|
||||
}
|
||||
|
||||
cache := newTestRedisCache(t, server, cfg)
|
||||
record, err := cache.Lookup(context.Background(), tt.requestID)
|
||||
if tt.wantErrIs != nil || tt.wantErrText != "" {
|
||||
require.Error(t, err)
|
||||
if tt.wantErrIs != nil {
|
||||
assert.ErrorIs(t, err, tt.wantErrIs)
|
||||
}
|
||||
if tt.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
}
|
||||
if tt.assertErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.assertErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, record)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedisCache(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) *RedisCache {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.LookupTimeout == 0 {
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
cache, err := NewRedisCache(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cache.Close())
|
||||
})
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func setRedisSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record redisRecord) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(record)
|
||||
require.NoError(t, err)
|
||||
|
||||
server.Set(key, string(payload))
|
||||
}
|
||||
|
||||
func TestRedisCacheLookupNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
|
||||
|
||||
_, err := cache.Lookup(context.TODO(), "device-session-123")
|
||||
require.Error(t, err)
|
||||
assert.False(t, errors.Is(err, ErrNotFound))
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
// Package session defines the authenticated session-cache contract used by the
|
||||
// gateway hot path.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotFound reports that SessionCache does not currently contain the
|
||||
// requested device session identifier.
|
||||
ErrNotFound = errors.New("session cache record not found")
|
||||
)
|
||||
|
||||
// Cache resolves authenticated device-session state from the gateway hot-path
|
||||
// cache.
|
||||
type Cache interface {
|
||||
// Lookup returns the cached record for deviceSessionID. Implementations must
|
||||
// wrap ErrNotFound when the cache does not contain the requested record.
|
||||
Lookup(ctx context.Context, deviceSessionID string) (Record, error)
|
||||
}
|
||||
|
||||
// SnapshotStore stores mutable session record snapshots inside one gateway
|
||||
// process and exposes the same read contract as Cache for the hot path.
|
||||
type SnapshotStore interface {
|
||||
Cache
|
||||
|
||||
// Upsert stores record under record.DeviceSessionID, replacing any previous
|
||||
// snapshot for that session.
|
||||
Upsert(record Record) error
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when it exists.
|
||||
Delete(deviceSessionID string)
|
||||
}
|
||||
|
||||
// Status identifies the cached lifecycle state of a device session.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusActive reports that the cached device session may continue through
|
||||
// later authenticated gateway checks.
|
||||
StatusActive Status = "active"
|
||||
|
||||
// StatusRevoked reports that the cached device session has been revoked and
|
||||
// must be rejected before later auth steps run.
|
||||
StatusRevoked Status = "revoked"
|
||||
)
|
||||
|
||||
// Record is the minimum authenticated session state required by the gateway
|
||||
// before signature verification begins.
|
||||
type Record struct {
|
||||
// DeviceSessionID is the stable device-session identifier resolved from the
|
||||
// hot-path cache.
|
||||
DeviceSessionID string
|
||||
|
||||
// UserID is the authenticated user identity bound to DeviceSessionID.
|
||||
UserID string
|
||||
|
||||
// ClientPublicKey is the standard base64-encoded raw Ed25519 public key
|
||||
// material used for request-signature verification.
|
||||
ClientPublicKey string
|
||||
|
||||
// Status reports whether the cached session is active or revoked.
|
||||
Status Status
|
||||
|
||||
// RevokedAtMS optionally records when the device session was revoked.
|
||||
RevokedAtMS *int64
|
||||
}
|
||||
|
||||
// IsKnown reports whether s is one of the session states supported by the
|
||||
// gateway.
|
||||
func (s Status) IsKnown() bool {
|
||||
switch s {
|
||||
case StatusActive, StatusRevoked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user