feat: authsession service
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
// Package antiabuse provides runtime in-process adapters for auth-specific
|
||||
// public abuse controls.
|
||||
package antiabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// SendEmailCodeProtector is a concurrency-safe in-process resend-throttle
|
||||
// adapter for public send-email-code attempts.
|
||||
type SendEmailCodeProtector struct {
|
||||
mu sync.Mutex
|
||||
reservedUntil map[common.Email]time.Time
|
||||
}
|
||||
|
||||
// CheckAndReserve applies the fixed Stage-17 resend cooldown using input.Now
|
||||
// as the authoritative decision timestamp.
|
||||
func (p *SendEmailCodeProtector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
|
||||
if ctx == nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err)
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.reservedUntil == nil {
|
||||
p.reservedUntil = make(map[common.Email]time.Time)
|
||||
}
|
||||
|
||||
reservedUntil, exists := p.reservedUntil[input.Email]
|
||||
if exists && input.Now.Before(reservedUntil) {
|
||||
return ports.SendEmailCodeAbuseResult{
|
||||
Outcome: ports.SendEmailCodeAbuseOutcomeThrottled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
p.reservedUntil[input.Email] = input.Now.UTC().Add(challenge.ResendThrottleCooldown)
|
||||
return ports.SendEmailCodeAbuseResult{
|
||||
Outcome: ports.SendEmailCodeAbuseOutcomeAllowed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ ports.SendEmailCodeAbuseProtector = (*SendEmailCodeProtector)(nil)
|
||||
@@ -0,0 +1,64 @@
|
||||
package antiabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSendEmailCodeProtectorCheckAndReserve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protector := &SendEmailCodeProtector{}
|
||||
email := common.Email("pilot@example.com")
|
||||
now := time.Unix(10, 0).UTC()
|
||||
|
||||
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(30 * time.Second),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
}
|
||||
|
||||
func TestSendEmailCodeProtectorNilOrCanceledContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protector := &SendEmailCodeProtector{}
|
||||
_, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Now: time.Unix(10, 0).UTC(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err = protector.CheckAndReserve(ctx, ports.SendEmailCodeAbuseInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Now: time.Unix(10, 0).UTC(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
// Package contracttest provides reusable adapter conformance suites that
|
||||
// exercise storage-agnostic port contracts without depending on one concrete
|
||||
// backend implementation.
|
||||
package contracttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ChallengeStoreFactory constructs a fresh ChallengeStore instance suitable
|
||||
// for one isolated contract subtest.
|
||||
type ChallengeStoreFactory func(t *testing.T) ports.ChallengeStore
|
||||
|
||||
// RunChallengeStoreContractTests executes the backend-agnostic ChallengeStore
|
||||
// contract suite against newStore.
|
||||
func RunChallengeStoreContractTests(t *testing.T, newStore ChallengeStoreFactory) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("create and get", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractConfirmedChallenge(t, time.Unix(1_775_130_000, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
})
|
||||
|
||||
t.Run("get not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("create conflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractPendingChallenge(time.Unix(1_775_130_100, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("compare and swap success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
now := time.Unix(1_775_130_200, 0).UTC()
|
||||
previous := contractPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
next.Attempts.Send = 1
|
||||
next.Abuse.LastAttemptAt = contractTimePointer(now.Add(time.Minute))
|
||||
require.NoError(t, next.Validate())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), previous))
|
||||
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
|
||||
|
||||
got, err := store.Get(context.Background(), previous.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, next, got)
|
||||
})
|
||||
|
||||
t.Run("compare and swap conflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
now := time.Unix(1_775_130_300, 0).UTC()
|
||||
stored := contractPendingChallenge(now)
|
||||
previous := stored
|
||||
previous.Attempts.Send = 99
|
||||
require.NoError(t, previous.Validate())
|
||||
next := stored
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
require.NoError(t, next.Validate())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), stored))
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("compare and swap not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
now := time.Unix(1_775_130_400, 0).UTC()
|
||||
previous := contractPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
require.NoError(t, next.Validate())
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("get returns defensive copies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractConfirmedChallenge(t, time.Unix(1_775_130_500, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got.CodeHash)
|
||||
got.CodeHash[0] = 0xFF
|
||||
if got.Confirmation != nil {
|
||||
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
|
||||
if len(keyBytes) > 0 {
|
||||
keyBytes[0] = 0xFE
|
||||
}
|
||||
}
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.CodeHash, again.CodeHash)
|
||||
require.NotNil(t, again.Confirmation)
|
||||
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
|
||||
})
|
||||
}
|
||||
|
||||
func contractPendingChallenge(now time.Time) challenge.Challenge {
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-pending"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-pending-code"),
|
||||
Status: challenge.StatusPendingSend,
|
||||
DeliveryState: challenge.DeliveryPending,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.InitialTTL),
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractConfirmedChallenge(t *testing.T, now time.Time) challenge.Challenge {
|
||||
t.Helper()
|
||||
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-confirmed"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-code"),
|
||||
Status: challenge.StatusConfirmedPendingExpire,
|
||||
DeliveryState: challenge.DeliverySent,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.ConfirmedRetention),
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: 1,
|
||||
Confirm: 2,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: contractTimePointer(now.Add(30 * time.Second)),
|
||||
},
|
||||
Confirmation: &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID("device-session-1"),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: now.Add(time.Minute),
|
||||
},
|
||||
}
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractTimePointer(value time.Time) *time.Time {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package contracttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ConfigProviderHarnessFactory constructs a fresh semantic ConfigProvider
|
||||
// harness suitable for one isolated contract subtest.
|
||||
type ConfigProviderHarnessFactory func(t *testing.T) ConfigProviderHarness
|
||||
|
||||
// ConfigProviderHarness bundles one semantic ConfigProvider instance with the
|
||||
// seed hooks needed by the backend-agnostic contract suite.
|
||||
type ConfigProviderHarness struct {
|
||||
// Provider is the semantic ConfigProvider under test.
|
||||
Provider ports.ConfigProvider
|
||||
|
||||
// SeedDisabled prepares storage so LoadSessionLimit observes “limit absent”.
|
||||
SeedDisabled func(t *testing.T)
|
||||
|
||||
// SeedLimit prepares storage so LoadSessionLimit observes a valid positive
|
||||
// configured limit.
|
||||
SeedLimit func(t *testing.T, limit int)
|
||||
}
|
||||
|
||||
// RunConfigProviderContractTests executes the backend-agnostic ConfigProvider
|
||||
// semantic contract suite against newHarness.
|
||||
func RunConfigProviderContractTests(t *testing.T, newHarness ConfigProviderHarnessFactory) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("limit absent means disabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
harness := newHarness(t)
|
||||
require.NotNil(t, harness.Provider)
|
||||
require.NotNil(t, harness.SeedDisabled)
|
||||
|
||||
harness.SeedDisabled(t)
|
||||
|
||||
got, err := harness.Provider.LoadSessionLimit(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SessionLimitConfig{}, got)
|
||||
})
|
||||
|
||||
t.Run("valid positive limit means configured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
harness := newHarness(t)
|
||||
require.NotNil(t, harness.Provider)
|
||||
require.NotNil(t, harness.SeedLimit)
|
||||
|
||||
want := 5
|
||||
harness.SeedLimit(t, want)
|
||||
|
||||
got, err := harness.Provider.LoadSessionLimit(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.ActiveSessionLimit)
|
||||
assert.Equal(t, want, *got.ActiveSessionLimit)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package contracttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// SessionStoreFactory constructs a fresh SessionStore instance suitable for
|
||||
// one isolated contract subtest.
|
||||
type SessionStoreFactory func(t *testing.T) ports.SessionStore
|
||||
|
||||
// RunSessionStoreContractTests executes the backend-agnostic SessionStore
|
||||
// contract suite against newStore.
|
||||
func RunSessionStoreContractTests(t *testing.T, newStore SessionStoreFactory) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("create and get", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
})
|
||||
|
||||
t.Run("create conflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_050, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("get not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("list by user id returns newest first", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
older := contractActiveSession(t, "device-session-old", "user-1", time.Unix(10, 0).UTC())
|
||||
newer := contractActiveSession(t, "device-session-new", "user-1", time.Unix(20, 0).UTC())
|
||||
revoked := contractRevokedSession(t, "device-session-revoked", "user-1", time.Unix(15, 0).UTC())
|
||||
otherUser := contractActiveSession(t, "device-session-other", "user-2", time.Unix(30, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got, 3)
|
||||
assert.Equal(
|
||||
t,
|
||||
[]common.DeviceSessionID{newer.ID, revoked.ID, older.ID},
|
||||
[]common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("list by user id returns empty slice for unknown user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
|
||||
t.Run("count active by user id", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
activeOne := contractActiveSession(t, "device-session-1", "user-1", time.Unix(40, 0).UTC())
|
||||
activeTwo := contractActiveSession(t, "device-session-2", "user-1", time.Unix(50, 0).UTC())
|
||||
revoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(60, 0).UTC())
|
||||
otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(70, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{activeOne, activeTwo, revoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
})
|
||||
|
||||
t.Run("revoke active session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
revocation := contractRevocation(time.Unix(200, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", "")
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, revocation, *result.Session.Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
})
|
||||
|
||||
t.Run("revoke already revoked preserves stored revocation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractRevokedSession(t, "device-session-2", "user-1", time.Unix(110, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
|
||||
})
|
||||
|
||||
t.Run("revoke not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
|
||||
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: common.DeviceSessionID("missing-session"),
|
||||
Revocation: contractRevocation(time.Unix(210, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", ""),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("revoke all by user id revokes active sessions newest first", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
older := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
newer := contractActiveSession(t, "device-session-2", "user-1", time.Unix(200, 0).UTC())
|
||||
alreadyRevoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(150, 0).UTC())
|
||||
otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(250, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
revocation := contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1")
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
|
||||
require.Len(t, result.Sessions, 2)
|
||||
assert.Equal(
|
||||
t,
|
||||
[]common.DeviceSessionID{newer.ID, older.ID},
|
||||
[]common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID},
|
||||
)
|
||||
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
|
||||
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
})
|
||||
|
||||
t.Run("revoke all by user id reports no active sessions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractRevokedSession(t, "device-session-5", "user-1", time.Unix(120, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: contractRevocation(time.Unix(400, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", ""),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
|
||||
require.NotNil(t, result.Sessions)
|
||||
assert.Empty(t, result.Sessions)
|
||||
})
|
||||
|
||||
t.Run("get returns defensive copies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newStore(t)
|
||||
record := contractRevokedSession(t, "device-session-copy", "user-1", time.Unix(130, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.Revocation)
|
||||
got.Revocation.ActorID = "mutated"
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, again.Revocation)
|
||||
assert.Equal(t, record, again)
|
||||
})
|
||||
}
|
||||
|
||||
func contractActiveSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
t.Helper()
|
||||
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
record := devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractRevokedSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
t.Helper()
|
||||
|
||||
record := contractActiveSession(t, deviceSessionID, userID, createdAt)
|
||||
revocation := contractRevocation(createdAt.Add(time.Minute), devicesession.RevokeReasonDeviceLogout, "user", "user-actor")
|
||||
record.Status = devicesession.StatusRevoked
|
||||
record.Revocation = &revocation
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func contractRevocation(at time.Time, reasonCode common.RevokeReasonCode, actorType string, actorID string) devicesession.Revocation {
|
||||
record := devicesession.Revocation{
|
||||
At: at,
|
||||
ReasonCode: reasonCode,
|
||||
ActorType: common.RevokeActorType(actorType),
|
||||
ActorID: actorID,
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
// Package local provides small in-process runtime implementations for
|
||||
// authsession ports that do not require network dependencies.
|
||||
package local
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
challengeIDPrefix = "challenge-"
|
||||
deviceSessionIDPrefix = "device-session-"
|
||||
codeDigits = 6
|
||||
)
|
||||
|
||||
// Clock implements ports.Clock using the local system clock in UTC.
|
||||
type Clock struct{}
|
||||
|
||||
// Now returns the current system time normalized to UTC.
|
||||
func (Clock) Now() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
// IDGenerator implements ports.IDGenerator with cryptographically random
|
||||
// opaque identifiers.
|
||||
type IDGenerator struct{}
|
||||
|
||||
// NewChallengeID returns a fresh random challenge identifier.
|
||||
func (IDGenerator) NewChallengeID() (common.ChallengeID, error) {
|
||||
value, err := newOpaqueIDString(challengeIDPrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return common.ChallengeID(value), nil
|
||||
}
|
||||
|
||||
// NewDeviceSessionID returns a fresh random device-session identifier.
|
||||
func (IDGenerator) NewDeviceSessionID() (common.DeviceSessionID, error) {
|
||||
value, err := newOpaqueIDString(deviceSessionIDPrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return common.DeviceSessionID(value), nil
|
||||
}
|
||||
|
||||
// CodeGenerator implements ports.CodeGenerator with random 6-digit decimal
|
||||
// confirmation codes.
|
||||
type CodeGenerator struct{}
|
||||
|
||||
// Generate returns one fresh random 6-digit decimal code.
|
||||
func (CodeGenerator) Generate() (string, error) {
|
||||
var builder strings.Builder
|
||||
builder.Grow(codeDigits)
|
||||
|
||||
for idx := 0; idx < codeDigits; idx++ {
|
||||
digit, err := rand.Int(rand.Reader, big.NewInt(10))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate confirmation code: %w", err)
|
||||
}
|
||||
builder.WriteByte(byte('0' + digit.Int64()))
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
// CodeHasher implements ports.CodeHasher with bcrypt-backed hashes.
|
||||
type CodeHasher struct{}
|
||||
|
||||
// Hash returns the bcrypt hash of code.
|
||||
func (CodeHasher) Hash(code string) ([]byte, error) {
|
||||
if err := validateCode(code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash confirmation code: %w", err)
|
||||
}
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// Compare reports whether hash matches code.
|
||||
func (CodeHasher) Compare(hash []byte, code string) (bool, error) {
|
||||
if err := validateCode(code); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(hash) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword(hash, []byte(code))
|
||||
switch err {
|
||||
case nil:
|
||||
return true, nil
|
||||
case bcrypt.ErrMismatchedHashAndPassword:
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("compare confirmation code hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newOpaqueIDString(prefix string) (string, error) {
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("generate opaque identifier: %w", err)
|
||||
}
|
||||
|
||||
return prefix + base64.RawURLEncoding.EncodeToString(randomBytes), nil
|
||||
}
|
||||
|
||||
func validateCode(code string) error {
|
||||
switch {
|
||||
case strings.TrimSpace(code) == "":
|
||||
return fmt.Errorf("code must not be empty")
|
||||
case strings.TrimSpace(code) != code:
|
||||
return fmt.Errorf("code must not contain surrounding whitespace")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ ports.Clock = Clock{}
|
||||
_ ports.IDGenerator = IDGenerator{}
|
||||
_ ports.CodeGenerator = CodeGenerator{}
|
||||
_ ports.CodeHasher = CodeHasher{}
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClockNowReturnsUTC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := Clock{}.Now()
|
||||
|
||||
assert.Equal(t, time.UTC, now.Location())
|
||||
}
|
||||
|
||||
func TestIDGeneratorProducesValidOpaqueIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
generator := IDGenerator{}
|
||||
|
||||
challengeID, err := generator.NewChallengeID()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, challengeID.Validate())
|
||||
assert.Regexp(t, regexp.MustCompile(`^challenge-[A-Za-z0-9_-]+$`), challengeID.String())
|
||||
|
||||
deviceSessionID, err := generator.NewDeviceSessionID()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, deviceSessionID.Validate())
|
||||
assert.Regexp(t, regexp.MustCompile(`^device-session-[A-Za-z0-9_-]+$`), deviceSessionID.String())
|
||||
}
|
||||
|
||||
func TestCodeGeneratorProducesSixDigitNumericCodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
code, err := CodeGenerator{}.Generate()
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, regexp.MustCompile(`^\d{6}$`), code)
|
||||
}
|
||||
|
||||
func TestCodeHasherHashesAndComparesCodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hasher := CodeHasher{}
|
||||
|
||||
hash, err := hasher.Hash("123456")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
match, err := hasher.Compare(hash, "123456")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, match)
|
||||
|
||||
match, err = hasher.Compare(hash, "000000")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, match)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
// Package mail provides runtime mail-delivery adapters for the auth/session
|
||||
// service.
|
||||
package mail
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
const sendLoginCodePath = "/api/v1/internal/login-code-deliveries"
|
||||
|
||||
// Config configures one HTTP-based mail-delivery client.
|
||||
type Config struct {
|
||||
// BaseURL is the absolute base URL of the internal mail-service HTTP API.
|
||||
BaseURL string
|
||||
|
||||
// RequestTimeout bounds each outbound mail-service request.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// RESTClient implements ports.MailSender over the frozen internal REST mail
|
||||
// contract.
|
||||
type RESTClient struct {
|
||||
baseURL string
|
||||
requestTimeout time.Duration
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewRESTClient constructs a REST-backed MailSender adapter from cfg.
|
||||
func NewRESTClient(cfg Config) (*RESTClient, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
return newRESTClient(cfg, &http.Client{Transport: transport})
|
||||
}
|
||||
|
||||
func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.BaseURL) == "":
|
||||
return nil, errors.New("new mail service REST client: base URL must not be empty")
|
||||
case cfg.RequestTimeout <= 0:
|
||||
return nil, errors.New("new mail service REST client: request timeout must be positive")
|
||||
case httpClient == nil:
|
||||
return nil, errors.New("new mail service REST client: http client must not be nil")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new mail service REST client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new mail service REST client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &RESTClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
requestTimeout: cfg.RequestTimeout,
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *RESTClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendLoginCode submits one delivery request to the internal mail service
|
||||
// without retrying transport or upstream failures.
|
||||
func (c *RESTClient) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
|
||||
if err := validateRESTContext(ctx, "send login code"); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
|
||||
}
|
||||
|
||||
payload, statusCode, err := c.doRequest(ctx, "send login code", map[string]string{
|
||||
"email": input.Email.String(),
|
||||
"code": input.Code,
|
||||
})
|
||||
if err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.SendLoginCodeOutcome `json:"outcome"`
|
||||
}
|
||||
if err := decodeJSONPayload(payload, &response); err != nil {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
|
||||
}
|
||||
|
||||
result := ports.SendLoginCodeResult{Outcome: response.Outcome}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *RESTClient) doRequest(ctx context.Context, operation string, requestBody any) ([]byte, int, error) {
|
||||
bodyBytes, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: marshal request body: %w", operation, err)
|
||||
}
|
||||
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
request, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, c.baseURL+sendLoginCodePath, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: build request: %w", operation, err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: read response body: %w", operation, err)
|
||||
}
|
||||
|
||||
return payload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
func decodeJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("decode response body: unexpected trailing JSON input")
|
||||
}
|
||||
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRESTContext(ctx context.Context, operation string) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ports.MailSender = (*RESTClient)(nil)
|
||||
@@ -0,0 +1,394 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRESTClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
cfg: Config{
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must not be empty",
|
||||
},
|
||||
{
|
||||
name: "relative base url",
|
||||
cfg: Config{
|
||||
BaseURL: "/relative",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must be absolute",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
},
|
||||
wantErr: "request timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := NewRESTClient(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientSendLoginCodeSuccessCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response string
|
||||
wantOutcome ports.SendLoginCodeOutcome
|
||||
}{
|
||||
{
|
||||
name: "sent",
|
||||
response: `{"outcome":"sent"}`,
|
||||
wantOutcome: ports.SendLoginCodeOutcomeSent,
|
||||
},
|
||||
{
|
||||
name: "suppressed",
|
||||
response: `{"outcome":"suppressed"}`,
|
||||
wantOutcome: ports.SendLoginCodeOutcomeSuppressed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestsMu sync.Mutex
|
||||
var requests []capturedRequest
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, captureRequest(t, r))
|
||||
requestsMu.Unlock()
|
||||
|
||||
writeJSON(t, w, http.StatusOK, json.RawMessage(tt.response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
result, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantOutcome, result.Outcome)
|
||||
|
||||
requestsMu.Lock()
|
||||
defer requestsMu.Unlock()
|
||||
|
||||
require.Len(t, requests, 1)
|
||||
assert.Equal(t, http.MethodPost, requests[0].Method)
|
||||
assert.Equal(t, sendLoginCodePath, requests[0].Path)
|
||||
assert.Equal(t, "application/json", requests[0].ContentType)
|
||||
assert.JSONEq(t, `{"email":"pilot@example.com","code":"654321"}`, requests[0].Body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientPreservesNormalizedEmailAndCodeExactly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var captured string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured = captureRequest(t, r).Body
|
||||
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
result, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
||||
Email: common.Email("Pilot+Alias@Example.com"),
|
||||
Code: "123456",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
|
||||
assert.JSONEq(t, `{"email":"Pilot+Alias@Example.com","code":"123456"}`, captured)
|
||||
}
|
||||
|
||||
func TestRESTClientSendLoginCodeDoesNotRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("no retry on 503", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls atomic.Int64
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls.Add(1)
|
||||
http.Error(w, "temporary", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "unexpected HTTP status 503")
|
||||
assert.EqualValues(t, 1, calls.Load())
|
||||
})
|
||||
|
||||
t.Run("no retry on transport failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls atomic.Int64
|
||||
client, err := newRESTClient(Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
RequestTimeout: 250 * time.Millisecond,
|
||||
}, &http.Client{
|
||||
Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
return nil, errors.New("temporary transport failure")
|
||||
}),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "temporary transport failure")
|
||||
assert.EqualValues(t, 1, calls.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "rejects unknown field",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"sent","extra":true}`,
|
||||
wantErrText: "decode response body",
|
||||
},
|
||||
{
|
||||
name: "rejects unsupported outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"queued"}`,
|
||||
wantErrText: "unsupported",
|
||||
},
|
||||
{
|
||||
name: "rejects missing outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{}`,
|
||||
wantErrText: "unsupported",
|
||||
},
|
||||
{
|
||||
name: "rejects trailing json",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"sent"}{}`,
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "rejects unexpected status",
|
||||
statusCode: http.StatusBadGateway,
|
||||
body: `{"error":"temporary"}`,
|
||||
wantErrText: "unexpected HTTP status 502",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.body)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientRequestTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
|
||||
|
||||
_, err := client.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "context deadline exceeded")
|
||||
}
|
||||
|
||||
func TestRESTClientContextAndValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected upstream call")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func() error
|
||||
}{
|
||||
{
|
||||
name: "nil context",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(nil, validInput())
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled context",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(cancelledCtx, validInput())
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
||||
Email: common.Email(" bad@example.com "),
|
||||
Code: "123456",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
run: func() error {
|
||||
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Code: " 123456 ",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.run()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type capturedRequest struct {
|
||||
Method string
|
||||
Path string
|
||||
ContentType string
|
||||
Body string
|
||||
}
|
||||
|
||||
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
|
||||
t.Helper()
|
||||
|
||||
body, err := io.ReadAll(request.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return capturedRequest{
|
||||
Method: request.Method,
|
||||
Path: request.URL.Path,
|
||||
ContentType: request.Header.Get("Content-Type"),
|
||||
Body: strings.TrimSpace(string(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(statusCode)
|
||||
_, err = writer.Write(payload)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := NewRESTClient(Config{
|
||||
BaseURL: baseURL,
|
||||
RequestTimeout: timeout,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return fn(request)
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
// Package mail provides runtime mail-delivery adapters for the auth/session
|
||||
// service.
|
||||
package mail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
var errForcedFailure = errors.New("stub mail sender: forced failure")
|
||||
|
||||
// StubMode identifies the deterministic outcome used by StubSender for one
|
||||
// delivery attempt.
|
||||
type StubMode string
|
||||
|
||||
const (
|
||||
// StubModeSent reports that the adapter accepts delivery and returns the
|
||||
// stable sent outcome expected by the auth flow.
|
||||
StubModeSent StubMode = "sent"
|
||||
|
||||
// StubModeSuppressed reports that the adapter intentionally suppresses
|
||||
// outward delivery while still returning a successful suppressed outcome.
|
||||
StubModeSuppressed StubMode = "suppressed"
|
||||
|
||||
// StubModeFailed reports that the adapter returns an explicit delivery
|
||||
// failure instead of a successful outcome.
|
||||
StubModeFailed StubMode = "failed"
|
||||
)
|
||||
|
||||
// IsKnown reports whether mode is one of the supported stub delivery modes.
|
||||
func (mode StubMode) IsKnown() bool {
|
||||
switch mode {
|
||||
case StubModeSent, StubModeSuppressed, StubModeFailed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// StubStep overrides the default stub behavior for one queued delivery
|
||||
// attempt.
|
||||
type StubStep struct {
|
||||
// Mode selects the delivery behavior for this queued step.
|
||||
Mode StubMode
|
||||
|
||||
// Err optionally overrides the failure returned when Mode is StubModeFailed.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Validate reports whether step contains one supported queued behavior.
|
||||
func (step StubStep) Validate() error {
|
||||
if !step.Mode.IsKnown() {
|
||||
return fmt.Errorf("stub mail step mode %q is unsupported", step.Mode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Attempt records one validated delivery request handled by StubSender.
|
||||
type Attempt struct {
|
||||
// Input stores the validated cleartext mail-delivery request exactly as it
|
||||
// was passed into SendLoginCode.
|
||||
Input ports.SendLoginCodeInput
|
||||
|
||||
// Mode stores the resolved stub mode after queued overrides were applied.
|
||||
Mode StubMode
|
||||
}
|
||||
|
||||
// StubSender is a deterministic runtime MailSender implementation intended
|
||||
// for development, local integration, and explicit stub-based tests.
|
||||
//
|
||||
// The zero value is ready to use and defaults to StubModeSent.
|
||||
type StubSender struct {
|
||||
// DefaultMode controls the fallback behavior when Script is empty. The zero
|
||||
// value is treated as StubModeSent so the zero-value sender is usable
|
||||
// without extra configuration.
|
||||
DefaultMode StubMode
|
||||
|
||||
// DefaultError optionally overrides the failure returned when DefaultMode
|
||||
// resolves to StubModeFailed.
|
||||
DefaultError error
|
||||
|
||||
// Script stores queued one-shot overrides consumed in FIFO order before the
|
||||
// default behavior is used.
|
||||
Script []StubStep
|
||||
|
||||
mu sync.Mutex
|
||||
attempts []Attempt
|
||||
}
|
||||
|
||||
// SendLoginCode records one validated delivery request and returns the
|
||||
// deterministic stub outcome selected by the queued script or the default
|
||||
// mode.
|
||||
func (s *StubSender) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
|
||||
if ctx == nil {
|
||||
return ports.SendLoginCodeResult{}, errors.New("stub mail sender: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
mode, errOverride, err := s.resolveNextStepLocked()
|
||||
if err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
|
||||
s.attempts = append(s.attempts, Attempt{
|
||||
Input: input,
|
||||
Mode: mode,
|
||||
})
|
||||
|
||||
switch mode {
|
||||
case StubModeSent:
|
||||
return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSent}, nil
|
||||
case StubModeSuppressed:
|
||||
return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSuppressed}, nil
|
||||
case StubModeFailed:
|
||||
if errOverride != nil {
|
||||
return ports.SendLoginCodeResult{}, errOverride
|
||||
}
|
||||
return ports.SendLoginCodeResult{}, errForcedFailure
|
||||
default:
|
||||
return ports.SendLoginCodeResult{}, fmt.Errorf("stub mail sender: unsupported resolved mode %q", mode)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordedAttempts returns a stable defensive copy of every validated delivery
|
||||
// attempt handled by the stub.
|
||||
func (s *StubSender) RecordedAttempts() []Attempt {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return append([]Attempt(nil), s.attempts...)
|
||||
}
|
||||
|
||||
func (s *StubSender) resolveNextStepLocked() (StubMode, error, error) {
|
||||
if len(s.Script) > 0 {
|
||||
step := s.Script[0]
|
||||
s.Script = append([]StubStep(nil), s.Script[1:]...)
|
||||
if err := step.Validate(); err != nil {
|
||||
return "", nil, fmt.Errorf("stub mail sender: %w", err)
|
||||
}
|
||||
if step.Mode == StubModeFailed {
|
||||
if step.Err != nil {
|
||||
return step.Mode, step.Err, nil
|
||||
}
|
||||
return step.Mode, errForcedFailure, nil
|
||||
}
|
||||
return step.Mode, nil, nil
|
||||
}
|
||||
|
||||
mode := s.DefaultMode
|
||||
if mode == "" {
|
||||
mode = StubModeSent
|
||||
}
|
||||
if !mode.IsKnown() {
|
||||
return "", nil, fmt.Errorf("stub mail sender: default mode %q is unsupported", mode)
|
||||
}
|
||||
if mode == StubModeFailed {
|
||||
if s.DefaultError != nil {
|
||||
return mode, s.DefaultError, nil
|
||||
}
|
||||
return mode, errForcedFailure, nil
|
||||
}
|
||||
|
||||
return mode, nil, nil
|
||||
}
|
||||
|
||||
var _ ports.MailSender = (*StubSender)(nil)
|
||||
@@ -0,0 +1,198 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStubSenderSendLoginCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("zero value defaults to sent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
result, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, StubModeSent, attempts[0].Mode)
|
||||
assert.Equal(t, validInput(), attempts[0].Input)
|
||||
})
|
||||
|
||||
t.Run("default suppressed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{DefaultMode: StubModeSuppressed}
|
||||
|
||||
result, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, result.Outcome)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, StubModeSuppressed, attempts[0].Mode)
|
||||
})
|
||||
|
||||
t.Run("default failed uses configured error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wantErr := errors.New("delivery refused")
|
||||
sender := &StubSender{
|
||||
DefaultMode: StubModeFailed,
|
||||
DefaultError: wantErr,
|
||||
}
|
||||
|
||||
result, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, wantErr)
|
||||
assert.Equal(t, ports.SendLoginCodeResult{}, result)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, StubModeFailed, attempts[0].Mode)
|
||||
})
|
||||
|
||||
t.Run("default failed uses stable fallback error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{DefaultMode: StubModeFailed}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, "stub mail sender: forced failure")
|
||||
})
|
||||
|
||||
t.Run("script overrides default and is consumed fifo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wantErr := errors.New("step failed")
|
||||
sender := &StubSender{
|
||||
DefaultMode: StubModeSent,
|
||||
Script: []StubStep{
|
||||
{Mode: StubModeSuppressed},
|
||||
{Mode: StubModeFailed, Err: wantErr},
|
||||
},
|
||||
}
|
||||
|
||||
first, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, first.Outcome)
|
||||
|
||||
second, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, wantErr)
|
||||
assert.Equal(t, ports.SendLoginCodeResult{}, second)
|
||||
|
||||
third, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendLoginCodeOutcomeSent, third.Outcome)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 3)
|
||||
assert.Equal(t, []StubMode{StubModeSuppressed, StubModeFailed, StubModeSent}, []StubMode{
|
||||
attempts[0].Mode,
|
||||
attempts[1].Mode,
|
||||
attempts[2].Mode,
|
||||
})
|
||||
assert.Empty(t, sender.Script)
|
||||
})
|
||||
|
||||
t.Run("invalid default mode returns adapter error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{DefaultMode: StubMode("queued")}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `default mode "queued" is unsupported`)
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
})
|
||||
|
||||
t.Run("invalid scripted mode returns adapter error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{
|
||||
Script: []StubStep{
|
||||
{Mode: StubMode("queued")},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `mode "queued" is unsupported`)
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
assert.Empty(t, sender.Script)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubSenderRecordedAttemptsAreDefensive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), validInput())
|
||||
require.NoError(t, err)
|
||||
|
||||
attempts := sender.RecordedAttempts()
|
||||
require.Len(t, attempts, 1)
|
||||
attempts[0].Mode = StubModeFailed
|
||||
attempts[0].Input.Code = "000000"
|
||||
|
||||
again := sender.RecordedAttempts()
|
||||
require.Len(t, again, 1)
|
||||
assert.Equal(t, StubModeSent, again[0].Mode)
|
||||
assert.Equal(t, "654321", again[0].Input.Code)
|
||||
}
|
||||
|
||||
func TestStubSenderSendLoginCodeNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
_, err := sender.SendLoginCode(nil, validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
}
|
||||
|
||||
func TestStubSenderSendLoginCodeCancelledContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := sender.SendLoginCode(ctx, validInput())
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
}
|
||||
|
||||
func TestStubSenderSendLoginCodeInvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &StubSender{}
|
||||
|
||||
_, err := sender.SendLoginCode(context.Background(), ports.SendLoginCodeInput{})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "send login code input email")
|
||||
assert.Empty(t, sender.RecordedAttempts())
|
||||
}
|
||||
|
||||
func validInput() ports.SendLoginCodeInput {
|
||||
return ports.SendLoginCodeInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Code: "654321",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,484 @@
|
||||
// Package challengestore implements ports.ChallengeStore with Redis-backed
|
||||
// strict JSON challenge records.
|
||||
package challengestore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const expirationGracePeriod = 5 * time.Minute
|
||||
|
||||
// Config configures one Redis-backed challenge store instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// KeyPrefix is the namespace prefix applied to every challenge key.
|
||||
KeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store persists challenges as one strict JSON value per Redis key.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
Email string `json:"email"`
|
||||
CodeHashBase64 string `json:"code_hash_base64"`
|
||||
Status challenge.Status `json:"status"`
|
||||
DeliveryState challenge.DeliveryState `json:"delivery_state"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
SendAttemptCount int `json:"send_attempt_count"`
|
||||
ConfirmAttemptCount int `json:"confirm_attempt_count"`
|
||||
LastAttemptAt *string `json:"last_attempt_at,omitempty"`
|
||||
ConfirmedSessionID string `json:"confirmed_session_id,omitempty"`
|
||||
ConfirmedClientPublicKey string `json:"confirmed_client_public_key,omitempty"`
|
||||
ConfirmedAt *string `json:"confirmed_at,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed challenge store from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
if strings.TrimSpace(cfg.Addr) == "" {
|
||||
return nil, errors.New("new redis challenge store: redis addr must not be empty")
|
||||
}
|
||||
if cfg.DB < 0 {
|
||||
return nil, errors.New("new redis challenge store: redis db must not be negative")
|
||||
}
|
||||
if strings.TrimSpace(cfg.KeyPrefix) == "" {
|
||||
return nil, errors.New("new redis challenge store: redis key prefix must not be empty")
|
||||
}
|
||||
if cfg.OperationTimeout <= 0 {
|
||||
return nil, errors.New("new redis challenge store: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis challenge store")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis challenge store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the stored challenge for challengeID.
|
||||
func (s *Store) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
|
||||
if err := challengeID.Validate(); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "get challenge from redis")
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
payload, err := s.client.Get(operationCtx, s.lookupKey(challengeID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
|
||||
}
|
||||
|
||||
record, err := decodeChallengeRecord(challengeID, payload)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Create persists record as a new challenge.
|
||||
func (s *Store) Create(ctx context.Context, record challenge.Challenge) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalChallengeRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "create challenge in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
created, err := s.client.SetNX(operationCtx, s.lookupKey(record.ID), payload, redisTTL(record.ExpiresAt)).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create challenge %q in redis: %w", record.ID, err)
|
||||
}
|
||||
if !created {
|
||||
return fmt.Errorf("create challenge %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompareAndSwap replaces previous with next when the currently stored
|
||||
// challenge matches previous exactly in canonical Redis representation.
|
||||
func (s *Store) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
|
||||
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
|
||||
return fmt.Errorf("compare and swap challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
nextPayload, err := marshalChallengeRecord(next)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge in redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "compare and swap challenge in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
key := s.lookupKey(previous.ID)
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
payload, err := tx.Get(operationCtx, key).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrNotFound)
|
||||
case err != nil:
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
current, err := decodeChallengeRecord(previous.ID, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
matches, err := equalStoredChallenges(current, previous)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
if !matches {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, key, nextPayload, redisTTL(next.ExpiresAt))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, key)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (s *Store) lookupKey(challengeID common.ChallengeID) string {
|
||||
return s.keyPrefix + encodeKeyComponent(challengeID.String())
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func marshalChallengeRecord(record challenge.Challenge) ([]byte, error) {
|
||||
stored, err := redisRecordFromChallenge(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(stored)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func redisRecordFromChallenge(record challenge.Challenge) (redisRecord, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return redisRecord{}, fmt.Errorf("encode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
stored := redisRecord{
|
||||
ChallengeID: record.ID.String(),
|
||||
Email: record.Email.String(),
|
||||
CodeHashBase64: base64.StdEncoding.EncodeToString(record.CodeHash),
|
||||
Status: record.Status,
|
||||
DeliveryState: record.DeliveryState,
|
||||
CreatedAt: formatTimestamp(record.CreatedAt),
|
||||
ExpiresAt: formatTimestamp(record.ExpiresAt),
|
||||
SendAttemptCount: record.Attempts.Send,
|
||||
ConfirmAttemptCount: record.Attempts.Confirm,
|
||||
LastAttemptAt: formatOptionalTimestamp(record.Abuse.LastAttemptAt),
|
||||
}
|
||||
if record.Confirmation != nil {
|
||||
stored.ConfirmedSessionID = record.Confirmation.SessionID.String()
|
||||
stored.ConfirmedClientPublicKey = record.Confirmation.ClientPublicKey.String()
|
||||
stored.ConfirmedAt = formatOptionalTimestamp(&record.Confirmation.ConfirmedAt)
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func decodeChallengeRecord(expectedChallengeID common.ChallengeID, payload []byte) (challenge.Challenge, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return challenge.Challenge{}, errors.New("decode redis challenge record: unexpected trailing JSON input")
|
||||
}
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
record, err := challengeFromRedisRecord(stored)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
if record.ID != expectedChallengeID {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: challenge_id %q does not match requested %q", record.ID, expectedChallengeID)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func challengeFromRedisRecord(stored redisRecord) (challenge.Challenge, error) {
|
||||
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
expiresAt, err := parseTimestamp("expires_at", stored.ExpiresAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
lastAttemptAt, err := parseOptionalTimestamp("last_attempt_at", stored.LastAttemptAt)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
|
||||
codeHash, err := base64.StdEncoding.Strict().DecodeString(stored.CodeHashBase64)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: code_hash_base64: %w", err)
|
||||
}
|
||||
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID(stored.ChallengeID),
|
||||
Email: common.Email(stored.Email),
|
||||
CodeHash: codeHash,
|
||||
Status: stored.Status,
|
||||
DeliveryState: stored.DeliveryState,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: stored.SendAttemptCount,
|
||||
Confirm: stored.ConfirmAttemptCount,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: lastAttemptAt,
|
||||
},
|
||||
}
|
||||
|
||||
confirmation, err := parseConfirmation(stored)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
record.Confirmation = confirmation
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func parseConfirmation(stored redisRecord) (*challenge.Confirmation, error) {
|
||||
hasSessionID := strings.TrimSpace(stored.ConfirmedSessionID) != ""
|
||||
hasClientPublicKey := strings.TrimSpace(stored.ConfirmedClientPublicKey) != ""
|
||||
hasConfirmedAt := stored.ConfirmedAt != nil
|
||||
|
||||
if !hasSessionID && !hasClientPublicKey && !hasConfirmedAt {
|
||||
return nil, nil
|
||||
}
|
||||
if !hasSessionID || !hasClientPublicKey || !hasConfirmedAt {
|
||||
return nil, errors.New("decode redis challenge record: confirmation metadata must be either fully present or fully absent")
|
||||
}
|
||||
|
||||
confirmedAt, err := parseTimestamp("confirmed_at", *stored.ConfirmedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ConfirmedClientPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
|
||||
}
|
||||
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
|
||||
}
|
||||
|
||||
return &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID(stored.ConfirmedSessionID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: confirmedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseOptionalTimestamp(fieldName string, value *string) (*time.Time, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
parsed, err := parseTimestamp(fieldName, *value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(fieldName string, value string) (time.Time, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must not be empty", fieldName)
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s: %w", fieldName, err)
|
||||
}
|
||||
|
||||
canonical := parsed.UTC().Format(time.RFC3339Nano)
|
||||
if value != canonical {
|
||||
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
|
||||
}
|
||||
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
|
||||
func formatTimestamp(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func formatOptionalTimestamp(value *time.Time) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
formatted := formatTimestamp(*value)
|
||||
return &formatted
|
||||
}
|
||||
|
||||
func redisTTL(expiresAt time.Time) time.Duration {
|
||||
ttl := time.Until(expiresAt.UTC())
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
return ttl + expirationGracePeriod
|
||||
}
|
||||
|
||||
func equalStoredChallenges(left challenge.Challenge, right challenge.Challenge) (bool, error) {
|
||||
leftRecord, err := redisRecordFromChallenge(left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
rightRecord, err := redisRecordFromChallenge(right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(leftRecord, rightRecord), nil
|
||||
}
|
||||
|
||||
var _ ports.ChallengeStore = (*Store)(nil)
|
||||
@@ -0,0 +1,531 @@
|
||||
package challengestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
return newTestStore(t, server, Config{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive operation timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
KeyPrefix: "authsession:challenge:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_000, 0).UTC()
|
||||
|
||||
record := testChallenge(now)
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
got.CodeHash[0] = 0xFF
|
||||
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
|
||||
keyBytes[0] = 0xFE
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.CodeHash, again.CodeHash)
|
||||
require.NotNil(t, again.Confirmation)
|
||||
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetPendingChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_100, 0).UTC()
|
||||
|
||||
record := testPendingChallenge(now)
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetThrottledChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
now := time.Unix(1_775_130_150, 0).UTC()
|
||||
|
||||
record := testPendingChallenge(now)
|
||||
record.Status = challenge.StatusDeliveryThrottled
|
||||
record.DeliveryState = challenge.DeliveryThrottled
|
||||
record.Attempts.Send = 1
|
||||
record.Abuse.LastAttemptAt = timePointer(now)
|
||||
require.NoError(t, record.Validate())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
}
|
||||
|
||||
func TestStoreGetStrictDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_130_200, 0).UTC()
|
||||
baseRecord := testChallenge(now)
|
||||
baseStored, err := redisRecordFromChallenge(baseRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(redisRecord) string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "malformed json",
|
||||
mutate: func(_ redisRecord) string {
|
||||
return "{"
|
||||
},
|
||||
wantErrText: "decode redis challenge record",
|
||||
},
|
||||
{
|
||||
name: "trailing json input",
|
||||
mutate: func(record redisRecord) string {
|
||||
return mustMarshalJSON(t, record) + "{}"
|
||||
},
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
mutate: func(record redisRecord) string {
|
||||
payload := map[string]any{
|
||||
"challenge_id": record.ChallengeID,
|
||||
"email": record.Email,
|
||||
"code_hash_base64": record.CodeHashBase64,
|
||||
"status": record.Status,
|
||||
"delivery_state": record.DeliveryState,
|
||||
"created_at": record.CreatedAt,
|
||||
"expires_at": record.ExpiresAt,
|
||||
"send_attempt_count": record.SendAttemptCount,
|
||||
"confirm_attempt_count": record.ConfirmAttemptCount,
|
||||
"last_attempt_at": record.LastAttemptAt,
|
||||
"confirmed_session_id": record.ConfirmedSessionID,
|
||||
"confirmed_client_public_key": record.ConfirmedClientPublicKey,
|
||||
"confirmed_at": record.ConfirmedAt,
|
||||
"unexpected": true,
|
||||
}
|
||||
return mustMarshalJSON(t, payload)
|
||||
},
|
||||
wantErrText: "unknown field",
|
||||
},
|
||||
{
|
||||
name: "unsupported status",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Status = challenge.Status("paused")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "unsupported delivery state",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.DeliveryState = challenge.DeliveryState("queued")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `delivery state "queued" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "missing required email",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Email = ""
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "challenge email",
|
||||
},
|
||||
{
|
||||
name: "challenge id mismatch",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.ChallengeID = "other-challenge"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `does not match requested`,
|
||||
},
|
||||
{
|
||||
name: "non canonical utc timestamp",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.CreatedAt = "2026-04-04T12:00:00+03:00"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "canonical UTC RFC3339Nano timestamp",
|
||||
},
|
||||
{
|
||||
name: "partial confirmation metadata",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.ConfirmedAt = nil
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "confirmation metadata must be either fully present or fully absent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored))
|
||||
|
||||
_, err := store.Get(context.Background(), baseRecord.ID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreKeySchemeAndTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"})
|
||||
now := time.Now().UTC()
|
||||
|
||||
prefixed := testPendingChallenge(now)
|
||||
prefixed.ID = common.ChallengeID("challenge:opaque/id?value")
|
||||
require.NoError(t, store.Create(context.Background(), prefixed))
|
||||
|
||||
key := store.lookupKey(prefixed.ID)
|
||||
assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key)
|
||||
assert.True(t, server.Exists(key))
|
||||
|
||||
freshTTL := server.TTL(key)
|
||||
assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod)
|
||||
assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second)
|
||||
|
||||
expired := testPendingChallenge(now.Add(-10 * time.Minute))
|
||||
expired.ID = common.ChallengeID("expired-challenge")
|
||||
expired.CreatedAt = now.Add(-20 * time.Minute)
|
||||
expired.ExpiresAt = now.Add(-1 * time.Minute)
|
||||
require.NoError(t, store.Create(context.Background(), expired))
|
||||
|
||||
expiredTTL := server.TTL(store.lookupKey(expired.ID))
|
||||
assert.LessOrEqual(t, expiredTTL, expirationGracePeriod)
|
||||
assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second)
|
||||
}
|
||||
|
||||
func TestStoreCreateConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
}
|
||||
|
||||
func TestStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreCompareAndSwap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_130_400, 0).UTC()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
next.Attempts.Send = 1
|
||||
next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute))
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), previous))
|
||||
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
|
||||
|
||||
got, err := store.Get(context.Background(), previous.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, next, got)
|
||||
})
|
||||
|
||||
t.Run("conflict when stored record differs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
stored := testPendingChallenge(now)
|
||||
previous := stored
|
||||
previous.Attempts.Send = 99
|
||||
next := stored
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), stored))
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("corrupt stored record returns adapter error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
previous := testPendingChallenge(now)
|
||||
next := previous
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
server.Set(store.lookupKey(previous.ID), "{")
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
require.Error(t, err)
|
||||
assert.NotErrorIs(t, err, ports.ErrConflict)
|
||||
assert.ErrorContains(t, err, "decode redis challenge record")
|
||||
})
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "authsession:challenge:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func testPendingChallenge(now time.Time) challenge.Challenge {
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-pending"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-pending-code"),
|
||||
Status: challenge.StatusPendingSend,
|
||||
DeliveryState: challenge.DeliveryPending,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.InitialTTL),
|
||||
}
|
||||
}
|
||||
|
||||
func testChallenge(now time.Time) challenge.Challenge {
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-confirmed"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hashed-code"),
|
||||
Status: challenge.StatusConfirmedPendingExpire,
|
||||
DeliveryState: challenge.DeliverySent,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.ConfirmedRetention),
|
||||
Attempts: challenge.AttemptCounters{
|
||||
Send: 1,
|
||||
Confirm: 2,
|
||||
},
|
||||
Abuse: challenge.AbuseMetadata{
|
||||
LastAttemptAt: timePointer(now.Add(30 * time.Second)),
|
||||
},
|
||||
Confirmation: &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID("device-session-1"),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: now.Add(1 * time.Minute),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func timePointer(value time.Time) *time.Time {
|
||||
return &value
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func TestStorePingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
err := store.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestStoreGetNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(nil, common.ChallengeID("challenge"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
// Package configprovider implements ports.ConfigProvider with Redis-backed
|
||||
// dynamic auth/session configuration.
|
||||
package configprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Config configures one Redis-backed config provider instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// SessionLimitKey identifies the single Redis string key that stores the
|
||||
// active-session-limit configuration value.
|
||||
SessionLimitKey string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store reads dynamic auth/session configuration from Redis.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
sessionLimitKey string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed config provider from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis config provider: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis config provider: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.SessionLimitKey) == "":
|
||||
return nil, errors.New("new redis config provider: session limit key must not be empty")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis config provider: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
sessionLimitKey: cfg.SessionLimitKey,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis config provider")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis config provider: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSessionLimit returns the current active-session-limit configuration.
|
||||
// Missing or invalid Redis values are treated as “limit absent” by policy.
|
||||
func (s *Store) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "load session limit from redis")
|
||||
if err != nil {
|
||||
return ports.SessionLimitConfig{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
value, err := s.client.Get(operationCtx, s.sessionLimitKey).Result()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return ports.SessionLimitConfig{}, nil
|
||||
case err != nil:
|
||||
return ports.SessionLimitConfig{}, fmt.Errorf("load session limit from redis: %w", err)
|
||||
}
|
||||
|
||||
config, valid := parseSessionLimitConfig(value)
|
||||
if !valid {
|
||||
return ports.SessionLimitConfig{}, nil
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return ports.SessionLimitConfig{}, nil
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func parseSessionLimitConfig(raw string) (ports.SessionLimitConfig, bool) {
|
||||
if strings.TrimSpace(raw) == "" || strings.TrimSpace(raw) != raw {
|
||||
return ports.SessionLimitConfig{}, false
|
||||
}
|
||||
for _, symbol := range raw {
|
||||
if symbol < '0' || symbol > '9' {
|
||||
return ports.SessionLimitConfig{}, false
|
||||
}
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseInt(raw, 10, strconv.IntSize)
|
||||
if err != nil || parsed <= 0 {
|
||||
return ports.SessionLimitConfig{}, false
|
||||
}
|
||||
|
||||
limit := int(parsed)
|
||||
return ports.SessionLimitConfig{
|
||||
ActiveSessionLimit: &limit,
|
||||
}, true
|
||||
}
|
||||
|
||||
var _ ports.ConfigProvider = (*Store)(nil)
|
||||
@@ -0,0 +1,283 @@
|
||||
package configprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunConfigProviderContractTests(t, func(t *testing.T) contracttest.ConfigProviderHarness {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
return contracttest.ConfigProviderHarness{
|
||||
Provider: store,
|
||||
SeedDisabled: func(t *testing.T) {
|
||||
t.Helper()
|
||||
server.Del(store.sessionLimitKey)
|
||||
},
|
||||
SeedLimit: func(t *testing.T, limit int) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, strconv.Itoa(limit))
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session limit key",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session limit key must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionLimitKey: "authsession:config:active-session-limit",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreLoadSessionLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
seed func(*testing.T, *miniredis.Miniredis, *Store)
|
||||
wantConfig ports.SessionLimitConfig
|
||||
}{
|
||||
{
|
||||
name: "missing key means disabled",
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "valid positive integer",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "5")
|
||||
},
|
||||
wantConfig: configWithLimit(5),
|
||||
},
|
||||
{
|
||||
name: "empty string is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "whitespace only is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, " ")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "whitespace padded integer is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, " 5 ")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "non integer text is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "five")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "zero is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "0")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "negative integer is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "-3")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
{
|
||||
name: "overflow is invalid and disabled",
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
|
||||
t.Helper()
|
||||
server.Set(store.sessionLimitKey, "999999999999999999999999999999")
|
||||
},
|
||||
wantConfig: ports.SessionLimitConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
if tt.seed != nil {
|
||||
tt.seed(t, server, store)
|
||||
}
|
||||
|
||||
got, err := store.LoadSessionLimit(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantConfig, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreLoadSessionLimitBackendFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Close()
|
||||
|
||||
_, err := store.LoadSessionLimit(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "load session limit from redis")
|
||||
}
|
||||
|
||||
func TestStoreLoadSessionLimitNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.LoadSessionLimit(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestStorePingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
err := store.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionLimitKey == "" {
|
||||
cfg.SessionLimitKey = "authsession:config:active-session-limit"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func configWithLimit(limit int) ports.SessionLimitConfig {
|
||||
return ports.SessionLimitConfig{
|
||||
ActiveSessionLimit: &limit,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
// Package projectionpublisher implements
|
||||
// ports.GatewaySessionProjectionPublisher with Redis-backed gateway-compatible
|
||||
// cache snapshots and session lifecycle events.
|
||||
package projectionpublisher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Config configures one Redis-backed gateway session projection publisher.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// SessionCacheKeyPrefix is the namespace prefix applied to gateway session
|
||||
// cache keys. The raw device session identifier is appended directly.
|
||||
SessionCacheKeyPrefix string
|
||||
|
||||
// SessionEventsStream identifies the gateway session lifecycle Redis Stream.
|
||||
SessionEventsStream string
|
||||
|
||||
// StreamMaxLen bounds the session lifecycle stream with approximate
|
||||
// trimming via XADD MAXLEN ~.
|
||||
StreamMaxLen int64
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Publisher publishes gateway-compatible session projections into Redis cache
|
||||
// and stream namespaces.
|
||||
type Publisher struct {
|
||||
client *redis.Client
|
||||
sessionCacheKeyPrefix string
|
||||
sessionEventsStream string
|
||||
streamMaxLen int64
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type cacheRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status gatewayprojection.Status `json:"status"`
|
||||
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed gateway session projection publisher from
|
||||
// cfg.
|
||||
func New(cfg Config) (*Publisher, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis projection publisher: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis projection publisher: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.SessionCacheKeyPrefix) == "":
|
||||
return nil, errors.New("new redis projection publisher: session cache key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.SessionEventsStream) == "":
|
||||
return nil, errors.New("new redis projection publisher: session events stream must not be empty")
|
||||
case cfg.StreamMaxLen <= 0:
|
||||
return nil, errors.New("new redis projection publisher: stream max len must be positive")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis projection publisher: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Publisher{
|
||||
client: redis.NewClient(options),
|
||||
sessionCacheKeyPrefix: cfg.SessionCacheKeyPrefix,
|
||||
sessionEventsStream: cfg.SessionEventsStream,
|
||||
streamMaxLen: cfg.StreamMaxLen,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (p *Publisher) Close() error {
|
||||
if p == nil || p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (p *Publisher) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "ping redis projection publisher")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := p.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis projection publisher: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishSession writes one gateway-compatible session snapshot into the
|
||||
// gateway cache namespace and appends the same snapshot to the gateway session
|
||||
// event stream within one Redis transaction.
|
||||
func (p *Publisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error {
|
||||
if err := snapshot.Validate(); err != nil {
|
||||
return fmt.Errorf("publish session projection to redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalCacheRecord(snapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("publish session projection to redis: %w", err)
|
||||
}
|
||||
values := buildStreamValues(snapshot)
|
||||
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "publish session projection to redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
key := p.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
_, err = p.client.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, key, payload, 0)
|
||||
pipe.XAdd(operationCtx, &redis.XAddArgs{
|
||||
Stream: p.sessionEventsStream,
|
||||
MaxLen: p.streamMaxLen,
|
||||
Approx: true,
|
||||
Values: values,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("publish session projection %q to redis: %w", snapshot.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Publisher) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if p == nil || p.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil publisher", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (p *Publisher) sessionCacheKey(deviceSessionID interface{ String() string }) string {
|
||||
return p.sessionCacheKeyPrefix + deviceSessionID.String()
|
||||
}
|
||||
|
||||
func marshalCacheRecord(snapshot gatewayprojection.Snapshot) ([]byte, error) {
|
||||
record := cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: snapshot.Status,
|
||||
}
|
||||
if snapshot.RevokedAt != nil {
|
||||
revokedAtMS := snapshot.RevokedAt.UTC().UnixMilli()
|
||||
record.RevokedAtMS = &revokedAtMS
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal gateway session cache record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func buildStreamValues(snapshot gatewayprojection.Snapshot) map[string]any {
|
||||
values := map[string]any{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(snapshot.Status),
|
||||
}
|
||||
if snapshot.RevokedAt != nil {
|
||||
values["revoked_at_ms"] = fmt.Sprint(snapshot.RevokedAt.UTC().UnixMilli())
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
var _ ports.GatewaySessionProjectionPublisher = (*Publisher)(nil)
|
||||
@@ -0,0 +1,442 @@
|
||||
package projectionpublisher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 3,
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session cache key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session cache key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty session events stream",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
StreamMaxLen: 1024,
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session events stream must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive stream max len",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "stream max len must be positive",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionCacheKeyPrefix: "gateway:session:",
|
||||
SessionEventsStream: "gateway:session_events",
|
||||
StreamMaxLen: 1024,
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, publisher.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublisherPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
require.NoError(t, publisher.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key)
|
||||
assert.True(t, server.Exists(key))
|
||||
assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String())))
|
||||
|
||||
payload, err := server.Get(key)
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}, record)
|
||||
assert.Zero(t, server.TTL(key))
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusActive),
|
||||
}, stringifyValues(entries[0].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionRevoked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC()
|
||||
snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
|
||||
payload, err := server.Get(key)
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusRevoked,
|
||||
RevokedAtMS: int64Pointer(revokedAt.UnixMilli()),
|
||||
}, record)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": snapshot.DeviceSessionID.String(),
|
||||
"user_id": snapshot.UserID.String(),
|
||||
"client_public_key": snapshot.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusRevoked),
|
||||
"revoked_at_ms": "1776000123456",
|
||||
}, stringifyValues(entries[0].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
||||
deviceSessionID := "device-session-456"
|
||||
|
||||
active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil)
|
||||
revokedAt := time.Unix(1_776_010_000, 0).UTC()
|
||||
revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), active))
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), revoked))
|
||||
|
||||
payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID))
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, record.Status)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": active.DeviceSessionID.String(),
|
||||
"user_id": active.UserID.String(),
|
||||
"client_public_key": active.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusActive),
|
||||
}, stringifyValues(entries[0].Values))
|
||||
assert.Equal(t, map[string]string{
|
||||
"device_session_id": revoked.DeviceSessionID.String(),
|
||||
"user_id": revoked.UserID.String(),
|
||||
"client_public_key": revoked.ClientPublicKey,
|
||||
"status": string(gatewayprojection.StatusRevoked),
|
||||
"revoked_at_ms": "1776010000000",
|
||||
}, stringifyValues(entries[1].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
|
||||
snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil)
|
||||
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
|
||||
payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID))
|
||||
require.NoError(t, err)
|
||||
record := decodeCachePayload(t, payload)
|
||||
assert.Equal(t, cacheRecord{
|
||||
DeviceSessionID: snapshot.DeviceSessionID.String(),
|
||||
UserID: snapshot.UserID.String(),
|
||||
ClientPublicKey: snapshot.ClientPublicKey,
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}, record)
|
||||
|
||||
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2})
|
||||
|
||||
for index := range 6 {
|
||||
snapshot := testSnapshot(
|
||||
common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(),
|
||||
gatewayprojection.StatusActive,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
|
||||
}
|
||||
|
||||
streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result()
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, streamLength, int64(2))
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID("device-session-123"),
|
||||
UserID: common.UserID("user-123"),
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}
|
||||
|
||||
err := publisher.PublishSession(context.Background(), snapshot)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "gateway projection client public key")
|
||||
assert.Empty(t, server.Keys())
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func TestPublisherPublishSessionBackendFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
server.Close()
|
||||
|
||||
err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "publish session projection")
|
||||
}
|
||||
|
||||
func TestPublisherPingNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := newTestPublisher(t, server, Config{})
|
||||
|
||||
err := publisher.Ping(nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionCacheKeyPrefix == "" {
|
||||
cfg.SessionCacheKeyPrefix = "gateway:session:"
|
||||
}
|
||||
if cfg.SessionEventsStream == "" {
|
||||
cfg.SessionEventsStream = "gateway:session_events"
|
||||
}
|
||||
if cfg.StreamMaxLen == 0 {
|
||||
cfg.StreamMaxLen = 1024
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
publisher, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, publisher.Close())
|
||||
})
|
||||
|
||||
return publisher
|
||||
}
|
||||
|
||||
func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot {
|
||||
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
for index := range raw {
|
||||
raw[index] = byte(index + 1)
|
||||
}
|
||||
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID("user-123"),
|
||||
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
|
||||
Status: status,
|
||||
RevokedAt: revokedAt,
|
||||
}
|
||||
if status == gatewayprojection.StatusRevoked {
|
||||
snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked")
|
||||
snapshot.RevokeActorType = common.RevokeActorType("system")
|
||||
}
|
||||
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func decodeCachePayload(t *testing.T, payload string) cacheRecord {
|
||||
t.Helper()
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader([]byte(payload)))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var record cacheRecord
|
||||
require.NoError(t, decoder.Decode(&record))
|
||||
err := decoder.Decode(&struct{}{})
|
||||
if err == nil {
|
||||
require.FailNow(t, "expected cache payload EOF after first JSON value")
|
||||
}
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
var fieldSet map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet))
|
||||
expectedFields := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"user_id": {},
|
||||
"client_public_key": {},
|
||||
"status": {},
|
||||
}
|
||||
if record.RevokedAtMS != nil {
|
||||
expectedFields["revoked_at_ms"] = struct{}{}
|
||||
}
|
||||
assert.Equal(t, len(expectedFields), len(fieldSet))
|
||||
for field := range fieldSet {
|
||||
_, ok := expectedFields[field]
|
||||
assert.Truef(t, ok, "unexpected cache payload field %q", field)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func stringifyValues(values map[string]any) map[string]string {
|
||||
stringified := make(map[string]string, len(values))
|
||||
for key, value := range values {
|
||||
stringified[key] = fmt.Sprint(value)
|
||||
}
|
||||
return stringified
|
||||
}
|
||||
|
||||
func encodeBase64URL(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func int64Pointer(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
// Package sendemailcodeabuse implements ports.SendEmailCodeAbuseProtector with
|
||||
// one Redis TTL key per normalized e-mail address.
|
||||
package sendemailcodeabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Config configures one Redis-backed send-email-code abuse protector.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// KeyPrefix is the namespace prefix applied to every resend-throttle key.
|
||||
KeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Protector applies the fixed resend cooldown with one Redis key per
|
||||
// normalized e-mail address.
|
||||
type Protector struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed resend-throttle protector from cfg.
|
||||
func New(cfg Config) (*Protector, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis send email code abuse protector: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis send email code abuse protector: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.KeyPrefix) == "":
|
||||
return nil, errors.New("new redis send email code abuse protector: redis key prefix must not be empty")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis send email code abuse protector: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Protector{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (p *Protector) Close() error {
|
||||
if p == nil || p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (p *Protector) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "ping redis send email code abuse protector")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := p.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis send email code abuse protector: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckAndReserve applies the fixed resend cooldown using one TTL key per
|
||||
// normalized e-mail address.
|
||||
func (p *Protector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := p.operationContext(ctx, "check and reserve send email code abuse")
|
||||
if err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
key := p.lookupKey(input.Email)
|
||||
value := input.Now.UTC().Add(challenge.ResendThrottleCooldown).Format(time.RFC3339Nano)
|
||||
created, err := p.client.SetNX(operationCtx, key, value, challenge.ResendThrottleCooldown).Result()
|
||||
if err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse for %q: %w", input.Email, err)
|
||||
}
|
||||
if created {
|
||||
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeAllowed}, nil
|
||||
}
|
||||
|
||||
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeThrottled}, nil
|
||||
}
|
||||
|
||||
func (p *Protector) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if p == nil || p.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil protector", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (p *Protector) lookupKey(email common.Email) string {
|
||||
return p.keyPrefix + base64.RawURLEncoding.EncodeToString([]byte(email.String()))
|
||||
}
|
||||
|
||||
var _ ports.SendEmailCodeAbuseProtector = (*Protector)(nil)
|
||||
@@ -0,0 +1,176 @@
|
||||
package sendemailcodeabuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 1,
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty key prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
KeyPrefix: "authsession:send-email-code-throttle:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protector, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, protector.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectorPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
protector := newTestProtector(t, server, Config{})
|
||||
|
||||
require.NoError(t, protector.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestProtectorCheckAndReserve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
protector := newTestProtector(t, server, Config{})
|
||||
email := common.Email("pilot@example.com")
|
||||
now := time.Unix(10, 0).UTC()
|
||||
|
||||
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
|
||||
key := protector.lookupKey(email)
|
||||
assert.True(t, server.Exists(key))
|
||||
ttl := server.TTL(key)
|
||||
assert.LessOrEqual(t, ttl, challenge.ResendThrottleCooldown)
|
||||
assert.GreaterOrEqual(t, ttl, challenge.ResendThrottleCooldown-2*time.Second)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(30 * time.Second),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
|
||||
ttlAfterThrottle := server.TTL(key)
|
||||
assert.LessOrEqual(t, ttlAfterThrottle, ttl)
|
||||
|
||||
server.FastForward(challenge.ResendThrottleCooldown)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(challenge.ResendThrottleCooldown),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
}
|
||||
|
||||
func TestProtectorNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
protector := newTestProtector(t, server, Config{})
|
||||
|
||||
_, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Now: time.Unix(10, 0).UTC(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestProtector(t *testing.T, server *miniredis.Miniredis, cfg Config) *Protector {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "authsession:send-email-code-throttle:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
protector, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, protector.Close())
|
||||
})
|
||||
|
||||
return protector
|
||||
}
|
||||
@@ -0,0 +1,723 @@
|
||||
// Package sessionstore implements ports.SessionStore with Redis-backed strict
|
||||
// JSON source-of-truth session records and per-user indexes.
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const mutationRetryLimit = 3
|
||||
|
||||
// Config configures one Redis-backed session store instance.
|
||||
type Config struct {
|
||||
// Addr is the Redis network address in host:port form.
|
||||
Addr string
|
||||
|
||||
// Username is the optional Redis ACL username.
|
||||
Username string
|
||||
|
||||
// Password is the optional Redis ACL password.
|
||||
Password string
|
||||
|
||||
// DB is the Redis logical database index.
|
||||
DB int
|
||||
|
||||
// TLSEnabled enables TLS with a conservative minimum protocol version.
|
||||
TLSEnabled bool
|
||||
|
||||
// SessionKeyPrefix is the namespace prefix applied to primary session keys.
|
||||
SessionKeyPrefix string
|
||||
|
||||
// UserSessionsKeyPrefix is the namespace prefix applied to all-session user
|
||||
// indexes.
|
||||
UserSessionsKeyPrefix string
|
||||
|
||||
// UserActiveSessionsKeyPrefix is the namespace prefix applied to active
|
||||
// session user indexes.
|
||||
UserActiveSessionsKeyPrefix string
|
||||
|
||||
// OperationTimeout bounds each Redis round trip performed by the adapter.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// Store persists source-of-truth sessions in Redis and maintains user-scoped
|
||||
// indexes for list and count operations.
|
||||
type Store struct {
|
||||
client *redis.Client
|
||||
sessionKeyPrefix string
|
||||
userSessionsKeyPrefix string
|
||||
userActiveSessionsKeyPrefix string
|
||||
operationTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKeyBase64 string `json:"client_public_key_base64"`
|
||||
Status devicesession.Status `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
RevokedAt *string `json:"revoked_at,omitempty"`
|
||||
RevokeReasonCode string `json:"revoke_reason_code,omitempty"`
|
||||
RevokeActorType string `json:"revoke_actor_type,omitempty"`
|
||||
RevokeActorID string `json:"revoke_actor_id,omitempty"`
|
||||
}
|
||||
|
||||
// New constructs a Redis-backed session store from cfg.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.Addr) == "":
|
||||
return nil, errors.New("new redis session store: redis addr must not be empty")
|
||||
case cfg.DB < 0:
|
||||
return nil, errors.New("new redis session store: redis db must not be negative")
|
||||
case strings.TrimSpace(cfg.SessionKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: session key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: user sessions key prefix must not be empty")
|
||||
case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "":
|
||||
return nil, errors.New("new redis session store: user active sessions key prefix must not be empty")
|
||||
case cfg.OperationTimeout <= 0:
|
||||
return nil, errors.New("new redis session store: operation timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &Store{
|
||||
client: redis.NewClient(options),
|
||||
sessionKeyPrefix: cfg.SessionKeyPrefix,
|
||||
userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix,
|
||||
userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix,
|
||||
operationTimeout: cfg.OperationTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// adapter operation timeout budget.
|
||||
func (s *Store) Ping(ctx context.Context) error {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(operationCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis session store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the stored session for deviceSessionID.
|
||||
func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
||||
if err := deviceSessionID.Validate(); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "get session from redis")
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
record, err := s.loadSession(operationCtx, deviceSessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound)
|
||||
default:
|
||||
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// ListByUserID returns every stored session for userID in newest-first order.
|
||||
func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("list sessions by user id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
if len(deviceSessionIDs) == 0 {
|
||||
return []devicesession.Session{}, nil
|
||||
}
|
||||
|
||||
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
|
||||
for _, rawDeviceSessionID := range deviceSessionIDs {
|
||||
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
|
||||
record, err := s.loadSession(operationCtx, deviceSessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID)
|
||||
default:
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
if record.UserID != userID {
|
||||
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
sortSessionsNewestFirst(records)
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// CountActiveByUserID returns the number of active sessions currently stored
|
||||
// for userID.
|
||||
func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return 0, fmt.Errorf("count active sessions by user id from redis: %w", err)
|
||||
}
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err)
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// Create persists record as a new device session.
|
||||
func (s *Store) Create(ctx context.Context, record devicesession.Session) error {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create session in redis: %w", err)
|
||||
}
|
||||
|
||||
payload, err := marshalSessionRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session in redis: %w", err)
|
||||
}
|
||||
|
||||
deviceSessionKey := s.sessionKey(record.ID)
|
||||
allSessionsKey := s.userSessionsKey(record.UserID)
|
||||
activeSessionsKey := s.userActiveSessionsKey(record.UserID)
|
||||
|
||||
operationCtx, cancel, err := s.operationContext(ctx, "create session in redis")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
_, err := tx.Get(operationCtx, deviceSessionKey).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
case err != nil:
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
|
||||
default:
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
|
||||
pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{
|
||||
Score: createdAtScore(record.CreatedAt),
|
||||
Member: record.ID.String(),
|
||||
})
|
||||
if record.Status == devicesession.StatusActive {
|
||||
pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{
|
||||
Score: createdAtScore(record.CreatedAt),
|
||||
Member: record.ID.String(),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, deviceSessionKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke stores a revoked view of one target session.
|
||||
func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err)
|
||||
}
|
||||
|
||||
var result ports.RevokeSessionResult
|
||||
err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error {
|
||||
deviceSessionKey := s.sessionKey(input.DeviceSessionID)
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound)
|
||||
default:
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Status == devicesession.StatusRevoked {
|
||||
result = ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
|
||||
Session: current,
|
||||
}
|
||||
return result.Validate()
|
||||
}
|
||||
|
||||
next := current
|
||||
next.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
next.Revocation = &revocation
|
||||
if err := next.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
payload, err := marshalSessionRecord(next)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
|
||||
pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String())
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
|
||||
}
|
||||
|
||||
result = ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeRevoked,
|
||||
Session: next,
|
||||
}
|
||||
return result.Validate()
|
||||
}, deviceSessionKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return errRetryMutation
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RevokeAllByUserID stores revoked views for all currently active sessions
|
||||
// owned by input.UserID.
|
||||
func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err)
|
||||
}
|
||||
|
||||
var result ports.RevokeUserSessionsResult
|
||||
err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error {
|
||||
activeSessionsKey := s.userActiveSessionsKey(input.UserID)
|
||||
|
||||
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
|
||||
deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
if len(deviceSessionIDs) == 0 {
|
||||
// Force EXEC so WATCH observes concurrent active-index changes even
|
||||
// for the no-op path.
|
||||
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
pipe.ZCard(operationCtx, activeSessionsKey)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
|
||||
result = ports.RevokeUserSessionsResult{
|
||||
Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions,
|
||||
UserID: input.UserID,
|
||||
Sessions: []devicesession.Session{},
|
||||
}
|
||||
return result.Validate()
|
||||
}
|
||||
|
||||
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
|
||||
for _, rawDeviceSessionID := range deviceSessionIDs {
|
||||
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
|
||||
record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID)
|
||||
default:
|
||||
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
|
||||
}
|
||||
}
|
||||
if record.UserID != input.UserID {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID)
|
||||
}
|
||||
if record.Status != devicesession.StatusActive {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status)
|
||||
}
|
||||
|
||||
next := record
|
||||
next.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
next.Revocation = &revocation
|
||||
if err := next.Validate(); err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
|
||||
}
|
||||
records = append(records, next)
|
||||
}
|
||||
|
||||
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
|
||||
for _, record := range records {
|
||||
payload, err := marshalSessionRecord(record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("session %q: %w", record.ID, err)
|
||||
}
|
||||
pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0)
|
||||
pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
|
||||
}
|
||||
|
||||
sortSessionsNewestFirst(records)
|
||||
result = ports.RevokeUserSessionsResult{
|
||||
Outcome: ports.RevokeUserSessionsOutcomeRevoked,
|
||||
UserID: input.UserID,
|
||||
Sessions: records,
|
||||
}
|
||||
return result.Validate()
|
||||
}, activeSessionsKey)
|
||||
|
||||
switch {
|
||||
case errors.Is(watchErr, redis.TxFailedErr):
|
||||
return errRetryMutation
|
||||
case watchErr != nil:
|
||||
return watchErr
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var errRetryMutation = errors.New("redis session store: retry mutation")
|
||||
|
||||
func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error {
|
||||
for attempt := 0; attempt < mutationRetryLimit; attempt++ {
|
||||
operationCtx, cancel, err := s.operationContext(ctx, operation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = execute(operationCtx)
|
||||
cancel()
|
||||
|
||||
switch {
|
||||
case errors.Is(err, errRetryMutation):
|
||||
if attempt == mutationRetryLimit-1 {
|
||||
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
|
||||
}
|
||||
continue
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
|
||||
}
|
||||
|
||||
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
|
||||
if s == nil || s.client == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil store", operation)
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, nil, fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
|
||||
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
|
||||
return operationCtx, cancel, nil
|
||||
}
|
||||
|
||||
func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
||||
return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get)
|
||||
}
|
||||
|
||||
func (s *Store) loadSessionWithGetter(
|
||||
ctx context.Context,
|
||||
deviceSessionID common.DeviceSessionID,
|
||||
getter func(context.Context, string) *redis.StringCmd,
|
||||
) (devicesession.Session, error) {
|
||||
payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return devicesession.Session{}, ports.ErrNotFound
|
||||
case err != nil:
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
record, err := decodeSessionRecord(deviceSessionID, payload)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string {
|
||||
return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String())
|
||||
}
|
||||
|
||||
func (s *Store) userSessionsKey(userID common.UserID) string {
|
||||
return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
func (s *Store) userActiveSessionsKey(userID common.UserID) string {
|
||||
return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String())
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func marshalSessionRecord(record devicesession.Session) ([]byte, error) {
|
||||
stored, err := redisRecordFromSession(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(stored)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode redis session record: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func redisRecordFromSession(record devicesession.Session) (redisRecord, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return redisRecord{}, fmt.Errorf("encode redis session record: %w", err)
|
||||
}
|
||||
|
||||
stored := redisRecord{
|
||||
DeviceSessionID: record.ID.String(),
|
||||
UserID: record.UserID.String(),
|
||||
ClientPublicKeyBase64: record.ClientPublicKey.String(),
|
||||
Status: record.Status,
|
||||
CreatedAt: formatTimestamp(record.CreatedAt),
|
||||
}
|
||||
if record.Revocation != nil {
|
||||
stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At)
|
||||
stored.RevokeReasonCode = record.Revocation.ReasonCode.String()
|
||||
stored.RevokeActorType = record.Revocation.ActorType.String()
|
||||
stored.RevokeActorID = record.Revocation.ActorID
|
||||
}
|
||||
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
||||
}
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
record, err := sessionFromRedisRecord(stored)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
if record.ID != expectedDeviceSessionID {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) {
|
||||
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
|
||||
}
|
||||
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
|
||||
}
|
||||
|
||||
record := devicesession.Session{
|
||||
ID: common.DeviceSessionID(stored.DeviceSessionID),
|
||||
UserID: common.UserID(stored.UserID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: stored.Status,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
revocation, err := parseRevocation(stored)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
record.Revocation = revocation
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) {
|
||||
hasRevokedAt := stored.RevokedAt != nil
|
||||
hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != ""
|
||||
hasActorType := strings.TrimSpace(stored.RevokeActorType) != ""
|
||||
hasActorID := strings.TrimSpace(stored.RevokeActorID) != ""
|
||||
|
||||
if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID {
|
||||
return nil, nil
|
||||
}
|
||||
if !hasRevokedAt || !hasReasonCode || !hasActorType {
|
||||
return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent")
|
||||
}
|
||||
|
||||
revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &devicesession.Revocation{
|
||||
At: revokedAt,
|
||||
ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode),
|
||||
ActorType: common.RevokeActorType(stored.RevokeActorType),
|
||||
ActorID: stored.RevokeActorID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(fieldName string, value string) (time.Time, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName)
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err)
|
||||
}
|
||||
|
||||
canonical := parsed.UTC().Format(time.RFC3339Nano)
|
||||
if value != canonical {
|
||||
return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
|
||||
}
|
||||
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
|
||||
func formatTimestamp(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func formatOptionalTimestamp(value *time.Time) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
formatted := formatTimestamp(*value)
|
||||
return &formatted
|
||||
}
|
||||
|
||||
func createdAtScore(createdAt time.Time) float64 {
|
||||
return float64(createdAt.UTC().UnixMicro())
|
||||
}
|
||||
|
||||
func sortSessionsNewestFirst(records []devicesession.Session) {
|
||||
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
|
||||
switch {
|
||||
case left.CreatedAt.Equal(right.CreatedAt):
|
||||
return strings.Compare(left.ID.String(), right.ID.String())
|
||||
case left.CreatedAt.After(right.CreatedAt):
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var _ ports.SessionStore = (*Store)(nil)
|
||||
@@ -0,0 +1,635 @@
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/adapters/contracttest"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStoreContract(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contracttest.RunSessionStoreContractTests(t, func(t *testing.T) ports.SessionStore {
|
||||
t.Helper()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
return newTestStore(t, server, Config{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: 1,
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: Config{
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty session prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "session key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty all sessions prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "user sessions key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty active sessions prefix",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
OperationTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "user active sessions key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
Addr: server.Addr(),
|
||||
SessionKeyPrefix: "authsession:session:",
|
||||
UserSessionsKeyPrefix: "authsession:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
|
||||
},
|
||||
wantErr: "operation timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := New(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
got.Revocation = &devicesession.Revocation{
|
||||
At: got.CreatedAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
}
|
||||
|
||||
again, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, again.Revocation)
|
||||
assert.Equal(t, record, again)
|
||||
}
|
||||
|
||||
func TestStoreCreateAndGetRevoked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(1_775_240_100, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record, got)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
}
|
||||
|
||||
func TestStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreCreateConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_200, 0).UTC())
|
||||
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
err := store.Create(context.Background(), record)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrConflict)
|
||||
}
|
||||
|
||||
func TestStoreIndexesAndOrdering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
older := activeSessionFixture("device-session-old", "user-1", time.Unix(10, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-new", "user-1", time.Unix(20, 0).UTC())
|
||||
revoked := revokedSessionFixture("device-session-revoked", "user-1", time.Unix(15, 0).UTC())
|
||||
otherUser := activeSessionFixture("device-session-other", "user-2", time.Unix(30, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got, 3)
|
||||
assert.Equal(t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID})
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
unknown, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unknown)
|
||||
}
|
||||
|
||||
func TestStoreKeyPrefixesAndEncodedPrimaryKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{
|
||||
SessionKeyPrefix: "custom:session:",
|
||||
UserSessionsKeyPrefix: "custom:user-sessions:",
|
||||
UserActiveSessionsKeyPrefix: "custom:user-active-sessions:",
|
||||
})
|
||||
|
||||
record := activeSessionFixture("device/session:opaque?1", "user/opaque:1", time.Unix(40, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
primaryKey := store.sessionKey(record.ID)
|
||||
assert.Equal(t, "custom:session:"+encodeKeyComponent(record.ID.String()), primaryKey)
|
||||
assert.True(t, server.Exists(primaryKey))
|
||||
|
||||
allSessionsKey := store.userSessionsKey(record.UserID)
|
||||
activeSessionsKey := store.userActiveSessionsKey(record.UserID)
|
||||
assert.Equal(t, "custom:user-sessions:"+encodeKeyComponent(record.UserID.String()), allSessionsKey)
|
||||
assert.Equal(t, "custom:user-active-sessions:"+encodeKeyComponent(record.UserID.String()), activeSessionsKey)
|
||||
|
||||
allMembers, err := server.ZMembers(allSessionsKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{record.ID.String()}, allMembers)
|
||||
|
||||
activeMembers, err := server.ZMembers(activeSessionsKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{record.ID.String()}, activeMembers)
|
||||
}
|
||||
|
||||
func TestStoreRevoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("active session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
revocation := devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
}
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, revocation, *result.Session.Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
})
|
||||
|
||||
t.Run("already revoked keeps stored revocation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(300, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
ActorID: "admin-1",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
|
||||
require.NotNil(t, result.Session.Revocation)
|
||||
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
|
||||
})
|
||||
|
||||
t.Run("unknown session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
|
||||
DeviceSessionID: common.DeviceSessionID("missing-session"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRevokeAllByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("revokes active sessions newest first and clears active index", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
|
||||
older := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(200, 0).UTC())
|
||||
alreadyRevoked := revokedSessionFixture("device-session-3", "user-1", time.Unix(150, 0).UTC())
|
||||
otherUser := activeSessionFixture("device-session-4", "user-2", time.Unix(250, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
}
|
||||
|
||||
revocation := devicesession.Revocation{
|
||||
At: time.Unix(300, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
ActorID: "admin-1",
|
||||
}
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: revocation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
|
||||
require.Len(t, result.Sessions, 2)
|
||||
assert.Equal(t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID})
|
||||
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
|
||||
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
|
||||
|
||||
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Zero(t, count)
|
||||
|
||||
otherCount, err := store.CountActiveByUserID(context.Background(), common.UserID("user-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, otherCount)
|
||||
})
|
||||
|
||||
t.Run("no active sessions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-5", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, store.Create(context.Background(), record))
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(400, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
|
||||
assert.Empty(t, result.Sessions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreStrictDecodeCorruption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_775_240_300, 0).UTC()
|
||||
baseRecord := revokedSessionFixture("device-session-corrupt", "user-1", now)
|
||||
stored, err := redisRecordFromSession(baseRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(redisRecord) string
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "malformed json",
|
||||
mutate: func(_ redisRecord) string {
|
||||
return "{"
|
||||
},
|
||||
wantErrText: "decode redis session record",
|
||||
},
|
||||
{
|
||||
name: "trailing json input",
|
||||
mutate: func(record redisRecord) string {
|
||||
return mustMarshalJSON(t, record) + "{}"
|
||||
},
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
mutate: func(record redisRecord) string {
|
||||
payload := map[string]any{
|
||||
"device_session_id": record.DeviceSessionID,
|
||||
"user_id": record.UserID,
|
||||
"client_public_key_base64": record.ClientPublicKeyBase64,
|
||||
"status": record.Status,
|
||||
"created_at": record.CreatedAt,
|
||||
"revoked_at": record.RevokedAt,
|
||||
"revoke_reason_code": record.RevokeReasonCode,
|
||||
"revoke_actor_type": record.RevokeActorType,
|
||||
"revoke_actor_id": record.RevokeActorID,
|
||||
"unexpected": true,
|
||||
}
|
||||
return mustMarshalJSON(t, payload)
|
||||
},
|
||||
wantErrText: "unknown field",
|
||||
},
|
||||
{
|
||||
name: "unsupported status",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.Status = devicesession.Status("paused")
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "non canonical timestamp",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.CreatedAt = "2026-04-04T12:00:00+03:00"
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "canonical UTC RFC3339Nano timestamp",
|
||||
},
|
||||
{
|
||||
name: "incomplete revocation metadata",
|
||||
mutate: func(record redisRecord) string {
|
||||
record.RevokeActorType = ""
|
||||
return mustMarshalJSON(t, record)
|
||||
},
|
||||
wantErrText: "revocation metadata must be either fully present or fully absent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
server.Set(store.sessionKey(baseRecord.ID), tt.mutate(stored))
|
||||
|
||||
_, err := store.Get(context.Background(), baseRecord.ID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreListByUserIDDetectsCorruptIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("missing primary record", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
userID := common.UserID("user-1")
|
||||
_, err := server.ZAdd(store.userSessionsKey(userID), 100, "missing-session")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.ListByUserID(context.Background(), userID)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "references missing session")
|
||||
})
|
||||
|
||||
t.Run("wrong user id in primary record", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := activeSessionFixture("device-session-1", "user-2", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
|
||||
_, err := server.ZAdd(store.userSessionsKey(common.UserID("user-1")), createdAtScore(record.CreatedAt), record.ID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `belongs to "user-2"`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRevokeAllByUserIDDetectsCorruptActiveIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestStore(t, server, Config{})
|
||||
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
|
||||
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
|
||||
_, err := server.ZAdd(store.userActiveSessionsKey(record.UserID), createdAtScore(record.CreatedAt), record.ID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
|
||||
UserID: record.UserID,
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(200, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, `is "revoked"`)
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.SessionKeyPrefix == "" {
|
||||
cfg.SessionKeyPrefix = "authsession:session:"
|
||||
}
|
||||
if cfg.UserSessionsKeyPrefix == "" {
|
||||
cfg.UserSessionsKeyPrefix = "authsession:user-sessions:"
|
||||
}
|
||||
if cfg.UserActiveSessionsKeyPrefix == "" {
|
||||
cfg.UserActiveSessionsKeyPrefix = "authsession:user-active-sessions:"
|
||||
}
|
||||
if cfg.OperationTimeout == 0 {
|
||||
cfg.OperationTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
|
||||
0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
record := activeSessionFixture(deviceSessionID, userID, createdAt)
|
||||
record.Status = devicesession.StatusRevoked
|
||||
record.Revocation = &devicesession.Revocation{
|
||||
At: createdAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonDeviceLogout,
|
||||
ActorType: common.RevokeActorType("user"),
|
||||
ActorID: "user-actor",
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
func seedSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record devicesession.Session) error {
|
||||
t.Helper()
|
||||
|
||||
stored, err := redisRecordFromSession(record)
|
||||
require.NoError(t, err)
|
||||
server.Set(key, mustMarshalJSON(t, stored))
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(payload)
|
||||
}
|
||||
@@ -0,0 +1,382 @@
|
||||
// Package userservice provides runtime user-directory adapters for the
|
||||
// auth/session service.
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
resolveByEmailPath = "/api/v1/internal/user-resolutions/by-email"
|
||||
existsByUserIDPath = "/api/v1/internal/users/%s/exists"
|
||||
ensureByEmailPath = "/api/v1/internal/users/ensure-by-email"
|
||||
blockByUserIDPath = "/api/v1/internal/users/%s/block"
|
||||
blockByEmailPath = "/api/v1/internal/user-blocks/by-email"
|
||||
)
|
||||
|
||||
// Config configures one HTTP-based UserDirectory client.
|
||||
type Config struct {
|
||||
// BaseURL is the absolute base URL of the future user-service internal
|
||||
// HTTP API.
|
||||
BaseURL string
|
||||
|
||||
// RequestTimeout bounds each outbound user-service request.
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
// RESTClient implements ports.UserDirectory over a frozen internal REST
|
||||
// contract.
|
||||
type RESTClient struct {
|
||||
baseURL string
|
||||
requestTimeout time.Duration
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewRESTClient constructs a REST-backed UserDirectory adapter from cfg.
|
||||
func NewRESTClient(cfg Config) (*RESTClient, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
return newRESTClient(cfg, &http.Client{Transport: transport})
|
||||
}
|
||||
|
||||
func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) {
|
||||
switch {
|
||||
case strings.TrimSpace(cfg.BaseURL) == "":
|
||||
return nil, errors.New("new user service REST client: base URL must not be empty")
|
||||
case cfg.RequestTimeout <= 0:
|
||||
return nil, errors.New("new user service REST client: request timeout must be positive")
|
||||
case httpClient == nil:
|
||||
return nil, errors.New("new user service REST client: http client must not be nil")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new user service REST client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new user service REST client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &RESTClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
requestTimeout: cfg.RequestTimeout,
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *RESTClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveByEmail returns the current coarse user-resolution state for email
|
||||
// without creating any new user record.
|
||||
func (c *RESTClient) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
|
||||
if err := validateContext(ctx, "resolve by email"); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Kind userresolution.Kind `json:"kind"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "resolve by email", http.MethodPost, resolveByEmailPath, map[string]string{
|
||||
"email": email.String(),
|
||||
}, &response, true); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
result := userresolution.Result{
|
||||
Kind: response.Kind,
|
||||
UserID: common.UserID(response.UserID),
|
||||
BlockReasonCode: response.BlockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID reports whether userID currently identifies a stored user
|
||||
// record.
|
||||
func (c *RESTClient) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
|
||||
if err := validateContext(ctx, "exists by user id"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return false, fmt.Errorf("exists by user id: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Exists bool `json:"exists"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "exists by user id", http.MethodGet, fmt.Sprintf(existsByUserIDPath, url.PathEscape(userID.String())), nil, &response, true); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return response.Exists, nil
|
||||
}
|
||||
|
||||
// EnsureUserByEmail returns an existing user for email, creates a new user
|
||||
// when registration is allowed, or reports a blocked outcome.
|
||||
func (c *RESTClient) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
|
||||
if err := validateContext(ctx, "ensure user by email"); err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.EnsureUserOutcome `json:"outcome"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "ensure user by email", http.MethodPost, ensureByEmailPath, map[string]string{
|
||||
"email": email.String(),
|
||||
}, &response, false); err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: response.Outcome,
|
||||
UserID: common.UserID(response.UserID),
|
||||
BlockReasonCode: response.BlockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByUserID applies a block state to the user identified by input.UserID.
|
||||
// Unknown user ids wrap ports.ErrNotFound.
|
||||
func (c *RESTClient) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by user id"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
payload, statusCode, err := c.doRequest(ctx, "block by user id", http.MethodPost, fmt.Sprintf(blockByUserIDPath, url.PathEscape(input.UserID.String())), map[string]string{
|
||||
"reason_code": input.ReasonCode.String(),
|
||||
}, false)
|
||||
if err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if statusCode == http.StatusNotFound {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.BlockUserOutcome `json:"outcome"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
if err := decodeJSONPayload(payload, &response); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: response.Outcome,
|
||||
UserID: common.UserID(response.UserID),
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByEmail applies a block state to input.Email even when no user record
|
||||
// currently exists for that e-mail address.
|
||||
func (c *RESTClient) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by email"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Outcome ports.BlockUserOutcome `json:"outcome"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.doJSON(ctx, "block by email", http.MethodPost, blockByEmailPath, map[string]string{
|
||||
"email": input.Email.String(),
|
||||
"reason_code": input.ReasonCode.String(),
|
||||
}, &response, false); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: response.Outcome,
|
||||
UserID: common.UserID(response.UserID),
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *RESTClient) doJSON(ctx context.Context, operation string, method string, requestPath string, requestBody any, responseTarget any, retryRead bool) error {
|
||||
payload, statusCode, err := c.doRequest(ctx, operation, method, requestPath, requestBody, retryRead)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return fmt.Errorf("%s: unexpected HTTP status %d", operation, statusCode)
|
||||
}
|
||||
if err := decodeJSONPayload(payload, responseTarget); err != nil {
|
||||
return fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RESTClient) doRequest(ctx context.Context, operation string, method string, requestPath string, requestBody any, retryRead bool) ([]byte, int, error) {
|
||||
bodyBytes, err := marshalOptionalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
attempts := 1
|
||||
if retryRead {
|
||||
attempts = 2
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
|
||||
request, err := http.NewRequestWithContext(attemptCtx, method, c.baseURL+requestPath, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, 0, fmt.Errorf("%s: build request: %w", operation, err)
|
||||
}
|
||||
if method == http.MethodPost {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
cancel()
|
||||
lastErr = fmt.Errorf("%s: %w", operation, err)
|
||||
if retryRead && attempt == 0 && ctx.Err() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
|
||||
payload, readErr := io.ReadAll(response.Body)
|
||||
closeErr := response.Body.Close()
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
lastErr = fmt.Errorf("%s: read response body: %w", operation, readErr)
|
||||
if retryRead && attempt == 0 && ctx.Err() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
if closeErr != nil {
|
||||
lastErr = fmt.Errorf("%s: close response body: %w", operation, closeErr)
|
||||
if retryRead && attempt == 0 && ctx.Err() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
|
||||
if retryRead && attempt == 0 && isRetriableUserServiceStatus(response.StatusCode) {
|
||||
lastErr = fmt.Errorf("%s: unexpected HTTP status %d", operation, response.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
return payload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
return nil, 0, lastErr
|
||||
}
|
||||
|
||||
func marshalOptionalRequestBody(value any) ([]byte, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func decodeJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("decode response body: unexpected trailing JSON input")
|
||||
}
|
||||
|
||||
return fmt.Errorf("decode response body: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRetriableUserServiceStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var _ ports.UserDirectory = (*RESTClient)(nil)
|
||||
@@ -0,0 +1,622 @@
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRESTClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
cfg: Config{
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must not be empty",
|
||||
},
|
||||
{
|
||||
name: "relative base url",
|
||||
cfg: Config{
|
||||
BaseURL: "/relative",
|
||||
RequestTimeout: time.Second,
|
||||
},
|
||||
wantErr: "base URL must be absolute",
|
||||
},
|
||||
{
|
||||
name: "non positive timeout",
|
||||
cfg: Config{
|
||||
BaseURL: "http://127.0.0.1:8080",
|
||||
},
|
||||
wantErr: "request timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := NewRESTClient(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientEndpointSuccessCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*testing.T, *RESTClient)
|
||||
}{
|
||||
{
|
||||
name: "resolve by email",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Case@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userresolution.Result{
|
||||
Kind: userresolution.KindExisting,
|
||||
UserID: common.UserID("user-123"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exists by user id",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ensure user by email",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeCreated,
|
||||
UserID: common.UserID("user-234"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by user id",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-123"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: common.UserID("user-123"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by email",
|
||||
run: func(t *testing.T, client *RESTClient) {
|
||||
result, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("blocked@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: common.UserID("user-345"),
|
||||
}, result)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestsMu sync.Mutex
|
||||
var requests []capturedRequest
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, captureRequest(t, r))
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == resolveByEmailPath:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"kind": "existing",
|
||||
"user_id": "user-123",
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/internal/users/user-123/exists":
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"exists": true})
|
||||
case r.Method == http.MethodPost && r.URL.Path == ensureByEmailPath:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"outcome": "created",
|
||||
"user_id": "user-234",
|
||||
})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/internal/users/user-123/block":
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"outcome": "blocked",
|
||||
"user_id": "user-123",
|
||||
})
|
||||
case r.Method == http.MethodPost && r.URL.Path == blockByEmailPath:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"outcome": "already_blocked",
|
||||
"user_id": "user-345",
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
tt.run(t, client)
|
||||
|
||||
requestsMu.Lock()
|
||||
defer requestsMu.Unlock()
|
||||
|
||||
require.Len(t, requests, 1)
|
||||
switch tt.name {
|
||||
case "resolve by email":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: resolveByEmailPath,
|
||||
ContentType: "application/json",
|
||||
Body: `{"email":"Pilot+Case@example.com"}`,
|
||||
}, requests[0])
|
||||
case "exists by user id":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/v1/internal/users/user-123/exists",
|
||||
}, requests[0])
|
||||
case "ensure user by email":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: ensureByEmailPath,
|
||||
ContentType: "application/json",
|
||||
Body: `{"email":"created@example.com"}`,
|
||||
}, requests[0])
|
||||
case "block by user id":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/v1/internal/users/user-123/block",
|
||||
ContentType: "application/json",
|
||||
Body: `{"reason_code":"policy_blocked"}`,
|
||||
}, requests[0])
|
||||
case "block by email":
|
||||
assert.Equal(t, capturedRequest{
|
||||
Method: http.MethodPost,
|
||||
Path: blockByEmailPath,
|
||||
ContentType: "application/json",
|
||||
Body: `{"email":"blocked@example.com","reason_code":"policy_blocked"}`,
|
||||
}, requests[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientPreservesNormalizedEmailExactly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var captured string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
request := captureRequest(t, r)
|
||||
captured = request.Body
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Alias@Example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{"email":"Pilot+Alias@Example.com"}`, captured)
|
||||
}
|
||||
|
||||
func TestRESTClientBlockByUserIDNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("missing-user"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRESTClientReadMethodsRetryOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("resolve by email retries on 503", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
http.Error(w, "temporary", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
result, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userresolution.KindCreatable, result.Kind)
|
||||
assert.Equal(t, 2, calls)
|
||||
})
|
||||
|
||||
t.Run("exists by user id retries on transport failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"exists": true})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
baseTransport := server.Client().Transport
|
||||
client, err := newRESTClient(Config{
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: 250 * time.Millisecond,
|
||||
}, &http.Client{
|
||||
Transport: &failOnceRoundTripper{
|
||||
next: baseTransport,
|
||||
err: errors.New("temporary transport failure"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRESTClientMutationMethodsDoNotRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*RESTClient) error
|
||||
}{
|
||||
{
|
||||
name: "ensure user by email",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by user id",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-123"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by email",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls++
|
||||
http.Error(w, "temporary", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
err := tt.run(client)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, calls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
wantErrText string
|
||||
run func(*RESTClient) error
|
||||
}{
|
||||
{
|
||||
name: "resolve by email rejects unknown field",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"kind":"creatable","extra":true}`,
|
||||
wantErrText: "decode response body",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ensure user by email rejects malformed outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"mystery"}`,
|
||||
wantErrText: "unsupported",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ensure user by email rejects missing user id for created outcome",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"outcome":"created"}`,
|
||||
wantErrText: "user id",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exists by user id rejects trailing json",
|
||||
statusCode: http.StatusOK,
|
||||
body: `{"exists":true}{}`,
|
||||
wantErrText: "unexpected trailing JSON input",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block by email rejects unexpected status",
|
||||
statusCode: http.StatusBadGateway,
|
||||
body: `{"error":"temporary"}`,
|
||||
wantErrText: "unexpected HTTP status 502",
|
||||
run: func(client *RESTClient) error {
|
||||
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.body)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
|
||||
err := tt.run(client)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClientRequestTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
|
||||
|
||||
_, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "context deadline exceeded")
|
||||
}
|
||||
|
||||
func TestRESTClientContextAndValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected upstream call")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func() error
|
||||
}{
|
||||
{
|
||||
name: "nil context",
|
||||
run: func() error {
|
||||
_, err := client.ResolveByEmail(nil, common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled context",
|
||||
run: func() error {
|
||||
_, err := client.ExistsByUserID(cancelledCtx, common.UserID("user-123"))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
run: func() error {
|
||||
_, err := client.EnsureUserByEmail(context.Background(), common.Email(" bad@example.com "))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid user id",
|
||||
run: func() error {
|
||||
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID(" bad "),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid reason code",
|
||||
run: func() error {
|
||||
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode(" bad "),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.run()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type capturedRequest struct {
|
||||
Method string
|
||||
Path string
|
||||
ContentType string
|
||||
Body string
|
||||
}
|
||||
|
||||
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
|
||||
t.Helper()
|
||||
|
||||
body, err := io.ReadAll(request.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
return capturedRequest{
|
||||
Method: request.Method,
|
||||
Path: request.URL.Path,
|
||||
ContentType: request.Header.Get("Content-Type"),
|
||||
Body: strings.TrimSpace(string(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(statusCode)
|
||||
_, err = writer.Write(payload)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := NewRESTClient(Config{
|
||||
BaseURL: baseURL,
|
||||
RequestTimeout: timeout,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
type failOnceRoundTripper struct {
|
||||
mu sync.Mutex
|
||||
next http.RoundTripper
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
func (rt *failOnceRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
rt.mu.Lock()
|
||||
if !rt.done {
|
||||
rt.done = true
|
||||
err := rt.err
|
||||
rt.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
next := rt.next
|
||||
rt.mu.Unlock()
|
||||
|
||||
return next.RoundTrip(request)
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
// Package userservice provides runtime user-directory adapters for the
|
||||
// auth/session service.
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
userID common.UserID
|
||||
blockReasonCode userresolution.BlockReasonCode
|
||||
}
|
||||
|
||||
// StubDirectory is a concurrency-safe in-process UserDirectory stub intended
|
||||
// for development, local integration, and explicit stub-based tests.
|
||||
//
|
||||
// The zero value is ready to use. Unknown e-mail addresses resolve as
|
||||
// creatable, unknown user identifiers do not exist, and EnsureUserByEmail
|
||||
// creates deterministic user ids such as "user-1", "user-2", and so on.
|
||||
type StubDirectory struct {
|
||||
mu sync.Mutex
|
||||
byEmail map[common.Email]entry
|
||||
emailByUserID map[common.UserID]common.Email
|
||||
createdUserIDs []common.UserID
|
||||
nextUserNumber int
|
||||
}
|
||||
|
||||
// ResolveByEmail returns the current coarse user-resolution state for email
|
||||
// without creating any new user record.
|
||||
func (d *StubDirectory) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
|
||||
if err := validateContext(ctx, "resolve by email"); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
result, err := d.resolveLocked(email)
|
||||
if err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID reports whether userID currently identifies a stored user
|
||||
// record.
|
||||
func (d *StubDirectory) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
|
||||
if err := validateContext(ctx, "exists by user id"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return false, fmt.Errorf("exists by user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
_, ok := d.emailByUserID[userID]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// EnsureUserByEmail returns an existing user for email, creates a new user
|
||||
// when registration is allowed, or reports a blocked outcome.
|
||||
func (d *StubDirectory) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
|
||||
if err := validateContext(ctx, "ensure user by email"); err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
|
||||
stored, ok := d.byEmail[email]
|
||||
if ok {
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeBlocked,
|
||||
BlockReasonCode: stored.blockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeExisting,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
userID, err := d.nextCreatedUserIDLocked()
|
||||
if err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
d.byEmail[email] = entry{userID: userID}
|
||||
d.emailByUserID[userID] = email
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeCreated,
|
||||
UserID: userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByUserID applies a block state to the user identified by input.UserID.
|
||||
// Unknown user ids wrap ports.ErrNotFound.
|
||||
func (d *StubDirectory) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by user id"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
email, ok := d.emailByUserID[input.UserID]
|
||||
if !ok {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
stored := d.byEmail[email]
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: input.UserID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
stored.blockReasonCode = input.ReasonCode
|
||||
d.byEmail[email] = stored
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: input.UserID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BlockByEmail applies a block state to input.Email even when no user record
|
||||
// currently exists for that e-mail address.
|
||||
func (d *StubDirectory) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
|
||||
if err := validateContext(ctx, "block by email"); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
|
||||
stored := d.byEmail[input.Email]
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
stored.blockReasonCode = input.ReasonCode
|
||||
d.byEmail[input.Email] = stored
|
||||
if !stored.userID.IsZero() {
|
||||
d.emailByUserID[stored.userID] = input.Email
|
||||
}
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SeedExisting preloads one existing unblocked user record into the runtime
|
||||
// stub.
|
||||
func (d *StubDirectory) SeedExisting(email common.Email, userID common.UserID) error {
|
||||
if err := email.Validate(); err != nil {
|
||||
return fmt.Errorf("seed existing email: %w", err)
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return fmt.Errorf("seed existing user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
d.byEmail[email] = entry{userID: userID}
|
||||
d.emailByUserID[userID] = email
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeedBlockedEmail preloads one blocked e-mail address that does not
|
||||
// necessarily belong to an existing user record.
|
||||
func (d *StubDirectory) SeedBlockedEmail(email common.Email, reasonCode userresolution.BlockReasonCode) error {
|
||||
if err := email.Validate(); err != nil {
|
||||
return fmt.Errorf("seed blocked email: %w", err)
|
||||
}
|
||||
if err := reasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("seed blocked email reason code: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.ensureMapsLocked()
|
||||
d.byEmail[email] = entry{blockReasonCode: reasonCode}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeedBlockedUser preloads one blocked existing user record into the runtime
|
||||
// stub.
|
||||
func (d *StubDirectory) SeedBlockedUser(email common.Email, userID common.UserID, reasonCode userresolution.BlockReasonCode) error {
|
||||
if err := d.SeedExisting(email, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
stored := d.byEmail[email]
|
||||
stored.blockReasonCode = reasonCode
|
||||
d.byEmail[email] = stored
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueCreatedUserIDs appends deterministic user identifiers that
|
||||
// EnsureUserByEmail consumes before falling back to generated ids.
|
||||
func (d *StubDirectory) QueueCreatedUserIDs(userIDs ...common.UserID) error {
|
||||
for index, userID := range userIDs {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return fmt.Errorf("queue created user id %d: %w", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.createdUserIDs = append(d.createdUserIDs, userIDs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *StubDirectory) ensureMapsLocked() {
|
||||
if d.byEmail == nil {
|
||||
d.byEmail = make(map[common.Email]entry)
|
||||
}
|
||||
if d.emailByUserID == nil {
|
||||
d.emailByUserID = make(map[common.UserID]common.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *StubDirectory) resolveLocked(email common.Email) (userresolution.Result, error) {
|
||||
stored, ok := d.byEmail[email]
|
||||
if !ok {
|
||||
result := userresolution.Result{Kind: userresolution.KindCreatable}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
if !stored.blockReasonCode.IsZero() {
|
||||
result := userresolution.Result{
|
||||
Kind: userresolution.KindBlocked,
|
||||
BlockReasonCode: stored.blockReasonCode,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result := userresolution.Result{
|
||||
Kind: userresolution.KindExisting,
|
||||
UserID: stored.userID,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (d *StubDirectory) nextCreatedUserIDLocked() (common.UserID, error) {
|
||||
if len(d.createdUserIDs) > 0 {
|
||||
userID := d.createdUserIDs[0]
|
||||
d.createdUserIDs = d.createdUserIDs[1:]
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
d.nextUserNumber++
|
||||
userID := common.UserID(fmt.Sprintf("user-%d", d.nextUserNumber))
|
||||
if err := userID.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func validateContext(ctx context.Context, operation string) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s: nil context", operation)
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fmt.Errorf("%s: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ports.UserDirectory = (*StubDirectory)(nil)
|
||||
@@ -0,0 +1,329 @@
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStubDirectoryResolveByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
|
||||
require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email common.Email
|
||||
wantKind userresolution.Kind
|
||||
wantUserID common.UserID
|
||||
wantReasonCode userresolution.BlockReasonCode
|
||||
}{
|
||||
{
|
||||
name: "zero value unknown email is creatable",
|
||||
email: common.Email("new@example.com"),
|
||||
wantKind: userresolution.KindCreatable,
|
||||
},
|
||||
{
|
||||
name: "existing email",
|
||||
email: common.Email("existing@example.com"),
|
||||
wantKind: userresolution.KindExisting,
|
||||
wantUserID: common.UserID("user-existing"),
|
||||
},
|
||||
{
|
||||
name: "blocked email",
|
||||
email: common.Email("blocked@example.com"),
|
||||
wantKind: userresolution.KindBlocked,
|
||||
wantReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := directory.ResolveByEmail(context.Background(), tt.email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantKind, result.Kind)
|
||||
assert.Equal(t, tt.wantUserID, result.UserID)
|
||||
assert.Equal(t, tt.wantReasonCode, result.BlockReasonCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStubDirectoryEnsureUserByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("existing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
|
||||
|
||||
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("existing@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeExisting, result.Outcome)
|
||||
assert.Equal(t, common.UserID("user-existing"), result.UserID)
|
||||
})
|
||||
|
||||
t.Run("blocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")))
|
||||
|
||||
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("blocked@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeBlocked, result.Outcome)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), result.BlockReasonCode)
|
||||
})
|
||||
|
||||
t.Run("created queued then existing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.QueueCreatedUserIDs(common.UserID("user-created")))
|
||||
|
||||
first, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeCreated, first.Outcome)
|
||||
assert.Equal(t, common.UserID("user-created"), first.UserID)
|
||||
|
||||
second, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeExisting, second.Outcome)
|
||||
assert.Equal(t, common.UserID("user-created"), second.UserID)
|
||||
})
|
||||
|
||||
t.Run("created fallback id", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
|
||||
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("fallback@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.EnsureUserOutcomeCreated, result.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), result.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubDirectoryExistsByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
|
||||
|
||||
exists, err := directory.ExistsByUserID(context.Background(), common.UserID("user-existing"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, err = directory.ExistsByUserID(context.Background(), common.UserID("missing"))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestStubDirectoryBlockByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unknown email becomes blocked without user id", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
|
||||
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("blocked@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeBlocked, result.Outcome)
|
||||
assert.True(t, result.UserID.IsZero())
|
||||
|
||||
resolution, err := directory.ResolveByEmail(context.Background(), common.Email("blocked@example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
})
|
||||
|
||||
t.Run("existing user preserves linked user id and repeat is already blocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
first, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), first.UserID)
|
||||
|
||||
second, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), second.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubDirectoryBlockByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unknown user wraps ErrNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
|
||||
_, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("missing"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ports.ErrNotFound)
|
||||
})
|
||||
|
||||
t.Run("existing user blocks then returns already blocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
first, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), first.UserID)
|
||||
|
||||
second, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), second.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStubDirectoryContextAndValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func() error
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "resolve nil context",
|
||||
run: func() error {
|
||||
_, err := directory.ResolveByEmail(nil, common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
want: "nil context",
|
||||
},
|
||||
{
|
||||
name: "ensure cancelled context",
|
||||
run: func() error {
|
||||
_, err := directory.EnsureUserByEmail(cancelledCtx, common.Email("pilot@example.com"))
|
||||
return err
|
||||
},
|
||||
want: context.Canceled.Error(),
|
||||
},
|
||||
{
|
||||
name: "exists invalid user id",
|
||||
run: func() error {
|
||||
_, err := directory.ExistsByUserID(context.Background(), common.UserID(" bad "))
|
||||
return err
|
||||
},
|
||||
want: "exists by user id",
|
||||
},
|
||||
{
|
||||
name: "block by email invalid email",
|
||||
run: func() error {
|
||||
_, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("bad"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
want: "block by email",
|
||||
},
|
||||
{
|
||||
name: "seed invalid user id",
|
||||
run: func() error {
|
||||
return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID(" bad "))
|
||||
},
|
||||
want: "seed existing user id",
|
||||
},
|
||||
{
|
||||
name: "queue invalid created user id",
|
||||
run: func() error {
|
||||
return directory.QueueCreatedUserIDs(common.UserID(" bad "))
|
||||
},
|
||||
want: "queue created user id 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tt.run()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStubDirectorySeedBlockedUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
require.NoError(t, directory.SeedBlockedUser(
|
||||
common.Email("pilot@example.com"),
|
||||
common.UserID("user-1"),
|
||||
userresolution.BlockReasonCode("policy_block"),
|
||||
))
|
||||
|
||||
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, result.Outcome)
|
||||
assert.Equal(t, common.UserID("user-1"), result.UserID)
|
||||
}
|
||||
|
||||
func TestStubDirectoryCancelledContextWrapsContextError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &StubDirectory{}
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := directory.BlockByUserID(cancelledCtx, ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.ErrorContains(t, err, "block by user id")
|
||||
}
|
||||
Reference in New Issue
Block a user