Files
galaxy-game/authsession/internal/adapters/redis/projectionpublisher/publisher_test.go
T
2026-04-08 16:23:07 +02:00

443 lines
13 KiB
Go

package projectionpublisher
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/gatewayprojection"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 3,
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty session cache key prefix",
cfg: Config{
Addr: server.Addr(),
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session cache key prefix must not be empty",
},
{
name: "empty session events stream",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session events stream must not be empty",
},
{
name: "non positive stream max len",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "stream max len must be positive",
},
{
name: "non positive timeout",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
publisher, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, publisher.Close())
})
})
}
}
func TestPublisherPing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
require.NoError(t, publisher.Ping(context.Background()))
}
func TestPublisherPublishSessionActive(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key)
assert.True(t, server.Exists(key))
assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String())))
payload, err := server.Get(key)
require.NoError(t, err)
record := decodeCachePayload(t, payload)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusActive,
}, record)
assert.Zero(t, server.TTL(key))
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, map[string]string{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(gatewayprojection.StatusActive),
}, stringifyValues(entries[0].Values))
}
func TestPublisherPublishSessionRevoked(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC()
snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
payload, err := server.Get(key)
require.NoError(t, err)
record := decodeCachePayload(t, payload)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusRevoked,
RevokedAtMS: int64Pointer(revokedAt.UnixMilli()),
}, record)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, map[string]string{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(gatewayprojection.StatusRevoked),
"revoked_at_ms": "1776000123456",
}, stringifyValues(entries[0].Values))
}
func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
deviceSessionID := "device-session-456"
active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil)
revokedAt := time.Unix(1_776_010_000, 0).UTC()
revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt)
require.NoError(t, publisher.PublishSession(context.Background(), active))
require.NoError(t, publisher.PublishSession(context.Background(), revoked))
payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID))
require.NoError(t, err)
record := decodeCachePayload(t, payload)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
assert.Equal(t, gatewayprojection.StatusRevoked, record.Status)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 2)
assert.Equal(t, map[string]string{
"device_session_id": active.DeviceSessionID.String(),
"user_id": active.UserID.String(),
"client_public_key": active.ClientPublicKey,
"status": string(gatewayprojection.StatusActive),
}, stringifyValues(entries[0].Values))
assert.Equal(t, map[string]string{
"device_session_id": revoked.DeviceSessionID.String(),
"user_id": revoked.UserID.String(),
"client_public_key": revoked.ClientPublicKey,
"status": string(gatewayprojection.StatusRevoked),
"revoked_at_ms": "1776010000000",
}, stringifyValues(entries[1].Values))
}
func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID))
require.NoError(t, err)
record := decodeCachePayload(t, payload)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusActive,
}, record)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 2)
assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values))
}
func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2})
for index := range 6 {
snapshot := testSnapshot(
common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(),
gatewayprojection.StatusActive,
nil,
)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
}
streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result()
require.NoError(t, err)
assert.LessOrEqual(t, streamLength, int64(2))
}
func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID("device-session-123"),
UserID: common.UserID("user-123"),
Status: gatewayprojection.StatusActive,
}
err := publisher.PublishSession(context.Background(), snapshot)
require.Error(t, err)
assert.ErrorContains(t, err, "gateway projection client public key")
assert.Empty(t, server.Keys())
}
func TestPublisherPublishSessionNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func TestPublisherPublishSessionBackendFailure(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
server.Close()
err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
require.Error(t, err)
assert.ErrorContains(t, err, "publish session projection")
}
func TestPublisherPingNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
err := publisher.Ping(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.SessionCacheKeyPrefix == "" {
cfg.SessionCacheKeyPrefix = "gateway:session:"
}
if cfg.SessionEventsStream == "" {
cfg.SessionEventsStream = "gateway:session_events"
}
if cfg.StreamMaxLen == 0 {
cfg.StreamMaxLen = 1024
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
publisher, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, publisher.Close())
})
return publisher
}
func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot {
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
for index := range raw {
raw[index] = byte(index + 1)
}
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID("user-123"),
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
Status: status,
RevokedAt: revokedAt,
}
if status == gatewayprojection.StatusRevoked {
snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked")
snapshot.RevokeActorType = common.RevokeActorType("system")
}
return snapshot
}
func decodeCachePayload(t *testing.T, payload string) cacheRecord {
t.Helper()
decoder := json.NewDecoder(bytes.NewReader([]byte(payload)))
decoder.DisallowUnknownFields()
var record cacheRecord
require.NoError(t, decoder.Decode(&record))
err := decoder.Decode(&struct{}{})
if err == nil {
require.FailNow(t, "expected cache payload EOF after first JSON value")
}
require.ErrorIs(t, err, io.EOF)
var fieldSet map[string]json.RawMessage
require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet))
expectedFields := map[string]struct{}{
"device_session_id": {},
"user_id": {},
"client_public_key": {},
"status": {},
}
if record.RevokedAtMS != nil {
expectedFields["revoked_at_ms"] = struct{}{}
}
assert.Equal(t, len(expectedFields), len(fieldSet))
for field := range fieldSet {
_, ok := expectedFields[field]
assert.Truef(t, ok, "unexpected cache payload field %q", field)
}
return record
}
func stringifyValues(values map[string]any) map[string]string {
stringified := make(map[string]string, len(values))
for key, value := range values {
stringified[key] = fmt.Sprint(value)
}
return stringified
}
func encodeBase64URL(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func int64Pointer(value int64) *int64 {
return &value
}