feat: authsession service
This commit is contained in:
@@ -0,0 +1,442 @@
|
||||
package projectionpublisher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 3,
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session cache key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session cache key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty session events stream",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session events stream must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive stream max len",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "stream max len must be positive",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, publisher.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublisherPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
require.NoError(t, publisher.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key)
|
||||
assert.True(t, server.Exists(key))
|
||||
assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String())))
|
||||
|
||||
payload, err := server.Get(key)
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}, record)
|
||||
assert.Zero(t, server.TTL(key))
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusActive),
|
||||
}, stringifyValues(entries[0].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionRevoked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC()
|
||||
snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
payload, err := server.Get(key)
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusRevoked,
|
||||
RevokedAtMS: int64Pointer(revokedAt.UnixMilli()),
|
||||
}, record)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusRevoked),
|
||||
"revoked_at_ms": "1776000123456",
|
||||
}, stringifyValues(entries[0].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
||||
deviceSessionID := "device-session-456"
|
||||
|
||||
active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil)
|
||||
revokedAt := time.Unix(1_776_010_000, 0).UTC()
|
||||
revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), active))
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), revoked))
|
||||
|
||||
payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID))
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, record.Status)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": active.DeviceSessionID.String(),
|
||||
"user_id": active.UserID.String(),
|
||||
"client_public_key": active.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusActive),
|
||||
}, stringifyValues(entries[0].Values))
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": revoked.DeviceSessionID.String(),
|
||||
"user_id": revoked.UserID.String(),
|
||||
"client_public_key": revoked.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusRevoked),
|
||||
"revoked_at_ms": "1776010000000",
|
||||
}, stringifyValues(entries[1].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
||||
snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID))
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}, record)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2})
|
||||
|
||||
for index := range 6 {
|
||||
snapshot := testSnapshot(
|
||||
common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(),
|
||||
gatewayprojection.StatusActive,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
}
|
||||
|
||||
streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result()
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, streamLength, int64(2))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID("device-session-123"),
|
||||
UserID: common.UserID("user-123"),
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}
|
||||
|
||||
err := publisher.PublishSession(context.Background(), snapshot)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "gateway projection client public key")
|
||||
assert.Empty(t, server.Keys())
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionBackendFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
server.Close()
|
||||
|
||||
err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "publish session projection")
|
||||
}
|
||||
|
||||
func TestPublisherPingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
err := publisher.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionCacheKeyPrefix == "" {
|
||||
cfg.SessionCacheKeyPrefix = "gateway:session:"
|
||||
}
|
||||
if cfg.SessionEventsStream == "" {
|
||||
cfg.SessionEventsStream = "gateway:session_events"
|
||||
}
|
||||
if cfg.StreamMaxLen == 0 {
|
||||
cfg.StreamMaxLen = 1024
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
publisher, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, publisher.Close())
|
||||
})
|
||||
|
||||
return publisher
|
||||
}
|
||||
|
||||
func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot {
|
||||
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
for index := range raw {
|
||||
raw[index] = byte(index + 1)
|
||||
}
|
||||
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID("user-123"),
|
||||
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
|
||||
Status: status,
|
||||
RevokedAt: revokedAt,
|
||||
}
|
||||
if status == gatewayprojection.StatusRevoked {
|
||||
snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked")
|
||||
snapshot.RevokeActorType = common.RevokeActorType("system")
|
||||
}
|
||||
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func decodeCachePayload(t *testing.T, payload string) cacheRecord {
|
||||
t.Helper()
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader([]byte(payload)))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var record cacheRecord
|
||||
require.NoError(t, decoder.Decode(&record))
|
||||
err := decoder.Decode(&struct{}{})
|
||||
if err == nil {
|
||||
require.FailNow(t, "expected cache payload EOF after first JSON value")
|
||||
}
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
var fieldSet map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet))
|
||||
expectedFields := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"user_id": {},
|
||||
"client_public_key": {},
|
||||
"status": {},
|
||||
}
|
||||
if record.RevokedAtMS != nil {
|
||||
expectedFields["revoked_at_ms"] = struct{}{}
|
||||
}
|
||||
assert.Equal(t, len(expectedFields), len(fieldSet))
|
||||
for field := range fieldSet {
|
||||
_, ok := expectedFields[field]
|
||||
assert.Truef(t, ok, "unexpected cache payload field %q", field)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func stringifyValues(values map[string]any) map[string]string {
|
||||
stringified := make(map[string]string, len(values))
|
||||
for key, value := range values {
|
||||
stringified[key] = fmt.Sprint(value)
|
||||
}
|
||||
return stringified
|
||||
}
|
||||
|
||||
func encodeBase64URL(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func int64Pointer(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
Reference in New Issue
Block a user