177 lines
4.7 KiB
Go
177 lines
4.7 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cache := NewMemoryCache()
|
|
revokedAtMS := int64(123456789)
|
|
|
|
require.NoError(t, cache.Upsert(Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: StatusRevoked,
|
|
RevokedAtMS: &revokedAtMS,
|
|
}))
|
|
|
|
record, err := cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, record.RevokedAtMS)
|
|
|
|
*record.RevokedAtMS = 1
|
|
|
|
stored, err := cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, stored.RevokedAtMS)
|
|
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
|
}
|
|
|
|
func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
local := NewMemoryCache()
|
|
require.NoError(t, local.Upsert(Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: StatusActive,
|
|
}))
|
|
|
|
fallback := &recordingCache{
|
|
lookupFunc: func(context.Context, string) (Record, error) {
|
|
return Record{}, errors.New("fallback should not be called")
|
|
},
|
|
}
|
|
|
|
cache, err := NewReadThroughCache(local, fallback)
|
|
require.NoError(t, err)
|
|
|
|
record, err := cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: StatusActive,
|
|
}, record)
|
|
assert.Equal(t, 0, fallback.lookupCalls)
|
|
}
|
|
|
|
func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
local := NewMemoryCache()
|
|
fallback := &recordingCache{
|
|
lookupFunc: func(context.Context, string) (Record, error) {
|
|
return Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: StatusActive,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
cache, err := NewReadThroughCache(local, fallback)
|
|
require.NoError(t, err)
|
|
|
|
record, err := cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, fallback.lookupCalls)
|
|
assert.Equal(t, "user-123", record.UserID)
|
|
|
|
record, err = cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, fallback.lookupCalls)
|
|
assert.Equal(t, "user-123", record.UserID)
|
|
}
|
|
|
|
func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
revokedAtMS := int64(123456789)
|
|
local := NewMemoryCache()
|
|
fallback := &recordingCache{
|
|
lookupFunc: func(context.Context, string) (Record, error) {
|
|
return Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: StatusRevoked,
|
|
RevokedAtMS: &revokedAtMS,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
cache, err := NewReadThroughCache(local, fallback)
|
|
require.NoError(t, err)
|
|
|
|
record, err := cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, record.RevokedAtMS)
|
|
assert.Equal(t, StatusRevoked, record.Status)
|
|
assert.Equal(t, 1, fallback.lookupCalls)
|
|
|
|
record, err = cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, record.RevokedAtMS)
|
|
assert.Equal(t, StatusRevoked, record.Status)
|
|
assert.Equal(t, revokedAtMS, *record.RevokedAtMS)
|
|
assert.Equal(t, 1, fallback.lookupCalls)
|
|
}
|
|
|
|
func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
revokedAtMS := int64(123456789)
|
|
local := NewMemoryCache()
|
|
fallback := &recordingCache{
|
|
lookupFunc: func(context.Context, string) (Record, error) {
|
|
return Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: "user-123",
|
|
ClientPublicKey: "public-key-123",
|
|
Status: StatusRevoked,
|
|
RevokedAtMS: &revokedAtMS,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
cache, err := NewReadThroughCache(local, fallback)
|
|
require.NoError(t, err)
|
|
|
|
record, err := cache.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, record.RevokedAtMS)
|
|
|
|
*record.RevokedAtMS = 1
|
|
|
|
stored, err := local.Lookup(context.Background(), "device-session-123")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, stored.RevokedAtMS)
|
|
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
|
}
|
|
|
|
type recordingCache struct {
|
|
lookupCalls int
|
|
lookupFunc func(context.Context, string) (Record, error)
|
|
}
|
|
|
|
func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
|
c.lookupCalls++
|
|
if c.lookupFunc != nil {
|
|
return c.lookupFunc(ctx, deviceSessionID)
|
|
}
|
|
|
|
return Record{}, errors.New("lookup is not implemented")
|
|
}
|