package session import ( "context" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) { t.Parallel() cache := NewMemoryCache() revokedAtMS := int64(123456789) require.NoError(t, cache.Upsert(Record{ DeviceSessionID: "device-session-123", UserID: "user-123", ClientPublicKey: "public-key-123", Status: StatusRevoked, RevokedAtMS: &revokedAtMS, })) record, err := cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) require.NotNil(t, record.RevokedAtMS) *record.RevokedAtMS = 1 stored, err := cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) require.NotNil(t, stored.RevokedAtMS) assert.Equal(t, revokedAtMS, *stored.RevokedAtMS) } func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) { t.Parallel() local := NewMemoryCache() require.NoError(t, local.Upsert(Record{ DeviceSessionID: "device-session-123", UserID: "user-123", ClientPublicKey: "public-key-123", Status: StatusActive, })) fallback := &recordingCache{ lookupFunc: func(context.Context, string) (Record, error) { return Record{}, errors.New("fallback should not be called") }, } cache, err := NewReadThroughCache(local, fallback) require.NoError(t, err) record, err := cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) assert.Equal(t, Record{ DeviceSessionID: "device-session-123", UserID: "user-123", ClientPublicKey: "public-key-123", Status: StatusActive, }, record) assert.Equal(t, 0, fallback.lookupCalls) } func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) { t.Parallel() local := NewMemoryCache() fallback := &recordingCache{ lookupFunc: func(context.Context, string) (Record, error) { return Record{ DeviceSessionID: "device-session-123", UserID: "user-123", ClientPublicKey: "public-key-123", Status: StatusActive, }, nil }, } cache, err := NewReadThroughCache(local, fallback) require.NoError(t, err) record, err := cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) assert.Equal(t, 1, fallback.lookupCalls) assert.Equal(t, "user-123", record.UserID) record, err = cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) assert.Equal(t, 1, fallback.lookupCalls) assert.Equal(t, "user-123", record.UserID) } func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) { t.Parallel() revokedAtMS := int64(123456789) local := NewMemoryCache() fallback := &recordingCache{ lookupFunc: func(context.Context, string) (Record, error) { return Record{ DeviceSessionID: "device-session-123", UserID: "user-123", ClientPublicKey: "public-key-123", Status: StatusRevoked, RevokedAtMS: &revokedAtMS, }, nil }, } cache, err := NewReadThroughCache(local, fallback) require.NoError(t, err) record, err := cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) require.NotNil(t, record.RevokedAtMS) assert.Equal(t, StatusRevoked, record.Status) assert.Equal(t, 1, fallback.lookupCalls) record, err = cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) require.NotNil(t, record.RevokedAtMS) assert.Equal(t, StatusRevoked, record.Status) assert.Equal(t, revokedAtMS, *record.RevokedAtMS) assert.Equal(t, 1, fallback.lookupCalls) } func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) { t.Parallel() revokedAtMS := int64(123456789) local := NewMemoryCache() fallback := &recordingCache{ lookupFunc: func(context.Context, string) (Record, error) { return Record{ DeviceSessionID: "device-session-123", UserID: "user-123", ClientPublicKey: "public-key-123", Status: StatusRevoked, RevokedAtMS: &revokedAtMS, }, nil }, } cache, err := NewReadThroughCache(local, fallback) require.NoError(t, err) record, err := cache.Lookup(context.Background(), "device-session-123") require.NoError(t, err) require.NotNil(t, record.RevokedAtMS) *record.RevokedAtMS = 1 stored, err := local.Lookup(context.Background(), "device-session-123") require.NoError(t, err) require.NotNil(t, stored.RevokedAtMS) assert.Equal(t, revokedAtMS, *stored.RevokedAtMS) } type recordingCache struct { lookupCalls int lookupFunc func(context.Context, string) (Record, error) } func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) { c.lookupCalls++ if c.lookupFunc != nil { return c.lookupFunc(ctx, deviceSessionID) } return Record{}, errors.New("lookup is not implemented") }