443 lines
13 KiB
Go
443 lines
13 KiB
Go
package projectionpublisher
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ed25519"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/authsession/internal/domain/common"
|
|
"galaxy/authsession/internal/domain/gatewayprojection"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNew(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
|
|
tests := []struct {
|
|
name string
|
|
cfg Config
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
cfg: Config{
|
|
Addr: server.Addr(),
|
|
DB: 3,
|
|
SessionCacheKeyPrefix: "gateway:session:",
|
|
SessionEventsStream: "gateway:session_events",
|
|
StreamMaxLen: 1024,
|
|
OperationTimeout: 250 * time.Millisecond,
|
|
},
|
|
},
|
|
{
|
|
name: "empty addr",
|
|
cfg: Config{
|
|
SessionCacheKeyPrefix: "gateway:session:",
|
|
SessionEventsStream: "gateway:session_events",
|
|
StreamMaxLen: 1024,
|
|
OperationTimeout: 250 * time.Millisecond,
|
|
},
|
|
wantErr: "redis addr must not be empty",
|
|
},
|
|
{
|
|
name: "negative db",
|
|
cfg: Config{
|
|
Addr: server.Addr(),
|
|
DB: -1,
|
|
SessionCacheKeyPrefix: "gateway:session:",
|
|
SessionEventsStream: "gateway:session_events",
|
|
StreamMaxLen: 1024,
|
|
OperationTimeout: 250 * time.Millisecond,
|
|
},
|
|
wantErr: "redis db must not be negative",
|
|
},
|
|
{
|
|
name: "empty session cache key prefix",
|
|
cfg: Config{
|
|
Addr: server.Addr(),
|
|
SessionEventsStream: "gateway:session_events",
|
|
StreamMaxLen: 1024,
|
|
OperationTimeout: 250 * time.Millisecond,
|
|
},
|
|
wantErr: "session cache key prefix must not be empty",
|
|
},
|
|
{
|
|
name: "empty session events stream",
|
|
cfg: Config{
|
|
Addr: server.Addr(),
|
|
SessionCacheKeyPrefix: "gateway:session:",
|
|
StreamMaxLen: 1024,
|
|
OperationTimeout: 250 * time.Millisecond,
|
|
},
|
|
wantErr: "session events stream must not be empty",
|
|
},
|
|
{
|
|
name: "non positive stream max len",
|
|
cfg: Config{
|
|
Addr: server.Addr(),
|
|
SessionCacheKeyPrefix: "gateway:session:",
|
|
SessionEventsStream: "gateway:session_events",
|
|
OperationTimeout: 250 * time.Millisecond,
|
|
},
|
|
wantErr: "stream max len must be positive",
|
|
},
|
|
{
|
|
name: "non positive timeout",
|
|
cfg: Config{
|
|
Addr: server.Addr(),
|
|
SessionCacheKeyPrefix: "gateway:session:",
|
|
SessionEventsStream: "gateway:session_events",
|
|
StreamMaxLen: 1024,
|
|
},
|
|
wantErr: "operation timeout must be positive",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
publisher, err := New(tt.cfg)
|
|
if tt.wantErr != "" {
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, tt.wantErr)
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, publisher.Close())
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPublisherPing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{})
|
|
|
|
require.NoError(t, publisher.Ping(context.Background()))
|
|
}
|
|
|
|
func TestPublisherPublishSessionActive(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{})
|
|
snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil)
|
|
|
|
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
|
|
|
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
|
assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key)
|
|
assert.True(t, server.Exists(key))
|
|
assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String())))
|
|
|
|
payload, err := server.Get(key)
|
|
require.NoError(t, err)
|
|
record := decodeCachePayload(t, payload)
|
|
assert.Equal(t, cacheRecord{
|
|
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
|
UserID: snapshot.UserID.String(),
|
|
ClientPublicKey: snapshot.ClientPublicKey,
|
|
Status: gatewayprojection.StatusActive,
|
|
}, record)
|
|
assert.Zero(t, server.TTL(key))
|
|
|
|
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 1)
|
|
assert.Equal(t, map[string]string{
|
|
"device_session_id": snapshot.DeviceSessionID.String(),
|
|
"user_id": snapshot.UserID.String(),
|
|
"client_public_key": snapshot.ClientPublicKey,
|
|
"status": string(gatewayprojection.StatusActive),
|
|
}, stringifyValues(entries[0].Values))
|
|
}
|
|
|
|
func TestPublisherPublishSessionRevoked(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{})
|
|
revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC()
|
|
snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt)
|
|
|
|
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
|
|
|
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
|
payload, err := server.Get(key)
|
|
require.NoError(t, err)
|
|
record := decodeCachePayload(t, payload)
|
|
require.NotNil(t, record.RevokedAtMS)
|
|
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
|
assert.Equal(t, cacheRecord{
|
|
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
|
UserID: snapshot.UserID.String(),
|
|
ClientPublicKey: snapshot.ClientPublicKey,
|
|
Status: gatewayprojection.StatusRevoked,
|
|
RevokedAtMS: int64Pointer(revokedAt.UnixMilli()),
|
|
}, record)
|
|
|
|
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 1)
|
|
assert.Equal(t, map[string]string{
|
|
"device_session_id": snapshot.DeviceSessionID.String(),
|
|
"user_id": snapshot.UserID.String(),
|
|
"client_public_key": snapshot.ClientPublicKey,
|
|
"status": string(gatewayprojection.StatusRevoked),
|
|
"revoked_at_ms": "1776000123456",
|
|
}, stringifyValues(entries[0].Values))
|
|
}
|
|
|
|
func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
|
deviceSessionID := "device-session-456"
|
|
|
|
active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil)
|
|
revokedAt := time.Unix(1_776_010_000, 0).UTC()
|
|
revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt)
|
|
|
|
require.NoError(t, publisher.PublishSession(context.Background(), active))
|
|
require.NoError(t, publisher.PublishSession(context.Background(), revoked))
|
|
|
|
payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID))
|
|
require.NoError(t, err)
|
|
record := decodeCachePayload(t, payload)
|
|
require.NotNil(t, record.RevokedAtMS)
|
|
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
|
assert.Equal(t, gatewayprojection.StatusRevoked, record.Status)
|
|
|
|
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 2)
|
|
assert.Equal(t, map[string]string{
|
|
"device_session_id": active.DeviceSessionID.String(),
|
|
"user_id": active.UserID.String(),
|
|
"client_public_key": active.ClientPublicKey,
|
|
"status": string(gatewayprojection.StatusActive),
|
|
}, stringifyValues(entries[0].Values))
|
|
assert.Equal(t, map[string]string{
|
|
"device_session_id": revoked.DeviceSessionID.String(),
|
|
"user_id": revoked.UserID.String(),
|
|
"client_public_key": revoked.ClientPublicKey,
|
|
"status": string(gatewayprojection.StatusRevoked),
|
|
"revoked_at_ms": "1776010000000",
|
|
}, stringifyValues(entries[1].Values))
|
|
}
|
|
|
|
func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
|
snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil)
|
|
|
|
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
|
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
|
|
|
payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID))
|
|
require.NoError(t, err)
|
|
record := decodeCachePayload(t, payload)
|
|
assert.Equal(t, cacheRecord{
|
|
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
|
UserID: snapshot.UserID.String(),
|
|
ClientPublicKey: snapshot.ClientPublicKey,
|
|
Status: gatewayprojection.StatusActive,
|
|
}, record)
|
|
|
|
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
|
require.NoError(t, err)
|
|
require.Len(t, entries, 2)
|
|
assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values))
|
|
}
|
|
|
|
func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2})
|
|
|
|
for index := range 6 {
|
|
snapshot := testSnapshot(
|
|
common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(),
|
|
gatewayprojection.StatusActive,
|
|
nil,
|
|
)
|
|
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
|
}
|
|
|
|
streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result()
|
|
require.NoError(t, err)
|
|
assert.LessOrEqual(t, streamLength, int64(2))
|
|
}
|
|
|
|
func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{})
|
|
snapshot := gatewayprojection.Snapshot{
|
|
DeviceSessionID: common.DeviceSessionID("device-session-123"),
|
|
UserID: common.UserID("user-123"),
|
|
Status: gatewayprojection.StatusActive,
|
|
}
|
|
|
|
err := publisher.PublishSession(context.Background(), snapshot)
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "gateway projection client public key")
|
|
assert.Empty(t, server.Keys())
|
|
}
|
|
|
|
func TestPublisherPublishSessionNilContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{})
|
|
|
|
err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "nil context")
|
|
}
|
|
|
|
func TestPublisherPublishSessionBackendFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{})
|
|
server.Close()
|
|
|
|
err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "publish session projection")
|
|
}
|
|
|
|
func TestPublisherPingNilContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
publisher := newTestPublisher(t, server, Config{})
|
|
|
|
err := publisher.Ping(nil)
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "nil context")
|
|
}
|
|
|
|
func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher {
|
|
t.Helper()
|
|
|
|
if cfg.Addr == "" {
|
|
cfg.Addr = server.Addr()
|
|
}
|
|
if cfg.SessionCacheKeyPrefix == "" {
|
|
cfg.SessionCacheKeyPrefix = "gateway:session:"
|
|
}
|
|
if cfg.SessionEventsStream == "" {
|
|
cfg.SessionEventsStream = "gateway:session_events"
|
|
}
|
|
if cfg.StreamMaxLen == 0 {
|
|
cfg.StreamMaxLen = 1024
|
|
}
|
|
if cfg.OperationTimeout == 0 {
|
|
cfg.OperationTimeout = 250 * time.Millisecond
|
|
}
|
|
|
|
publisher, err := New(cfg)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, publisher.Close())
|
|
})
|
|
|
|
return publisher
|
|
}
|
|
|
|
func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot {
|
|
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
|
for index := range raw {
|
|
raw[index] = byte(index + 1)
|
|
}
|
|
|
|
snapshot := gatewayprojection.Snapshot{
|
|
DeviceSessionID: common.DeviceSessionID(deviceSessionID),
|
|
UserID: common.UserID("user-123"),
|
|
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
|
|
Status: status,
|
|
RevokedAt: revokedAt,
|
|
}
|
|
if status == gatewayprojection.StatusRevoked {
|
|
snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked")
|
|
snapshot.RevokeActorType = common.RevokeActorType("system")
|
|
}
|
|
|
|
return snapshot
|
|
}
|
|
|
|
func decodeCachePayload(t *testing.T, payload string) cacheRecord {
|
|
t.Helper()
|
|
|
|
decoder := json.NewDecoder(bytes.NewReader([]byte(payload)))
|
|
decoder.DisallowUnknownFields()
|
|
|
|
var record cacheRecord
|
|
require.NoError(t, decoder.Decode(&record))
|
|
err := decoder.Decode(&struct{}{})
|
|
if err == nil {
|
|
require.FailNow(t, "expected cache payload EOF after first JSON value")
|
|
}
|
|
require.ErrorIs(t, err, io.EOF)
|
|
|
|
var fieldSet map[string]json.RawMessage
|
|
require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet))
|
|
expectedFields := map[string]struct{}{
|
|
"device_session_id": {},
|
|
"user_id": {},
|
|
"client_public_key": {},
|
|
"status": {},
|
|
}
|
|
if record.RevokedAtMS != nil {
|
|
expectedFields["revoked_at_ms"] = struct{}{}
|
|
}
|
|
assert.Equal(t, len(expectedFields), len(fieldSet))
|
|
for field := range fieldSet {
|
|
_, ok := expectedFields[field]
|
|
assert.Truef(t, ok, "unexpected cache payload field %q", field)
|
|
}
|
|
|
|
return record
|
|
}
|
|
|
|
func stringifyValues(values map[string]any) map[string]string {
|
|
stringified := make(map[string]string, len(values))
|
|
for key, value := range values {
|
|
stringified[key] = fmt.Sprint(value)
|
|
}
|
|
return stringified
|
|
}
|
|
|
|
func encodeBase64URL(value string) string {
|
|
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
|
}
|
|
|
|
func int64Pointer(value int64) *int64 {
|
|
return &value
|
|
}
|