feat: authsession service

This commit is contained in:
Ilia Denisov
2026-04-08 16:23:07 +02:00
committed by GitHub
parent 28f04916af
commit 86a68ed9d0
174 changed files with 31732 additions and 112 deletions
@@ -0,0 +1,122 @@
package testkit
import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
)
// InMemoryChallengeStore is a deterministic map-backed ChallengeStore double
// suitable for service tests.
type InMemoryChallengeStore struct {
mu sync.Mutex
records map[common.ChallengeID]challenge.Challenge
}
// Get returns the stored challenge for challengeID.
func (s *InMemoryChallengeStore) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
if err := ctx.Err(); err != nil {
return challenge.Challenge{}, err
}
if err := challengeID.Validate(); err != nil {
return challenge.Challenge{}, fmt.Errorf("get challenge: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
record, ok := s.records[challengeID]
if !ok {
return challenge.Challenge{}, fmt.Errorf("get challenge %q: %w", challengeID, ports.ErrNotFound)
}
cloned, err := cloneChallenge(record)
if err != nil {
return challenge.Challenge{}, err
}
return cloned, nil
}
// Create stores record as a new challenge.
func (s *InMemoryChallengeStore) Create(ctx context.Context, record challenge.Challenge) error {
if err := ctx.Err(); err != nil {
return err
}
if err := record.Validate(); err != nil {
return fmt.Errorf("create challenge: %w", err)
}
cloned, err := cloneChallenge(record)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.records == nil {
s.records = make(map[common.ChallengeID]challenge.Challenge)
}
if _, exists := s.records[record.ID]; exists {
return fmt.Errorf("create challenge %q: %w", record.ID, ports.ErrConflict)
}
s.records[record.ID] = cloned
return nil
}
// CompareAndSwap replaces previous with next when the currently stored
// challenge matches previous exactly.
func (s *InMemoryChallengeStore) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
if err := ctx.Err(); err != nil {
return err
}
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
return fmt.Errorf("compare and swap challenge: %w", err)
}
clonedPrevious, err := cloneChallenge(previous)
if err != nil {
return err
}
clonedNext, err := cloneChallenge(next)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
current, ok := s.records[previous.ID]
if !ok {
return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrNotFound)
}
if !reflect.DeepEqual(current, clonedPrevious) {
return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrConflict)
}
s.records[next.ID] = clonedNext
return nil
}
var _ ports.ChallengeStore = (*InMemoryChallengeStore)(nil)
func mustGetChallenge(store *InMemoryChallengeStore, challengeID common.ChallengeID) challenge.Challenge {
record, err := store.Get(context.Background(), challengeID)
if err != nil {
panic(err)
}
return record
}
func isNotFound(err error) bool {
return errors.Is(err, ports.ErrNotFound)
}
@@ -0,0 +1,80 @@
package testkit
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
)
func TestInMemoryChallengeStoreCreateAndGet(t *testing.T) {
t.Parallel()
store := &InMemoryChallengeStore{}
record := challengeFixture()
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
got, err := store.Get(context.Background(), record.ID)
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if got.ID != record.ID {
require.Failf(t, "test failed", "Get().ID = %q, want %q", got.ID, record.ID)
}
if &got.CodeHash[0] == &record.CodeHash[0] {
require.FailNow(t, "Get() returned aliased code hash slice")
}
}
func TestInMemoryChallengeStoreGetNotFound(t *testing.T) {
t.Parallel()
store := &InMemoryChallengeStore{}
_, err := store.Get(context.Background(), common.ChallengeID("missing"))
if !errors.Is(err, ports.ErrNotFound) {
require.Failf(t, "test failed", "Get() error = %v, want ErrNotFound", err)
}
}
func TestInMemoryChallengeStoreCompareAndSwapConflict(t *testing.T) {
t.Parallel()
store := &InMemoryChallengeStore{}
record := challengeFixture()
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
previous := record
previous.Attempts.Confirm = 1
next := record
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
err := store.CompareAndSwap(context.Background(), previous, next)
if !errors.Is(err, ports.ErrConflict) {
require.Failf(t, "test failed", "CompareAndSwap() error = %v, want ErrConflict", err)
}
}
func challengeFixture() challenge.Challenge {
timestamp := time.Unix(20, 0).UTC()
return challenge.Challenge{
ID: common.ChallengeID("challenge-1"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hash"),
Status: challenge.StatusPendingSend,
DeliveryState: challenge.DeliveryPending,
CreatedAt: timestamp,
ExpiresAt: timestamp.Add(10 * time.Minute),
}
}
+15
View File
@@ -0,0 +1,15 @@
package testkit
import "time"
// FixedClock is a deterministic Clock double that always returns the same
// instant.
type FixedClock struct {
// Time is the instant returned by Now.
Time time.Time
}
// Now returns the configured instant.
func (c FixedClock) Now() time.Time {
return c.Time
}
+130
View File
@@ -0,0 +1,130 @@
package testkit
import (
"bytes"
"fmt"
"slices"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
)
func cloneChallenge(record challenge.Challenge) (challenge.Challenge, error) {
cloned := record
cloned.CodeHash = bytes.Clone(record.CodeHash)
cloned.Abuse = cloneAbuseMetadata(record.Abuse)
if record.Confirmation != nil {
confirmation, err := cloneChallengeConfirmation(*record.Confirmation)
if err != nil {
return challenge.Challenge{}, err
}
cloned.Confirmation = &confirmation
}
return cloned, nil
}
func cloneChallengeConfirmation(value challenge.Confirmation) (challenge.Confirmation, error) {
cloned := value
if value.ClientPublicKey.IsZero() {
cloned.ClientPublicKey = common.ClientPublicKey{}
return cloned, nil
}
key, err := common.NewClientPublicKey(value.ClientPublicKey.PublicKey())
if err != nil {
return challenge.Confirmation{}, fmt.Errorf("clone challenge confirmation client public key: %w", err)
}
cloned.ClientPublicKey = key
return cloned, nil
}
func cloneAbuseMetadata(value challenge.AbuseMetadata) challenge.AbuseMetadata {
cloned := value
if value.LastAttemptAt != nil {
lastAttemptAt := *value.LastAttemptAt
cloned.LastAttemptAt = &lastAttemptAt
}
return cloned
}
func cloneSession(record devicesession.Session) (devicesession.Session, error) {
cloned := record
if !record.ClientPublicKey.IsZero() {
key, err := common.NewClientPublicKey(record.ClientPublicKey.PublicKey())
if err != nil {
return devicesession.Session{}, fmt.Errorf("clone session client public key: %w", err)
}
cloned.ClientPublicKey = key
}
if record.Revocation != nil {
revocation := *record.Revocation
cloned.Revocation = &revocation
}
return cloned, nil
}
func cloneSessions(records []devicesession.Session) ([]devicesession.Session, error) {
cloned := make([]devicesession.Session, 0, len(records))
for _, record := range records {
session, err := cloneSession(record)
if err != nil {
return nil, err
}
cloned = append(cloned, session)
}
return cloned, nil
}
func cloneProjectionSnapshot(snapshot gatewayprojection.Snapshot) gatewayprojection.Snapshot {
cloned := snapshot
if snapshot.RevokedAt != nil {
revokedAt := *snapshot.RevokedAt
cloned.RevokedAt = &revokedAt
}
return cloned
}
func sortSessionsNewestFirst(records []devicesession.Session) {
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
switch {
case left.CreatedAt.Equal(right.CreatedAt):
return compareStrings(left.ID.String(), right.ID.String())
case left.CreatedAt.After(right.CreatedAt):
return -1
default:
return 1
}
})
}
func compareStrings(left string, right string) int {
switch {
case left < right:
return -1
case left > right:
return 1
default:
return 0
}
}
func cloneTimePointer(value *time.Time) *time.Time {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
@@ -0,0 +1,35 @@
package testkit
import (
"errors"
"strings"
"galaxy/authsession/internal/ports"
)
// FixedCodeGenerator is a deterministic CodeGenerator double that always
// returns the same code or error.
type FixedCodeGenerator struct {
// Code stores the fixed code returned by Generate when Err is nil.
Code string
// Err is returned directly from Generate when set.
Err error
}
// Generate returns the configured fixed code.
func (g FixedCodeGenerator) Generate() (string, error) {
if g.Err != nil {
return "", g.Err
}
switch {
case strings.TrimSpace(g.Code) == "":
return "", errors.New("fixed code generator code must not be empty")
case strings.TrimSpace(g.Code) != g.Code:
return "", errors.New("fixed code generator code must not contain surrounding whitespace")
default:
return g.Code, nil
}
}
var _ ports.CodeGenerator = FixedCodeGenerator{}
@@ -0,0 +1,51 @@
package testkit
import (
"crypto/sha256"
"crypto/subtle"
"errors"
"strings"
"galaxy/authsession/internal/ports"
)
// DeterministicCodeHasher is a deterministic CodeHasher double backed by
// SHA-256 for test stability.
type DeterministicCodeHasher struct{}
// Hash returns the SHA-256 digest of code.
func (DeterministicCodeHasher) Hash(code string) ([]byte, error) {
if err := validateCode(code); err != nil {
return nil, err
}
sum := sha256.Sum256([]byte(code))
return sum[:], nil
}
// Compare reports whether hash equals the deterministic hash of code.
func (h DeterministicCodeHasher) Compare(hash []byte, code string) (bool, error) {
if err := validateCode(code); err != nil {
return false, err
}
expected, err := h.Hash(code)
if err != nil {
return false, err
}
return subtle.ConstantTimeCompare(hash, expected) == 1, nil
}
var _ ports.CodeHasher = DeterministicCodeHasher{}
func validateCode(code string) error {
switch {
case strings.TrimSpace(code) == "":
return errors.New("code must not be empty")
case strings.TrimSpace(code) != code:
return errors.New("code must not contain surrounding whitespace")
default:
return nil
}
}
@@ -0,0 +1,34 @@
package testkit
import (
"context"
"galaxy/authsession/internal/ports"
)
// StaticConfigProvider is a deterministic ConfigProvider double that returns a
// preconfigured session-limit value or error.
type StaticConfigProvider struct {
// Config stores the configuration returned when Err is nil.
Config ports.SessionLimitConfig
// Err is returned directly from LoadSessionLimit when set.
Err error
}
// LoadSessionLimit returns the preconfigured session-limit result.
func (p StaticConfigProvider) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) {
if err := ctx.Err(); err != nil {
return ports.SessionLimitConfig{}, err
}
if p.Err != nil {
return ports.SessionLimitConfig{}, p.Err
}
if err := p.Config.Validate(); err != nil {
return ports.SessionLimitConfig{}, err
}
return p.Config, nil
}
var _ ports.ConfigProvider = StaticConfigProvider{}
+4
View File
@@ -0,0 +1,4 @@
// Package testkit provides deterministic in-memory doubles for auth/session
// service ports so later service tests can run without Redis, HTTP, or other
// external dependencies.
package testkit
@@ -0,0 +1,101 @@
package testkit
import (
"fmt"
"sync"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
)
// SequenceIDGenerator is a deterministic IDGenerator double that consumes
// queued identifiers before falling back to monotonic generated ids.
type SequenceIDGenerator struct {
mu sync.Mutex
// ChallengeIDs stores queued challenge identifiers returned by
// NewChallengeID before generated ids are used.
ChallengeIDs []common.ChallengeID
// DeviceSessionIDs stores queued device-session identifiers returned by
// NewDeviceSessionID before generated ids are used.
DeviceSessionIDs []common.DeviceSessionID
// ChallengeErr is returned directly from NewChallengeID when set.
ChallengeErr error
// DeviceSessionErr is returned directly from NewDeviceSessionID when set.
DeviceSessionErr error
ChallengePrefix string
DeviceSessionPrefix string
nextChallengeNumber int
nextSessionNumber int
}
// NewChallengeID returns the next deterministic challenge identifier.
func (g *SequenceIDGenerator) NewChallengeID() (common.ChallengeID, error) {
if g.ChallengeErr != nil {
return "", g.ChallengeErr
}
g.mu.Lock()
defer g.mu.Unlock()
if len(g.ChallengeIDs) > 0 {
id := g.ChallengeIDs[0]
g.ChallengeIDs = g.ChallengeIDs[1:]
if err := id.Validate(); err != nil {
return "", err
}
return id, nil
}
g.nextChallengeNumber++
prefix := g.ChallengePrefix
if prefix == "" {
prefix = "challenge-"
}
id := common.ChallengeID(fmt.Sprintf("%s%d", prefix, g.nextChallengeNumber))
if err := id.Validate(); err != nil {
return "", err
}
return id, nil
}
// NewDeviceSessionID returns the next deterministic device-session
// identifier.
func (g *SequenceIDGenerator) NewDeviceSessionID() (common.DeviceSessionID, error) {
if g.DeviceSessionErr != nil {
return "", g.DeviceSessionErr
}
g.mu.Lock()
defer g.mu.Unlock()
if len(g.DeviceSessionIDs) > 0 {
id := g.DeviceSessionIDs[0]
g.DeviceSessionIDs = g.DeviceSessionIDs[1:]
if err := id.Validate(); err != nil {
return "", err
}
return id, nil
}
g.nextSessionNumber++
prefix := g.DeviceSessionPrefix
if prefix == "" {
prefix = "device-session-"
}
id := common.DeviceSessionID(fmt.Sprintf("%s%d", prefix, g.nextSessionNumber))
if err := id.Validate(); err != nil {
return "", err
}
return id, nil
}
var _ ports.IDGenerator = (*SequenceIDGenerator)(nil)
@@ -0,0 +1,73 @@
package testkit
import (
"context"
"sync"
"galaxy/authsession/internal/ports"
)
// RecordingMailSender is a deterministic MailSender double that records every
// delivery request and returns preconfigured outcomes or errors.
type RecordingMailSender struct {
mu sync.Mutex
// Results stores queued results consumed by SendLoginCode before
// DefaultResult is used.
Results []ports.SendLoginCodeResult
// DefaultResult stores the result used when Results is empty.
DefaultResult ports.SendLoginCodeResult
// Err is returned directly from SendLoginCode when set.
Err error
recordedInputs []ports.SendLoginCodeInput
}
// SendLoginCode records input and returns the next configured result.
func (s *RecordingMailSender) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
if err := ctx.Err(); err != nil {
return ports.SendLoginCodeResult{}, err
}
if err := input.Validate(); err != nil {
return ports.SendLoginCodeResult{}, err
}
s.mu.Lock()
defer s.mu.Unlock()
s.recordedInputs = append(s.recordedInputs, input)
if s.Err != nil {
return ports.SendLoginCodeResult{}, s.Err
}
if len(s.Results) > 0 {
result := s.Results[0]
s.Results = s.Results[1:]
if err := result.Validate(); err != nil {
return ports.SendLoginCodeResult{}, err
}
return result, nil
}
result := s.DefaultResult
if result.Outcome == "" {
result.Outcome = ports.SendLoginCodeOutcomeSent
}
if err := result.Validate(); err != nil {
return ports.SendLoginCodeResult{}, err
}
return result, nil
}
// RecordedInputs returns a stable snapshot of every recorded mail request.
func (s *RecordingMailSender) RecordedInputs() []ports.SendLoginCodeInput {
s.mu.Lock()
defer s.mu.Unlock()
return append([]ports.SendLoginCodeInput(nil), s.recordedInputs...)
}
var _ ports.MailSender = (*RecordingMailSender)(nil)
@@ -0,0 +1,62 @@
package testkit
import (
"context"
"sync"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/ports"
)
// RecordingProjectionPublisher is a deterministic
// GatewaySessionProjectionPublisher double that records every published
// snapshot.
type RecordingProjectionPublisher struct {
mu sync.Mutex
// Err is returned directly from PublishSession when set.
Err error
// Errors is an optional FIFO error script consumed before Err. Nil entries
// represent successful publish attempts.
Errors []error
published []gatewayprojection.Snapshot
}
// PublishSession records snapshot and returns the configured error, if any.
func (p *RecordingProjectionPublisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error {
if err := ctx.Err(); err != nil {
return err
}
if err := snapshot.Validate(); err != nil {
return err
}
p.mu.Lock()
defer p.mu.Unlock()
p.published = append(p.published, cloneProjectionSnapshot(snapshot))
if len(p.Errors) > 0 {
err := p.Errors[0]
p.Errors = append([]error(nil), p.Errors[1:]...)
return err
}
return p.Err
}
// PublishedSnapshots returns a stable snapshot of every published projection.
func (p *RecordingProjectionPublisher) PublishedSnapshots() []gatewayprojection.Snapshot {
p.mu.Lock()
defer p.mu.Unlock()
snapshots := make([]gatewayprojection.Snapshot, 0, len(p.published))
for _, snapshot := range p.published {
snapshots = append(snapshots, cloneProjectionSnapshot(snapshot))
}
return snapshots
}
var _ ports.GatewaySessionProjectionPublisher = (*RecordingProjectionPublisher)(nil)
@@ -0,0 +1,48 @@
package testkit
import (
"context"
"errors"
"testing"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/gatewayprojection"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRecordingProjectionPublisherConsumesScriptedErrorsAndRecordsAttempts(t *testing.T) {
t.Parallel()
publisher := &RecordingProjectionPublisher{
Errors: []error{errors.New("first publish failed"), nil},
}
snapshot := projectionSnapshotFixture()
err := publisher.PublishSession(context.Background(), snapshot)
require.Error(t, err)
err = publisher.PublishSession(context.Background(), snapshot)
require.NoError(t, err)
published := publisher.PublishedSnapshots()
require.Len(t, published, 2)
assert.Equal(t, snapshot.DeviceSessionID, published[0].DeviceSessionID)
assert.Equal(t, snapshot.DeviceSessionID, published[1].DeviceSessionID)
published[0].ClientPublicKey = "mutated"
stable := publisher.PublishedSnapshots()
require.Len(t, stable, 2)
assert.Equal(t, snapshot.ClientPublicKey, stable[0].ClientPublicKey)
}
func projectionSnapshotFixture() gatewayprojection.Snapshot {
return gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID("device-session-1"),
UserID: common.UserID("user-1"),
ClientPublicKey: "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
Status: gatewayprojection.StatusActive,
}
}
@@ -0,0 +1,58 @@
package testkit
import (
"context"
"fmt"
"sync"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
)
// InMemorySendEmailCodeAbuseProtector is a deterministic map-backed
// SendEmailCodeAbuseProtector double suitable for service tests.
type InMemorySendEmailCodeAbuseProtector struct {
mu sync.Mutex
// Err is returned directly from CheckAndReserve when set.
Err error
reservedUntil map[common.Email]time.Time
}
// CheckAndReserve applies the fixed resend cooldown using input.Now as the
// authoritative decision timestamp.
func (p *InMemorySendEmailCodeAbuseProtector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
if err := ctx.Err(); err != nil {
return ports.SendEmailCodeAbuseResult{}, err
}
if err := input.Validate(); err != nil {
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check send email code abuse: %w", err)
}
if p.Err != nil {
return ports.SendEmailCodeAbuseResult{}, p.Err
}
p.mu.Lock()
defer p.mu.Unlock()
if p.reservedUntil == nil {
p.reservedUntil = make(map[common.Email]time.Time)
}
reservedUntil, exists := p.reservedUntil[input.Email]
if exists && input.Now.Before(reservedUntil) {
return ports.SendEmailCodeAbuseResult{
Outcome: ports.SendEmailCodeAbuseOutcomeThrottled,
}, nil
}
p.reservedUntil[input.Email] = input.Now.UTC().Add(challenge.ResendThrottleCooldown)
return ports.SendEmailCodeAbuseResult{
Outcome: ports.SendEmailCodeAbuseOutcomeAllowed,
}, nil
}
var _ ports.SendEmailCodeAbuseProtector = (*InMemorySendEmailCodeAbuseProtector)(nil)
@@ -0,0 +1,42 @@
package testkit
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInMemorySendEmailCodeAbuseProtector(t *testing.T) {
t.Parallel()
protector := &InMemorySendEmailCodeAbuseProtector{}
email := common.Email("pilot@example.com")
now := time.Unix(10, 0).UTC()
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now,
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(30 * time.Second),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(time.Minute),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
}
@@ -0,0 +1,229 @@
package testkit
import (
"context"
"fmt"
"slices"
"sync"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
)
// InMemorySessionStore is a deterministic map-backed SessionStore double
// suitable for service tests.
type InMemorySessionStore struct {
mu sync.Mutex
records map[common.DeviceSessionID]devicesession.Session
}
// Get returns the stored device session for deviceSessionID.
func (s *InMemorySessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
if err := ctx.Err(); err != nil {
return devicesession.Session{}, err
}
if err := deviceSessionID.Validate(); err != nil {
return devicesession.Session{}, fmt.Errorf("get session: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
record, ok := s.records[deviceSessionID]
if !ok {
return devicesession.Session{}, fmt.Errorf("get session %q: %w", deviceSessionID, ports.ErrNotFound)
}
cloned, err := cloneSession(record)
if err != nil {
return devicesession.Session{}, err
}
return cloned, nil
}
// ListByUserID returns every stored session for userID in newest-first order.
func (s *InMemorySessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list sessions by user id: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
var records []devicesession.Session
for _, record := range s.records {
if record.UserID == userID {
cloned, err := cloneSession(record)
if err != nil {
return nil, err
}
records = append(records, cloned)
}
}
sortSessionsNewestFirst(records)
return records, nil
}
// CountActiveByUserID returns the number of active sessions currently stored
// for userID.
func (s *InMemorySessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
if err := ctx.Err(); err != nil {
return 0, err
}
if err := userID.Validate(); err != nil {
return 0, fmt.Errorf("count active sessions by user id: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, record := range s.records {
if record.UserID == userID && record.Status == devicesession.StatusActive {
count++
}
}
return count, nil
}
// Create stores record as a new device session.
func (s *InMemorySessionStore) Create(ctx context.Context, record devicesession.Session) error {
if err := ctx.Err(); err != nil {
return err
}
if err := record.Validate(); err != nil {
return fmt.Errorf("create session: %w", err)
}
cloned, err := cloneSession(record)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.records == nil {
s.records = make(map[common.DeviceSessionID]devicesession.Session)
}
if _, exists := s.records[record.ID]; exists {
return fmt.Errorf("create session %q: %w", record.ID, ports.ErrConflict)
}
s.records[record.ID] = cloned
return nil
}
// Revoke stores a revoked view of one target session.
func (s *InMemorySessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
if err := ctx.Err(); err != nil {
return ports.RevokeSessionResult{}, err
}
if err := input.Validate(); err != nil {
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
record, ok := s.records[input.DeviceSessionID]
if !ok {
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session %q: %w", input.DeviceSessionID, ports.ErrNotFound)
}
if record.Status == devicesession.StatusRevoked {
cloned, err := cloneSession(record)
if err != nil {
return ports.RevokeSessionResult{}, err
}
result := ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
Session: cloned,
}
if err := result.Validate(); err != nil {
return ports.RevokeSessionResult{}, err
}
return result, nil
}
record.Status = devicesession.StatusRevoked
revocation := input.Revocation
record.Revocation = &revocation
cloned, err := cloneSession(record)
if err != nil {
return ports.RevokeSessionResult{}, err
}
s.records[input.DeviceSessionID] = cloned
result := ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeRevoked,
Session: cloned,
}
if err := result.Validate(); err != nil {
return ports.RevokeSessionResult{}, err
}
return result, nil
}
// RevokeAllByUserID stores revoked views for all currently active sessions
// owned by input.UserID.
func (s *InMemorySessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
if err := ctx.Err(); err != nil {
return ports.RevokeUserSessionsResult{}, err
}
if err := input.Validate(); err != nil {
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
var affected []devicesession.Session
for id, record := range s.records {
if record.UserID != input.UserID || record.Status != devicesession.StatusActive {
continue
}
record.Status = devicesession.StatusRevoked
revocation := input.Revocation
record.Revocation = &revocation
cloned, err := cloneSession(record)
if err != nil {
return ports.RevokeUserSessionsResult{}, err
}
s.records[id] = cloned
affected = append(affected, cloned)
}
sortSessionsNewestFirst(affected)
outcome := ports.RevokeUserSessionsOutcomeNoActiveSessions
if len(affected) > 0 {
outcome = ports.RevokeUserSessionsOutcomeRevoked
}
result := ports.RevokeUserSessionsResult{
Outcome: outcome,
UserID: input.UserID,
Sessions: slices.Clone(affected),
}
if err := result.Validate(); err != nil {
return ports.RevokeUserSessionsResult{}, err
}
return result, nil
}
var _ ports.SessionStore = (*InMemorySessionStore)(nil)
@@ -0,0 +1,182 @@
package testkit
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
)
func TestInMemorySessionStoreCreateAndGet(t *testing.T) {
t.Parallel()
store := &InMemorySessionStore{}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
got, err := store.Get(context.Background(), record.ID)
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if got.ID != record.ID {
require.Failf(t, "test failed", "Get().ID = %q, want %q", got.ID, record.ID)
}
}
func TestInMemorySessionStoreListByUserIDNewestFirst(t *testing.T) {
t.Parallel()
store := &InMemorySessionStore{}
older := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())
otherUser := activeSessionFixture("device-session-3", "user-2", time.Unix(30, 0).UTC())
for _, record := range []devicesession.Session{older, newer, otherUser} {
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
}
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
if err != nil {
require.Failf(t, "test failed", "ListByUserID() returned error: %v", err)
}
if len(got) != 2 {
require.Failf(t, "test failed", "ListByUserID() length = %d, want 2", len(got))
}
if got[0].ID != newer.ID || got[1].ID != older.ID {
require.Failf(t, "test failed", "ListByUserID() order = [%q %q], want [%q %q]", got[0].ID, got[1].ID, newer.ID, older.ID)
}
}
func TestInMemorySessionStoreCountActiveByUserID(t *testing.T) {
t.Parallel()
store := &InMemorySessionStore{}
active := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
revoked := revokedSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())
for _, record := range []devicesession.Session{active, revoked} {
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
}
got, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
if err != nil {
require.Failf(t, "test failed", "CountActiveByUserID() returned error: %v", err)
}
if got != 1 {
require.Failf(t, "test failed", "CountActiveByUserID() = %d, want 1", got)
}
}
func TestInMemorySessionStoreRevokeIsIdempotent(t *testing.T) {
t.Parallel()
store := &InMemorySessionStore{}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
input := ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: devicesession.Revocation{
At: time.Unix(30, 0).UTC(),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
},
}
first, err := store.Revoke(context.Background(), input)
if err != nil {
require.Failf(t, "test failed", "first Revoke() returned error: %v", err)
}
if first.Outcome != ports.RevokeSessionOutcomeRevoked {
require.Failf(t, "test failed", "first Revoke() outcome = %q, want %q", first.Outcome, ports.RevokeSessionOutcomeRevoked)
}
second, err := store.Revoke(context.Background(), input)
if err != nil {
require.Failf(t, "test failed", "second Revoke() returned error: %v", err)
}
if second.Outcome != ports.RevokeSessionOutcomeAlreadyRevoked {
require.Failf(t, "test failed", "second Revoke() outcome = %q, want %q", second.Outcome, ports.RevokeSessionOutcomeAlreadyRevoked)
}
}
func TestInMemorySessionStoreRevokeAllNoActiveSessions(t *testing.T) {
t.Parallel()
store := &InMemorySessionStore{}
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
input := ports.RevokeUserSessionsInput{
UserID: common.UserID("user-1"),
Revocation: devicesession.Revocation{
At: time.Unix(40, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
},
}
result, err := store.RevokeAllByUserID(context.Background(), input)
if err != nil {
require.Failf(t, "test failed", "RevokeAllByUserID() returned error: %v", err)
}
if result.Outcome != ports.RevokeUserSessionsOutcomeNoActiveSessions {
require.Failf(t, "test failed", "RevokeAllByUserID() outcome = %q, want %q", result.Outcome, ports.RevokeUserSessionsOutcomeNoActiveSessions)
}
if len(result.Sessions) != 0 {
require.Failf(t, "test failed", "RevokeAllByUserID() session count = %d, want 0", len(result.Sessions))
}
}
func TestInMemorySessionStoreGetNotFound(t *testing.T) {
t.Parallel()
store := &InMemorySessionStore{}
_, err := store.Get(context.Background(), common.DeviceSessionID("missing"))
if !errors.Is(err, ports.ErrNotFound) {
require.Failf(t, "test failed", "Get() error = %v, want ErrNotFound", err)
}
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
record := activeSessionFixture(deviceSessionID, userID, createdAt)
record.Status = devicesession.StatusRevoked
record.Revocation = &devicesession.Revocation{
At: createdAt.Add(time.Minute),
ReasonCode: devicesession.RevokeReasonDeviceLogout,
ActorType: common.RevokeActorType("user"),
}
return record
}
@@ -0,0 +1,147 @@
package testkit
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/ports"
)
func TestStaticConfigProvider(t *testing.T) {
t.Parallel()
limit := 4
provider := StaticConfigProvider{
Config: ports.SessionLimitConfig{ActiveSessionLimit: &limit},
}
got, err := provider.LoadSessionLimit(context.Background())
if err != nil {
require.Failf(t, "test failed", "LoadSessionLimit() returned error: %v", err)
}
if got.ActiveSessionLimit == nil || *got.ActiveSessionLimit != limit {
require.Failf(t, "test failed", "LoadSessionLimit() = %+v, want limit %d", got, limit)
}
}
func TestSequenceIDGenerator(t *testing.T) {
t.Parallel()
generator := &SequenceIDGenerator{
ChallengeIDs: []common.ChallengeID{"challenge-queue"},
DeviceSessionIDs: []common.DeviceSessionID{"device-session-queue"},
}
challengeID, err := generator.NewChallengeID()
if err != nil {
require.Failf(t, "test failed", "NewChallengeID() returned error: %v", err)
}
if challengeID != common.ChallengeID("challenge-queue") {
require.Failf(t, "test failed", "NewChallengeID() = %q, want queued id", challengeID)
}
deviceSessionID, err := generator.NewDeviceSessionID()
if err != nil {
require.Failf(t, "test failed", "NewDeviceSessionID() returned error: %v", err)
}
if deviceSessionID != common.DeviceSessionID("device-session-queue") {
require.Failf(t, "test failed", "NewDeviceSessionID() = %q, want queued id", deviceSessionID)
}
}
func TestFixedCodeGenerator(t *testing.T) {
t.Parallel()
generator := FixedCodeGenerator{Code: "123456"}
got, err := generator.Generate()
if err != nil {
require.Failf(t, "test failed", "Generate() returned error: %v", err)
}
if got != "123456" {
require.Failf(t, "test failed", "Generate() = %q, want %q", got, "123456")
}
}
func TestDeterministicCodeHasher(t *testing.T) {
t.Parallel()
hasher := DeterministicCodeHasher{}
hash, err := hasher.Hash("123456")
if err != nil {
require.Failf(t, "test failed", "Hash() returned error: %v", err)
}
match, err := hasher.Compare(hash, "123456")
if err != nil {
require.Failf(t, "test failed", "Compare() returned error: %v", err)
}
if !match {
require.FailNow(t, "Compare() = false, want true")
}
}
func TestRecordingMailSender(t *testing.T) {
t.Parallel()
sender := &RecordingMailSender{
Results: []ports.SendLoginCodeResult{
{Outcome: ports.SendLoginCodeOutcomeSuppressed},
},
}
result, err := sender.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email("pilot@example.com"),
Code: "654321",
})
if err != nil {
require.Failf(t, "test failed", "SendLoginCode() returned error: %v", err)
}
if result.Outcome != ports.SendLoginCodeOutcomeSuppressed {
require.Failf(t, "test failed", "SendLoginCode().Outcome = %q, want %q", result.Outcome, ports.SendLoginCodeOutcomeSuppressed)
}
if len(sender.RecordedInputs()) != 1 {
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 1", len(sender.RecordedInputs()))
}
}
func TestRecordingProjectionPublisher(t *testing.T) {
t.Parallel()
publisher := &RecordingProjectionPublisher{}
revokedAt := time.Unix(30, 0).UTC()
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID("device-session-1"),
UserID: common.UserID("user-1"),
ClientPublicKey: "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
Status: gatewayprojection.StatusRevoked,
RevokedAt: &revokedAt,
RevokeReasonCode: common.RevokeReasonCode("logout_all"),
RevokeActorType: common.RevokeActorType("system"),
}
if err := publisher.PublishSession(context.Background(), snapshot); err != nil {
require.Failf(t, "test failed", "PublishSession() returned error: %v", err)
}
if len(publisher.PublishedSnapshots()) != 1 {
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(publisher.PublishedSnapshots()))
}
}
func TestStaticConfigProviderReturnsConfiguredError(t *testing.T) {
t.Parallel()
wantErr := errors.New("config failed")
provider := StaticConfigProvider{Err: wantErr}
_, err := provider.LoadSessionLimit(context.Background())
if !errors.Is(err, wantErr) {
require.Failf(t, "test failed", "LoadSessionLimit() error = %v, want %v", err, wantErr)
}
}
@@ -0,0 +1,309 @@
package testkit
import (
"context"
"fmt"
"sync"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
)
type userDirectoryEntry struct {
UserID common.UserID
BlockReasonCode userresolution.BlockReasonCode
}
// InMemoryUserDirectory is a deterministic map-backed UserDirectory double
// suitable for service tests.
type InMemoryUserDirectory struct {
mu sync.Mutex
byEmail map[common.Email]userDirectoryEntry
emailByUserID map[common.UserID]common.Email
createdUserIDs []common.UserID
nextUserNumber int
}
// ResolveByEmail returns the current resolution state for email without
// creating a new user.
func (d *InMemoryUserDirectory) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
if err := ctx.Err(); err != nil {
return userresolution.Result{}, err
}
if err := email.Validate(); err != nil {
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
result, err := d.resolveLocked(email)
if err != nil {
return userresolution.Result{}, err
}
return result, nil
}
// ExistsByUserID reports whether userID currently identifies a stored user
// record.
func (d *InMemoryUserDirectory) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
if err := ctx.Err(); err != nil {
return false, err
}
if err := userID.Validate(); err != nil {
return false, fmt.Errorf("exists by user id: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
_, ok := d.emailByUserID[userID]
return ok, nil
}
// EnsureUserByEmail returns an existing user for email, creates a new user
// when registration is allowed, or reports a blocked outcome.
func (d *InMemoryUserDirectory) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
if err := ctx.Err(); err != nil {
return ports.EnsureUserResult{}, err
}
if err := email.Validate(); err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
if d.byEmail == nil {
d.byEmail = make(map[common.Email]userDirectoryEntry)
}
if d.emailByUserID == nil {
d.emailByUserID = make(map[common.UserID]common.Email)
}
entry, ok := d.byEmail[email]
if ok {
if !entry.BlockReasonCode.IsZero() {
result := ports.EnsureUserResult{
Outcome: ports.EnsureUserOutcomeBlocked,
BlockReasonCode: entry.BlockReasonCode,
}
return result, result.Validate()
}
result := ports.EnsureUserResult{
Outcome: ports.EnsureUserOutcomeExisting,
UserID: entry.UserID,
}
return result, result.Validate()
}
userID, err := d.nextCreatedUserIDLocked()
if err != nil {
return ports.EnsureUserResult{}, err
}
d.byEmail[email] = userDirectoryEntry{UserID: userID}
d.emailByUserID[userID] = email
result := ports.EnsureUserResult{
Outcome: ports.EnsureUserOutcomeCreated,
UserID: userID,
}
return result, result.Validate()
}
// BlockByUserID applies a block state to the user identified by input.UserID.
func (d *InMemoryUserDirectory) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
if err := ctx.Err(); err != nil {
return ports.BlockUserResult{}, err
}
if err := input.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
email, ok := d.emailByUserID[input.UserID]
if !ok {
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
}
entry := d.byEmail[email]
if !entry.BlockReasonCode.IsZero() {
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
UserID: input.UserID,
}
return result, result.Validate()
}
entry.BlockReasonCode = input.ReasonCode
d.byEmail[email] = entry
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeBlocked,
UserID: input.UserID,
}
return result, result.Validate()
}
// BlockByEmail applies a block state to input.Email even when no user record
// currently exists for that e-mail address.
func (d *InMemoryUserDirectory) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
if err := ctx.Err(); err != nil {
return ports.BlockUserResult{}, err
}
if err := input.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
if d.byEmail == nil {
d.byEmail = make(map[common.Email]userDirectoryEntry)
}
if d.emailByUserID == nil {
d.emailByUserID = make(map[common.UserID]common.Email)
}
entry := d.byEmail[input.Email]
if !entry.BlockReasonCode.IsZero() {
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
UserID: entry.UserID,
}
return result, result.Validate()
}
entry.BlockReasonCode = input.ReasonCode
d.byEmail[input.Email] = entry
if !entry.UserID.IsZero() {
d.emailByUserID[entry.UserID] = input.Email
}
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeBlocked,
UserID: entry.UserID,
}
return result, result.Validate()
}
// SeedExisting preloads one existing unblocked user record for service tests.
func (d *InMemoryUserDirectory) SeedExisting(email common.Email, userID common.UserID) error {
if err := email.Validate(); err != nil {
return fmt.Errorf("seed existing email: %w", err)
}
if err := userID.Validate(); err != nil {
return fmt.Errorf("seed existing user id: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
if d.byEmail == nil {
d.byEmail = make(map[common.Email]userDirectoryEntry)
}
if d.emailByUserID == nil {
d.emailByUserID = make(map[common.UserID]common.Email)
}
d.byEmail[email] = userDirectoryEntry{UserID: userID}
d.emailByUserID[userID] = email
return nil
}
// SeedBlockedEmail preloads one blocked e-mail address that does not
// necessarily belong to an existing user record.
func (d *InMemoryUserDirectory) SeedBlockedEmail(email common.Email, reasonCode userresolution.BlockReasonCode) error {
if err := email.Validate(); err != nil {
return fmt.Errorf("seed blocked email: %w", err)
}
if err := reasonCode.Validate(); err != nil {
return fmt.Errorf("seed blocked email reason code: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
if d.byEmail == nil {
d.byEmail = make(map[common.Email]userDirectoryEntry)
}
d.byEmail[email] = userDirectoryEntry{BlockReasonCode: reasonCode}
return nil
}
// SeedBlockedUser preloads one blocked existing user record for service tests.
func (d *InMemoryUserDirectory) SeedBlockedUser(email common.Email, userID common.UserID, reasonCode userresolution.BlockReasonCode) error {
if err := d.SeedExisting(email, userID); err != nil {
return err
}
d.mu.Lock()
defer d.mu.Unlock()
entry := d.byEmail[email]
entry.BlockReasonCode = reasonCode
d.byEmail[email] = entry
return nil
}
// QueueCreatedUserIDs appends deterministic user identifiers that
// EnsureUserByEmail will consume before falling back to generated ids.
func (d *InMemoryUserDirectory) QueueCreatedUserIDs(userIDs ...common.UserID) error {
for index, userID := range userIDs {
if err := userID.Validate(); err != nil {
return fmt.Errorf("queue created user id %d: %w", index, err)
}
}
d.mu.Lock()
defer d.mu.Unlock()
d.createdUserIDs = append(d.createdUserIDs, userIDs...)
return nil
}
var _ ports.UserDirectory = (*InMemoryUserDirectory)(nil)
func (d *InMemoryUserDirectory) resolveLocked(email common.Email) (userresolution.Result, error) {
entry, ok := d.byEmail[email]
if !ok {
result := userresolution.Result{Kind: userresolution.KindCreatable}
return result, result.Validate()
}
if !entry.BlockReasonCode.IsZero() {
result := userresolution.Result{
Kind: userresolution.KindBlocked,
BlockReasonCode: entry.BlockReasonCode,
}
return result, result.Validate()
}
result := userresolution.Result{
Kind: userresolution.KindExisting,
UserID: entry.UserID,
}
return result, result.Validate()
}
func (d *InMemoryUserDirectory) nextCreatedUserIDLocked() (common.UserID, error) {
if len(d.createdUserIDs) > 0 {
userID := d.createdUserIDs[0]
d.createdUserIDs = d.createdUserIDs[1:]
return userID, nil
}
d.nextUserNumber++
userID := common.UserID(fmt.Sprintf("user-%d", d.nextUserNumber))
if err := userID.Validate(); err != nil {
return "", err
}
return userID, nil
}
@@ -0,0 +1,203 @@
package testkit
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"testing"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
)
func TestInMemoryUserDirectoryResolveExistingCreatableAndBlocked(t *testing.T) {
t.Parallel()
directory := &InMemoryUserDirectory{}
if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
}
tests := []struct {
name string
email common.Email
wantKind userresolution.Kind
}{
{name: "existing", email: common.Email("existing@example.com"), wantKind: userresolution.KindExisting},
{name: "creatable", email: common.Email("new@example.com"), wantKind: userresolution.KindCreatable},
{name: "blocked", email: common.Email("blocked@example.com"), wantKind: userresolution.KindBlocked},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := directory.ResolveByEmail(context.Background(), tt.email)
if err != nil {
require.Failf(t, "test failed", "ResolveByEmail() returned error: %v", err)
}
if got.Kind != tt.wantKind {
require.Failf(t, "test failed", "ResolveByEmail().Kind = %q, want %q", got.Kind, tt.wantKind)
}
})
}
}
func TestInMemoryUserDirectoryEnsureUserExistingCreatedAndBlocked(t *testing.T) {
t.Parallel()
directory := &InMemoryUserDirectory{}
if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
}
if err := directory.QueueCreatedUserIDs(common.UserID("user-created")); err != nil {
require.Failf(t, "test failed", "QueueCreatedUserIDs() returned error: %v", err)
}
tests := []struct {
name string
email common.Email
wantOutcome ports.EnsureUserOutcome
wantUserID common.UserID
}{
{
name: "existing",
email: common.Email("existing@example.com"),
wantOutcome: ports.EnsureUserOutcomeExisting,
wantUserID: common.UserID("user-existing"),
},
{
name: "created",
email: common.Email("created@example.com"),
wantOutcome: ports.EnsureUserOutcomeCreated,
wantUserID: common.UserID("user-created"),
},
{
name: "blocked",
email: common.Email("blocked@example.com"),
wantOutcome: ports.EnsureUserOutcomeBlocked,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := directory.EnsureUserByEmail(context.Background(), tt.email)
if err != nil {
require.Failf(t, "test failed", "EnsureUserByEmail() returned error: %v", err)
}
if got.Outcome != tt.wantOutcome {
require.Failf(t, "test failed", "EnsureUserByEmail().Outcome = %q, want %q", got.Outcome, tt.wantOutcome)
}
if got.UserID != tt.wantUserID {
require.Failf(t, "test failed", "EnsureUserByEmail().UserID = %q, want %q", got.UserID, tt.wantUserID)
}
})
}
}
func TestInMemoryUserDirectoryExistsByUserID(t *testing.T) {
t.Parallel()
directory := &InMemoryUserDirectory{}
if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
exists, err := directory.ExistsByUserID(context.Background(), common.UserID("user-existing"))
if err != nil {
require.Failf(t, "test failed", "ExistsByUserID() returned error: %v", err)
}
if !exists {
require.FailNow(t, "ExistsByUserID() = false, want true")
}
exists, err = directory.ExistsByUserID(context.Background(), common.UserID("missing"))
if err != nil {
require.Failf(t, "test failed", "ExistsByUserID() returned error: %v", err)
}
if exists {
require.FailNow(t, "ExistsByUserID() = true, want false")
}
}
func TestInMemoryUserDirectoryBlockByEmail(t *testing.T) {
t.Parallel()
directory := &InMemoryUserDirectory{}
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("blocked@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
if err != nil {
require.Failf(t, "test failed", "BlockByEmail() returned error: %v", err)
}
if result.Outcome != ports.BlockUserOutcomeBlocked {
require.Failf(t, "test failed", "BlockByEmail().Outcome = %q, want %q", result.Outcome, ports.BlockUserOutcomeBlocked)
}
resolution, err := directory.ResolveByEmail(context.Background(), common.Email("blocked@example.com"))
if err != nil {
require.Failf(t, "test failed", "ResolveByEmail() returned error: %v", err)
}
if resolution.Kind != userresolution.KindBlocked {
require.Failf(t, "test failed", "ResolveByEmail().Kind = %q, want %q", resolution.Kind, userresolution.KindBlocked)
}
}
func TestInMemoryUserDirectoryBlockByUserID(t *testing.T) {
t.Parallel()
directory := &InMemoryUserDirectory{}
if err := directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
result, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("user-1"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
if err != nil {
require.Failf(t, "test failed", "BlockByUserID() returned error: %v", err)
}
if result.Outcome != ports.BlockUserOutcomeBlocked {
require.Failf(t, "test failed", "BlockByUserID().Outcome = %q, want %q", result.Outcome, ports.BlockUserOutcomeBlocked)
}
second, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("user-1"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
if err != nil {
require.Failf(t, "test failed", "second BlockByUserID() returned error: %v", err)
}
if second.Outcome != ports.BlockUserOutcomeAlreadyBlocked {
require.Failf(t, "test failed", "second BlockByUserID().Outcome = %q, want %q", second.Outcome, ports.BlockUserOutcomeAlreadyBlocked)
}
}
func TestInMemoryUserDirectoryBlockByUserIDNotFound(t *testing.T) {
t.Parallel()
directory := &InMemoryUserDirectory{}
_, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("missing"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
if !errors.Is(err, ports.ErrNotFound) {
require.Failf(t, "test failed", "BlockByUserID() error = %v, want ErrNotFound", err)
}
}