feat: authsession service
This commit is contained in:
@@ -0,0 +1,122 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// InMemoryChallengeStore is a deterministic map-backed ChallengeStore double
|
||||
// suitable for service tests.
|
||||
type InMemoryChallengeStore struct {
|
||||
mu sync.Mutex
|
||||
records map[common.ChallengeID]challenge.Challenge
|
||||
}
|
||||
|
||||
// Get returns the stored challenge for challengeID.
|
||||
func (s *InMemoryChallengeStore) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
if err := challengeID.Validate(); err != nil {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
record, ok := s.records[challengeID]
|
||||
if !ok {
|
||||
return challenge.Challenge{}, fmt.Errorf("get challenge %q: %w", challengeID, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
cloned, err := cloneChallenge(record)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
|
||||
return cloned, nil
|
||||
}
|
||||
|
||||
// Create stores record as a new challenge.
|
||||
func (s *InMemoryChallengeStore) Create(ctx context.Context, record challenge.Challenge) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create challenge: %w", err)
|
||||
}
|
||||
|
||||
cloned, err := cloneChallenge(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.records == nil {
|
||||
s.records = make(map[common.ChallengeID]challenge.Challenge)
|
||||
}
|
||||
if _, exists := s.records[record.ID]; exists {
|
||||
return fmt.Errorf("create challenge %q: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
s.records[record.ID] = cloned
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompareAndSwap replaces previous with next when the currently stored
|
||||
// challenge matches previous exactly.
|
||||
func (s *InMemoryChallengeStore) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
|
||||
return fmt.Errorf("compare and swap challenge: %w", err)
|
||||
}
|
||||
|
||||
clonedPrevious, err := cloneChallenge(previous)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clonedNext, err := cloneChallenge(next)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
current, ok := s.records[previous.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrNotFound)
|
||||
}
|
||||
if !reflect.DeepEqual(current, clonedPrevious) {
|
||||
return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
s.records[next.ID] = clonedNext
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ports.ChallengeStore = (*InMemoryChallengeStore)(nil)
|
||||
|
||||
func mustGetChallenge(store *InMemoryChallengeStore, challengeID common.ChallengeID) challenge.Challenge {
|
||||
record, err := store.Get(context.Background(), challengeID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func isNotFound(err error) bool {
|
||||
return errors.Is(err, ports.ErrNotFound)
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
func TestInMemoryChallengeStoreCreateAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemoryChallengeStore{}
|
||||
record := challengeFixture()
|
||||
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if got.ID != record.ID {
|
||||
require.Failf(t, "test failed", "Get().ID = %q, want %q", got.ID, record.ID)
|
||||
}
|
||||
if &got.CodeHash[0] == &record.CodeHash[0] {
|
||||
require.FailNow(t, "Get() returned aliased code hash slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryChallengeStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemoryChallengeStore{}
|
||||
|
||||
_, err := store.Get(context.Background(), common.ChallengeID("missing"))
|
||||
if !errors.Is(err, ports.ErrNotFound) {
|
||||
require.Failf(t, "test failed", "Get() error = %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryChallengeStoreCompareAndSwapConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemoryChallengeStore{}
|
||||
record := challengeFixture()
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
previous := record
|
||||
previous.Attempts.Confirm = 1
|
||||
next := record
|
||||
next.Status = challenge.StatusSent
|
||||
next.DeliveryState = challenge.DeliverySent
|
||||
|
||||
err := store.CompareAndSwap(context.Background(), previous, next)
|
||||
if !errors.Is(err, ports.ErrConflict) {
|
||||
require.Failf(t, "test failed", "CompareAndSwap() error = %v, want ErrConflict", err)
|
||||
}
|
||||
}
|
||||
|
||||
func challengeFixture() challenge.Challenge {
|
||||
timestamp := time.Unix(20, 0).UTC()
|
||||
return challenge.Challenge{
|
||||
ID: common.ChallengeID("challenge-1"),
|
||||
Email: common.Email("pilot@example.com"),
|
||||
CodeHash: []byte("hash"),
|
||||
Status: challenge.StatusPendingSend,
|
||||
DeliveryState: challenge.DeliveryPending,
|
||||
CreatedAt: timestamp,
|
||||
ExpiresAt: timestamp.Add(10 * time.Minute),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package testkit
|
||||
|
||||
import "time"
|
||||
|
||||
// FixedClock is a deterministic Clock double that always returns the same
|
||||
// instant.
|
||||
type FixedClock struct {
|
||||
// Time is the instant returned by Now.
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
// Now returns the configured instant.
|
||||
func (c FixedClock) Now() time.Time {
|
||||
return c.Time
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
)
|
||||
|
||||
func cloneChallenge(record challenge.Challenge) (challenge.Challenge, error) {
|
||||
cloned := record
|
||||
cloned.CodeHash = bytes.Clone(record.CodeHash)
|
||||
cloned.Abuse = cloneAbuseMetadata(record.Abuse)
|
||||
|
||||
if record.Confirmation != nil {
|
||||
confirmation, err := cloneChallengeConfirmation(*record.Confirmation)
|
||||
if err != nil {
|
||||
return challenge.Challenge{}, err
|
||||
}
|
||||
cloned.Confirmation = &confirmation
|
||||
}
|
||||
|
||||
return cloned, nil
|
||||
}
|
||||
|
||||
func cloneChallengeConfirmation(value challenge.Confirmation) (challenge.Confirmation, error) {
|
||||
cloned := value
|
||||
|
||||
if value.ClientPublicKey.IsZero() {
|
||||
cloned.ClientPublicKey = common.ClientPublicKey{}
|
||||
return cloned, nil
|
||||
}
|
||||
|
||||
key, err := common.NewClientPublicKey(value.ClientPublicKey.PublicKey())
|
||||
if err != nil {
|
||||
return challenge.Confirmation{}, fmt.Errorf("clone challenge confirmation client public key: %w", err)
|
||||
}
|
||||
cloned.ClientPublicKey = key
|
||||
|
||||
return cloned, nil
|
||||
}
|
||||
|
||||
func cloneAbuseMetadata(value challenge.AbuseMetadata) challenge.AbuseMetadata {
|
||||
cloned := value
|
||||
if value.LastAttemptAt != nil {
|
||||
lastAttemptAt := *value.LastAttemptAt
|
||||
cloned.LastAttemptAt = &lastAttemptAt
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneSession(record devicesession.Session) (devicesession.Session, error) {
|
||||
cloned := record
|
||||
|
||||
if !record.ClientPublicKey.IsZero() {
|
||||
key, err := common.NewClientPublicKey(record.ClientPublicKey.PublicKey())
|
||||
if err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("clone session client public key: %w", err)
|
||||
}
|
||||
cloned.ClientPublicKey = key
|
||||
}
|
||||
if record.Revocation != nil {
|
||||
revocation := *record.Revocation
|
||||
cloned.Revocation = &revocation
|
||||
}
|
||||
|
||||
return cloned, nil
|
||||
}
|
||||
|
||||
func cloneSessions(records []devicesession.Session) ([]devicesession.Session, error) {
|
||||
cloned := make([]devicesession.Session, 0, len(records))
|
||||
for _, record := range records {
|
||||
session, err := cloneSession(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cloned = append(cloned, session)
|
||||
}
|
||||
|
||||
return cloned, nil
|
||||
}
|
||||
|
||||
func cloneProjectionSnapshot(snapshot gatewayprojection.Snapshot) gatewayprojection.Snapshot {
|
||||
cloned := snapshot
|
||||
if snapshot.RevokedAt != nil {
|
||||
revokedAt := *snapshot.RevokedAt
|
||||
cloned.RevokedAt = &revokedAt
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
func sortSessionsNewestFirst(records []devicesession.Session) {
|
||||
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
|
||||
switch {
|
||||
case left.CreatedAt.Equal(right.CreatedAt):
|
||||
return compareStrings(left.ID.String(), right.ID.String())
|
||||
case left.CreatedAt.After(right.CreatedAt):
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func compareStrings(left string, right string) int {
|
||||
switch {
|
||||
case left < right:
|
||||
return -1
|
||||
case left > right:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func cloneTimePointer(value *time.Time) *time.Time {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// FixedCodeGenerator is a deterministic CodeGenerator double that always
|
||||
// returns the same code or error.
|
||||
type FixedCodeGenerator struct {
|
||||
// Code stores the fixed code returned by Generate when Err is nil.
|
||||
Code string
|
||||
|
||||
// Err is returned directly from Generate when set.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Generate returns the configured fixed code.
|
||||
func (g FixedCodeGenerator) Generate() (string, error) {
|
||||
if g.Err != nil {
|
||||
return "", g.Err
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(g.Code) == "":
|
||||
return "", errors.New("fixed code generator code must not be empty")
|
||||
case strings.TrimSpace(g.Code) != g.Code:
|
||||
return "", errors.New("fixed code generator code must not contain surrounding whitespace")
|
||||
default:
|
||||
return g.Code, nil
|
||||
}
|
||||
}
|
||||
|
||||
var _ ports.CodeGenerator = FixedCodeGenerator{}
|
||||
@@ -0,0 +1,51 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// DeterministicCodeHasher is a deterministic CodeHasher double backed by
|
||||
// SHA-256 for test stability.
|
||||
type DeterministicCodeHasher struct{}
|
||||
|
||||
// Hash returns the SHA-256 digest of code.
|
||||
func (DeterministicCodeHasher) Hash(code string) ([]byte, error) {
|
||||
if err := validateCode(code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(code))
|
||||
return sum[:], nil
|
||||
}
|
||||
|
||||
// Compare reports whether hash equals the deterministic hash of code.
|
||||
func (h DeterministicCodeHasher) Compare(hash []byte, code string) (bool, error) {
|
||||
if err := validateCode(code); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
expected, err := h.Hash(code)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare(hash, expected) == 1, nil
|
||||
}
|
||||
|
||||
var _ ports.CodeHasher = DeterministicCodeHasher{}
|
||||
|
||||
func validateCode(code string) error {
|
||||
switch {
|
||||
case strings.TrimSpace(code) == "":
|
||||
return errors.New("code must not be empty")
|
||||
case strings.TrimSpace(code) != code:
|
||||
return errors.New("code must not contain surrounding whitespace")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// StaticConfigProvider is a deterministic ConfigProvider double that returns a
|
||||
// preconfigured session-limit value or error.
|
||||
type StaticConfigProvider struct {
|
||||
// Config stores the configuration returned when Err is nil.
|
||||
Config ports.SessionLimitConfig
|
||||
|
||||
// Err is returned directly from LoadSessionLimit when set.
|
||||
Err error
|
||||
}
|
||||
|
||||
// LoadSessionLimit returns the preconfigured session-limit result.
|
||||
func (p StaticConfigProvider) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.SessionLimitConfig{}, err
|
||||
}
|
||||
if p.Err != nil {
|
||||
return ports.SessionLimitConfig{}, p.Err
|
||||
}
|
||||
if err := p.Config.Validate(); err != nil {
|
||||
return ports.SessionLimitConfig{}, err
|
||||
}
|
||||
|
||||
return p.Config, nil
|
||||
}
|
||||
|
||||
var _ ports.ConfigProvider = StaticConfigProvider{}
|
||||
@@ -0,0 +1,4 @@
|
||||
// Package testkit provides deterministic in-memory doubles for auth/session
|
||||
// service ports so later service tests can run without Redis, HTTP, or other
|
||||
// external dependencies.
|
||||
package testkit
|
||||
@@ -0,0 +1,101 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// SequenceIDGenerator is a deterministic IDGenerator double that consumes
|
||||
// queued identifiers before falling back to monotonic generated ids.
|
||||
type SequenceIDGenerator struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// ChallengeIDs stores queued challenge identifiers returned by
|
||||
// NewChallengeID before generated ids are used.
|
||||
ChallengeIDs []common.ChallengeID
|
||||
|
||||
// DeviceSessionIDs stores queued device-session identifiers returned by
|
||||
// NewDeviceSessionID before generated ids are used.
|
||||
DeviceSessionIDs []common.DeviceSessionID
|
||||
|
||||
// ChallengeErr is returned directly from NewChallengeID when set.
|
||||
ChallengeErr error
|
||||
|
||||
// DeviceSessionErr is returned directly from NewDeviceSessionID when set.
|
||||
DeviceSessionErr error
|
||||
|
||||
ChallengePrefix string
|
||||
DeviceSessionPrefix string
|
||||
nextChallengeNumber int
|
||||
nextSessionNumber int
|
||||
}
|
||||
|
||||
// NewChallengeID returns the next deterministic challenge identifier.
|
||||
func (g *SequenceIDGenerator) NewChallengeID() (common.ChallengeID, error) {
|
||||
if g.ChallengeErr != nil {
|
||||
return "", g.ChallengeErr
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if len(g.ChallengeIDs) > 0 {
|
||||
id := g.ChallengeIDs[0]
|
||||
g.ChallengeIDs = g.ChallengeIDs[1:]
|
||||
if err := id.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
g.nextChallengeNumber++
|
||||
prefix := g.ChallengePrefix
|
||||
if prefix == "" {
|
||||
prefix = "challenge-"
|
||||
}
|
||||
|
||||
id := common.ChallengeID(fmt.Sprintf("%s%d", prefix, g.nextChallengeNumber))
|
||||
if err := id.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// NewDeviceSessionID returns the next deterministic device-session
|
||||
// identifier.
|
||||
func (g *SequenceIDGenerator) NewDeviceSessionID() (common.DeviceSessionID, error) {
|
||||
if g.DeviceSessionErr != nil {
|
||||
return "", g.DeviceSessionErr
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if len(g.DeviceSessionIDs) > 0 {
|
||||
id := g.DeviceSessionIDs[0]
|
||||
g.DeviceSessionIDs = g.DeviceSessionIDs[1:]
|
||||
if err := id.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
g.nextSessionNumber++
|
||||
prefix := g.DeviceSessionPrefix
|
||||
if prefix == "" {
|
||||
prefix = "device-session-"
|
||||
}
|
||||
|
||||
id := common.DeviceSessionID(fmt.Sprintf("%s%d", prefix, g.nextSessionNumber))
|
||||
if err := id.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
var _ ports.IDGenerator = (*SequenceIDGenerator)(nil)
|
||||
@@ -0,0 +1,73 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// RecordingMailSender is a deterministic MailSender double that records every
|
||||
// delivery request and returns preconfigured outcomes or errors.
|
||||
type RecordingMailSender struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// Results stores queued results consumed by SendLoginCode before
|
||||
// DefaultResult is used.
|
||||
Results []ports.SendLoginCodeResult
|
||||
|
||||
// DefaultResult stores the result used when Results is empty.
|
||||
DefaultResult ports.SendLoginCodeResult
|
||||
|
||||
// Err is returned directly from SendLoginCode when set.
|
||||
Err error
|
||||
|
||||
recordedInputs []ports.SendLoginCodeInput
|
||||
}
|
||||
|
||||
// SendLoginCode records input and returns the next configured result.
|
||||
func (s *RecordingMailSender) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.recordedInputs = append(s.recordedInputs, input)
|
||||
if s.Err != nil {
|
||||
return ports.SendLoginCodeResult{}, s.Err
|
||||
}
|
||||
|
||||
if len(s.Results) > 0 {
|
||||
result := s.Results[0]
|
||||
s.Results = s.Results[1:]
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result := s.DefaultResult
|
||||
if result.Outcome == "" {
|
||||
result.Outcome = ports.SendLoginCodeOutcomeSent
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.SendLoginCodeResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RecordedInputs returns a stable snapshot of every recorded mail request.
|
||||
func (s *RecordingMailSender) RecordedInputs() []ports.SendLoginCodeInput {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return append([]ports.SendLoginCodeInput(nil), s.recordedInputs...)
|
||||
}
|
||||
|
||||
var _ ports.MailSender = (*RecordingMailSender)(nil)
|
||||
@@ -0,0 +1,62 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// RecordingProjectionPublisher is a deterministic
|
||||
// GatewaySessionProjectionPublisher double that records every published
|
||||
// snapshot.
|
||||
type RecordingProjectionPublisher struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// Err is returned directly from PublishSession when set.
|
||||
Err error
|
||||
|
||||
// Errors is an optional FIFO error script consumed before Err. Nil entries
|
||||
// represent successful publish attempts.
|
||||
Errors []error
|
||||
|
||||
published []gatewayprojection.Snapshot
|
||||
}
|
||||
|
||||
// PublishSession records snapshot and returns the configured error, if any.
|
||||
func (p *RecordingProjectionPublisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := snapshot.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.published = append(p.published, cloneProjectionSnapshot(snapshot))
|
||||
if len(p.Errors) > 0 {
|
||||
err := p.Errors[0]
|
||||
p.Errors = append([]error(nil), p.Errors[1:]...)
|
||||
return err
|
||||
}
|
||||
|
||||
return p.Err
|
||||
}
|
||||
|
||||
// PublishedSnapshots returns a stable snapshot of every published projection.
|
||||
func (p *RecordingProjectionPublisher) PublishedSnapshots() []gatewayprojection.Snapshot {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
snapshots := make([]gatewayprojection.Snapshot, 0, len(p.published))
|
||||
for _, snapshot := range p.published {
|
||||
snapshots = append(snapshots, cloneProjectionSnapshot(snapshot))
|
||||
}
|
||||
|
||||
return snapshots
|
||||
}
|
||||
|
||||
var _ ports.GatewaySessionProjectionPublisher = (*RecordingProjectionPublisher)(nil)
|
||||
@@ -0,0 +1,48 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecordingProjectionPublisherConsumesScriptedErrorsAndRecordsAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher := &RecordingProjectionPublisher{
|
||||
Errors: []error{errors.New("first publish failed"), nil},
|
||||
}
|
||||
snapshot := projectionSnapshotFixture()
|
||||
|
||||
err := publisher.PublishSession(context.Background(), snapshot)
|
||||
require.Error(t, err)
|
||||
|
||||
err = publisher.PublishSession(context.Background(), snapshot)
|
||||
require.NoError(t, err)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 2)
|
||||
assert.Equal(t, snapshot.DeviceSessionID, published[0].DeviceSessionID)
|
||||
assert.Equal(t, snapshot.DeviceSessionID, published[1].DeviceSessionID)
|
||||
|
||||
published[0].ClientPublicKey = "mutated"
|
||||
|
||||
stable := publisher.PublishedSnapshots()
|
||||
require.Len(t, stable, 2)
|
||||
assert.Equal(t, snapshot.ClientPublicKey, stable[0].ClientPublicKey)
|
||||
}
|
||||
|
||||
func projectionSnapshotFixture() gatewayprojection.Snapshot {
|
||||
return gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID("device-session-1"),
|
||||
UserID: common.UserID("user-1"),
|
||||
ClientPublicKey: "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
|
||||
Status: gatewayprojection.StatusActive,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// InMemorySendEmailCodeAbuseProtector is a deterministic map-backed
|
||||
// SendEmailCodeAbuseProtector double suitable for service tests.
|
||||
type InMemorySendEmailCodeAbuseProtector struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// Err is returned directly from CheckAndReserve when set.
|
||||
Err error
|
||||
|
||||
reservedUntil map[common.Email]time.Time
|
||||
}
|
||||
|
||||
// CheckAndReserve applies the fixed resend cooldown using input.Now as the
|
||||
// authoritative decision timestamp.
|
||||
func (p *InMemorySendEmailCodeAbuseProtector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check send email code abuse: %w", err)
|
||||
}
|
||||
if p.Err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, p.Err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.reservedUntil == nil {
|
||||
p.reservedUntil = make(map[common.Email]time.Time)
|
||||
}
|
||||
|
||||
reservedUntil, exists := p.reservedUntil[input.Email]
|
||||
if exists && input.Now.Before(reservedUntil) {
|
||||
return ports.SendEmailCodeAbuseResult{
|
||||
Outcome: ports.SendEmailCodeAbuseOutcomeThrottled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
p.reservedUntil[input.Email] = input.Now.UTC().Add(challenge.ResendThrottleCooldown)
|
||||
return ports.SendEmailCodeAbuseResult{
|
||||
Outcome: ports.SendEmailCodeAbuseOutcomeAllowed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ ports.SendEmailCodeAbuseProtector = (*InMemorySendEmailCodeAbuseProtector)(nil)
|
||||
@@ -0,0 +1,42 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInMemorySendEmailCodeAbuseProtector(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protector := &InMemorySendEmailCodeAbuseProtector{}
|
||||
email := common.Email("pilot@example.com")
|
||||
now := time.Unix(10, 0).UTC()
|
||||
|
||||
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(30 * time.Second),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
|
||||
|
||||
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now.Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// InMemorySessionStore is a deterministic map-backed SessionStore double
|
||||
// suitable for service tests.
|
||||
type InMemorySessionStore struct {
|
||||
mu sync.Mutex
|
||||
records map[common.DeviceSessionID]devicesession.Session
|
||||
}
|
||||
|
||||
// Get returns the stored device session for deviceSessionID.
|
||||
func (s *InMemorySessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
if err := deviceSessionID.Validate(); err != nil {
|
||||
return devicesession.Session{}, fmt.Errorf("get session: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
record, ok := s.records[deviceSessionID]
|
||||
if !ok {
|
||||
return devicesession.Session{}, fmt.Errorf("get session %q: %w", deviceSessionID, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
cloned, err := cloneSession(record)
|
||||
if err != nil {
|
||||
return devicesession.Session{}, err
|
||||
}
|
||||
|
||||
return cloned, nil
|
||||
}
|
||||
|
||||
// ListByUserID returns every stored session for userID in newest-first order.
|
||||
func (s *InMemorySessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("list sessions by user id: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var records []devicesession.Session
|
||||
for _, record := range s.records {
|
||||
if record.UserID == userID {
|
||||
cloned, err := cloneSession(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records = append(records, cloned)
|
||||
}
|
||||
}
|
||||
sortSessionsNewestFirst(records)
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// CountActiveByUserID returns the number of active sessions currently stored
|
||||
// for userID.
|
||||
func (s *InMemorySessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return 0, fmt.Errorf("count active sessions by user id: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
count := 0
|
||||
for _, record := range s.records {
|
||||
if record.UserID == userID && record.Status == devicesession.StatusActive {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Create stores record as a new device session.
|
||||
func (s *InMemorySessionStore) Create(ctx context.Context, record devicesession.Session) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
cloned, err := cloneSession(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.records == nil {
|
||||
s.records = make(map[common.DeviceSessionID]devicesession.Session)
|
||||
}
|
||||
if _, exists := s.records[record.ID]; exists {
|
||||
return fmt.Errorf("create session %q: %w", record.ID, ports.ErrConflict)
|
||||
}
|
||||
|
||||
s.records[record.ID] = cloned
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke stores a revoked view of one target session.
|
||||
func (s *InMemorySessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
record, ok := s.records[input.DeviceSessionID]
|
||||
if !ok {
|
||||
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session %q: %w", input.DeviceSessionID, ports.ErrNotFound)
|
||||
}
|
||||
|
||||
if record.Status == devicesession.StatusRevoked {
|
||||
cloned, err := cloneSession(record)
|
||||
if err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
|
||||
result := ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
|
||||
Session: cloned,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
record.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
record.Revocation = &revocation
|
||||
|
||||
cloned, err := cloneSession(record)
|
||||
if err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
s.records[input.DeviceSessionID] = cloned
|
||||
|
||||
result := ports.RevokeSessionResult{
|
||||
Outcome: ports.RevokeSessionOutcomeRevoked,
|
||||
Session: cloned,
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.RevokeSessionResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RevokeAllByUserID stores revoked views for all currently active sessions
|
||||
// owned by input.UserID.
|
||||
func (s *InMemorySessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var affected []devicesession.Session
|
||||
for id, record := range s.records {
|
||||
if record.UserID != input.UserID || record.Status != devicesession.StatusActive {
|
||||
continue
|
||||
}
|
||||
|
||||
record.Status = devicesession.StatusRevoked
|
||||
revocation := input.Revocation
|
||||
record.Revocation = &revocation
|
||||
|
||||
cloned, err := cloneSession(record)
|
||||
if err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, err
|
||||
}
|
||||
s.records[id] = cloned
|
||||
affected = append(affected, cloned)
|
||||
}
|
||||
|
||||
sortSessionsNewestFirst(affected)
|
||||
|
||||
outcome := ports.RevokeUserSessionsOutcomeNoActiveSessions
|
||||
if len(affected) > 0 {
|
||||
outcome = ports.RevokeUserSessionsOutcomeRevoked
|
||||
}
|
||||
|
||||
result := ports.RevokeUserSessionsResult{
|
||||
Outcome: outcome,
|
||||
UserID: input.UserID,
|
||||
Sessions: slices.Clone(affected),
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.RevokeUserSessionsResult{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var _ ports.SessionStore = (*InMemorySessionStore)(nil)
|
||||
@@ -0,0 +1,182 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
func TestInMemorySessionStoreCreateAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemorySessionStore{}
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
got, err := store.Get(context.Background(), record.ID)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if got.ID != record.ID {
|
||||
require.Failf(t, "test failed", "Get().ID = %q, want %q", got.ID, record.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemorySessionStoreListByUserIDNewestFirst(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemorySessionStore{}
|
||||
older := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())
|
||||
otherUser := activeSessionFixture("device-session-3", "user-2", time.Unix(30, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{older, newer, otherUser} {
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ListByUserID() returned error: %v", err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
require.Failf(t, "test failed", "ListByUserID() length = %d, want 2", len(got))
|
||||
}
|
||||
if got[0].ID != newer.ID || got[1].ID != older.ID {
|
||||
require.Failf(t, "test failed", "ListByUserID() order = [%q %q], want [%q %q]", got[0].ID, got[1].ID, newer.ID, older.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemorySessionStoreCountActiveByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemorySessionStore{}
|
||||
active := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
revoked := revokedSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())
|
||||
|
||||
for _, record := range []devicesession.Session{active, revoked} {
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
got, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "CountActiveByUserID() returned error: %v", err)
|
||||
}
|
||||
if got != 1 {
|
||||
require.Failf(t, "test failed", "CountActiveByUserID() = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemorySessionStoreRevokeIsIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemorySessionStore{}
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
input := ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(30, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
},
|
||||
}
|
||||
|
||||
first, err := store.Revoke(context.Background(), input)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "first Revoke() returned error: %v", err)
|
||||
}
|
||||
if first.Outcome != ports.RevokeSessionOutcomeRevoked {
|
||||
require.Failf(t, "test failed", "first Revoke() outcome = %q, want %q", first.Outcome, ports.RevokeSessionOutcomeRevoked)
|
||||
}
|
||||
|
||||
second, err := store.Revoke(context.Background(), input)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "second Revoke() returned error: %v", err)
|
||||
}
|
||||
if second.Outcome != ports.RevokeSessionOutcomeAlreadyRevoked {
|
||||
require.Failf(t, "test failed", "second Revoke() outcome = %q, want %q", second.Outcome, ports.RevokeSessionOutcomeAlreadyRevoked)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemorySessionStoreRevokeAllNoActiveSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemorySessionStore{}
|
||||
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
input := ports.RevokeUserSessionsInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
Revocation: devicesession.Revocation{
|
||||
At: time.Unix(40, 0).UTC(),
|
||||
ReasonCode: devicesession.RevokeReasonAdminRevoke,
|
||||
ActorType: common.RevokeActorType("admin"),
|
||||
},
|
||||
}
|
||||
|
||||
result, err := store.RevokeAllByUserID(context.Background(), input)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "RevokeAllByUserID() returned error: %v", err)
|
||||
}
|
||||
if result.Outcome != ports.RevokeUserSessionsOutcomeNoActiveSessions {
|
||||
require.Failf(t, "test failed", "RevokeAllByUserID() outcome = %q, want %q", result.Outcome, ports.RevokeUserSessionsOutcomeNoActiveSessions)
|
||||
}
|
||||
if len(result.Sessions) != 0 {
|
||||
require.Failf(t, "test failed", "RevokeAllByUserID() session count = %d, want 0", len(result.Sessions))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemorySessionStoreGetNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &InMemorySessionStore{}
|
||||
|
||||
_, err := store.Get(context.Background(), common.DeviceSessionID("missing"))
|
||||
if !errors.Is(err, ports.ErrNotFound) {
|
||||
require.Failf(t, "test failed", "Get() error = %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
record := activeSessionFixture(deviceSessionID, userID, createdAt)
|
||||
record.Status = devicesession.StatusRevoked
|
||||
record.Revocation = &devicesession.Revocation{
|
||||
At: createdAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonDeviceLogout,
|
||||
ActorType: common.RevokeActorType("user"),
|
||||
}
|
||||
return record
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
func TestStaticConfigProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limit := 4
|
||||
provider := StaticConfigProvider{
|
||||
Config: ports.SessionLimitConfig{ActiveSessionLimit: &limit},
|
||||
}
|
||||
|
||||
got, err := provider.LoadSessionLimit(context.Background())
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "LoadSessionLimit() returned error: %v", err)
|
||||
}
|
||||
if got.ActiveSessionLimit == nil || *got.ActiveSessionLimit != limit {
|
||||
require.Failf(t, "test failed", "LoadSessionLimit() = %+v, want limit %d", got, limit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSequenceIDGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
generator := &SequenceIDGenerator{
|
||||
ChallengeIDs: []common.ChallengeID{"challenge-queue"},
|
||||
DeviceSessionIDs: []common.DeviceSessionID{"device-session-queue"},
|
||||
}
|
||||
|
||||
challengeID, err := generator.NewChallengeID()
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "NewChallengeID() returned error: %v", err)
|
||||
}
|
||||
if challengeID != common.ChallengeID("challenge-queue") {
|
||||
require.Failf(t, "test failed", "NewChallengeID() = %q, want queued id", challengeID)
|
||||
}
|
||||
|
||||
deviceSessionID, err := generator.NewDeviceSessionID()
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "NewDeviceSessionID() returned error: %v", err)
|
||||
}
|
||||
if deviceSessionID != common.DeviceSessionID("device-session-queue") {
|
||||
require.Failf(t, "test failed", "NewDeviceSessionID() = %q, want queued id", deviceSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedCodeGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
generator := FixedCodeGenerator{Code: "123456"}
|
||||
|
||||
got, err := generator.Generate()
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Generate() returned error: %v", err)
|
||||
}
|
||||
if got != "123456" {
|
||||
require.Failf(t, "test failed", "Generate() = %q, want %q", got, "123456")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeterministicCodeHasher(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hasher := DeterministicCodeHasher{}
|
||||
|
||||
hash, err := hasher.Hash("123456")
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Hash() returned error: %v", err)
|
||||
}
|
||||
|
||||
match, err := hasher.Compare(hash, "123456")
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Compare() returned error: %v", err)
|
||||
}
|
||||
if !match {
|
||||
require.FailNow(t, "Compare() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordingMailSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := &RecordingMailSender{
|
||||
Results: []ports.SendLoginCodeResult{
|
||||
{Outcome: ports.SendLoginCodeOutcomeSuppressed},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := sender.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
||||
Email: common.Email("pilot@example.com"),
|
||||
Code: "654321",
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "SendLoginCode() returned error: %v", err)
|
||||
}
|
||||
if result.Outcome != ports.SendLoginCodeOutcomeSuppressed {
|
||||
require.Failf(t, "test failed", "SendLoginCode().Outcome = %q, want %q", result.Outcome, ports.SendLoginCodeOutcomeSuppressed)
|
||||
}
|
||||
if len(sender.RecordedInputs()) != 1 {
|
||||
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 1", len(sender.RecordedInputs()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordingProjectionPublisher(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher := &RecordingProjectionPublisher{}
|
||||
revokedAt := time.Unix(30, 0).UTC()
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: common.DeviceSessionID("device-session-1"),
|
||||
UserID: common.UserID("user-1"),
|
||||
ClientPublicKey: "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
|
||||
Status: gatewayprojection.StatusRevoked,
|
||||
RevokedAt: &revokedAt,
|
||||
RevokeReasonCode: common.RevokeReasonCode("logout_all"),
|
||||
RevokeActorType: common.RevokeActorType("system"),
|
||||
}
|
||||
|
||||
if err := publisher.PublishSession(context.Background(), snapshot); err != nil {
|
||||
require.Failf(t, "test failed", "PublishSession() returned error: %v", err)
|
||||
}
|
||||
if len(publisher.PublishedSnapshots()) != 1 {
|
||||
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(publisher.PublishedSnapshots()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticConfigProviderReturnsConfiguredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wantErr := errors.New("config failed")
|
||||
provider := StaticConfigProvider{Err: wantErr}
|
||||
|
||||
_, err := provider.LoadSessionLimit(context.Background())
|
||||
if !errors.Is(err, wantErr) {
|
||||
require.Failf(t, "test failed", "LoadSessionLimit() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
type userDirectoryEntry struct {
|
||||
UserID common.UserID
|
||||
BlockReasonCode userresolution.BlockReasonCode
|
||||
}
|
||||
|
||||
// InMemoryUserDirectory is a deterministic map-backed UserDirectory double
|
||||
// suitable for service tests.
|
||||
type InMemoryUserDirectory struct {
|
||||
mu sync.Mutex
|
||||
byEmail map[common.Email]userDirectoryEntry
|
||||
emailByUserID map[common.UserID]common.Email
|
||||
createdUserIDs []common.UserID
|
||||
nextUserNumber int
|
||||
}
|
||||
|
||||
// ResolveByEmail returns the current resolution state for email without
|
||||
// creating a new user.
|
||||
func (d *InMemoryUserDirectory) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
result, err := d.resolveLocked(email)
|
||||
if err != nil {
|
||||
return userresolution.Result{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID reports whether userID currently identifies a stored user
|
||||
// record.
|
||||
func (d *InMemoryUserDirectory) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return false, fmt.Errorf("exists by user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
_, ok := d.emailByUserID[userID]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// EnsureUserByEmail returns an existing user for email, creates a new user
|
||||
// when registration is allowed, or reports a blocked outcome.
|
||||
func (d *InMemoryUserDirectory) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
if err := email.Validate(); err != nil {
|
||||
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.byEmail == nil {
|
||||
d.byEmail = make(map[common.Email]userDirectoryEntry)
|
||||
}
|
||||
if d.emailByUserID == nil {
|
||||
d.emailByUserID = make(map[common.UserID]common.Email)
|
||||
}
|
||||
|
||||
entry, ok := d.byEmail[email]
|
||||
if ok {
|
||||
if !entry.BlockReasonCode.IsZero() {
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeBlocked,
|
||||
BlockReasonCode: entry.BlockReasonCode,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeExisting,
|
||||
UserID: entry.UserID,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
userID, err := d.nextCreatedUserIDLocked()
|
||||
if err != nil {
|
||||
return ports.EnsureUserResult{}, err
|
||||
}
|
||||
d.byEmail[email] = userDirectoryEntry{UserID: userID}
|
||||
d.emailByUserID[userID] = email
|
||||
|
||||
result := ports.EnsureUserResult{
|
||||
Outcome: ports.EnsureUserOutcomeCreated,
|
||||
UserID: userID,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
// BlockByUserID applies a block state to the user identified by input.UserID.
|
||||
func (d *InMemoryUserDirectory) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
email, ok := d.emailByUserID[input.UserID]
|
||||
if !ok {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
|
||||
}
|
||||
entry := d.byEmail[email]
|
||||
if !entry.BlockReasonCode.IsZero() {
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: input.UserID,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
entry.BlockReasonCode = input.ReasonCode
|
||||
d.byEmail[email] = entry
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: input.UserID,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
// BlockByEmail applies a block state to input.Email even when no user record
|
||||
// currently exists for that e-mail address.
|
||||
func (d *InMemoryUserDirectory) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ports.BlockUserResult{}, err
|
||||
}
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.byEmail == nil {
|
||||
d.byEmail = make(map[common.Email]userDirectoryEntry)
|
||||
}
|
||||
if d.emailByUserID == nil {
|
||||
d.emailByUserID = make(map[common.UserID]common.Email)
|
||||
}
|
||||
|
||||
entry := d.byEmail[input.Email]
|
||||
if !entry.BlockReasonCode.IsZero() {
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
|
||||
UserID: entry.UserID,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
entry.BlockReasonCode = input.ReasonCode
|
||||
d.byEmail[input.Email] = entry
|
||||
if !entry.UserID.IsZero() {
|
||||
d.emailByUserID[entry.UserID] = input.Email
|
||||
}
|
||||
|
||||
result := ports.BlockUserResult{
|
||||
Outcome: ports.BlockUserOutcomeBlocked,
|
||||
UserID: entry.UserID,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
// SeedExisting preloads one existing unblocked user record for service tests.
|
||||
func (d *InMemoryUserDirectory) SeedExisting(email common.Email, userID common.UserID) error {
|
||||
if err := email.Validate(); err != nil {
|
||||
return fmt.Errorf("seed existing email: %w", err)
|
||||
}
|
||||
if err := userID.Validate(); err != nil {
|
||||
return fmt.Errorf("seed existing user id: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.byEmail == nil {
|
||||
d.byEmail = make(map[common.Email]userDirectoryEntry)
|
||||
}
|
||||
if d.emailByUserID == nil {
|
||||
d.emailByUserID = make(map[common.UserID]common.Email)
|
||||
}
|
||||
|
||||
d.byEmail[email] = userDirectoryEntry{UserID: userID}
|
||||
d.emailByUserID[userID] = email
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeedBlockedEmail preloads one blocked e-mail address that does not
|
||||
// necessarily belong to an existing user record.
|
||||
func (d *InMemoryUserDirectory) SeedBlockedEmail(email common.Email, reasonCode userresolution.BlockReasonCode) error {
|
||||
if err := email.Validate(); err != nil {
|
||||
return fmt.Errorf("seed blocked email: %w", err)
|
||||
}
|
||||
if err := reasonCode.Validate(); err != nil {
|
||||
return fmt.Errorf("seed blocked email reason code: %w", err)
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.byEmail == nil {
|
||||
d.byEmail = make(map[common.Email]userDirectoryEntry)
|
||||
}
|
||||
|
||||
d.byEmail[email] = userDirectoryEntry{BlockReasonCode: reasonCode}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeedBlockedUser preloads one blocked existing user record for service tests.
|
||||
func (d *InMemoryUserDirectory) SeedBlockedUser(email common.Email, userID common.UserID, reasonCode userresolution.BlockReasonCode) error {
|
||||
if err := d.SeedExisting(email, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
entry := d.byEmail[email]
|
||||
entry.BlockReasonCode = reasonCode
|
||||
d.byEmail[email] = entry
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueCreatedUserIDs appends deterministic user identifiers that
|
||||
// EnsureUserByEmail will consume before falling back to generated ids.
|
||||
func (d *InMemoryUserDirectory) QueueCreatedUserIDs(userIDs ...common.UserID) error {
|
||||
for index, userID := range userIDs {
|
||||
if err := userID.Validate(); err != nil {
|
||||
return fmt.Errorf("queue created user id %d: %w", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.createdUserIDs = append(d.createdUserIDs, userIDs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ports.UserDirectory = (*InMemoryUserDirectory)(nil)
|
||||
|
||||
func (d *InMemoryUserDirectory) resolveLocked(email common.Email) (userresolution.Result, error) {
|
||||
entry, ok := d.byEmail[email]
|
||||
if !ok {
|
||||
result := userresolution.Result{Kind: userresolution.KindCreatable}
|
||||
return result, result.Validate()
|
||||
}
|
||||
if !entry.BlockReasonCode.IsZero() {
|
||||
result := userresolution.Result{
|
||||
Kind: userresolution.KindBlocked,
|
||||
BlockReasonCode: entry.BlockReasonCode,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
result := userresolution.Result{
|
||||
Kind: userresolution.KindExisting,
|
||||
UserID: entry.UserID,
|
||||
}
|
||||
return result, result.Validate()
|
||||
}
|
||||
|
||||
func (d *InMemoryUserDirectory) nextCreatedUserIDLocked() (common.UserID, error) {
|
||||
if len(d.createdUserIDs) > 0 {
|
||||
userID := d.createdUserIDs[0]
|
||||
d.createdUserIDs = d.createdUserIDs[1:]
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
d.nextUserNumber++
|
||||
userID := common.UserID(fmt.Sprintf("user-%d", d.nextUserNumber))
|
||||
if err := userID.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
func TestInMemoryUserDirectoryResolveExistingCreatableAndBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &InMemoryUserDirectory{}
|
||||
if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email common.Email
|
||||
wantKind userresolution.Kind
|
||||
}{
|
||||
{name: "existing", email: common.Email("existing@example.com"), wantKind: userresolution.KindExisting},
|
||||
{name: "creatable", email: common.Email("new@example.com"), wantKind: userresolution.KindCreatable},
|
||||
{name: "blocked", email: common.Email("blocked@example.com"), wantKind: userresolution.KindBlocked},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := directory.ResolveByEmail(context.Background(), tt.email)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ResolveByEmail() returned error: %v", err)
|
||||
}
|
||||
if got.Kind != tt.wantKind {
|
||||
require.Failf(t, "test failed", "ResolveByEmail().Kind = %q, want %q", got.Kind, tt.wantKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryUserDirectoryEnsureUserExistingCreatedAndBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &InMemoryUserDirectory{}
|
||||
if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
|
||||
}
|
||||
if err := directory.QueueCreatedUserIDs(common.UserID("user-created")); err != nil {
|
||||
require.Failf(t, "test failed", "QueueCreatedUserIDs() returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email common.Email
|
||||
wantOutcome ports.EnsureUserOutcome
|
||||
wantUserID common.UserID
|
||||
}{
|
||||
{
|
||||
name: "existing",
|
||||
email: common.Email("existing@example.com"),
|
||||
wantOutcome: ports.EnsureUserOutcomeExisting,
|
||||
wantUserID: common.UserID("user-existing"),
|
||||
},
|
||||
{
|
||||
name: "created",
|
||||
email: common.Email("created@example.com"),
|
||||
wantOutcome: ports.EnsureUserOutcomeCreated,
|
||||
wantUserID: common.UserID("user-created"),
|
||||
},
|
||||
{
|
||||
name: "blocked",
|
||||
email: common.Email("blocked@example.com"),
|
||||
wantOutcome: ports.EnsureUserOutcomeBlocked,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := directory.EnsureUserByEmail(context.Background(), tt.email)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "EnsureUserByEmail() returned error: %v", err)
|
||||
}
|
||||
if got.Outcome != tt.wantOutcome {
|
||||
require.Failf(t, "test failed", "EnsureUserByEmail().Outcome = %q, want %q", got.Outcome, tt.wantOutcome)
|
||||
}
|
||||
if got.UserID != tt.wantUserID {
|
||||
require.Failf(t, "test failed", "EnsureUserByEmail().UserID = %q, want %q", got.UserID, tt.wantUserID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryUserDirectoryExistsByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &InMemoryUserDirectory{}
|
||||
if err := directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
|
||||
exists, err := directory.ExistsByUserID(context.Background(), common.UserID("user-existing"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ExistsByUserID() returned error: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
require.FailNow(t, "ExistsByUserID() = false, want true")
|
||||
}
|
||||
|
||||
exists, err = directory.ExistsByUserID(context.Background(), common.UserID("missing"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ExistsByUserID() returned error: %v", err)
|
||||
}
|
||||
if exists {
|
||||
require.FailNow(t, "ExistsByUserID() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryUserDirectoryBlockByEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &InMemoryUserDirectory{}
|
||||
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
|
||||
Email: common.Email("blocked@example.com"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "BlockByEmail() returned error: %v", err)
|
||||
}
|
||||
if result.Outcome != ports.BlockUserOutcomeBlocked {
|
||||
require.Failf(t, "test failed", "BlockByEmail().Outcome = %q, want %q", result.Outcome, ports.BlockUserOutcomeBlocked)
|
||||
}
|
||||
|
||||
resolution, err := directory.ResolveByEmail(context.Background(), common.Email("blocked@example.com"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ResolveByEmail() returned error: %v", err)
|
||||
}
|
||||
if resolution.Kind != userresolution.KindBlocked {
|
||||
require.Failf(t, "test failed", "ResolveByEmail().Kind = %q, want %q", resolution.Kind, userresolution.KindBlocked)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryUserDirectoryBlockByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &InMemoryUserDirectory{}
|
||||
if err := directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
|
||||
result, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "BlockByUserID() returned error: %v", err)
|
||||
}
|
||||
if result.Outcome != ports.BlockUserOutcomeBlocked {
|
||||
require.Failf(t, "test failed", "BlockByUserID().Outcome = %q, want %q", result.Outcome, ports.BlockUserOutcomeBlocked)
|
||||
}
|
||||
|
||||
second, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("user-1"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "second BlockByUserID() returned error: %v", err)
|
||||
}
|
||||
if second.Outcome != ports.BlockUserOutcomeAlreadyBlocked {
|
||||
require.Failf(t, "test failed", "second BlockByUserID().Outcome = %q, want %q", second.Outcome, ports.BlockUserOutcomeAlreadyBlocked)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryUserDirectoryBlockByUserIDNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := &InMemoryUserDirectory{}
|
||||
|
||||
_, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
|
||||
UserID: common.UserID("missing"),
|
||||
ReasonCode: userresolution.BlockReasonCode("policy_block"),
|
||||
})
|
||||
if !errors.Is(err, ports.ErrNotFound) {
|
||||
require.Failf(t, "test failed", "BlockByUserID() error = %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user