feat: authsession service

This commit is contained in:
Ilia Denisov
2026-04-08 16:23:07 +02:00
committed by GitHub
parent 28f04916af
commit 86a68ed9d0
174 changed files with 31732 additions and 112 deletions
@@ -0,0 +1,56 @@
// Package antiabuse provides runtime in-process adapters for auth-specific
// public abuse controls.
package antiabuse
import (
"context"
"fmt"
"sync"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
)
// SendEmailCodeProtector is a concurrency-safe in-process resend-throttle
// adapter for public send-email-code attempts.
type SendEmailCodeProtector struct {
mu sync.Mutex
reservedUntil map[common.Email]time.Time
}
// CheckAndReserve applies the fixed Stage-17 resend cooldown using input.Now
// as the authoritative decision timestamp.
func (p *SendEmailCodeProtector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
if ctx == nil {
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: nil context")
}
if err := ctx.Err(); err != nil {
return ports.SendEmailCodeAbuseResult{}, err
}
if err := input.Validate(); err != nil {
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err)
}
p.mu.Lock()
defer p.mu.Unlock()
if p.reservedUntil == nil {
p.reservedUntil = make(map[common.Email]time.Time)
}
reservedUntil, exists := p.reservedUntil[input.Email]
if exists && input.Now.Before(reservedUntil) {
return ports.SendEmailCodeAbuseResult{
Outcome: ports.SendEmailCodeAbuseOutcomeThrottled,
}, nil
}
p.reservedUntil[input.Email] = input.Now.UTC().Add(challenge.ResendThrottleCooldown)
return ports.SendEmailCodeAbuseResult{
Outcome: ports.SendEmailCodeAbuseOutcomeAllowed,
}, nil
}
var _ ports.SendEmailCodeAbuseProtector = (*SendEmailCodeProtector)(nil)
@@ -0,0 +1,64 @@
package antiabuse
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSendEmailCodeProtectorCheckAndReserve(t *testing.T) {
t.Parallel()
protector := &SendEmailCodeProtector{}
email := common.Email("pilot@example.com")
now := time.Unix(10, 0).UTC()
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now,
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(30 * time.Second),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(time.Minute),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
}
func TestSendEmailCodeProtectorNilOrCanceledContext(t *testing.T) {
t.Parallel()
protector := &SendEmailCodeProtector{}
_, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{
Email: common.Email("pilot@example.com"),
Now: time.Unix(10, 0).UTC(),
})
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = protector.CheckAndReserve(ctx, ports.SendEmailCodeAbuseInput{
Email: common.Email("pilot@example.com"),
Now: time.Unix(10, 0).UTC(),
})
require.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
}
@@ -0,0 +1,206 @@
// Package contracttest provides reusable adapter conformance suites that
// exercise storage-agnostic port contracts without depending on one concrete
// backend implementation.
package contracttest
import (
"context"
"crypto/ed25519"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ChallengeStoreFactory constructs a fresh ChallengeStore instance suitable
// for one isolated contract subtest.
type ChallengeStoreFactory func(t *testing.T) ports.ChallengeStore
// RunChallengeStoreContractTests executes the backend-agnostic ChallengeStore
// contract suite against newStore.
func RunChallengeStoreContractTests(t *testing.T, newStore ChallengeStoreFactory) {
t.Helper()
t.Run("create and get", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractConfirmedChallenge(t, time.Unix(1_775_130_000, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
})
t.Run("get not found", func(t *testing.T) {
t.Parallel()
store := newStore(t)
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
t.Run("create conflict", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractPendingChallenge(time.Unix(1_775_130_100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
err := store.Create(context.Background(), record)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
})
t.Run("compare and swap success", func(t *testing.T) {
t.Parallel()
store := newStore(t)
now := time.Unix(1_775_130_200, 0).UTC()
previous := contractPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
next.Attempts.Send = 1
next.Abuse.LastAttemptAt = contractTimePointer(now.Add(time.Minute))
require.NoError(t, next.Validate())
require.NoError(t, store.Create(context.Background(), previous))
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
got, err := store.Get(context.Background(), previous.ID)
require.NoError(t, err)
assert.Equal(t, next, got)
})
t.Run("compare and swap conflict", func(t *testing.T) {
t.Parallel()
store := newStore(t)
now := time.Unix(1_775_130_300, 0).UTC()
stored := contractPendingChallenge(now)
previous := stored
previous.Attempts.Send = 99
require.NoError(t, previous.Validate())
next := stored
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
require.NoError(t, next.Validate())
require.NoError(t, store.Create(context.Background(), stored))
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
})
t.Run("compare and swap not found", func(t *testing.T) {
t.Parallel()
store := newStore(t)
now := time.Unix(1_775_130_400, 0).UTC()
previous := contractPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
require.NoError(t, next.Validate())
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
t.Run("get returns defensive copies", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractConfirmedChallenge(t, time.Unix(1_775_130_500, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
require.NotEmpty(t, got.CodeHash)
got.CodeHash[0] = 0xFF
if got.Confirmation != nil {
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
if len(keyBytes) > 0 {
keyBytes[0] = 0xFE
}
}
again, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record.CodeHash, again.CodeHash)
require.NotNil(t, again.Confirmation)
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
})
}
func contractPendingChallenge(now time.Time) challenge.Challenge {
record := challenge.Challenge{
ID: common.ChallengeID("challenge-pending"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hashed-pending-code"),
Status: challenge.StatusPendingSend,
DeliveryState: challenge.DeliveryPending,
CreatedAt: now,
ExpiresAt: now.Add(challenge.InitialTTL),
}
if err := record.Validate(); err != nil {
panic(err)
}
return record
}
func contractConfirmedChallenge(t *testing.T, now time.Time) challenge.Challenge {
t.Helper()
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31,
})
require.NoError(t, err)
record := challenge.Challenge{
ID: common.ChallengeID("challenge-confirmed"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hashed-code"),
Status: challenge.StatusConfirmedPendingExpire,
DeliveryState: challenge.DeliverySent,
CreatedAt: now,
ExpiresAt: now.Add(challenge.ConfirmedRetention),
Attempts: challenge.AttemptCounters{
Send: 1,
Confirm: 2,
},
Abuse: challenge.AbuseMetadata{
LastAttemptAt: contractTimePointer(now.Add(30 * time.Second)),
},
Confirmation: &challenge.Confirmation{
SessionID: common.DeviceSessionID("device-session-1"),
ClientPublicKey: clientPublicKey,
ConfirmedAt: now.Add(time.Minute),
},
}
require.NoError(t, record.Validate())
return record
}
func contractTimePointer(value time.Time) *time.Time {
return &value
}
@@ -0,0 +1,65 @@
package contracttest
import (
"context"
"testing"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ConfigProviderHarnessFactory constructs a fresh semantic ConfigProvider
// harness suitable for one isolated contract subtest.
type ConfigProviderHarnessFactory func(t *testing.T) ConfigProviderHarness
// ConfigProviderHarness bundles one semantic ConfigProvider instance with the
// seed hooks needed by the backend-agnostic contract suite.
type ConfigProviderHarness struct {
// Provider is the semantic ConfigProvider under test.
Provider ports.ConfigProvider
// SeedDisabled prepares storage so LoadSessionLimit observes “limit absent”.
SeedDisabled func(t *testing.T)
// SeedLimit prepares storage so LoadSessionLimit observes a valid positive
// configured limit.
SeedLimit func(t *testing.T, limit int)
}
// RunConfigProviderContractTests executes the backend-agnostic ConfigProvider
// semantic contract suite against newHarness.
func RunConfigProviderContractTests(t *testing.T, newHarness ConfigProviderHarnessFactory) {
t.Helper()
t.Run("limit absent means disabled", func(t *testing.T) {
t.Parallel()
harness := newHarness(t)
require.NotNil(t, harness.Provider)
require.NotNil(t, harness.SeedDisabled)
harness.SeedDisabled(t)
got, err := harness.Provider.LoadSessionLimit(context.Background())
require.NoError(t, err)
assert.Equal(t, ports.SessionLimitConfig{}, got)
})
t.Run("valid positive limit means configured", func(t *testing.T) {
t.Parallel()
harness := newHarness(t)
require.NotNil(t, harness.Provider)
require.NotNil(t, harness.SeedLimit)
want := 5
harness.SeedLimit(t, want)
got, err := harness.Provider.LoadSessionLimit(context.Background())
require.NoError(t, err)
require.NotNil(t, got.ActiveSessionLimit)
assert.Equal(t, want, *got.ActiveSessionLimit)
})
}
@@ -0,0 +1,283 @@
package contracttest
import (
"context"
"crypto/ed25519"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// SessionStoreFactory constructs a fresh SessionStore instance suitable for
// one isolated contract subtest.
type SessionStoreFactory func(t *testing.T) ports.SessionStore
// RunSessionStoreContractTests executes the backend-agnostic SessionStore
// contract suite against newStore.
func RunSessionStoreContractTests(t *testing.T, newStore SessionStoreFactory) {
t.Helper()
t.Run("create and get", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
})
t.Run("create conflict", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(1_775_240_050, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
err := store.Create(context.Background(), record)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
})
t.Run("get not found", func(t *testing.T) {
t.Parallel()
store := newStore(t)
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
t.Run("list by user id returns newest first", func(t *testing.T) {
t.Parallel()
store := newStore(t)
older := contractActiveSession(t, "device-session-old", "user-1", time.Unix(10, 0).UTC())
newer := contractActiveSession(t, "device-session-new", "user-1", time.Unix(20, 0).UTC())
revoked := contractRevokedSession(t, "device-session-revoked", "user-1", time.Unix(15, 0).UTC())
otherUser := contractActiveSession(t, "device-session-other", "user-2", time.Unix(30, 0).UTC())
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
require.NoError(t, store.Create(context.Background(), record))
}
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
require.Len(t, got, 3)
assert.Equal(
t,
[]common.DeviceSessionID{newer.ID, revoked.ID, older.ID},
[]common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID},
)
})
t.Run("list by user id returns empty slice for unknown user", func(t *testing.T) {
t.Parallel()
store := newStore(t)
got, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
require.NoError(t, err)
require.NotNil(t, got)
assert.Empty(t, got)
})
t.Run("count active by user id", func(t *testing.T) {
t.Parallel()
store := newStore(t)
activeOne := contractActiveSession(t, "device-session-1", "user-1", time.Unix(40, 0).UTC())
activeTwo := contractActiveSession(t, "device-session-2", "user-1", time.Unix(50, 0).UTC())
revoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(60, 0).UTC())
otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(70, 0).UTC())
for _, record := range []devicesession.Session{activeOne, activeTwo, revoked, otherUser} {
require.NoError(t, store.Create(context.Background(), record))
}
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
assert.Equal(t, 2, count)
})
t.Run("revoke active session", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
revocation := contractRevocation(time.Unix(200, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", "")
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: revocation,
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
require.NotNil(t, result.Session.Revocation)
assert.Equal(t, revocation, *result.Session.Revocation)
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
require.NoError(t, err)
assert.Zero(t, count)
})
t.Run("revoke already revoked preserves stored revocation", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractRevokedSession(t, "device-session-2", "user-1", time.Unix(110, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1"),
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
require.NotNil(t, result.Session.Revocation)
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
})
t.Run("revoke not found", func(t *testing.T) {
t.Parallel()
store := newStore(t)
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: common.DeviceSessionID("missing-session"),
Revocation: contractRevocation(time.Unix(210, 0).UTC(), devicesession.RevokeReasonLogoutAll, "system", ""),
})
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
t.Run("revoke all by user id revokes active sessions newest first", func(t *testing.T) {
t.Parallel()
store := newStore(t)
older := contractActiveSession(t, "device-session-1", "user-1", time.Unix(100, 0).UTC())
newer := contractActiveSession(t, "device-session-2", "user-1", time.Unix(200, 0).UTC())
alreadyRevoked := contractRevokedSession(t, "device-session-3", "user-1", time.Unix(150, 0).UTC())
otherUser := contractActiveSession(t, "device-session-4", "user-2", time.Unix(250, 0).UTC())
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
require.NoError(t, store.Create(context.Background(), record))
}
revocation := contractRevocation(time.Unix(300, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", "admin-1")
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: common.UserID("user-1"),
Revocation: revocation,
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
require.Len(t, result.Sessions, 2)
assert.Equal(
t,
[]common.DeviceSessionID{newer.ID, older.ID},
[]common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID},
)
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
assert.Zero(t, count)
})
t.Run("revoke all by user id reports no active sessions", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractRevokedSession(t, "device-session-5", "user-1", time.Unix(120, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: common.UserID("user-1"),
Revocation: contractRevocation(time.Unix(400, 0).UTC(), devicesession.RevokeReasonAdminRevoke, "admin", ""),
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
require.NotNil(t, result.Sessions)
assert.Empty(t, result.Sessions)
})
t.Run("get returns defensive copies", func(t *testing.T) {
t.Parallel()
store := newStore(t)
record := contractRevokedSession(t, "device-session-copy", "user-1", time.Unix(130, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
require.NotNil(t, got.Revocation)
got.Revocation.ActorID = "mutated"
again, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
require.NotNil(t, again.Revocation)
assert.Equal(t, record, again)
})
}
func contractActiveSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
t.Helper()
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31,
})
require.NoError(t, err)
record := devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: clientPublicKey,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
require.NoError(t, record.Validate())
return record
}
func contractRevokedSession(t *testing.T, deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
t.Helper()
record := contractActiveSession(t, deviceSessionID, userID, createdAt)
revocation := contractRevocation(createdAt.Add(time.Minute), devicesession.RevokeReasonDeviceLogout, "user", "user-actor")
record.Status = devicesession.StatusRevoked
record.Revocation = &revocation
require.NoError(t, record.Validate())
return record
}
func contractRevocation(at time.Time, reasonCode common.RevokeReasonCode, actorType string, actorID string) devicesession.Revocation {
record := devicesession.Revocation{
At: at,
ReasonCode: reasonCode,
ActorType: common.RevokeActorType(actorType),
ActorID: actorID,
}
if err := record.Validate(); err != nil {
panic(err)
}
return record
}
@@ -0,0 +1,139 @@
// Package local provides small in-process runtime implementations for
// authsession ports that do not require network dependencies.
package local
import (
"crypto/rand"
"encoding/base64"
"fmt"
"math/big"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"golang.org/x/crypto/bcrypt"
)
const (
challengeIDPrefix = "challenge-"
deviceSessionIDPrefix = "device-session-"
codeDigits = 6
)
// Clock implements ports.Clock using the local system clock in UTC.
type Clock struct{}
// Now returns the current system time normalized to UTC.
func (Clock) Now() time.Time {
return time.Now().UTC()
}
// IDGenerator implements ports.IDGenerator with cryptographically random
// opaque identifiers.
type IDGenerator struct{}
// NewChallengeID returns a fresh random challenge identifier.
func (IDGenerator) NewChallengeID() (common.ChallengeID, error) {
value, err := newOpaqueIDString(challengeIDPrefix)
if err != nil {
return "", err
}
return common.ChallengeID(value), nil
}
// NewDeviceSessionID returns a fresh random device-session identifier.
func (IDGenerator) NewDeviceSessionID() (common.DeviceSessionID, error) {
value, err := newOpaqueIDString(deviceSessionIDPrefix)
if err != nil {
return "", err
}
return common.DeviceSessionID(value), nil
}
// CodeGenerator implements ports.CodeGenerator with random 6-digit decimal
// confirmation codes.
type CodeGenerator struct{}
// Generate returns one fresh random 6-digit decimal code.
func (CodeGenerator) Generate() (string, error) {
var builder strings.Builder
builder.Grow(codeDigits)
for idx := 0; idx < codeDigits; idx++ {
digit, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
return "", fmt.Errorf("generate confirmation code: %w", err)
}
builder.WriteByte(byte('0' + digit.Int64()))
}
return builder.String(), nil
}
// CodeHasher implements ports.CodeHasher with bcrypt-backed hashes.
type CodeHasher struct{}
// Hash returns the bcrypt hash of code.
func (CodeHasher) Hash(code string) ([]byte, error) {
if err := validateCode(code); err != nil {
return nil, err
}
hash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash confirmation code: %w", err)
}
return hash, nil
}
// Compare reports whether hash matches code.
func (CodeHasher) Compare(hash []byte, code string) (bool, error) {
if err := validateCode(code); err != nil {
return false, err
}
if len(hash) == 0 {
return false, nil
}
err := bcrypt.CompareHashAndPassword(hash, []byte(code))
switch err {
case nil:
return true, nil
case bcrypt.ErrMismatchedHashAndPassword:
return false, nil
default:
return false, fmt.Errorf("compare confirmation code hash: %w", err)
}
}
func newOpaqueIDString(prefix string) (string, error) {
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("generate opaque identifier: %w", err)
}
return prefix + base64.RawURLEncoding.EncodeToString(randomBytes), nil
}
func validateCode(code string) error {
switch {
case strings.TrimSpace(code) == "":
return fmt.Errorf("code must not be empty")
case strings.TrimSpace(code) != code:
return fmt.Errorf("code must not contain surrounding whitespace")
default:
return nil
}
}
var (
_ ports.Clock = Clock{}
_ ports.IDGenerator = IDGenerator{}
_ ports.CodeGenerator = CodeGenerator{}
_ ports.CodeHasher = CodeHasher{}
)
@@ -0,0 +1,60 @@
package local
import (
"regexp"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClockNowReturnsUTC(t *testing.T) {
t.Parallel()
now := Clock{}.Now()
assert.Equal(t, time.UTC, now.Location())
}
func TestIDGeneratorProducesValidOpaqueIDs(t *testing.T) {
t.Parallel()
generator := IDGenerator{}
challengeID, err := generator.NewChallengeID()
require.NoError(t, err)
require.NoError(t, challengeID.Validate())
assert.Regexp(t, regexp.MustCompile(`^challenge-[A-Za-z0-9_-]+$`), challengeID.String())
deviceSessionID, err := generator.NewDeviceSessionID()
require.NoError(t, err)
require.NoError(t, deviceSessionID.Validate())
assert.Regexp(t, regexp.MustCompile(`^device-session-[A-Za-z0-9_-]+$`), deviceSessionID.String())
}
func TestCodeGeneratorProducesSixDigitNumericCodes(t *testing.T) {
t.Parallel()
code, err := CodeGenerator{}.Generate()
require.NoError(t, err)
assert.Regexp(t, regexp.MustCompile(`^\d{6}$`), code)
}
func TestCodeHasherHashesAndComparesCodes(t *testing.T) {
t.Parallel()
hasher := CodeHasher{}
hash, err := hasher.Hash("123456")
require.NoError(t, err)
require.NotEmpty(t, hash)
match, err := hasher.Compare(hash, "123456")
require.NoError(t, err)
assert.True(t, match)
match, err = hasher.Compare(hash, "000000")
require.NoError(t, err)
assert.False(t, match)
}
@@ -0,0 +1,182 @@
// Package mail provides runtime mail-delivery adapters for the auth/session
// service.
package mail
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"galaxy/authsession/internal/ports"
)
const sendLoginCodePath = "/api/v1/internal/login-code-deliveries"
// Config configures one HTTP-based mail-delivery client.
type Config struct {
// BaseURL is the absolute base URL of the internal mail-service HTTP API.
BaseURL string
// RequestTimeout bounds each outbound mail-service request.
RequestTimeout time.Duration
}
// RESTClient implements ports.MailSender over the frozen internal REST mail
// contract.
type RESTClient struct {
baseURL string
requestTimeout time.Duration
httpClient *http.Client
}
// NewRESTClient constructs a REST-backed MailSender adapter from cfg.
func NewRESTClient(cfg Config) (*RESTClient, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
return newRESTClient(cfg, &http.Client{Transport: transport})
}
func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) {
switch {
case strings.TrimSpace(cfg.BaseURL) == "":
return nil, errors.New("new mail service REST client: base URL must not be empty")
case cfg.RequestTimeout <= 0:
return nil, errors.New("new mail service REST client: request timeout must be positive")
case httpClient == nil:
return nil, errors.New("new mail service REST client: http client must not be nil")
}
parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
if err != nil {
return nil, fmt.Errorf("new mail service REST client: parse base URL: %w", err)
}
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
return nil, errors.New("new mail service REST client: base URL must be absolute")
}
return &RESTClient{
baseURL: parsedBaseURL.String(),
requestTimeout: cfg.RequestTimeout,
httpClient: httpClient,
}, nil
}
// Close releases idle HTTP connections owned by the client transport.
func (c *RESTClient) Close() error {
if c == nil || c.httpClient == nil {
return nil
}
type idleCloser interface {
CloseIdleConnections()
}
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
transport.CloseIdleConnections()
}
return nil
}
// SendLoginCode submits one delivery request to the internal mail service
// without retrying transport or upstream failures.
func (c *RESTClient) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
if err := validateRESTContext(ctx, "send login code"); err != nil {
return ports.SendLoginCodeResult{}, err
}
if err := input.Validate(); err != nil {
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
}
payload, statusCode, err := c.doRequest(ctx, "send login code", map[string]string{
"email": input.Email.String(),
"code": input.Code,
})
if err != nil {
return ports.SendLoginCodeResult{}, err
}
if statusCode != http.StatusOK {
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: unexpected HTTP status %d", statusCode)
}
var response struct {
Outcome ports.SendLoginCodeOutcome `json:"outcome"`
}
if err := decodeJSONPayload(payload, &response); err != nil {
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
}
result := ports.SendLoginCodeResult{Outcome: response.Outcome}
if err := result.Validate(); err != nil {
return ports.SendLoginCodeResult{}, fmt.Errorf("send login code: %w", err)
}
return result, nil
}
func (c *RESTClient) doRequest(ctx context.Context, operation string, requestBody any) ([]byte, int, error) {
bodyBytes, err := json.Marshal(requestBody)
if err != nil {
return nil, 0, fmt.Errorf("%s: marshal request body: %w", operation, err)
}
attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout)
defer cancel()
request, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, c.baseURL+sendLoginCodePath, bytes.NewReader(bodyBytes))
if err != nil {
return nil, 0, fmt.Errorf("%s: build request: %w", operation, err)
}
request.Header.Set("Content-Type", "application/json")
response, err := c.httpClient.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("%s: %w", operation, err)
}
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
if err != nil {
return nil, 0, fmt.Errorf("%s: read response body: %w", operation, err)
}
return payload, response.StatusCode, nil
}
func decodeJSONPayload(payload []byte, target any) error {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return fmt.Errorf("decode response body: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return errors.New("decode response body: unexpected trailing JSON input")
}
return fmt.Errorf("decode response body: %w", err)
}
return nil
}
func validateRESTContext(ctx context.Context, operation string) error {
if ctx == nil {
return fmt.Errorf("%s: nil context", operation)
}
if err := ctx.Err(); err != nil {
return fmt.Errorf("%s: %w", operation, err)
}
return nil
}
var _ ports.MailSender = (*RESTClient)(nil)
@@ -0,0 +1,394 @@
package mail
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRESTClient(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
BaseURL: "http://127.0.0.1:8080",
RequestTimeout: time.Second,
},
},
{
name: "empty base url",
cfg: Config{
RequestTimeout: time.Second,
},
wantErr: "base URL must not be empty",
},
{
name: "relative base url",
cfg: Config{
BaseURL: "/relative",
RequestTimeout: time.Second,
},
wantErr: "base URL must be absolute",
},
{
name: "non positive timeout",
cfg: Config{
BaseURL: "http://127.0.0.1:8080",
},
wantErr: "request timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client, err := NewRESTClient(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
assert.NoError(t, client.Close())
})
}
}
func TestRESTClientSendLoginCodeSuccessCases(t *testing.T) {
t.Parallel()
tests := []struct {
name string
response string
wantOutcome ports.SendLoginCodeOutcome
}{
{
name: "sent",
response: `{"outcome":"sent"}`,
wantOutcome: ports.SendLoginCodeOutcomeSent,
},
{
name: "suppressed",
response: `{"outcome":"suppressed"}`,
wantOutcome: ports.SendLoginCodeOutcomeSuppressed,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var requestsMu sync.Mutex
var requests []capturedRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsMu.Lock()
requests = append(requests, captureRequest(t, r))
requestsMu.Unlock()
writeJSON(t, w, http.StatusOK, json.RawMessage(tt.response))
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
result, err := client.SendLoginCode(context.Background(), validInput())
require.NoError(t, err)
assert.Equal(t, tt.wantOutcome, result.Outcome)
requestsMu.Lock()
defer requestsMu.Unlock()
require.Len(t, requests, 1)
assert.Equal(t, http.MethodPost, requests[0].Method)
assert.Equal(t, sendLoginCodePath, requests[0].Path)
assert.Equal(t, "application/json", requests[0].ContentType)
assert.JSONEq(t, `{"email":"pilot@example.com","code":"654321"}`, requests[0].Body)
})
}
}
func TestRESTClientPreservesNormalizedEmailAndCodeExactly(t *testing.T) {
t.Parallel()
var captured string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured = captureRequest(t, r).Body
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
result, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email("Pilot+Alias@Example.com"),
Code: "123456",
})
require.NoError(t, err)
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
assert.JSONEq(t, `{"email":"Pilot+Alias@Example.com","code":"123456"}`, captured)
}
func TestRESTClientSendLoginCodeDoesNotRetry(t *testing.T) {
t.Parallel()
t.Run("no retry on 503", func(t *testing.T) {
t.Parallel()
var calls atomic.Int64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls.Add(1)
http.Error(w, "temporary", http.StatusServiceUnavailable)
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
_, err := client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, "unexpected HTTP status 503")
assert.EqualValues(t, 1, calls.Load())
})
t.Run("no retry on transport failure", func(t *testing.T) {
t.Parallel()
var calls atomic.Int64
client, err := newRESTClient(Config{
BaseURL: "http://127.0.0.1:8080",
RequestTimeout: 250 * time.Millisecond,
}, &http.Client{
Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
calls.Add(1)
return nil, errors.New("temporary transport failure")
}),
})
require.NoError(t, err)
_, err = client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, "temporary transport failure")
assert.EqualValues(t, 1, calls.Load())
})
}
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
body string
wantErrText string
}{
{
name: "rejects unknown field",
statusCode: http.StatusOK,
body: `{"outcome":"sent","extra":true}`,
wantErrText: "decode response body",
},
{
name: "rejects unsupported outcome",
statusCode: http.StatusOK,
body: `{"outcome":"queued"}`,
wantErrText: "unsupported",
},
{
name: "rejects missing outcome",
statusCode: http.StatusOK,
body: `{}`,
wantErrText: "unsupported",
},
{
name: "rejects trailing json",
statusCode: http.StatusOK,
body: `{"outcome":"sent"}{}`,
wantErrText: "unexpected trailing JSON input",
},
{
name: "rejects unexpected status",
statusCode: http.StatusBadGateway,
body: `{"error":"temporary"}`,
wantErrText: "unexpected HTTP status 502",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.statusCode)
_, err := io.WriteString(w, tt.body)
require.NoError(t, err)
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
_, err := client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErrText)
})
}
}
func TestRESTClientRequestTimeout(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(40 * time.Millisecond)
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
_, err := client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, "context deadline exceeded")
}
func TestRESTClientContextAndValidation(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected upstream call")
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
tests := []struct {
name string
run func() error
}{
{
name: "nil context",
run: func() error {
_, err := client.SendLoginCode(nil, validInput())
return err
},
},
{
name: "cancelled context",
run: func() error {
_, err := client.SendLoginCode(cancelledCtx, validInput())
return err
},
},
{
name: "invalid email",
run: func() error {
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email(" bad@example.com "),
Code: "123456",
})
return err
},
},
{
name: "invalid code",
run: func() error {
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email("pilot@example.com"),
Code: " 123456 ",
})
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
err := tt.run()
require.Error(t, err)
})
}
}
type capturedRequest struct {
Method string
Path string
ContentType string
Body string
}
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
t.Helper()
body, err := io.ReadAll(request.Body)
require.NoError(t, err)
return capturedRequest{
Method: request.Method,
Path: request.URL.Path,
ContentType: request.Header.Get("Content-Type"),
Body: strings.TrimSpace(string(body)),
}
}
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
t.Helper()
payload, err := json.Marshal(value)
require.NoError(t, err)
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(statusCode)
_, err = writer.Write(payload)
require.NoError(t, err)
}
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
t.Helper()
client, err := NewRESTClient(Config{
BaseURL: baseURL,
RequestTimeout: timeout,
})
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return client
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (fn roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return fn(request)
}
@@ -0,0 +1,179 @@
// Package mail provides runtime mail-delivery adapters for the auth/session
// service.
package mail
import (
"context"
"errors"
"fmt"
"sync"
"galaxy/authsession/internal/ports"
)
var errForcedFailure = errors.New("stub mail sender: forced failure")
// StubMode identifies the deterministic outcome used by StubSender for one
// delivery attempt.
type StubMode string
const (
// StubModeSent reports that the adapter accepts delivery and returns the
// stable sent outcome expected by the auth flow.
StubModeSent StubMode = "sent"
// StubModeSuppressed reports that the adapter intentionally suppresses
// outward delivery while still returning a successful suppressed outcome.
StubModeSuppressed StubMode = "suppressed"
// StubModeFailed reports that the adapter returns an explicit delivery
// failure instead of a successful outcome.
StubModeFailed StubMode = "failed"
)
// IsKnown reports whether mode is one of the supported stub delivery modes.
func (mode StubMode) IsKnown() bool {
switch mode {
case StubModeSent, StubModeSuppressed, StubModeFailed:
return true
default:
return false
}
}
// StubStep overrides the default stub behavior for one queued delivery
// attempt.
type StubStep struct {
// Mode selects the delivery behavior for this queued step.
Mode StubMode
// Err optionally overrides the failure returned when Mode is StubModeFailed.
Err error
}
// Validate reports whether step contains one supported queued behavior.
func (step StubStep) Validate() error {
if !step.Mode.IsKnown() {
return fmt.Errorf("stub mail step mode %q is unsupported", step.Mode)
}
return nil
}
// Attempt records one validated delivery request handled by StubSender.
type Attempt struct {
// Input stores the validated cleartext mail-delivery request exactly as it
// was passed into SendLoginCode.
Input ports.SendLoginCodeInput
// Mode stores the resolved stub mode after queued overrides were applied.
Mode StubMode
}
// StubSender is a deterministic runtime MailSender implementation intended
// for development, local integration, and explicit stub-based tests.
//
// The zero value is ready to use and defaults to StubModeSent.
type StubSender struct {
// DefaultMode controls the fallback behavior when Script is empty. The zero
// value is treated as StubModeSent so the zero-value sender is usable
// without extra configuration.
DefaultMode StubMode
// DefaultError optionally overrides the failure returned when DefaultMode
// resolves to StubModeFailed.
DefaultError error
// Script stores queued one-shot overrides consumed in FIFO order before the
// default behavior is used.
Script []StubStep
mu sync.Mutex
attempts []Attempt
}
// SendLoginCode records one validated delivery request and returns the
// deterministic stub outcome selected by the queued script or the default
// mode.
func (s *StubSender) SendLoginCode(ctx context.Context, input ports.SendLoginCodeInput) (ports.SendLoginCodeResult, error) {
if ctx == nil {
return ports.SendLoginCodeResult{}, errors.New("stub mail sender: nil context")
}
if err := ctx.Err(); err != nil {
return ports.SendLoginCodeResult{}, err
}
if err := input.Validate(); err != nil {
return ports.SendLoginCodeResult{}, err
}
s.mu.Lock()
defer s.mu.Unlock()
mode, errOverride, err := s.resolveNextStepLocked()
if err != nil {
return ports.SendLoginCodeResult{}, err
}
s.attempts = append(s.attempts, Attempt{
Input: input,
Mode: mode,
})
switch mode {
case StubModeSent:
return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSent}, nil
case StubModeSuppressed:
return ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSuppressed}, nil
case StubModeFailed:
if errOverride != nil {
return ports.SendLoginCodeResult{}, errOverride
}
return ports.SendLoginCodeResult{}, errForcedFailure
default:
return ports.SendLoginCodeResult{}, fmt.Errorf("stub mail sender: unsupported resolved mode %q", mode)
}
}
// RecordedAttempts returns a stable defensive copy of every validated delivery
// attempt handled by the stub.
func (s *StubSender) RecordedAttempts() []Attempt {
s.mu.Lock()
defer s.mu.Unlock()
return append([]Attempt(nil), s.attempts...)
}
func (s *StubSender) resolveNextStepLocked() (StubMode, error, error) {
if len(s.Script) > 0 {
step := s.Script[0]
s.Script = append([]StubStep(nil), s.Script[1:]...)
if err := step.Validate(); err != nil {
return "", nil, fmt.Errorf("stub mail sender: %w", err)
}
if step.Mode == StubModeFailed {
if step.Err != nil {
return step.Mode, step.Err, nil
}
return step.Mode, errForcedFailure, nil
}
return step.Mode, nil, nil
}
mode := s.DefaultMode
if mode == "" {
mode = StubModeSent
}
if !mode.IsKnown() {
return "", nil, fmt.Errorf("stub mail sender: default mode %q is unsupported", mode)
}
if mode == StubModeFailed {
if s.DefaultError != nil {
return mode, s.DefaultError, nil
}
return mode, errForcedFailure, nil
}
return mode, nil, nil
}
var _ ports.MailSender = (*StubSender)(nil)
@@ -0,0 +1,198 @@
package mail
import (
"context"
"errors"
"testing"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStubSenderSendLoginCode(t *testing.T) {
t.Parallel()
t.Run("zero value defaults to sent", func(t *testing.T) {
t.Parallel()
sender := &StubSender{}
result, err := sender.SendLoginCode(context.Background(), validInput())
require.NoError(t, err)
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
attempts := sender.RecordedAttempts()
require.Len(t, attempts, 1)
assert.Equal(t, StubModeSent, attempts[0].Mode)
assert.Equal(t, validInput(), attempts[0].Input)
})
t.Run("default suppressed", func(t *testing.T) {
t.Parallel()
sender := &StubSender{DefaultMode: StubModeSuppressed}
result, err := sender.SendLoginCode(context.Background(), validInput())
require.NoError(t, err)
assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, result.Outcome)
attempts := sender.RecordedAttempts()
require.Len(t, attempts, 1)
assert.Equal(t, StubModeSuppressed, attempts[0].Mode)
})
t.Run("default failed uses configured error", func(t *testing.T) {
t.Parallel()
wantErr := errors.New("delivery refused")
sender := &StubSender{
DefaultMode: StubModeFailed,
DefaultError: wantErr,
}
result, err := sender.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorIs(t, err, wantErr)
assert.Equal(t, ports.SendLoginCodeResult{}, result)
attempts := sender.RecordedAttempts()
require.Len(t, attempts, 1)
assert.Equal(t, StubModeFailed, attempts[0].Mode)
})
t.Run("default failed uses stable fallback error", func(t *testing.T) {
t.Parallel()
sender := &StubSender{DefaultMode: StubModeFailed}
_, err := sender.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.EqualError(t, err, "stub mail sender: forced failure")
})
t.Run("script overrides default and is consumed fifo", func(t *testing.T) {
t.Parallel()
wantErr := errors.New("step failed")
sender := &StubSender{
DefaultMode: StubModeSent,
Script: []StubStep{
{Mode: StubModeSuppressed},
{Mode: StubModeFailed, Err: wantErr},
},
}
first, err := sender.SendLoginCode(context.Background(), validInput())
require.NoError(t, err)
assert.Equal(t, ports.SendLoginCodeOutcomeSuppressed, first.Outcome)
second, err := sender.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorIs(t, err, wantErr)
assert.Equal(t, ports.SendLoginCodeResult{}, second)
third, err := sender.SendLoginCode(context.Background(), validInput())
require.NoError(t, err)
assert.Equal(t, ports.SendLoginCodeOutcomeSent, third.Outcome)
attempts := sender.RecordedAttempts()
require.Len(t, attempts, 3)
assert.Equal(t, []StubMode{StubModeSuppressed, StubModeFailed, StubModeSent}, []StubMode{
attempts[0].Mode,
attempts[1].Mode,
attempts[2].Mode,
})
assert.Empty(t, sender.Script)
})
t.Run("invalid default mode returns adapter error", func(t *testing.T) {
t.Parallel()
sender := &StubSender{DefaultMode: StubMode("queued")}
_, err := sender.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, `default mode "queued" is unsupported`)
assert.Empty(t, sender.RecordedAttempts())
})
t.Run("invalid scripted mode returns adapter error", func(t *testing.T) {
t.Parallel()
sender := &StubSender{
Script: []StubStep{
{Mode: StubMode("queued")},
},
}
_, err := sender.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, `mode "queued" is unsupported`)
assert.Empty(t, sender.RecordedAttempts())
assert.Empty(t, sender.Script)
})
}
func TestStubSenderRecordedAttemptsAreDefensive(t *testing.T) {
t.Parallel()
sender := &StubSender{}
_, err := sender.SendLoginCode(context.Background(), validInput())
require.NoError(t, err)
attempts := sender.RecordedAttempts()
require.Len(t, attempts, 1)
attempts[0].Mode = StubModeFailed
attempts[0].Input.Code = "000000"
again := sender.RecordedAttempts()
require.Len(t, again, 1)
assert.Equal(t, StubModeSent, again[0].Mode)
assert.Equal(t, "654321", again[0].Input.Code)
}
func TestStubSenderSendLoginCodeNilContext(t *testing.T) {
t.Parallel()
sender := &StubSender{}
_, err := sender.SendLoginCode(nil, validInput())
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
assert.Empty(t, sender.RecordedAttempts())
}
func TestStubSenderSendLoginCodeCancelledContext(t *testing.T) {
t.Parallel()
sender := &StubSender{}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := sender.SendLoginCode(ctx, validInput())
require.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
assert.Empty(t, sender.RecordedAttempts())
}
func TestStubSenderSendLoginCodeInvalidInput(t *testing.T) {
t.Parallel()
sender := &StubSender{}
_, err := sender.SendLoginCode(context.Background(), ports.SendLoginCodeInput{})
require.Error(t, err)
assert.ErrorContains(t, err, "send login code input email")
assert.Empty(t, sender.RecordedAttempts())
}
func validInput() ports.SendLoginCodeInput {
return ports.SendLoginCodeInput{
Email: common.Email("pilot@example.com"),
Code: "654321",
}
}
@@ -0,0 +1,484 @@
// Package challengestore implements ports.ChallengeStore with Redis-backed
// strict JSON challenge records.
package challengestore
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
"strings"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
const expirationGracePeriod = 5 * time.Minute
// Config configures one Redis-backed challenge store instance.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// KeyPrefix is the namespace prefix applied to every challenge key.
KeyPrefix string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Store persists challenges as one strict JSON value per Redis key.
type Store struct {
client *redis.Client
keyPrefix string
operationTimeout time.Duration
}
type redisRecord struct {
ChallengeID string `json:"challenge_id"`
Email string `json:"email"`
CodeHashBase64 string `json:"code_hash_base64"`
Status challenge.Status `json:"status"`
DeliveryState challenge.DeliveryState `json:"delivery_state"`
CreatedAt string `json:"created_at"`
ExpiresAt string `json:"expires_at"`
SendAttemptCount int `json:"send_attempt_count"`
ConfirmAttemptCount int `json:"confirm_attempt_count"`
LastAttemptAt *string `json:"last_attempt_at,omitempty"`
ConfirmedSessionID string `json:"confirmed_session_id,omitempty"`
ConfirmedClientPublicKey string `json:"confirmed_client_public_key,omitempty"`
ConfirmedAt *string `json:"confirmed_at,omitempty"`
}
// New constructs a Redis-backed challenge store from cfg.
func New(cfg Config) (*Store, error) {
if strings.TrimSpace(cfg.Addr) == "" {
return nil, errors.New("new redis challenge store: redis addr must not be empty")
}
if cfg.DB < 0 {
return nil, errors.New("new redis challenge store: redis db must not be negative")
}
if strings.TrimSpace(cfg.KeyPrefix) == "" {
return nil, errors.New("new redis challenge store: redis key prefix must not be empty")
}
if cfg.OperationTimeout <= 0 {
return nil, errors.New("new redis challenge store: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Store{
client: redis.NewClient(options),
keyPrefix: cfg.KeyPrefix,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *Store) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (s *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := s.operationContext(ctx, "ping redis challenge store")
if err != nil {
return err
}
defer cancel()
if err := s.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis challenge store: %w", err)
}
return nil
}
// Get returns the stored challenge for challengeID.
func (s *Store) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
if err := challengeID.Validate(); err != nil {
return challenge.Challenge{}, fmt.Errorf("get challenge from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "get challenge from redis")
if err != nil {
return challenge.Challenge{}, err
}
defer cancel()
payload, err := s.client.Get(operationCtx, s.lookupKey(challengeID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, ports.ErrNotFound)
case err != nil:
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
}
record, err := decodeChallengeRecord(challengeID, payload)
if err != nil {
return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err)
}
return record, nil
}
// Create persists record as a new challenge.
func (s *Store) Create(ctx context.Context, record challenge.Challenge) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create challenge in redis: %w", err)
}
payload, err := marshalChallengeRecord(record)
if err != nil {
return fmt.Errorf("create challenge in redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "create challenge in redis")
if err != nil {
return err
}
defer cancel()
created, err := s.client.SetNX(operationCtx, s.lookupKey(record.ID), payload, redisTTL(record.ExpiresAt)).Result()
if err != nil {
return fmt.Errorf("create challenge %q in redis: %w", record.ID, err)
}
if !created {
return fmt.Errorf("create challenge %q in redis: %w", record.ID, ports.ErrConflict)
}
return nil
}
// CompareAndSwap replaces previous with next when the currently stored
// challenge matches previous exactly in canonical Redis representation.
func (s *Store) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
return fmt.Errorf("compare and swap challenge in redis: %w", err)
}
nextPayload, err := marshalChallengeRecord(next)
if err != nil {
return fmt.Errorf("compare and swap challenge in redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "compare and swap challenge in redis")
if err != nil {
return err
}
defer cancel()
key := s.lookupKey(previous.ID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
payload, err := tx.Get(operationCtx, key).Bytes()
switch {
case errors.Is(err, redis.Nil):
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrNotFound)
case err != nil:
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
current, err := decodeChallengeRecord(previous.ID, payload)
if err != nil {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
matches, err := equalStoredChallenges(current, previous)
if err != nil {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
if !matches {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, key, nextPayload, redisTTL(next.ExpiresAt))
return nil
})
if err != nil {
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err)
}
return nil
}, key)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if s == nil || s.client == nil {
return nil, nil, fmt.Errorf("%s: nil store", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
return operationCtx, cancel, nil
}
func (s *Store) lookupKey(challengeID common.ChallengeID) string {
return s.keyPrefix + encodeKeyComponent(challengeID.String())
}
func encodeKeyComponent(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func marshalChallengeRecord(record challenge.Challenge) ([]byte, error) {
stored, err := redisRecordFromChallenge(record)
if err != nil {
return nil, err
}
payload, err := json.Marshal(stored)
if err != nil {
return nil, fmt.Errorf("encode redis challenge record: %w", err)
}
return payload, nil
}
func redisRecordFromChallenge(record challenge.Challenge) (redisRecord, error) {
if err := record.Validate(); err != nil {
return redisRecord{}, fmt.Errorf("encode redis challenge record: %w", err)
}
stored := redisRecord{
ChallengeID: record.ID.String(),
Email: record.Email.String(),
CodeHashBase64: base64.StdEncoding.EncodeToString(record.CodeHash),
Status: record.Status,
DeliveryState: record.DeliveryState,
CreatedAt: formatTimestamp(record.CreatedAt),
ExpiresAt: formatTimestamp(record.ExpiresAt),
SendAttemptCount: record.Attempts.Send,
ConfirmAttemptCount: record.Attempts.Confirm,
LastAttemptAt: formatOptionalTimestamp(record.Abuse.LastAttemptAt),
}
if record.Confirmation != nil {
stored.ConfirmedSessionID = record.Confirmation.SessionID.String()
stored.ConfirmedClientPublicKey = record.Confirmation.ClientPublicKey.String()
stored.ConfirmedAt = formatOptionalTimestamp(&record.Confirmation.ConfirmedAt)
}
return stored, nil
}
func decodeChallengeRecord(expectedChallengeID common.ChallengeID, payload []byte) (challenge.Challenge, error) {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
var stored redisRecord
if err := decoder.Decode(&stored); err != nil {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return challenge.Challenge{}, errors.New("decode redis challenge record: unexpected trailing JSON input")
}
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
}
record, err := challengeFromRedisRecord(stored)
if err != nil {
return challenge.Challenge{}, err
}
if record.ID != expectedChallengeID {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: challenge_id %q does not match requested %q", record.ID, expectedChallengeID)
}
return record, nil
}
func challengeFromRedisRecord(stored redisRecord) (challenge.Challenge, error) {
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
if err != nil {
return challenge.Challenge{}, err
}
expiresAt, err := parseTimestamp("expires_at", stored.ExpiresAt)
if err != nil {
return challenge.Challenge{}, err
}
lastAttemptAt, err := parseOptionalTimestamp("last_attempt_at", stored.LastAttemptAt)
if err != nil {
return challenge.Challenge{}, err
}
codeHash, err := base64.StdEncoding.Strict().DecodeString(stored.CodeHashBase64)
if err != nil {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: code_hash_base64: %w", err)
}
record := challenge.Challenge{
ID: common.ChallengeID(stored.ChallengeID),
Email: common.Email(stored.Email),
CodeHash: codeHash,
Status: stored.Status,
DeliveryState: stored.DeliveryState,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
Attempts: challenge.AttemptCounters{
Send: stored.SendAttemptCount,
Confirm: stored.ConfirmAttemptCount,
},
Abuse: challenge.AbuseMetadata{
LastAttemptAt: lastAttemptAt,
},
}
confirmation, err := parseConfirmation(stored)
if err != nil {
return challenge.Challenge{}, err
}
record.Confirmation = confirmation
if err := record.Validate(); err != nil {
return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err)
}
return record, nil
}
func parseConfirmation(stored redisRecord) (*challenge.Confirmation, error) {
hasSessionID := strings.TrimSpace(stored.ConfirmedSessionID) != ""
hasClientPublicKey := strings.TrimSpace(stored.ConfirmedClientPublicKey) != ""
hasConfirmedAt := stored.ConfirmedAt != nil
if !hasSessionID && !hasClientPublicKey && !hasConfirmedAt {
return nil, nil
}
if !hasSessionID || !hasClientPublicKey || !hasConfirmedAt {
return nil, errors.New("decode redis challenge record: confirmation metadata must be either fully present or fully absent")
}
confirmedAt, err := parseTimestamp("confirmed_at", *stored.ConfirmedAt)
if err != nil {
return nil, err
}
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ConfirmedClientPublicKey)
if err != nil {
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
}
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
if err != nil {
return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err)
}
return &challenge.Confirmation{
SessionID: common.DeviceSessionID(stored.ConfirmedSessionID),
ClientPublicKey: clientPublicKey,
ConfirmedAt: confirmedAt,
}, nil
}
func parseOptionalTimestamp(fieldName string, value *string) (*time.Time, error) {
if value == nil {
return nil, nil
}
parsed, err := parseTimestamp(fieldName, *value)
if err != nil {
return nil, err
}
return &parsed, nil
}
func parseTimestamp(fieldName string, value string) (time.Time, error) {
if strings.TrimSpace(value) == "" {
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must not be empty", fieldName)
}
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return time.Time{}, fmt.Errorf("decode redis challenge record: %s: %w", fieldName, err)
}
canonical := parsed.UTC().Format(time.RFC3339Nano)
if value != canonical {
return time.Time{}, fmt.Errorf("decode redis challenge record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
}
return parsed.UTC(), nil
}
func formatTimestamp(value time.Time) string {
return value.UTC().Format(time.RFC3339Nano)
}
func formatOptionalTimestamp(value *time.Time) *string {
if value == nil {
return nil
}
formatted := formatTimestamp(*value)
return &formatted
}
func redisTTL(expiresAt time.Time) time.Duration {
ttl := time.Until(expiresAt.UTC())
if ttl < 0 {
ttl = 0
}
return ttl + expirationGracePeriod
}
func equalStoredChallenges(left challenge.Challenge, right challenge.Challenge) (bool, error) {
leftRecord, err := redisRecordFromChallenge(left)
if err != nil {
return false, err
}
rightRecord, err := redisRecordFromChallenge(right)
if err != nil {
return false, err
}
return reflect.DeepEqual(leftRecord, rightRecord), nil
}
var _ ports.ChallengeStore = (*Store)(nil)
@@ -0,0 +1,531 @@
package challengestore
import (
"context"
"crypto/ed25519"
"encoding/json"
"testing"
"time"
"galaxy/authsession/internal/adapters/contracttest"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStoreContract(t *testing.T) {
t.Parallel()
contracttest.RunChallengeStoreContractTests(t, func(t *testing.T) ports.ChallengeStore {
t.Helper()
server := miniredis.RunT(t)
return newTestStore(t, server, Config{})
})
}
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 2,
KeyPrefix: "authsession:challenge:",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
KeyPrefix: "authsession:challenge:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
KeyPrefix: "authsession:challenge:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty key prefix",
cfg: Config{
Addr: server.Addr(),
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis key prefix must not be empty",
},
{
name: "non-positive operation timeout",
cfg: Config{
Addr: server.Addr(),
KeyPrefix: "authsession:challenge:",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
store, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
})
}
}
func TestStorePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
require.NoError(t, store.Ping(context.Background()))
}
func TestStoreCreateAndGet(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
now := time.Unix(1_775_130_000, 0).UTC()
record := testChallenge(now)
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
got.CodeHash[0] = 0xFF
keyBytes := got.Confirmation.ClientPublicKey.PublicKey()
keyBytes[0] = 0xFE
again, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record.CodeHash, again.CodeHash)
require.NotNil(t, again.Confirmation)
assert.Equal(t, record.Confirmation.ClientPublicKey.String(), again.Confirmation.ClientPublicKey.String())
}
func TestStoreCreateAndGetPendingChallenge(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
now := time.Unix(1_775_130_100, 0).UTC()
record := testPendingChallenge(now)
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
}
func TestStoreCreateAndGetThrottledChallenge(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
now := time.Unix(1_775_130_150, 0).UTC()
record := testPendingChallenge(now)
record.Status = challenge.StatusDeliveryThrottled
record.DeliveryState = challenge.DeliveryThrottled
record.Attempts.Send = 1
record.Abuse.LastAttemptAt = timePointer(now)
require.NoError(t, record.Validate())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
}
func TestStoreGetStrictDecode(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_130_200, 0).UTC()
baseRecord := testChallenge(now)
baseStored, err := redisRecordFromChallenge(baseRecord)
require.NoError(t, err)
tests := []struct {
name string
mutate func(redisRecord) string
wantErrText string
}{
{
name: "malformed json",
mutate: func(_ redisRecord) string {
return "{"
},
wantErrText: "decode redis challenge record",
},
{
name: "trailing json input",
mutate: func(record redisRecord) string {
return mustMarshalJSON(t, record) + "{}"
},
wantErrText: "unexpected trailing JSON input",
},
{
name: "unknown field",
mutate: func(record redisRecord) string {
payload := map[string]any{
"challenge_id": record.ChallengeID,
"email": record.Email,
"code_hash_base64": record.CodeHashBase64,
"status": record.Status,
"delivery_state": record.DeliveryState,
"created_at": record.CreatedAt,
"expires_at": record.ExpiresAt,
"send_attempt_count": record.SendAttemptCount,
"confirm_attempt_count": record.ConfirmAttemptCount,
"last_attempt_at": record.LastAttemptAt,
"confirmed_session_id": record.ConfirmedSessionID,
"confirmed_client_public_key": record.ConfirmedClientPublicKey,
"confirmed_at": record.ConfirmedAt,
"unexpected": true,
}
return mustMarshalJSON(t, payload)
},
wantErrText: "unknown field",
},
{
name: "unsupported status",
mutate: func(record redisRecord) string {
record.Status = challenge.Status("paused")
return mustMarshalJSON(t, record)
},
wantErrText: `status "paused" is unsupported`,
},
{
name: "unsupported delivery state",
mutate: func(record redisRecord) string {
record.DeliveryState = challenge.DeliveryState("queued")
return mustMarshalJSON(t, record)
},
wantErrText: `delivery state "queued" is unsupported`,
},
{
name: "missing required email",
mutate: func(record redisRecord) string {
record.Email = ""
return mustMarshalJSON(t, record)
},
wantErrText: "challenge email",
},
{
name: "challenge id mismatch",
mutate: func(record redisRecord) string {
record.ChallengeID = "other-challenge"
return mustMarshalJSON(t, record)
},
wantErrText: `does not match requested`,
},
{
name: "non canonical utc timestamp",
mutate: func(record redisRecord) string {
record.CreatedAt = "2026-04-04T12:00:00+03:00"
return mustMarshalJSON(t, record)
},
wantErrText: "canonical UTC RFC3339Nano timestamp",
},
{
name: "partial confirmation metadata",
mutate: func(record redisRecord) string {
record.ConfirmedAt = nil
return mustMarshalJSON(t, record)
},
wantErrText: "confirmation metadata must be either fully present or fully absent",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
server.Set(store.lookupKey(baseRecord.ID), tt.mutate(baseStored))
_, err := store.Get(context.Background(), baseRecord.ID)
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErrText)
})
}
}
func TestStoreKeySchemeAndTTL(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{KeyPrefix: "authsession:challenge:"})
now := time.Now().UTC()
prefixed := testPendingChallenge(now)
prefixed.ID = common.ChallengeID("challenge:opaque/id?value")
require.NoError(t, store.Create(context.Background(), prefixed))
key := store.lookupKey(prefixed.ID)
assert.Equal(t, "authsession:challenge:"+encodeKeyComponent(prefixed.ID.String()), key)
assert.True(t, server.Exists(key))
freshTTL := server.TTL(key)
assert.LessOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod)
assert.GreaterOrEqual(t, freshTTL, challenge.InitialTTL+expirationGracePeriod-2*time.Second)
expired := testPendingChallenge(now.Add(-10 * time.Minute))
expired.ID = common.ChallengeID("expired-challenge")
expired.CreatedAt = now.Add(-20 * time.Minute)
expired.ExpiresAt = now.Add(-1 * time.Minute)
require.NoError(t, store.Create(context.Background(), expired))
expiredTTL := server.TTL(store.lookupKey(expired.ID))
assert.LessOrEqual(t, expiredTTL, expirationGracePeriod)
assert.GreaterOrEqual(t, expiredTTL, expirationGracePeriod-2*time.Second)
}
func TestStoreCreateConflict(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := testPendingChallenge(time.Unix(1_775_130_300, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
err := store.Create(context.Background(), record)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
}
func TestStoreGetNotFound(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Get(context.Background(), common.ChallengeID("missing-challenge"))
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
}
func TestStoreCompareAndSwap(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_130_400, 0).UTC()
t.Run("success", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
previous := testPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
next.Attempts.Send = 1
next.Abuse.LastAttemptAt = timePointer(now.Add(1 * time.Minute))
require.NoError(t, store.Create(context.Background(), previous))
require.NoError(t, store.CompareAndSwap(context.Background(), previous, next))
got, err := store.Get(context.Background(), previous.ID)
require.NoError(t, err)
assert.Equal(t, next, got)
})
t.Run("conflict when stored record differs", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
stored := testPendingChallenge(now)
previous := stored
previous.Attempts.Send = 99
next := stored
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
require.NoError(t, store.Create(context.Background(), stored))
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
})
t.Run("not found", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
previous := testPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
t.Run("corrupt stored record returns adapter error", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
previous := testPendingChallenge(now)
next := previous
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
server.Set(store.lookupKey(previous.ID), "{")
err := store.CompareAndSwap(context.Background(), previous, next)
require.Error(t, err)
assert.NotErrorIs(t, err, ports.ErrConflict)
assert.ErrorContains(t, err, "decode redis challenge record")
})
}
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.KeyPrefix == "" {
cfg.KeyPrefix = "authsession:challenge:"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
store, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return store
}
func testPendingChallenge(now time.Time) challenge.Challenge {
return challenge.Challenge{
ID: common.ChallengeID("challenge-pending"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hashed-pending-code"),
Status: challenge.StatusPendingSend,
DeliveryState: challenge.DeliveryPending,
CreatedAt: now,
ExpiresAt: now.Add(challenge.InitialTTL),
}
}
func testChallenge(now time.Time) challenge.Challenge {
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31,
})
if err != nil {
panic(err)
}
return challenge.Challenge{
ID: common.ChallengeID("challenge-confirmed"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hashed-code"),
Status: challenge.StatusConfirmedPendingExpire,
DeliveryState: challenge.DeliverySent,
CreatedAt: now,
ExpiresAt: now.Add(challenge.ConfirmedRetention),
Attempts: challenge.AttemptCounters{
Send: 1,
Confirm: 2,
},
Abuse: challenge.AbuseMetadata{
LastAttemptAt: timePointer(now.Add(30 * time.Second)),
},
Confirmation: &challenge.Confirmation{
SessionID: common.DeviceSessionID("device-session-1"),
ClientPublicKey: clientPublicKey,
ConfirmedAt: now.Add(1 * time.Minute),
},
}
}
func timePointer(value time.Time) *time.Time {
return &value
}
func mustMarshalJSON(t *testing.T, value any) string {
t.Helper()
payload, err := json.Marshal(value)
require.NoError(t, err)
return string(payload)
}
func TestStorePingNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
err := store.Ping(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func TestStoreGetNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Get(nil, common.ChallengeID("challenge"))
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
@@ -0,0 +1,169 @@
// Package configprovider implements ports.ConfigProvider with Redis-backed
// dynamic auth/session configuration.
package configprovider
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
"strings"
"time"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
// Config configures one Redis-backed config provider instance.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// SessionLimitKey identifies the single Redis string key that stores the
// active-session-limit configuration value.
SessionLimitKey string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Store reads dynamic auth/session configuration from Redis.
type Store struct {
client *redis.Client
sessionLimitKey string
operationTimeout time.Duration
}
// New constructs a Redis-backed config provider from cfg.
func New(cfg Config) (*Store, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis config provider: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis config provider: redis db must not be negative")
case strings.TrimSpace(cfg.SessionLimitKey) == "":
return nil, errors.New("new redis config provider: session limit key must not be empty")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis config provider: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Store{
client: redis.NewClient(options),
sessionLimitKey: cfg.SessionLimitKey,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *Store) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (s *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := s.operationContext(ctx, "ping redis config provider")
if err != nil {
return err
}
defer cancel()
if err := s.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis config provider: %w", err)
}
return nil
}
// LoadSessionLimit returns the current active-session-limit configuration.
// Missing or invalid Redis values are treated as “limit absent” by policy.
func (s *Store) LoadSessionLimit(ctx context.Context) (ports.SessionLimitConfig, error) {
operationCtx, cancel, err := s.operationContext(ctx, "load session limit from redis")
if err != nil {
return ports.SessionLimitConfig{}, err
}
defer cancel()
value, err := s.client.Get(operationCtx, s.sessionLimitKey).Result()
switch {
case errors.Is(err, redis.Nil):
return ports.SessionLimitConfig{}, nil
case err != nil:
return ports.SessionLimitConfig{}, fmt.Errorf("load session limit from redis: %w", err)
}
config, valid := parseSessionLimitConfig(value)
if !valid {
return ports.SessionLimitConfig{}, nil
}
if err := config.Validate(); err != nil {
return ports.SessionLimitConfig{}, nil
}
return config, nil
}
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if s == nil || s.client == nil {
return nil, nil, fmt.Errorf("%s: nil store", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
return operationCtx, cancel, nil
}
func parseSessionLimitConfig(raw string) (ports.SessionLimitConfig, bool) {
if strings.TrimSpace(raw) == "" || strings.TrimSpace(raw) != raw {
return ports.SessionLimitConfig{}, false
}
for _, symbol := range raw {
if symbol < '0' || symbol > '9' {
return ports.SessionLimitConfig{}, false
}
}
parsed, err := strconv.ParseInt(raw, 10, strconv.IntSize)
if err != nil || parsed <= 0 {
return ports.SessionLimitConfig{}, false
}
limit := int(parsed)
return ports.SessionLimitConfig{
ActiveSessionLimit: &limit,
}, true
}
var _ ports.ConfigProvider = (*Store)(nil)
@@ -0,0 +1,283 @@
package configprovider
import (
"context"
"strconv"
"testing"
"time"
"galaxy/authsession/internal/adapters/contracttest"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStoreContract(t *testing.T) {
t.Parallel()
contracttest.RunConfigProviderContractTests(t, func(t *testing.T) contracttest.ConfigProviderHarness {
t.Helper()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
return contracttest.ConfigProviderHarness{
Provider: store,
SeedDisabled: func(t *testing.T) {
t.Helper()
server.Del(store.sessionLimitKey)
},
SeedLimit: func(t *testing.T, limit int) {
t.Helper()
server.Set(store.sessionLimitKey, strconv.Itoa(limit))
},
}
})
}
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 2,
SessionLimitKey: "authsession:config:active-session-limit",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
SessionLimitKey: "authsession:config:active-session-limit",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
SessionLimitKey: "authsession:config:active-session-limit",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty session limit key",
cfg: Config{
Addr: server.Addr(),
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session limit key must not be empty",
},
{
name: "non positive timeout",
cfg: Config{
Addr: server.Addr(),
SessionLimitKey: "authsession:config:active-session-limit",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
store, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
})
}
}
func TestStorePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
require.NoError(t, store.Ping(context.Background()))
}
func TestStoreLoadSessionLimit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
seed func(*testing.T, *miniredis.Miniredis, *Store)
wantConfig ports.SessionLimitConfig
}{
{
name: "missing key means disabled",
wantConfig: ports.SessionLimitConfig{},
},
{
name: "valid positive integer",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "5")
},
wantConfig: configWithLimit(5),
},
{
name: "empty string is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "whitespace only is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, " ")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "whitespace padded integer is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, " 5 ")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "non integer text is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "five")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "zero is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "0")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "negative integer is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "-3")
},
wantConfig: ports.SessionLimitConfig{},
},
{
name: "overflow is invalid and disabled",
seed: func(t *testing.T, server *miniredis.Miniredis, store *Store) {
t.Helper()
server.Set(store.sessionLimitKey, "999999999999999999999999999999")
},
wantConfig: ports.SessionLimitConfig{},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
if tt.seed != nil {
tt.seed(t, server, store)
}
got, err := store.LoadSessionLimit(context.Background())
require.NoError(t, err)
assert.Equal(t, tt.wantConfig, got)
})
}
}
func TestStoreLoadSessionLimitBackendFailure(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
server.Close()
_, err := store.LoadSessionLimit(context.Background())
require.Error(t, err)
assert.ErrorContains(t, err, "load session limit from redis")
}
func TestStoreLoadSessionLimitNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.LoadSessionLimit(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func TestStorePingNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
err := store.Ping(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.SessionLimitKey == "" {
cfg.SessionLimitKey = "authsession:config:active-session-limit"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
store, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return store
}
func configWithLimit(limit int) ports.SessionLimitConfig {
return ports.SessionLimitConfig{
ActiveSessionLimit: &limit,
}
}
@@ -0,0 +1,223 @@
// Package projectionpublisher implements
// ports.GatewaySessionProjectionPublisher with Redis-backed gateway-compatible
// cache snapshots and session lifecycle events.
package projectionpublisher
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
// Config configures one Redis-backed gateway session projection publisher.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// SessionCacheKeyPrefix is the namespace prefix applied to gateway session
// cache keys. The raw device session identifier is appended directly.
SessionCacheKeyPrefix string
// SessionEventsStream identifies the gateway session lifecycle Redis Stream.
SessionEventsStream string
// StreamMaxLen bounds the session lifecycle stream with approximate
// trimming via XADD MAXLEN ~.
StreamMaxLen int64
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Publisher publishes gateway-compatible session projections into Redis cache
// and stream namespaces.
type Publisher struct {
client *redis.Client
sessionCacheKeyPrefix string
sessionEventsStream string
streamMaxLen int64
operationTimeout time.Duration
}
type cacheRecord struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKey string `json:"client_public_key"`
Status gatewayprojection.Status `json:"status"`
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
}
// New constructs a Redis-backed gateway session projection publisher from
// cfg.
func New(cfg Config) (*Publisher, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis projection publisher: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis projection publisher: redis db must not be negative")
case strings.TrimSpace(cfg.SessionCacheKeyPrefix) == "":
return nil, errors.New("new redis projection publisher: session cache key prefix must not be empty")
case strings.TrimSpace(cfg.SessionEventsStream) == "":
return nil, errors.New("new redis projection publisher: session events stream must not be empty")
case cfg.StreamMaxLen <= 0:
return nil, errors.New("new redis projection publisher: stream max len must be positive")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis projection publisher: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Publisher{
client: redis.NewClient(options),
sessionCacheKeyPrefix: cfg.SessionCacheKeyPrefix,
sessionEventsStream: cfg.SessionEventsStream,
streamMaxLen: cfg.StreamMaxLen,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (p *Publisher) Close() error {
if p == nil || p.client == nil {
return nil
}
return p.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (p *Publisher) Ping(ctx context.Context) error {
operationCtx, cancel, err := p.operationContext(ctx, "ping redis projection publisher")
if err != nil {
return err
}
defer cancel()
if err := p.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis projection publisher: %w", err)
}
return nil
}
// PublishSession writes one gateway-compatible session snapshot into the
// gateway cache namespace and appends the same snapshot to the gateway session
// event stream within one Redis transaction.
func (p *Publisher) PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error {
if err := snapshot.Validate(); err != nil {
return fmt.Errorf("publish session projection to redis: %w", err)
}
payload, err := marshalCacheRecord(snapshot)
if err != nil {
return fmt.Errorf("publish session projection to redis: %w", err)
}
values := buildStreamValues(snapshot)
operationCtx, cancel, err := p.operationContext(ctx, "publish session projection to redis")
if err != nil {
return err
}
defer cancel()
key := p.sessionCacheKey(snapshot.DeviceSessionID)
_, err = p.client.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, key, payload, 0)
pipe.XAdd(operationCtx, &redis.XAddArgs{
Stream: p.sessionEventsStream,
MaxLen: p.streamMaxLen,
Approx: true,
Values: values,
})
return nil
})
if err != nil {
return fmt.Errorf("publish session projection %q to redis: %w", snapshot.DeviceSessionID, err)
}
return nil
}
func (p *Publisher) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if p == nil || p.client == nil {
return nil, nil, fmt.Errorf("%s: nil publisher", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
return operationCtx, cancel, nil
}
func (p *Publisher) sessionCacheKey(deviceSessionID interface{ String() string }) string {
return p.sessionCacheKeyPrefix + deviceSessionID.String()
}
func marshalCacheRecord(snapshot gatewayprojection.Snapshot) ([]byte, error) {
record := cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: snapshot.Status,
}
if snapshot.RevokedAt != nil {
revokedAtMS := snapshot.RevokedAt.UTC().UnixMilli()
record.RevokedAtMS = &revokedAtMS
}
payload, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("marshal gateway session cache record: %w", err)
}
return payload, nil
}
func buildStreamValues(snapshot gatewayprojection.Snapshot) map[string]any {
values := map[string]any{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(snapshot.Status),
}
if snapshot.RevokedAt != nil {
values["revoked_at_ms"] = fmt.Sprint(snapshot.RevokedAt.UTC().UnixMilli())
}
return values
}
var _ ports.GatewaySessionProjectionPublisher = (*Publisher)(nil)
@@ -0,0 +1,442 @@
package projectionpublisher
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/gatewayprojection"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 3,
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty session cache key prefix",
cfg: Config{
Addr: server.Addr(),
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session cache key prefix must not be empty",
},
{
name: "empty session events stream",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
StreamMaxLen: 1024,
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session events stream must not be empty",
},
{
name: "non positive stream max len",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "stream max len must be positive",
},
{
name: "non positive timeout",
cfg: Config{
Addr: server.Addr(),
SessionCacheKeyPrefix: "gateway:session:",
SessionEventsStream: "gateway:session_events",
StreamMaxLen: 1024,
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
publisher, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, publisher.Close())
})
})
}
}
func TestPublisherPing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
require.NoError(t, publisher.Ping(context.Background()))
}
func TestPublisherPublishSessionActive(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
snapshot := testSnapshot("device/session:opaque?1", gatewayprojection.StatusActive, nil)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
assert.Equal(t, "gateway:session:"+snapshot.DeviceSessionID.String(), key)
assert.True(t, server.Exists(key))
assert.False(t, server.Exists("gateway:session:"+encodeBase64URL(snapshot.DeviceSessionID.String())))
payload, err := server.Get(key)
require.NoError(t, err)
record := decodeCachePayload(t, payload)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusActive,
}, record)
assert.Zero(t, server.TTL(key))
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, map[string]string{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(gatewayprojection.StatusActive),
}, stringifyValues(entries[0].Values))
}
func TestPublisherPublishSessionRevoked(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
revokedAt := time.Unix(1_776_000_123, 456_000_000).UTC()
snapshot := testSnapshot("device-session-123", gatewayprojection.StatusRevoked, &revokedAt)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
key := publisher.sessionCacheKey(snapshot.DeviceSessionID)
payload, err := server.Get(key)
require.NoError(t, err)
record := decodeCachePayload(t, payload)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusRevoked,
RevokedAtMS: int64Pointer(revokedAt.UnixMilli()),
}, record)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, map[string]string{
"device_session_id": snapshot.DeviceSessionID.String(),
"user_id": snapshot.UserID.String(),
"client_public_key": snapshot.ClientPublicKey,
"status": string(gatewayprojection.StatusRevoked),
"revoked_at_ms": "1776000123456",
}, stringifyValues(entries[0].Values))
}
func TestPublisherPublishSessionLaterSnapshotWinsInCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
deviceSessionID := "device-session-456"
active := testSnapshot(deviceSessionID, gatewayprojection.StatusActive, nil)
revokedAt := time.Unix(1_776_010_000, 0).UTC()
revoked := testSnapshot(deviceSessionID, gatewayprojection.StatusRevoked, &revokedAt)
require.NoError(t, publisher.PublishSession(context.Background(), active))
require.NoError(t, publisher.PublishSession(context.Background(), revoked))
payload, err := server.Get(publisher.sessionCacheKey(revoked.DeviceSessionID))
require.NoError(t, err)
record := decodeCachePayload(t, payload)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, revokedAt.UnixMilli(), *record.RevokedAtMS)
assert.Equal(t, gatewayprojection.StatusRevoked, record.Status)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 2)
assert.Equal(t, map[string]string{
"device_session_id": active.DeviceSessionID.String(),
"user_id": active.UserID.String(),
"client_public_key": active.ClientPublicKey,
"status": string(gatewayprojection.StatusActive),
}, stringifyValues(entries[0].Values))
assert.Equal(t, map[string]string{
"device_session_id": revoked.DeviceSessionID.String(),
"user_id": revoked.UserID.String(),
"client_public_key": revoked.ClientPublicKey,
"status": string(gatewayprojection.StatusRevoked),
"revoked_at_ms": "1776010000000",
}, stringifyValues(entries[1].Values))
}
func TestPublisherPublishSessionRepeatedPublishIsRetrySafe(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 8})
snapshot := testSnapshot("device-session-retry", gatewayprojection.StatusActive, nil)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
payload, err := server.Get(publisher.sessionCacheKey(snapshot.DeviceSessionID))
require.NoError(t, err)
record := decodeCachePayload(t, payload)
assert.Equal(t, cacheRecord{
DeviceSessionID: snapshot.DeviceSessionID.String(),
UserID: snapshot.UserID.String(),
ClientPublicKey: snapshot.ClientPublicKey,
Status: gatewayprojection.StatusActive,
}, record)
entries, err := publisher.client.XRange(context.Background(), publisher.sessionEventsStream, "-", "+").Result()
require.NoError(t, err)
require.Len(t, entries, 2)
assert.Equal(t, stringifyValues(entries[0].Values), stringifyValues(entries[1].Values))
}
func TestPublisherPublishSessionStreamMaxLenApprox(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{StreamMaxLen: 2})
for index := range 6 {
snapshot := testSnapshot(
common.DeviceSessionID("device-session-"+string(rune('a'+index))).String(),
gatewayprojection.StatusActive,
nil,
)
require.NoError(t, publisher.PublishSession(context.Background(), snapshot))
}
streamLength, err := publisher.client.XLen(context.Background(), publisher.sessionEventsStream).Result()
require.NoError(t, err)
assert.LessOrEqual(t, streamLength, int64(2))
}
func TestPublisherPublishSessionInvalidSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID("device-session-123"),
UserID: common.UserID("user-123"),
Status: gatewayprojection.StatusActive,
}
err := publisher.PublishSession(context.Background(), snapshot)
require.Error(t, err)
assert.ErrorContains(t, err, "gateway projection client public key")
assert.Empty(t, server.Keys())
}
func TestPublisherPublishSessionNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
err := publisher.PublishSession(nil, testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func TestPublisherPublishSessionBackendFailure(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
server.Close()
err := publisher.PublishSession(context.Background(), testSnapshot("device-session-123", gatewayprojection.StatusActive, nil))
require.Error(t, err)
assert.ErrorContains(t, err, "publish session projection")
}
func TestPublisherPingNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := newTestPublisher(t, server, Config{})
err := publisher.Ping(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestPublisher(t *testing.T, server *miniredis.Miniredis, cfg Config) *Publisher {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.SessionCacheKeyPrefix == "" {
cfg.SessionCacheKeyPrefix = "gateway:session:"
}
if cfg.SessionEventsStream == "" {
cfg.SessionEventsStream = "gateway:session_events"
}
if cfg.StreamMaxLen == 0 {
cfg.StreamMaxLen = 1024
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
publisher, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, publisher.Close())
})
return publisher
}
func testSnapshot(deviceSessionID string, status gatewayprojection.Status, revokedAt *time.Time) gatewayprojection.Snapshot {
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
for index := range raw {
raw[index] = byte(index + 1)
}
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID("user-123"),
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
Status: status,
RevokedAt: revokedAt,
}
if status == gatewayprojection.StatusRevoked {
snapshot.RevokeReasonCode = common.RevokeReasonCode("user_blocked")
snapshot.RevokeActorType = common.RevokeActorType("system")
}
return snapshot
}
func decodeCachePayload(t *testing.T, payload string) cacheRecord {
t.Helper()
decoder := json.NewDecoder(bytes.NewReader([]byte(payload)))
decoder.DisallowUnknownFields()
var record cacheRecord
require.NoError(t, decoder.Decode(&record))
err := decoder.Decode(&struct{}{})
if err == nil {
require.FailNow(t, "expected cache payload EOF after first JSON value")
}
require.ErrorIs(t, err, io.EOF)
var fieldSet map[string]json.RawMessage
require.NoError(t, json.Unmarshal([]byte(payload), &fieldSet))
expectedFields := map[string]struct{}{
"device_session_id": {},
"user_id": {},
"client_public_key": {},
"status": {},
}
if record.RevokedAtMS != nil {
expectedFields["revoked_at_ms"] = struct{}{}
}
assert.Equal(t, len(expectedFields), len(fieldSet))
for field := range fieldSet {
_, ok := expectedFields[field]
assert.Truef(t, ok, "unexpected cache payload field %q", field)
}
return record
}
func stringifyValues(values map[string]any) map[string]string {
stringified := make(map[string]string, len(values))
for key, value := range values {
stringified[key] = fmt.Sprint(value)
}
return stringified
}
func encodeBase64URL(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func int64Pointer(value int64) *int64 {
return &value
}
@@ -0,0 +1,152 @@
// Package sendemailcodeabuse implements ports.SendEmailCodeAbuseProtector with
// one Redis TTL key per normalized e-mail address.
package sendemailcodeabuse
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
// Config configures one Redis-backed send-email-code abuse protector.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// KeyPrefix is the namespace prefix applied to every resend-throttle key.
KeyPrefix string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Protector applies the fixed resend cooldown with one Redis key per
// normalized e-mail address.
type Protector struct {
client *redis.Client
keyPrefix string
operationTimeout time.Duration
}
// New constructs a Redis-backed resend-throttle protector from cfg.
func New(cfg Config) (*Protector, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis send email code abuse protector: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis send email code abuse protector: redis db must not be negative")
case strings.TrimSpace(cfg.KeyPrefix) == "":
return nil, errors.New("new redis send email code abuse protector: redis key prefix must not be empty")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis send email code abuse protector: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Protector{
client: redis.NewClient(options),
keyPrefix: cfg.KeyPrefix,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (p *Protector) Close() error {
if p == nil || p.client == nil {
return nil
}
return p.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (p *Protector) Ping(ctx context.Context) error {
operationCtx, cancel, err := p.operationContext(ctx, "ping redis send email code abuse protector")
if err != nil {
return err
}
defer cancel()
if err := p.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis send email code abuse protector: %w", err)
}
return nil
}
// CheckAndReserve applies the fixed resend cooldown using one TTL key per
// normalized e-mail address.
func (p *Protector) CheckAndReserve(ctx context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
if err := input.Validate(); err != nil {
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse: %w", err)
}
operationCtx, cancel, err := p.operationContext(ctx, "check and reserve send email code abuse")
if err != nil {
return ports.SendEmailCodeAbuseResult{}, err
}
defer cancel()
key := p.lookupKey(input.Email)
value := input.Now.UTC().Add(challenge.ResendThrottleCooldown).Format(time.RFC3339Nano)
created, err := p.client.SetNX(operationCtx, key, value, challenge.ResendThrottleCooldown).Result()
if err != nil {
return ports.SendEmailCodeAbuseResult{}, fmt.Errorf("check and reserve send email code abuse for %q: %w", input.Email, err)
}
if created {
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeAllowed}, nil
}
return ports.SendEmailCodeAbuseResult{Outcome: ports.SendEmailCodeAbuseOutcomeThrottled}, nil
}
func (p *Protector) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if p == nil || p.client == nil {
return nil, nil, fmt.Errorf("%s: nil protector", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, p.operationTimeout)
return operationCtx, cancel, nil
}
func (p *Protector) lookupKey(email common.Email) string {
return p.keyPrefix + base64.RawURLEncoding.EncodeToString([]byte(email.String()))
}
var _ ports.SendEmailCodeAbuseProtector = (*Protector)(nil)
@@ -0,0 +1,176 @@
package sendemailcodeabuse
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 1,
KeyPrefix: "authsession:send-email-code-throttle:",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
KeyPrefix: "authsession:send-email-code-throttle:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
KeyPrefix: "authsession:send-email-code-throttle:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty key prefix",
cfg: Config{
Addr: server.Addr(),
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis key prefix must not be empty",
},
{
name: "non-positive timeout",
cfg: Config{
Addr: server.Addr(),
KeyPrefix: "authsession:send-email-code-throttle:",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
protector, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, protector.Close())
})
})
}
}
func TestProtectorPing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
protector := newTestProtector(t, server, Config{})
require.NoError(t, protector.Ping(context.Background()))
}
func TestProtectorCheckAndReserve(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
protector := newTestProtector(t, server, Config{})
email := common.Email("pilot@example.com")
now := time.Unix(10, 0).UTC()
result, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now,
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
key := protector.lookupKey(email)
assert.True(t, server.Exists(key))
ttl := server.TTL(key)
assert.LessOrEqual(t, ttl, challenge.ResendThrottleCooldown)
assert.GreaterOrEqual(t, ttl, challenge.ResendThrottleCooldown-2*time.Second)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(30 * time.Second),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeThrottled, result.Outcome)
ttlAfterThrottle := server.TTL(key)
assert.LessOrEqual(t, ttlAfterThrottle, ttl)
server.FastForward(challenge.ResendThrottleCooldown)
result, err = protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now.Add(challenge.ResendThrottleCooldown),
})
require.NoError(t, err)
assert.Equal(t, ports.SendEmailCodeAbuseOutcomeAllowed, result.Outcome)
}
func TestProtectorNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
protector := newTestProtector(t, server, Config{})
_, err := protector.CheckAndReserve(nil, ports.SendEmailCodeAbuseInput{
Email: common.Email("pilot@example.com"),
Now: time.Unix(10, 0).UTC(),
})
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestProtector(t *testing.T, server *miniredis.Miniredis, cfg Config) *Protector {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.KeyPrefix == "" {
cfg.KeyPrefix = "authsession:send-email-code-throttle:"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
protector, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, protector.Close())
})
return protector
}
@@ -0,0 +1,723 @@
// Package sessionstore implements ports.SessionStore with Redis-backed strict
// JSON source-of-truth session records and per-user indexes.
package sessionstore
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"github.com/redis/go-redis/v9"
)
const mutationRetryLimit = 3
// Config configures one Redis-backed session store instance.
type Config struct {
// Addr is the Redis network address in host:port form.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled enables TLS with a conservative minimum protocol version.
TLSEnabled bool
// SessionKeyPrefix is the namespace prefix applied to primary session keys.
SessionKeyPrefix string
// UserSessionsKeyPrefix is the namespace prefix applied to all-session user
// indexes.
UserSessionsKeyPrefix string
// UserActiveSessionsKeyPrefix is the namespace prefix applied to active
// session user indexes.
UserActiveSessionsKeyPrefix string
// OperationTimeout bounds each Redis round trip performed by the adapter.
OperationTimeout time.Duration
}
// Store persists source-of-truth sessions in Redis and maintains user-scoped
// indexes for list and count operations.
type Store struct {
client *redis.Client
sessionKeyPrefix string
userSessionsKeyPrefix string
userActiveSessionsKeyPrefix string
operationTimeout time.Duration
}
type redisRecord struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKeyBase64 string `json:"client_public_key_base64"`
Status devicesession.Status `json:"status"`
CreatedAt string `json:"created_at"`
RevokedAt *string `json:"revoked_at,omitempty"`
RevokeReasonCode string `json:"revoke_reason_code,omitempty"`
RevokeActorType string `json:"revoke_actor_type,omitempty"`
RevokeActorID string `json:"revoke_actor_id,omitempty"`
}
// New constructs a Redis-backed session store from cfg.
func New(cfg Config) (*Store, error) {
switch {
case strings.TrimSpace(cfg.Addr) == "":
return nil, errors.New("new redis session store: redis addr must not be empty")
case cfg.DB < 0:
return nil, errors.New("new redis session store: redis db must not be negative")
case strings.TrimSpace(cfg.SessionKeyPrefix) == "":
return nil, errors.New("new redis session store: session key prefix must not be empty")
case strings.TrimSpace(cfg.UserSessionsKeyPrefix) == "":
return nil, errors.New("new redis session store: user sessions key prefix must not be empty")
case strings.TrimSpace(cfg.UserActiveSessionsKeyPrefix) == "":
return nil, errors.New("new redis session store: user active sessions key prefix must not be empty")
case cfg.OperationTimeout <= 0:
return nil, errors.New("new redis session store: operation timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &Store{
client: redis.NewClient(options),
sessionKeyPrefix: cfg.SessionKeyPrefix,
userSessionsKeyPrefix: cfg.UserSessionsKeyPrefix,
userActiveSessionsKeyPrefix: cfg.UserActiveSessionsKeyPrefix,
operationTimeout: cfg.OperationTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *Store) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// adapter operation timeout budget.
func (s *Store) Ping(ctx context.Context) error {
operationCtx, cancel, err := s.operationContext(ctx, "ping redis session store")
if err != nil {
return err
}
defer cancel()
if err := s.client.Ping(operationCtx).Err(); err != nil {
return fmt.Errorf("ping redis session store: %w", err)
}
return nil
}
// Get returns the stored session for deviceSessionID.
func (s *Store) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
if err := deviceSessionID.Validate(); err != nil {
return devicesession.Session{}, fmt.Errorf("get session from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "get session from redis")
if err != nil {
return devicesession.Session{}, err
}
defer cancel()
record, err := s.loadSession(operationCtx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, ports.ErrNotFound)
default:
return devicesession.Session{}, fmt.Errorf("get session %q from redis: %w", deviceSessionID, err)
}
}
return record, nil
}
// ListByUserID returns every stored session for userID in newest-first order.
func (s *Store) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
if err := userID.Validate(); err != nil {
return nil, fmt.Errorf("list sessions by user id from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "list sessions by user id from redis")
if err != nil {
return nil, err
}
defer cancel()
deviceSessionIDs, err := s.client.ZRevRange(operationCtx, s.userSessionsKey(userID), 0, -1).Result()
if err != nil {
return nil, fmt.Errorf("list sessions by user id %q from redis: %w", userID, err)
}
if len(deviceSessionIDs) == 0 {
return []devicesession.Session{}, nil
}
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
for _, rawDeviceSessionID := range deviceSessionIDs {
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
record, err := s.loadSession(operationCtx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return nil, fmt.Errorf("list sessions by user id %q from redis: all-sessions index references missing session %q", userID, deviceSessionID)
default:
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q: %w", userID, deviceSessionID, err)
}
}
if record.UserID != userID {
return nil, fmt.Errorf("list sessions by user id %q from redis: session %q belongs to %q", userID, deviceSessionID, record.UserID)
}
records = append(records, record)
}
sortSessionsNewestFirst(records)
return records, nil
}
// CountActiveByUserID returns the number of active sessions currently stored
// for userID.
func (s *Store) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
if err := userID.Validate(); err != nil {
return 0, fmt.Errorf("count active sessions by user id from redis: %w", err)
}
operationCtx, cancel, err := s.operationContext(ctx, "count active sessions by user id from redis")
if err != nil {
return 0, err
}
defer cancel()
count, err := s.client.ZCard(operationCtx, s.userActiveSessionsKey(userID)).Result()
if err != nil {
return 0, fmt.Errorf("count active sessions by user id %q from redis: %w", userID, err)
}
return int(count), nil
}
// Create persists record as a new device session.
func (s *Store) Create(ctx context.Context, record devicesession.Session) error {
if err := record.Validate(); err != nil {
return fmt.Errorf("create session in redis: %w", err)
}
payload, err := marshalSessionRecord(record)
if err != nil {
return fmt.Errorf("create session in redis: %w", err)
}
deviceSessionKey := s.sessionKey(record.ID)
allSessionsKey := s.userSessionsKey(record.UserID)
activeSessionsKey := s.userActiveSessionsKey(record.UserID)
operationCtx, cancel, err := s.operationContext(ctx, "create session in redis")
if err != nil {
return err
}
defer cancel()
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
_, err := tx.Get(operationCtx, deviceSessionKey).Bytes()
switch {
case errors.Is(err, redis.Nil):
case err != nil:
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
default:
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
pipe.ZAdd(operationCtx, allSessionsKey, redis.Z{
Score: createdAtScore(record.CreatedAt),
Member: record.ID.String(),
})
if record.Status == devicesession.StatusActive {
pipe.ZAdd(operationCtx, activeSessionsKey, redis.Z{
Score: createdAtScore(record.CreatedAt),
Member: record.ID.String(),
})
}
return nil
})
if err != nil {
return fmt.Errorf("create session %q in redis: %w", record.ID, err)
}
return nil
}, deviceSessionKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return fmt.Errorf("create session %q in redis: %w", record.ID, ports.ErrConflict)
case watchErr != nil:
return watchErr
default:
return nil
}
}
// Revoke stores a revoked view of one target session.
func (s *Store) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
if err := input.Validate(); err != nil {
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session in redis: %w", err)
}
var result ports.RevokeSessionResult
err := s.runMutation(ctx, "revoke session in redis", func(operationCtx context.Context) error {
deviceSessionKey := s.sessionKey(input.DeviceSessionID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
current, err := s.loadSessionWithGetter(operationCtx, input.DeviceSessionID, tx.Get)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, ports.ErrNotFound)
default:
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
}
if current.Status == devicesession.StatusRevoked {
result = ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
Session: current,
}
return result.Validate()
}
next := current
next.Status = devicesession.StatusRevoked
revocation := input.Revocation
next.Revocation = &revocation
if err := next.Validate(); err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
payload, err := marshalSessionRecord(next)
if err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.Set(operationCtx, deviceSessionKey, payload, 0)
pipe.ZRem(operationCtx, s.userActiveSessionsKey(current.UserID), current.ID.String())
return nil
})
if err != nil {
return fmt.Errorf("revoke session %q in redis: %w", input.DeviceSessionID, err)
}
result = ports.RevokeSessionResult{
Outcome: ports.RevokeSessionOutcomeRevoked,
Session: next,
}
return result.Validate()
}, deviceSessionKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return errRetryMutation
case watchErr != nil:
return watchErr
default:
return nil
}
})
if err != nil {
return ports.RevokeSessionResult{}, err
}
return result, nil
}
// RevokeAllByUserID stores revoked views for all currently active sessions
// owned by input.UserID.
func (s *Store) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
if err := input.Validate(); err != nil {
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions in redis: %w", err)
}
var result ports.RevokeUserSessionsResult
err := s.runMutation(ctx, "revoke user sessions in redis", func(operationCtx context.Context) error {
activeSessionsKey := s.userActiveSessionsKey(input.UserID)
watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error {
deviceSessionIDs, err := tx.ZRevRange(operationCtx, activeSessionsKey, 0, -1).Result()
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
if len(deviceSessionIDs) == 0 {
// Force EXEC so WATCH observes concurrent active-index changes even
// for the no-op path.
_, err := tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
pipe.ZCard(operationCtx, activeSessionsKey)
return nil
})
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
result = ports.RevokeUserSessionsResult{
Outcome: ports.RevokeUserSessionsOutcomeNoActiveSessions,
UserID: input.UserID,
Sessions: []devicesession.Session{},
}
return result.Validate()
}
records := make([]devicesession.Session, 0, len(deviceSessionIDs))
for _, rawDeviceSessionID := range deviceSessionIDs {
deviceSessionID := common.DeviceSessionID(rawDeviceSessionID)
record, err := s.loadSessionWithGetter(operationCtx, deviceSessionID, tx.Get)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return fmt.Errorf("revoke user sessions %q in redis: active index references missing session %q", input.UserID, deviceSessionID)
default:
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
}
}
if record.UserID != input.UserID {
return fmt.Errorf("revoke user sessions %q in redis: active index session %q belongs to %q", input.UserID, deviceSessionID, record.UserID)
}
if record.Status != devicesession.StatusActive {
return fmt.Errorf("revoke user sessions %q in redis: active index session %q is %q", input.UserID, deviceSessionID, record.Status)
}
next := record
next.Status = devicesession.StatusRevoked
revocation := input.Revocation
next.Revocation = &revocation
if err := next.Validate(); err != nil {
return fmt.Errorf("revoke user sessions %q in redis: session %q: %w", input.UserID, deviceSessionID, err)
}
records = append(records, next)
}
_, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error {
for _, record := range records {
payload, err := marshalSessionRecord(record)
if err != nil {
return fmt.Errorf("session %q: %w", record.ID, err)
}
pipe.Set(operationCtx, s.sessionKey(record.ID), payload, 0)
pipe.ZRem(operationCtx, activeSessionsKey, record.ID.String())
}
return nil
})
if err != nil {
return fmt.Errorf("revoke user sessions %q in redis: %w", input.UserID, err)
}
sortSessionsNewestFirst(records)
result = ports.RevokeUserSessionsResult{
Outcome: ports.RevokeUserSessionsOutcomeRevoked,
UserID: input.UserID,
Sessions: records,
}
return result.Validate()
}, activeSessionsKey)
switch {
case errors.Is(watchErr, redis.TxFailedErr):
return errRetryMutation
case watchErr != nil:
return watchErr
default:
return nil
}
})
if err != nil {
return ports.RevokeUserSessionsResult{}, err
}
return result, nil
}
var errRetryMutation = errors.New("redis session store: retry mutation")
func (s *Store) runMutation(ctx context.Context, operation string, execute func(context.Context) error) error {
for attempt := 0; attempt < mutationRetryLimit; attempt++ {
operationCtx, cancel, err := s.operationContext(ctx, operation)
if err != nil {
return err
}
err = execute(operationCtx)
cancel()
switch {
case errors.Is(err, errRetryMutation):
if attempt == mutationRetryLimit-1 {
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
}
continue
default:
return err
}
}
return fmt.Errorf("%s: mutation retry limit exceeded", operation)
}
func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) {
if s == nil || s.client == nil {
return nil, nil, fmt.Errorf("%s: nil store", operation)
}
if ctx == nil {
return nil, nil, fmt.Errorf("%s: nil context", operation)
}
operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout)
return operationCtx, cancel, nil
}
func (s *Store) loadSession(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
return s.loadSessionWithGetter(ctx, deviceSessionID, s.client.Get)
}
func (s *Store) loadSessionWithGetter(
ctx context.Context,
deviceSessionID common.DeviceSessionID,
getter func(context.Context, string) *redis.StringCmd,
) (devicesession.Session, error) {
payload, err := getter(ctx, s.sessionKey(deviceSessionID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return devicesession.Session{}, ports.ErrNotFound
case err != nil:
return devicesession.Session{}, err
}
record, err := decodeSessionRecord(deviceSessionID, payload)
if err != nil {
return devicesession.Session{}, err
}
return record, nil
}
func (s *Store) sessionKey(deviceSessionID common.DeviceSessionID) string {
return s.sessionKeyPrefix + encodeKeyComponent(deviceSessionID.String())
}
func (s *Store) userSessionsKey(userID common.UserID) string {
return s.userSessionsKeyPrefix + encodeKeyComponent(userID.String())
}
func (s *Store) userActiveSessionsKey(userID common.UserID) string {
return s.userActiveSessionsKeyPrefix + encodeKeyComponent(userID.String())
}
func encodeKeyComponent(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func marshalSessionRecord(record devicesession.Session) ([]byte, error) {
stored, err := redisRecordFromSession(record)
if err != nil {
return nil, err
}
payload, err := json.Marshal(stored)
if err != nil {
return nil, fmt.Errorf("encode redis session record: %w", err)
}
return payload, nil
}
func redisRecordFromSession(record devicesession.Session) (redisRecord, error) {
if err := record.Validate(); err != nil {
return redisRecord{}, fmt.Errorf("encode redis session record: %w", err)
}
stored := redisRecord{
DeviceSessionID: record.ID.String(),
UserID: record.UserID.String(),
ClientPublicKeyBase64: record.ClientPublicKey.String(),
Status: record.Status,
CreatedAt: formatTimestamp(record.CreatedAt),
}
if record.Revocation != nil {
stored.RevokedAt = formatOptionalTimestamp(&record.Revocation.At)
stored.RevokeReasonCode = record.Revocation.ReasonCode.String()
stored.RevokeActorType = record.Revocation.ActorType.String()
stored.RevokeActorID = record.Revocation.ActorID
}
return stored, nil
}
func decodeSessionRecord(expectedDeviceSessionID common.DeviceSessionID, payload []byte) (devicesession.Session, error) {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
var stored redisRecord
if err := decoder.Decode(&stored); err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return devicesession.Session{}, errors.New("decode redis session record: unexpected trailing JSON input")
}
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
record, err := sessionFromRedisRecord(stored)
if err != nil {
return devicesession.Session{}, err
}
if record.ID != expectedDeviceSessionID {
return devicesession.Session{}, fmt.Errorf("decode redis session record: device_session_id %q does not match requested %q", record.ID, expectedDeviceSessionID)
}
return record, nil
}
func sessionFromRedisRecord(stored redisRecord) (devicesession.Session, error) {
createdAt, err := parseTimestamp("created_at", stored.CreatedAt)
if err != nil {
return devicesession.Session{}, err
}
rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ClientPublicKeyBase64)
if err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
}
clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey)
if err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: client_public_key_base64: %w", err)
}
record := devicesession.Session{
ID: common.DeviceSessionID(stored.DeviceSessionID),
UserID: common.UserID(stored.UserID),
ClientPublicKey: clientPublicKey,
Status: stored.Status,
CreatedAt: createdAt,
}
revocation, err := parseRevocation(stored)
if err != nil {
return devicesession.Session{}, err
}
record.Revocation = revocation
if err := record.Validate(); err != nil {
return devicesession.Session{}, fmt.Errorf("decode redis session record: %w", err)
}
return record, nil
}
func parseRevocation(stored redisRecord) (*devicesession.Revocation, error) {
hasRevokedAt := stored.RevokedAt != nil
hasReasonCode := strings.TrimSpace(stored.RevokeReasonCode) != ""
hasActorType := strings.TrimSpace(stored.RevokeActorType) != ""
hasActorID := strings.TrimSpace(stored.RevokeActorID) != ""
if !hasRevokedAt && !hasReasonCode && !hasActorType && !hasActorID {
return nil, nil
}
if !hasRevokedAt || !hasReasonCode || !hasActorType {
return nil, errors.New("decode redis session record: revocation metadata must be either fully present or fully absent")
}
revokedAt, err := parseTimestamp("revoked_at", *stored.RevokedAt)
if err != nil {
return nil, err
}
return &devicesession.Revocation{
At: revokedAt,
ReasonCode: common.RevokeReasonCode(stored.RevokeReasonCode),
ActorType: common.RevokeActorType(stored.RevokeActorType),
ActorID: stored.RevokeActorID,
}, nil
}
func parseTimestamp(fieldName string, value string) (time.Time, error) {
if strings.TrimSpace(value) == "" {
return time.Time{}, fmt.Errorf("decode redis session record: %s must not be empty", fieldName)
}
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return time.Time{}, fmt.Errorf("decode redis session record: %s: %w", fieldName, err)
}
canonical := parsed.UTC().Format(time.RFC3339Nano)
if value != canonical {
return time.Time{}, fmt.Errorf("decode redis session record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName)
}
return parsed.UTC(), nil
}
func formatTimestamp(value time.Time) string {
return value.UTC().Format(time.RFC3339Nano)
}
func formatOptionalTimestamp(value *time.Time) *string {
if value == nil {
return nil
}
formatted := formatTimestamp(*value)
return &formatted
}
func createdAtScore(createdAt time.Time) float64 {
return float64(createdAt.UTC().UnixMicro())
}
func sortSessionsNewestFirst(records []devicesession.Session) {
slices.SortFunc(records, func(left devicesession.Session, right devicesession.Session) int {
switch {
case left.CreatedAt.Equal(right.CreatedAt):
return strings.Compare(left.ID.String(), right.ID.String())
case left.CreatedAt.After(right.CreatedAt):
return -1
default:
return 1
}
})
}
var _ ports.SessionStore = (*Store)(nil)
@@ -0,0 +1,635 @@
package sessionstore
import (
"context"
"crypto/ed25519"
"encoding/json"
"testing"
"time"
"galaxy/authsession/internal/adapters/contracttest"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStoreContract(t *testing.T) {
t.Parallel()
contracttest.RunSessionStoreContractTests(t, func(t *testing.T) ports.SessionStore {
t.Helper()
server := miniredis.RunT(t)
return newTestStore(t, server, Config{})
})
}
func TestNew(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
Addr: server.Addr(),
DB: 1,
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: Config{
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: Config{
Addr: server.Addr(),
DB: -1,
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty session prefix",
cfg: Config{
Addr: server.Addr(),
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "session key prefix must not be empty",
},
{
name: "empty all sessions prefix",
cfg: Config{
Addr: server.Addr(),
SessionKeyPrefix: "authsession:session:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "user sessions key prefix must not be empty",
},
{
name: "empty active sessions prefix",
cfg: Config{
Addr: server.Addr(),
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
OperationTimeout: 250 * time.Millisecond,
},
wantErr: "user active sessions key prefix must not be empty",
},
{
name: "non positive timeout",
cfg: Config{
Addr: server.Addr(),
SessionKeyPrefix: "authsession:session:",
UserSessionsKeyPrefix: "authsession:user-sessions:",
UserActiveSessionsKeyPrefix: "authsession:user-active-sessions:",
},
wantErr: "operation timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
store, err := New(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
})
}
}
func TestStorePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
require.NoError(t, store.Ping(context.Background()))
}
func TestStoreCreateAndGetActive(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_000, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
got.Revocation = &devicesession.Revocation{
At: got.CreatedAt.Add(time.Minute),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
}
again, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Nil(t, again.Revocation)
assert.Equal(t, record, again)
}
func TestStoreCreateAndGetRevoked(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(1_775_240_100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
got, err := store.Get(context.Background(), record.ID)
require.NoError(t, err)
assert.Equal(t, record, got)
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
require.NoError(t, err)
assert.Zero(t, count)
}
func TestStoreGetNotFound(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Get(context.Background(), common.DeviceSessionID("missing-session"))
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
}
func TestStoreCreateConflict(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-1", time.Unix(1_775_240_200, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
err := store.Create(context.Background(), record)
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrConflict)
}
func TestStoreIndexesAndOrdering(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
older := activeSessionFixture("device-session-old", "user-1", time.Unix(10, 0).UTC())
newer := activeSessionFixture("device-session-new", "user-1", time.Unix(20, 0).UTC())
revoked := revokedSessionFixture("device-session-revoked", "user-1", time.Unix(15, 0).UTC())
otherUser := activeSessionFixture("device-session-other", "user-2", time.Unix(30, 0).UTC())
for _, record := range []devicesession.Session{older, newer, revoked, otherUser} {
require.NoError(t, store.Create(context.Background(), record))
}
got, err := store.ListByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
require.Len(t, got, 3)
assert.Equal(t, []common.DeviceSessionID{newer.ID, revoked.ID, older.ID}, []common.DeviceSessionID{got[0].ID, got[1].ID, got[2].ID})
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
assert.Equal(t, 2, count)
unknown, err := store.ListByUserID(context.Background(), common.UserID("unknown-user"))
require.NoError(t, err)
assert.Empty(t, unknown)
}
func TestStoreKeyPrefixesAndEncodedPrimaryKey(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{
SessionKeyPrefix: "custom:session:",
UserSessionsKeyPrefix: "custom:user-sessions:",
UserActiveSessionsKeyPrefix: "custom:user-active-sessions:",
})
record := activeSessionFixture("device/session:opaque?1", "user/opaque:1", time.Unix(40, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
primaryKey := store.sessionKey(record.ID)
assert.Equal(t, "custom:session:"+encodeKeyComponent(record.ID.String()), primaryKey)
assert.True(t, server.Exists(primaryKey))
allSessionsKey := store.userSessionsKey(record.UserID)
activeSessionsKey := store.userActiveSessionsKey(record.UserID)
assert.Equal(t, "custom:user-sessions:"+encodeKeyComponent(record.UserID.String()), allSessionsKey)
assert.Equal(t, "custom:user-active-sessions:"+encodeKeyComponent(record.UserID.String()), activeSessionsKey)
allMembers, err := server.ZMembers(allSessionsKey)
require.NoError(t, err)
assert.Equal(t, []string{record.ID.String()}, allMembers)
activeMembers, err := server.ZMembers(activeSessionsKey)
require.NoError(t, err)
assert.Equal(t, []string{record.ID.String()}, activeMembers)
}
func TestStoreRevoke(t *testing.T) {
t.Parallel()
t.Run("active session", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
revocation := devicesession.Revocation{
At: time.Unix(200, 0).UTC(),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
}
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: revocation,
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeSessionOutcomeRevoked, result.Outcome)
require.NotNil(t, result.Session.Revocation)
assert.Equal(t, revocation, *result.Session.Revocation)
count, err := store.CountActiveByUserID(context.Background(), record.UserID)
require.NoError(t, err)
assert.Zero(t, count)
})
t.Run("already revoked keeps stored revocation", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-2", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
result, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: devicesession.Revocation{
At: time.Unix(300, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
ActorID: "admin-1",
},
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeSessionOutcomeAlreadyRevoked, result.Outcome)
require.NotNil(t, result.Session.Revocation)
assert.Equal(t, *record.Revocation, *result.Session.Revocation)
})
t.Run("unknown session", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
_, err := store.Revoke(context.Background(), ports.RevokeSessionInput{
DeviceSessionID: common.DeviceSessionID("missing-session"),
Revocation: devicesession.Revocation{
At: time.Unix(200, 0).UTC(),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
},
})
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
}
func TestStoreRevokeAllByUserID(t *testing.T) {
t.Parallel()
t.Run("revokes active sessions newest first and clears active index", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
older := activeSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(200, 0).UTC())
alreadyRevoked := revokedSessionFixture("device-session-3", "user-1", time.Unix(150, 0).UTC())
otherUser := activeSessionFixture("device-session-4", "user-2", time.Unix(250, 0).UTC())
for _, record := range []devicesession.Session{older, newer, alreadyRevoked, otherUser} {
require.NoError(t, store.Create(context.Background(), record))
}
revocation := devicesession.Revocation{
At: time.Unix(300, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
ActorID: "admin-1",
}
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: common.UserID("user-1"),
Revocation: revocation,
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeUserSessionsOutcomeRevoked, result.Outcome)
require.Len(t, result.Sessions, 2)
assert.Equal(t, []common.DeviceSessionID{newer.ID, older.ID}, []common.DeviceSessionID{result.Sessions[0].ID, result.Sessions[1].ID})
assert.Equal(t, revocation, *result.Sessions[0].Revocation)
assert.Equal(t, revocation, *result.Sessions[1].Revocation)
count, err := store.CountActiveByUserID(context.Background(), common.UserID("user-1"))
require.NoError(t, err)
assert.Zero(t, count)
otherCount, err := store.CountActiveByUserID(context.Background(), common.UserID("user-2"))
require.NoError(t, err)
assert.Equal(t, 1, otherCount)
})
t.Run("no active sessions", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-5", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, store.Create(context.Background(), record))
result, err := store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: common.UserID("user-1"),
Revocation: devicesession.Revocation{
At: time.Unix(400, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
},
})
require.NoError(t, err)
assert.Equal(t, ports.RevokeUserSessionsOutcomeNoActiveSessions, result.Outcome)
assert.Empty(t, result.Sessions)
})
}
func TestStoreStrictDecodeCorruption(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_240_300, 0).UTC()
baseRecord := revokedSessionFixture("device-session-corrupt", "user-1", now)
stored, err := redisRecordFromSession(baseRecord)
require.NoError(t, err)
tests := []struct {
name string
mutate func(redisRecord) string
wantErrText string
}{
{
name: "malformed json",
mutate: func(_ redisRecord) string {
return "{"
},
wantErrText: "decode redis session record",
},
{
name: "trailing json input",
mutate: func(record redisRecord) string {
return mustMarshalJSON(t, record) + "{}"
},
wantErrText: "unexpected trailing JSON input",
},
{
name: "unknown field",
mutate: func(record redisRecord) string {
payload := map[string]any{
"device_session_id": record.DeviceSessionID,
"user_id": record.UserID,
"client_public_key_base64": record.ClientPublicKeyBase64,
"status": record.Status,
"created_at": record.CreatedAt,
"revoked_at": record.RevokedAt,
"revoke_reason_code": record.RevokeReasonCode,
"revoke_actor_type": record.RevokeActorType,
"revoke_actor_id": record.RevokeActorID,
"unexpected": true,
}
return mustMarshalJSON(t, payload)
},
wantErrText: "unknown field",
},
{
name: "unsupported status",
mutate: func(record redisRecord) string {
record.Status = devicesession.Status("paused")
return mustMarshalJSON(t, record)
},
wantErrText: `status "paused" is unsupported`,
},
{
name: "non canonical timestamp",
mutate: func(record redisRecord) string {
record.CreatedAt = "2026-04-04T12:00:00+03:00"
return mustMarshalJSON(t, record)
},
wantErrText: "canonical UTC RFC3339Nano timestamp",
},
{
name: "incomplete revocation metadata",
mutate: func(record redisRecord) string {
record.RevokeActorType = ""
return mustMarshalJSON(t, record)
},
wantErrText: "revocation metadata must be either fully present or fully absent",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
server.Set(store.sessionKey(baseRecord.ID), tt.mutate(stored))
_, err := store.Get(context.Background(), baseRecord.ID)
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErrText)
})
}
}
func TestStoreListByUserIDDetectsCorruptIndexes(t *testing.T) {
t.Parallel()
t.Run("missing primary record", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
userID := common.UserID("user-1")
_, err := server.ZAdd(store.userSessionsKey(userID), 100, "missing-session")
require.NoError(t, err)
_, err = store.ListByUserID(context.Background(), userID)
require.Error(t, err)
assert.ErrorContains(t, err, "references missing session")
})
t.Run("wrong user id in primary record", func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := activeSessionFixture("device-session-1", "user-2", time.Unix(100, 0).UTC())
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
_, err := server.ZAdd(store.userSessionsKey(common.UserID("user-1")), createdAtScore(record.CreatedAt), record.ID.String())
require.NoError(t, err)
_, err = store.ListByUserID(context.Background(), common.UserID("user-1"))
require.Error(t, err)
assert.ErrorContains(t, err, `belongs to "user-2"`)
})
}
func TestStoreRevokeAllByUserIDDetectsCorruptActiveIndex(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestStore(t, server, Config{})
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(100, 0).UTC())
require.NoError(t, seedSessionRecord(t, server, store.sessionKey(record.ID), record))
_, err := server.ZAdd(store.userActiveSessionsKey(record.UserID), createdAtScore(record.CreatedAt), record.ID.String())
require.NoError(t, err)
_, err = store.RevokeAllByUserID(context.Background(), ports.RevokeUserSessionsInput{
UserID: record.UserID,
Revocation: devicesession.Revocation{
At: time.Unix(200, 0).UTC(),
ReasonCode: devicesession.RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
},
})
require.Error(t, err)
assert.ErrorContains(t, err, `is "revoked"`)
}
func newTestStore(t *testing.T, server *miniredis.Miniredis, cfg Config) *Store {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.SessionKeyPrefix == "" {
cfg.SessionKeyPrefix = "authsession:session:"
}
if cfg.UserSessionsKeyPrefix == "" {
cfg.UserSessionsKeyPrefix = "authsession:user-sessions:"
}
if cfg.UserActiveSessionsKeyPrefix == "" {
cfg.UserActiveSessionsKeyPrefix = "authsession:user-active-sessions:"
}
if cfg.OperationTimeout == 0 {
cfg.OperationTimeout = 250 * time.Millisecond
}
store, err := New(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return store
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
clientPublicKey, err := common.NewClientPublicKey(ed25519.PublicKey{
0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31,
})
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: clientPublicKey,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
record := activeSessionFixture(deviceSessionID, userID, createdAt)
record.Status = devicesession.StatusRevoked
record.Revocation = &devicesession.Revocation{
At: createdAt.Add(time.Minute),
ReasonCode: devicesession.RevokeReasonDeviceLogout,
ActorType: common.RevokeActorType("user"),
ActorID: "user-actor",
}
return record
}
func seedSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record devicesession.Session) error {
t.Helper()
stored, err := redisRecordFromSession(record)
require.NoError(t, err)
server.Set(key, mustMarshalJSON(t, stored))
return nil
}
func mustMarshalJSON(t *testing.T, value any) string {
t.Helper()
payload, err := json.Marshal(value)
require.NoError(t, err)
return string(payload)
}
@@ -0,0 +1,382 @@
// Package userservice provides runtime user-directory adapters for the
// auth/session service.
package userservice
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
)
const (
resolveByEmailPath = "/api/v1/internal/user-resolutions/by-email"
existsByUserIDPath = "/api/v1/internal/users/%s/exists"
ensureByEmailPath = "/api/v1/internal/users/ensure-by-email"
blockByUserIDPath = "/api/v1/internal/users/%s/block"
blockByEmailPath = "/api/v1/internal/user-blocks/by-email"
)
// Config configures one HTTP-based UserDirectory client.
type Config struct {
// BaseURL is the absolute base URL of the future user-service internal
// HTTP API.
BaseURL string
// RequestTimeout bounds each outbound user-service request.
RequestTimeout time.Duration
}
// RESTClient implements ports.UserDirectory over a frozen internal REST
// contract.
type RESTClient struct {
baseURL string
requestTimeout time.Duration
httpClient *http.Client
}
// NewRESTClient constructs a REST-backed UserDirectory adapter from cfg.
func NewRESTClient(cfg Config) (*RESTClient, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
return newRESTClient(cfg, &http.Client{Transport: transport})
}
func newRESTClient(cfg Config, httpClient *http.Client) (*RESTClient, error) {
switch {
case strings.TrimSpace(cfg.BaseURL) == "":
return nil, errors.New("new user service REST client: base URL must not be empty")
case cfg.RequestTimeout <= 0:
return nil, errors.New("new user service REST client: request timeout must be positive")
case httpClient == nil:
return nil, errors.New("new user service REST client: http client must not be nil")
}
parsedBaseURL, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
if err != nil {
return nil, fmt.Errorf("new user service REST client: parse base URL: %w", err)
}
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
return nil, errors.New("new user service REST client: base URL must be absolute")
}
return &RESTClient{
baseURL: parsedBaseURL.String(),
requestTimeout: cfg.RequestTimeout,
httpClient: httpClient,
}, nil
}
// Close releases idle HTTP connections owned by the client transport.
func (c *RESTClient) Close() error {
if c == nil || c.httpClient == nil {
return nil
}
type idleCloser interface {
CloseIdleConnections()
}
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
transport.CloseIdleConnections()
}
return nil
}
// ResolveByEmail returns the current coarse user-resolution state for email
// without creating any new user record.
func (c *RESTClient) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
if err := validateContext(ctx, "resolve by email"); err != nil {
return userresolution.Result{}, err
}
if err := email.Validate(); err != nil {
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
}
var response struct {
Kind userresolution.Kind `json:"kind"`
UserID string `json:"user_id,omitempty"`
BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"`
}
if err := c.doJSON(ctx, "resolve by email", http.MethodPost, resolveByEmailPath, map[string]string{
"email": email.String(),
}, &response, true); err != nil {
return userresolution.Result{}, err
}
result := userresolution.Result{
Kind: response.Kind,
UserID: common.UserID(response.UserID),
BlockReasonCode: response.BlockReasonCode,
}
if err := result.Validate(); err != nil {
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
}
return result, nil
}
// ExistsByUserID reports whether userID currently identifies a stored user
// record.
func (c *RESTClient) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
if err := validateContext(ctx, "exists by user id"); err != nil {
return false, err
}
if err := userID.Validate(); err != nil {
return false, fmt.Errorf("exists by user id: %w", err)
}
var response struct {
Exists bool `json:"exists"`
}
if err := c.doJSON(ctx, "exists by user id", http.MethodGet, fmt.Sprintf(existsByUserIDPath, url.PathEscape(userID.String())), nil, &response, true); err != nil {
return false, err
}
return response.Exists, nil
}
// EnsureUserByEmail returns an existing user for email, creates a new user
// when registration is allowed, or reports a blocked outcome.
func (c *RESTClient) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
if err := validateContext(ctx, "ensure user by email"); err != nil {
return ports.EnsureUserResult{}, err
}
if err := email.Validate(); err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
var response struct {
Outcome ports.EnsureUserOutcome `json:"outcome"`
UserID string `json:"user_id,omitempty"`
BlockReasonCode userresolution.BlockReasonCode `json:"block_reason_code,omitempty"`
}
if err := c.doJSON(ctx, "ensure user by email", http.MethodPost, ensureByEmailPath, map[string]string{
"email": email.String(),
}, &response, false); err != nil {
return ports.EnsureUserResult{}, err
}
result := ports.EnsureUserResult{
Outcome: response.Outcome,
UserID: common.UserID(response.UserID),
BlockReasonCode: response.BlockReasonCode,
}
if err := result.Validate(); err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
return result, nil
}
// BlockByUserID applies a block state to the user identified by input.UserID.
// Unknown user ids wrap ports.ErrNotFound.
func (c *RESTClient) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
if err := validateContext(ctx, "block by user id"); err != nil {
return ports.BlockUserResult{}, err
}
if err := input.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
}
payload, statusCode, err := c.doRequest(ctx, "block by user id", http.MethodPost, fmt.Sprintf(blockByUserIDPath, url.PathEscape(input.UserID.String())), map[string]string{
"reason_code": input.ReasonCode.String(),
}, false)
if err != nil {
return ports.BlockUserResult{}, err
}
if statusCode == http.StatusNotFound {
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
}
if statusCode != http.StatusOK {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: unexpected HTTP status %d", statusCode)
}
var response struct {
Outcome ports.BlockUserOutcome `json:"outcome"`
UserID string `json:"user_id,omitempty"`
}
if err := decodeJSONPayload(payload, &response); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
}
result := ports.BlockUserResult{
Outcome: response.Outcome,
UserID: common.UserID(response.UserID),
}
if err := result.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
}
return result, nil
}
// BlockByEmail applies a block state to input.Email even when no user record
// currently exists for that e-mail address.
func (c *RESTClient) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
if err := validateContext(ctx, "block by email"); err != nil {
return ports.BlockUserResult{}, err
}
if err := input.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
}
var response struct {
Outcome ports.BlockUserOutcome `json:"outcome"`
UserID string `json:"user_id,omitempty"`
}
if err := c.doJSON(ctx, "block by email", http.MethodPost, blockByEmailPath, map[string]string{
"email": input.Email.String(),
"reason_code": input.ReasonCode.String(),
}, &response, false); err != nil {
return ports.BlockUserResult{}, err
}
result := ports.BlockUserResult{
Outcome: response.Outcome,
UserID: common.UserID(response.UserID),
}
if err := result.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
}
return result, nil
}
func (c *RESTClient) doJSON(ctx context.Context, operation string, method string, requestPath string, requestBody any, responseTarget any, retryRead bool) error {
payload, statusCode, err := c.doRequest(ctx, operation, method, requestPath, requestBody, retryRead)
if err != nil {
return err
}
if statusCode != http.StatusOK {
return fmt.Errorf("%s: unexpected HTTP status %d", operation, statusCode)
}
if err := decodeJSONPayload(payload, responseTarget); err != nil {
return fmt.Errorf("%s: %w", operation, err)
}
return nil
}
func (c *RESTClient) doRequest(ctx context.Context, operation string, method string, requestPath string, requestBody any, retryRead bool) ([]byte, int, error) {
bodyBytes, err := marshalOptionalRequestBody(requestBody)
if err != nil {
return nil, 0, fmt.Errorf("%s: %w", operation, err)
}
attempts := 1
if retryRead {
attempts = 2
}
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
attemptCtx, cancel := context.WithTimeout(ctx, c.requestTimeout)
request, err := http.NewRequestWithContext(attemptCtx, method, c.baseURL+requestPath, bytes.NewReader(bodyBytes))
if err != nil {
cancel()
return nil, 0, fmt.Errorf("%s: build request: %w", operation, err)
}
if method == http.MethodPost {
request.Header.Set("Content-Type", "application/json")
}
response, err := c.httpClient.Do(request)
if err != nil {
cancel()
lastErr = fmt.Errorf("%s: %w", operation, err)
if retryRead && attempt == 0 && ctx.Err() == nil {
continue
}
return nil, 0, lastErr
}
payload, readErr := io.ReadAll(response.Body)
closeErr := response.Body.Close()
cancel()
if readErr != nil {
lastErr = fmt.Errorf("%s: read response body: %w", operation, readErr)
if retryRead && attempt == 0 && ctx.Err() == nil {
continue
}
return nil, 0, lastErr
}
if closeErr != nil {
lastErr = fmt.Errorf("%s: close response body: %w", operation, closeErr)
if retryRead && attempt == 0 && ctx.Err() == nil {
continue
}
return nil, 0, lastErr
}
if retryRead && attempt == 0 && isRetriableUserServiceStatus(response.StatusCode) {
lastErr = fmt.Errorf("%s: unexpected HTTP status %d", operation, response.StatusCode)
continue
}
return payload, response.StatusCode, nil
}
return nil, 0, lastErr
}
func marshalOptionalRequestBody(value any) ([]byte, error) {
if value == nil {
return nil, nil
}
payload, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("marshal request body: %w", err)
}
return payload, nil
}
func decodeJSONPayload(payload []byte, target any) error {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return fmt.Errorf("decode response body: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return errors.New("decode response body: unexpected trailing JSON input")
}
return fmt.Errorf("decode response body: %w", err)
}
return nil
}
func isRetriableUserServiceStatus(statusCode int) bool {
switch statusCode {
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return true
default:
return false
}
}
var _ ports.UserDirectory = (*RESTClient)(nil)
@@ -0,0 +1,622 @@
package userservice
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRESTClient(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
BaseURL: "http://127.0.0.1:8080",
RequestTimeout: time.Second,
},
},
{
name: "empty base url",
cfg: Config{
RequestTimeout: time.Second,
},
wantErr: "base URL must not be empty",
},
{
name: "relative base url",
cfg: Config{
BaseURL: "/relative",
RequestTimeout: time.Second,
},
wantErr: "base URL must be absolute",
},
{
name: "non positive timeout",
cfg: Config{
BaseURL: "http://127.0.0.1:8080",
},
wantErr: "request timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client, err := NewRESTClient(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
assert.NoError(t, client.Close())
})
}
}
func TestRESTClientEndpointSuccessCases(t *testing.T) {
t.Parallel()
tests := []struct {
name string
run func(*testing.T, *RESTClient)
}{
{
name: "resolve by email",
run: func(t *testing.T, client *RESTClient) {
result, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Case@example.com"))
require.NoError(t, err)
assert.Equal(t, userresolution.Result{
Kind: userresolution.KindExisting,
UserID: common.UserID("user-123"),
}, result)
},
},
{
name: "exists by user id",
run: func(t *testing.T, client *RESTClient) {
exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
require.NoError(t, err)
assert.True(t, exists)
},
},
{
name: "ensure user by email",
run: func(t *testing.T, client *RESTClient) {
result, err := client.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
require.NoError(t, err)
assert.Equal(t, ports.EnsureUserResult{
Outcome: ports.EnsureUserOutcomeCreated,
UserID: common.UserID("user-234"),
}, result)
},
},
{
name: "block by user id",
run: func(t *testing.T, client *RESTClient) {
result, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("user-123"),
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeBlocked,
UserID: common.UserID("user-123"),
}, result)
},
},
{
name: "block by email",
run: func(t *testing.T, client *RESTClient) {
result, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("blocked@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
UserID: common.UserID("user-345"),
}, result)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var requestsMu sync.Mutex
var requests []capturedRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsMu.Lock()
requests = append(requests, captureRequest(t, r))
requestsMu.Unlock()
switch {
case r.Method == http.MethodPost && r.URL.Path == resolveByEmailPath:
writeJSON(t, w, http.StatusOK, map[string]any{
"kind": "existing",
"user_id": "user-123",
})
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/internal/users/user-123/exists":
writeJSON(t, w, http.StatusOK, map[string]any{"exists": true})
case r.Method == http.MethodPost && r.URL.Path == ensureByEmailPath:
writeJSON(t, w, http.StatusOK, map[string]any{
"outcome": "created",
"user_id": "user-234",
})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/internal/users/user-123/block":
writeJSON(t, w, http.StatusOK, map[string]any{
"outcome": "blocked",
"user_id": "user-123",
})
case r.Method == http.MethodPost && r.URL.Path == blockByEmailPath:
writeJSON(t, w, http.StatusOK, map[string]any{
"outcome": "already_blocked",
"user_id": "user-345",
})
default:
http.NotFound(w, r)
}
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
tt.run(t, client)
requestsMu.Lock()
defer requestsMu.Unlock()
require.Len(t, requests, 1)
switch tt.name {
case "resolve by email":
assert.Equal(t, capturedRequest{
Method: http.MethodPost,
Path: resolveByEmailPath,
ContentType: "application/json",
Body: `{"email":"Pilot+Case@example.com"}`,
}, requests[0])
case "exists by user id":
assert.Equal(t, capturedRequest{
Method: http.MethodGet,
Path: "/api/v1/internal/users/user-123/exists",
}, requests[0])
case "ensure user by email":
assert.Equal(t, capturedRequest{
Method: http.MethodPost,
Path: ensureByEmailPath,
ContentType: "application/json",
Body: `{"email":"created@example.com"}`,
}, requests[0])
case "block by user id":
assert.Equal(t, capturedRequest{
Method: http.MethodPost,
Path: "/api/v1/internal/users/user-123/block",
ContentType: "application/json",
Body: `{"reason_code":"policy_blocked"}`,
}, requests[0])
case "block by email":
assert.Equal(t, capturedRequest{
Method: http.MethodPost,
Path: blockByEmailPath,
ContentType: "application/json",
Body: `{"email":"blocked@example.com","reason_code":"policy_blocked"}`,
}, requests[0])
}
})
}
}
func TestRESTClientPreservesNormalizedEmailExactly(t *testing.T) {
t.Parallel()
var captured string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request := captureRequest(t, r)
captured = request.Body
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
_, err := client.ResolveByEmail(context.Background(), common.Email("Pilot+Alias@Example.com"))
require.NoError(t, err)
assert.Equal(t, `{"email":"Pilot+Alias@Example.com"}`, captured)
}
func TestRESTClientBlockByUserIDNotFound(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("missing-user"),
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
})
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
}
func TestRESTClientReadMethodsRetryOnce(t *testing.T) {
t.Parallel()
t.Run("resolve by email retries on 503", func(t *testing.T) {
t.Parallel()
var calls int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls++
if calls == 1 {
http.Error(w, "temporary", http.StatusServiceUnavailable)
return
}
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
result, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, err)
assert.Equal(t, userresolution.KindCreatable, result.Kind)
assert.Equal(t, 2, calls)
})
t.Run("exists by user id retries on transport failure", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeJSON(t, w, http.StatusOK, map[string]any{"exists": true})
}))
defer server.Close()
baseTransport := server.Client().Transport
client, err := newRESTClient(Config{
BaseURL: server.URL,
RequestTimeout: 250 * time.Millisecond,
}, &http.Client{
Transport: &failOnceRoundTripper{
next: baseTransport,
err: errors.New("temporary transport failure"),
},
})
require.NoError(t, err)
exists, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
require.NoError(t, err)
assert.True(t, exists)
})
}
func TestRESTClientMutationMethodsDoNotRetry(t *testing.T) {
t.Parallel()
tests := []struct {
name string
run func(*RESTClient) error
}{
{
name: "ensure user by email",
run: func(client *RESTClient) error {
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
return err
},
},
{
name: "block by user id",
run: func(client *RESTClient) error {
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("user-123"),
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
})
return err
},
},
{
name: "block by email",
run: func(client *RESTClient) error {
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("pilot@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
})
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var calls int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls++
http.Error(w, "temporary", http.StatusServiceUnavailable)
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
err := tt.run(client)
require.Error(t, err)
assert.Equal(t, 1, calls)
})
}
}
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
body string
wantErrText string
run func(*RESTClient) error
}{
{
name: "resolve by email rejects unknown field",
statusCode: http.StatusOK,
body: `{"kind":"creatable","extra":true}`,
wantErrText: "decode response body",
run: func(client *RESTClient) error {
_, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
return err
},
},
{
name: "ensure user by email rejects malformed outcome",
statusCode: http.StatusOK,
body: `{"outcome":"mystery"}`,
wantErrText: "unsupported",
run: func(client *RESTClient) error {
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
return err
},
},
{
name: "ensure user by email rejects missing user id for created outcome",
statusCode: http.StatusOK,
body: `{"outcome":"created"}`,
wantErrText: "user id",
run: func(client *RESTClient) error {
_, err := client.EnsureUserByEmail(context.Background(), common.Email("pilot@example.com"))
return err
},
},
{
name: "exists by user id rejects trailing json",
statusCode: http.StatusOK,
body: `{"exists":true}{}`,
wantErrText: "unexpected trailing JSON input",
run: func(client *RESTClient) error {
_, err := client.ExistsByUserID(context.Background(), common.UserID("user-123"))
return err
},
},
{
name: "block by email rejects unexpected status",
statusCode: http.StatusBadGateway,
body: `{"error":"temporary"}`,
wantErrText: "unexpected HTTP status 502",
run: func(client *RESTClient) error {
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("pilot@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
})
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.statusCode)
_, err := io.WriteString(w, tt.body)
require.NoError(t, err)
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
err := tt.run(client)
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErrText)
})
}
}
func TestRESTClientRequestTimeout(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(40 * time.Millisecond)
writeJSON(t, w, http.StatusOK, map[string]any{"kind": "creatable"})
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
_, err := client.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.Error(t, err)
assert.ErrorContains(t, err, "context deadline exceeded")
}
func TestRESTClientContextAndValidation(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected upstream call")
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
tests := []struct {
name string
run func() error
}{
{
name: "nil context",
run: func() error {
_, err := client.ResolveByEmail(nil, common.Email("pilot@example.com"))
return err
},
},
{
name: "cancelled context",
run: func() error {
_, err := client.ExistsByUserID(cancelledCtx, common.UserID("user-123"))
return err
},
},
{
name: "invalid email",
run: func() error {
_, err := client.EnsureUserByEmail(context.Background(), common.Email(" bad@example.com "))
return err
},
},
{
name: "invalid user id",
run: func() error {
_, err := client.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID(" bad "),
ReasonCode: userresolution.BlockReasonCode("policy_blocked"),
})
return err
},
},
{
name: "invalid reason code",
run: func() error {
_, err := client.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("pilot@example.com"),
ReasonCode: userresolution.BlockReasonCode(" bad "),
})
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
err := tt.run()
require.Error(t, err)
})
}
}
type capturedRequest struct {
Method string
Path string
ContentType string
Body string
}
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
t.Helper()
body, err := io.ReadAll(request.Body)
require.NoError(t, err)
return capturedRequest{
Method: request.Method,
Path: request.URL.Path,
ContentType: request.Header.Get("Content-Type"),
Body: strings.TrimSpace(string(body)),
}
}
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
t.Helper()
payload, err := json.Marshal(value)
require.NoError(t, err)
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(statusCode)
_, err = writer.Write(payload)
require.NoError(t, err)
}
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
t.Helper()
client, err := NewRESTClient(Config{
BaseURL: baseURL,
RequestTimeout: timeout,
})
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return client
}
type failOnceRoundTripper struct {
mu sync.Mutex
next http.RoundTripper
err error
done bool
}
func (rt *failOnceRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
rt.mu.Lock()
if !rt.done {
rt.done = true
err := rt.err
rt.mu.Unlock()
return nil, err
}
next := rt.next
rt.mu.Unlock()
return next.RoundTrip(request)
}
@@ -0,0 +1,361 @@
// Package userservice provides runtime user-directory adapters for the
// auth/session service.
package userservice
import (
"context"
"fmt"
"sync"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
)
type entry struct {
userID common.UserID
blockReasonCode userresolution.BlockReasonCode
}
// StubDirectory is a concurrency-safe in-process UserDirectory stub intended
// for development, local integration, and explicit stub-based tests.
//
// The zero value is ready to use. Unknown e-mail addresses resolve as
// creatable, unknown user identifiers do not exist, and EnsureUserByEmail
// creates deterministic user ids such as "user-1", "user-2", and so on.
type StubDirectory struct {
mu sync.Mutex
byEmail map[common.Email]entry
emailByUserID map[common.UserID]common.Email
createdUserIDs []common.UserID
nextUserNumber int
}
// ResolveByEmail returns the current coarse user-resolution state for email
// without creating any new user record.
func (d *StubDirectory) ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error) {
if err := validateContext(ctx, "resolve by email"); err != nil {
return userresolution.Result{}, err
}
if err := email.Validate(); err != nil {
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
result, err := d.resolveLocked(email)
if err != nil {
return userresolution.Result{}, fmt.Errorf("resolve by email: %w", err)
}
return result, nil
}
// ExistsByUserID reports whether userID currently identifies a stored user
// record.
func (d *StubDirectory) ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error) {
if err := validateContext(ctx, "exists by user id"); err != nil {
return false, err
}
if err := userID.Validate(); err != nil {
return false, fmt.Errorf("exists by user id: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
_, ok := d.emailByUserID[userID]
return ok, nil
}
// EnsureUserByEmail returns an existing user for email, creates a new user
// when registration is allowed, or reports a blocked outcome.
func (d *StubDirectory) EnsureUserByEmail(ctx context.Context, email common.Email) (ports.EnsureUserResult, error) {
if err := validateContext(ctx, "ensure user by email"); err != nil {
return ports.EnsureUserResult{}, err
}
if err := email.Validate(); err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
d.ensureMapsLocked()
stored, ok := d.byEmail[email]
if ok {
if !stored.blockReasonCode.IsZero() {
result := ports.EnsureUserResult{
Outcome: ports.EnsureUserOutcomeBlocked,
BlockReasonCode: stored.blockReasonCode,
}
if err := result.Validate(); err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
return result, nil
}
result := ports.EnsureUserResult{
Outcome: ports.EnsureUserOutcomeExisting,
UserID: stored.userID,
}
if err := result.Validate(); err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
return result, nil
}
userID, err := d.nextCreatedUserIDLocked()
if err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
d.byEmail[email] = entry{userID: userID}
d.emailByUserID[userID] = email
result := ports.EnsureUserResult{
Outcome: ports.EnsureUserOutcomeCreated,
UserID: userID,
}
if err := result.Validate(); err != nil {
return ports.EnsureUserResult{}, fmt.Errorf("ensure user by email: %w", err)
}
return result, nil
}
// BlockByUserID applies a block state to the user identified by input.UserID.
// Unknown user ids wrap ports.ErrNotFound.
func (d *StubDirectory) BlockByUserID(ctx context.Context, input ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
if err := validateContext(ctx, "block by user id"); err != nil {
return ports.BlockUserResult{}, err
}
if err := input.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
email, ok := d.emailByUserID[input.UserID]
if !ok {
return ports.BlockUserResult{}, fmt.Errorf("block by user id %q: %w", input.UserID, ports.ErrNotFound)
}
stored := d.byEmail[email]
if !stored.blockReasonCode.IsZero() {
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
UserID: input.UserID,
}
if err := result.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
}
return result, nil
}
stored.blockReasonCode = input.ReasonCode
d.byEmail[email] = stored
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeBlocked,
UserID: input.UserID,
}
if err := result.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by user id: %w", err)
}
return result, nil
}
// BlockByEmail applies a block state to input.Email even when no user record
// currently exists for that e-mail address.
func (d *StubDirectory) BlockByEmail(ctx context.Context, input ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
if err := validateContext(ctx, "block by email"); err != nil {
return ports.BlockUserResult{}, err
}
if err := input.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
d.ensureMapsLocked()
stored := d.byEmail[input.Email]
if !stored.blockReasonCode.IsZero() {
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeAlreadyBlocked,
UserID: stored.userID,
}
if err := result.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
}
return result, nil
}
stored.blockReasonCode = input.ReasonCode
d.byEmail[input.Email] = stored
if !stored.userID.IsZero() {
d.emailByUserID[stored.userID] = input.Email
}
result := ports.BlockUserResult{
Outcome: ports.BlockUserOutcomeBlocked,
UserID: stored.userID,
}
if err := result.Validate(); err != nil {
return ports.BlockUserResult{}, fmt.Errorf("block by email: %w", err)
}
return result, nil
}
// SeedExisting preloads one existing unblocked user record into the runtime
// stub.
func (d *StubDirectory) SeedExisting(email common.Email, userID common.UserID) error {
if err := email.Validate(); err != nil {
return fmt.Errorf("seed existing email: %w", err)
}
if err := userID.Validate(); err != nil {
return fmt.Errorf("seed existing user id: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
d.ensureMapsLocked()
d.byEmail[email] = entry{userID: userID}
d.emailByUserID[userID] = email
return nil
}
// SeedBlockedEmail preloads one blocked e-mail address that does not
// necessarily belong to an existing user record.
func (d *StubDirectory) SeedBlockedEmail(email common.Email, reasonCode userresolution.BlockReasonCode) error {
if err := email.Validate(); err != nil {
return fmt.Errorf("seed blocked email: %w", err)
}
if err := reasonCode.Validate(); err != nil {
return fmt.Errorf("seed blocked email reason code: %w", err)
}
d.mu.Lock()
defer d.mu.Unlock()
d.ensureMapsLocked()
d.byEmail[email] = entry{blockReasonCode: reasonCode}
return nil
}
// SeedBlockedUser preloads one blocked existing user record into the runtime
// stub.
func (d *StubDirectory) SeedBlockedUser(email common.Email, userID common.UserID, reasonCode userresolution.BlockReasonCode) error {
if err := d.SeedExisting(email, userID); err != nil {
return err
}
d.mu.Lock()
defer d.mu.Unlock()
stored := d.byEmail[email]
stored.blockReasonCode = reasonCode
d.byEmail[email] = stored
return nil
}
// QueueCreatedUserIDs appends deterministic user identifiers that
// EnsureUserByEmail consumes before falling back to generated ids.
func (d *StubDirectory) QueueCreatedUserIDs(userIDs ...common.UserID) error {
for index, userID := range userIDs {
if err := userID.Validate(); err != nil {
return fmt.Errorf("queue created user id %d: %w", index, err)
}
}
d.mu.Lock()
defer d.mu.Unlock()
d.createdUserIDs = append(d.createdUserIDs, userIDs...)
return nil
}
func (d *StubDirectory) ensureMapsLocked() {
if d.byEmail == nil {
d.byEmail = make(map[common.Email]entry)
}
if d.emailByUserID == nil {
d.emailByUserID = make(map[common.UserID]common.Email)
}
}
func (d *StubDirectory) resolveLocked(email common.Email) (userresolution.Result, error) {
stored, ok := d.byEmail[email]
if !ok {
result := userresolution.Result{Kind: userresolution.KindCreatable}
if err := result.Validate(); err != nil {
return userresolution.Result{}, err
}
return result, nil
}
if !stored.blockReasonCode.IsZero() {
result := userresolution.Result{
Kind: userresolution.KindBlocked,
BlockReasonCode: stored.blockReasonCode,
}
if err := result.Validate(); err != nil {
return userresolution.Result{}, err
}
return result, nil
}
result := userresolution.Result{
Kind: userresolution.KindExisting,
UserID: stored.userID,
}
if err := result.Validate(); err != nil {
return userresolution.Result{}, err
}
return result, nil
}
func (d *StubDirectory) nextCreatedUserIDLocked() (common.UserID, error) {
if len(d.createdUserIDs) > 0 {
userID := d.createdUserIDs[0]
d.createdUserIDs = d.createdUserIDs[1:]
return userID, nil
}
d.nextUserNumber++
userID := common.UserID(fmt.Sprintf("user-%d", d.nextUserNumber))
if err := userID.Validate(); err != nil {
return "", err
}
return userID, nil
}
func validateContext(ctx context.Context, operation string) error {
if ctx == nil {
return fmt.Errorf("%s: nil context", operation)
}
if err := ctx.Err(); err != nil {
return fmt.Errorf("%s: %w", operation, err)
}
return nil
}
var _ ports.UserDirectory = (*StubDirectory)(nil)
@@ -0,0 +1,329 @@
package userservice
import (
"context"
"errors"
"testing"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStubDirectoryResolveByEmail(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")))
tests := []struct {
name string
email common.Email
wantKind userresolution.Kind
wantUserID common.UserID
wantReasonCode userresolution.BlockReasonCode
}{
{
name: "zero value unknown email is creatable",
email: common.Email("new@example.com"),
wantKind: userresolution.KindCreatable,
},
{
name: "existing email",
email: common.Email("existing@example.com"),
wantKind: userresolution.KindExisting,
wantUserID: common.UserID("user-existing"),
},
{
name: "blocked email",
email: common.Email("blocked@example.com"),
wantKind: userresolution.KindBlocked,
wantReasonCode: userresolution.BlockReasonCode("policy_block"),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := directory.ResolveByEmail(context.Background(), tt.email)
require.NoError(t, err)
assert.Equal(t, tt.wantKind, result.Kind)
assert.Equal(t, tt.wantUserID, result.UserID)
assert.Equal(t, tt.wantReasonCode, result.BlockReasonCode)
})
}
}
func TestStubDirectoryEnsureUserByEmail(t *testing.T) {
t.Parallel()
t.Run("existing", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("existing@example.com"))
require.NoError(t, err)
assert.Equal(t, ports.EnsureUserOutcomeExisting, result.Outcome)
assert.Equal(t, common.UserID("user-existing"), result.UserID)
})
t.Run("blocked", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block")))
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("blocked@example.com"))
require.NoError(t, err)
assert.Equal(t, ports.EnsureUserOutcomeBlocked, result.Outcome)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), result.BlockReasonCode)
})
t.Run("created queued then existing", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.QueueCreatedUserIDs(common.UserID("user-created")))
first, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
require.NoError(t, err)
assert.Equal(t, ports.EnsureUserOutcomeCreated, first.Outcome)
assert.Equal(t, common.UserID("user-created"), first.UserID)
second, err := directory.EnsureUserByEmail(context.Background(), common.Email("created@example.com"))
require.NoError(t, err)
assert.Equal(t, ports.EnsureUserOutcomeExisting, second.Outcome)
assert.Equal(t, common.UserID("user-created"), second.UserID)
})
t.Run("created fallback id", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
result, err := directory.EnsureUserByEmail(context.Background(), common.Email("fallback@example.com"))
require.NoError(t, err)
assert.Equal(t, ports.EnsureUserOutcomeCreated, result.Outcome)
assert.Equal(t, common.UserID("user-1"), result.UserID)
})
}
func TestStubDirectoryExistsByUserID(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.SeedExisting(common.Email("existing@example.com"), common.UserID("user-existing")))
exists, err := directory.ExistsByUserID(context.Background(), common.UserID("user-existing"))
require.NoError(t, err)
assert.True(t, exists)
exists, err = directory.ExistsByUserID(context.Background(), common.UserID("missing"))
require.NoError(t, err)
assert.False(t, exists)
}
func TestStubDirectoryBlockByEmail(t *testing.T) {
t.Parallel()
t.Run("unknown email becomes blocked without user id", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("blocked@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserOutcomeBlocked, result.Outcome)
assert.True(t, result.UserID.IsZero())
resolution, err := directory.ResolveByEmail(context.Background(), common.Email("blocked@example.com"))
require.NoError(t, err)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
})
t.Run("existing user preserves linked user id and repeat is already blocked", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
first, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("pilot@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome)
assert.Equal(t, common.UserID("user-1"), first.UserID)
second, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("pilot@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome)
assert.Equal(t, common.UserID("user-1"), second.UserID)
})
}
func TestStubDirectoryBlockByUserID(t *testing.T) {
t.Parallel()
t.Run("unknown user wraps ErrNotFound", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
_, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("missing"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.Error(t, err)
assert.ErrorIs(t, err, ports.ErrNotFound)
})
t.Run("existing user blocks then returns already blocked", func(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
first, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("user-1"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserOutcomeBlocked, first.Outcome)
assert.Equal(t, common.UserID("user-1"), first.UserID)
second, err := directory.BlockByUserID(context.Background(), ports.BlockUserByIDInput{
UserID: common.UserID("user-1"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, second.Outcome)
assert.Equal(t, common.UserID("user-1"), second.UserID)
})
}
func TestStubDirectoryContextAndValidation(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
tests := []struct {
name string
run func() error
want string
}{
{
name: "resolve nil context",
run: func() error {
_, err := directory.ResolveByEmail(nil, common.Email("pilot@example.com"))
return err
},
want: "nil context",
},
{
name: "ensure cancelled context",
run: func() error {
_, err := directory.EnsureUserByEmail(cancelledCtx, common.Email("pilot@example.com"))
return err
},
want: context.Canceled.Error(),
},
{
name: "exists invalid user id",
run: func() error {
_, err := directory.ExistsByUserID(context.Background(), common.UserID(" bad "))
return err
},
want: "exists by user id",
},
{
name: "block by email invalid email",
run: func() error {
_, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("bad"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
return err
},
want: "block by email",
},
{
name: "seed invalid user id",
run: func() error {
return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID(" bad "))
},
want: "seed existing user id",
},
{
name: "queue invalid created user id",
run: func() error {
return directory.QueueCreatedUserIDs(common.UserID(" bad "))
},
want: "queue created user id 0",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.run()
require.Error(t, err)
assert.ErrorContains(t, err, tt.want)
})
}
}
func TestStubDirectorySeedBlockedUser(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
require.NoError(t, directory.SeedBlockedUser(
common.Email("pilot@example.com"),
common.UserID("user-1"),
userresolution.BlockReasonCode("policy_block"),
))
result, err := directory.BlockByEmail(context.Background(), ports.BlockUserByEmailInput{
Email: common.Email("pilot@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.NoError(t, err)
assert.Equal(t, ports.BlockUserOutcomeAlreadyBlocked, result.Outcome)
assert.Equal(t, common.UserID("user-1"), result.UserID)
}
func TestStubDirectoryCancelledContextWrapsContextError(t *testing.T) {
t.Parallel()
directory := &StubDirectory{}
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
_, err := directory.BlockByUserID(cancelledCtx, ports.BlockUserByIDInput{
UserID: common.UserID("user-1"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
})
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.ErrorContains(t, err, "block by user id")
}
@@ -0,0 +1,3 @@
// Package internalhttp exposes the trusted internal HTTP API used for session
// read, revoke, and block operations.
package internalhttp
@@ -0,0 +1,286 @@
package internalhttp
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/blockuser"
"galaxy/authsession/internal/service/getsession"
"galaxy/authsession/internal/service/listusersessions"
"galaxy/authsession/internal/service/revokeallusersessions"
"galaxy/authsession/internal/service/revokedevicesession"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInternalHTTPEndToEndGetSession(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
server := httptest.NewServer(app.handler)
defer server.Close()
response := getJSON(t, server.URL+"/api/v1/internal/sessions/device-session-1")
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"session":{"device_session_id":"device-session-1","user_id":"user-1","client_public_key":"`+validClientPublicKey+`","status":"active","created_at":"2026-04-05T12:00:00Z"}}`, response.Body)
}
func TestInternalHTTPEndToEndListUserSessions(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
key := testClientPublicKey(t, validClientPublicKey)
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", key, time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-2", "user-1", key, time.Date(2026, 4, 5, 12, 1, 0, 0, time.UTC))))
server := httptest.NewServer(app.handler)
defer server.Close()
response := getJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions")
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.Contains(t, response.Body, `"device_session_id":"device-session-2"`)
assert.Contains(t, response.Body, `"device_session_id":"device-session-1"`)
assert.Less(t, bytes.Index([]byte(response.Body), []byte(`"device_session_id":"device-session-2"`)), bytes.Index([]byte(response.Body), []byte(`"device_session_id":"device-session-1"`)))
}
func TestInternalHTTPEndToEndListUserSessionsUnknownUserReturnsEmptyArray(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
server := httptest.NewServer(app.handler)
defer server.Close()
response := getJSON(t, server.URL+"/api/v1/internal/users/unknown-user/sessions")
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"sessions":[]}`, response.Body)
}
func TestInternalHTTPEndToEndGetSessionNotFound(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
server := httptest.NewServer(app.handler)
defer server.Close()
response := getJSON(t, server.URL+"/api/v1/internal/sessions/missing-session")
assert.Equal(t, http.StatusNotFound, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"session_not_found","message":"session not found"}}`, response.Body)
}
func TestInternalHTTPEndToEndRevokeDeviceSession(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/internal/sessions/device-session-1/revoke", `{"reason_code":"admin_revoke","actor":{"type":"system"}}`)
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"outcome":"revoked","device_session_id":"device-session-1","affected_session_count":1}`, response.Body)
}
func TestInternalHTTPEndToEndRevokeAllUserSessions(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
key := testClientPublicKey(t, validClientPublicKey)
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", key, time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-2", "user-1", key, time.Date(2026, 4, 5, 12, 1, 0, 0, time.UTC))))
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`)
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"outcome":"revoked","user_id":"user-1","affected_session_count":2,"affected_device_session_ids":["device-session-2","device-session-1"]}`, response.Body)
}
func TestInternalHTTPEndToEndRevokeAllUserSessionsNoActiveSessions(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/internal/users/user-1/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`)
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-1","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body)
}
func TestInternalHTTPEndToEndRevokeAllUserSessionsUnknownUser(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/internal/users/missing-user/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`)
assert.Equal(t, http.StatusNotFound, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"subject_not_found","message":"subject not found"}}`, response.Body)
}
func TestInternalHTTPEndToEndBlockUserByEmail(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`)
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":0,"affected_device_session_ids":[]}`, response.Body)
}
func TestInternalHTTPEndToEndBlockUserByUserID(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
require.NoError(t, app.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, app.sessionStore.Create(context.Background(), activeSession("device-session-1", "user-1", testClientPublicKey(t, validClientPublicKey), time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC))))
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"user_id":"user-1","reason_code":"policy_blocked","actor":{"type":"admin"}}`)
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"user_id","subject_value":"user-1","affected_session_count":1,"affected_device_session_ids":["device-session-1"]}`, response.Body)
}
func TestInternalHTTPEndToEndBlockUserUnknownUserID(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t)
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/internal/user-blocks", `{"user_id":"missing-user","reason_code":"policy_blocked","actor":{"type":"admin"}}`)
assert.Equal(t, http.StatusNotFound, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"subject_not_found","message":"subject not found"}}`, response.Body)
}
type endToEndApp struct {
handler http.Handler
sessionStore *testkit.InMemorySessionStore
userDirectory *userservice.StubDirectory
}
func newEndToEndApp(t *testing.T) endToEndApp {
t.Helper()
sessionStore := &testkit.InMemorySessionStore{}
userDirectory := &userservice.StubDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
clock := testkit.FixedClock{Time: time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)}
getSessionService, err := getsession.New(sessionStore)
require.NoError(t, err)
listUserSessionsService, err := listusersessions.New(sessionStore)
require.NoError(t, err)
revokeDeviceSessionService, err := revokedevicesession.New(sessionStore, publisher, clock)
require.NoError(t, err)
revokeAllUserSessionsService, err := revokeallusersessions.New(sessionStore, userDirectory, publisher, clock)
require.NoError(t, err)
blockUserService, err := blockuser.New(userDirectory, sessionStore, publisher, clock)
require.NoError(t, err)
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionService,
ListUserSessions: listUserSessionsService,
RevokeDeviceSession: revokeDeviceSessionService,
RevokeAllUserSessions: revokeAllUserSessionsService,
BlockUser: blockUserService,
})
return endToEndApp{
handler: handler,
sessionStore: sessionStore,
userDirectory: userDirectory,
}
}
type httpResponse struct {
StatusCode int
Body string
}
func getJSON(t *testing.T, url string) httpResponse {
t.Helper()
response, err := http.Get(url)
require.NoError(t, err)
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
require.NoError(t, err)
return httpResponse{StatusCode: response.StatusCode, Body: string(payload)}
}
func postJSON(t *testing.T, url string, body string) httpResponse {
t.Helper()
response, err := http.Post(url, "application/json", bytes.NewBufferString(body))
require.NoError(t, err)
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
require.NoError(t, err)
return httpResponse{StatusCode: response.StatusCode, Body: string(payload)}
}
func postJSONValue(t *testing.T, url string, value any) httpResponse {
t.Helper()
body, err := json.Marshal(value)
require.NoError(t, err)
return postJSON(t, url, string(body))
}
func activeSession(id string, userID string, key common.ClientPublicKey, createdAt time.Time) devicesession.Session {
return devicesession.Session{
ID: common.DeviceSessionID(id),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func testClientPublicKey(t *testing.T, encoded string) common.ClientPublicKey {
t.Helper()
decoded, err := base64.StdEncoding.DecodeString(encoded)
require.NoError(t, err)
key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded))
require.NoError(t, err)
return key
}
const validClientPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
@@ -0,0 +1,513 @@
package internalhttp
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"galaxy/authsession/internal/service/blockuser"
"galaxy/authsession/internal/service/getsession"
"galaxy/authsession/internal/service/listusersessions"
"galaxy/authsession/internal/service/revokeallusersessions"
"galaxy/authsession/internal/service/revokedevicesession"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
)
const jsonContentType = "application/json; charset=utf-8"
const internalHTTPServiceName = "galaxy-authsession-internal"
type errorResponse struct {
Error errorBody `json:"error"`
}
type errorBody struct {
Code string `json:"code"`
Message string `json:"message"`
}
type actorRequest struct {
Type string `json:"type"`
ID string `json:"id,omitempty"`
}
type sessionResponseDTO struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKey string `json:"client_public_key"`
Status string `json:"status"`
CreatedAt string `json:"created_at"`
RevokedAt *string `json:"revoked_at,omitempty"`
RevokeReasonCode *string `json:"revoke_reason_code,omitempty"`
RevokeActorType *string `json:"revoke_actor_type,omitempty"`
RevokeActorID *string `json:"revoke_actor_id,omitempty"`
}
type getSessionResponse struct {
Session sessionResponseDTO `json:"session"`
}
type listUserSessionsResponse struct {
Sessions []sessionResponseDTO `json:"sessions"`
}
type revokeDeviceSessionRequest struct {
ReasonCode string `json:"reason_code"`
Actor actorRequest `json:"actor"`
}
type revokeDeviceSessionResponse struct {
Outcome string `json:"outcome"`
DeviceSessionID string `json:"device_session_id"`
AffectedSessionCount int64 `json:"affected_session_count"`
}
type revokeAllUserSessionsRequest struct {
ReasonCode string `json:"reason_code"`
Actor actorRequest `json:"actor"`
}
type revokeAllUserSessionsResponse struct {
Outcome string `json:"outcome"`
UserID string `json:"user_id"`
AffectedSessionCount int64 `json:"affected_session_count"`
AffectedDeviceSessionIDs []string `json:"affected_device_session_ids"`
}
type blockUserRequest struct {
UserID string `json:"user_id,omitempty"`
Email string `json:"email,omitempty"`
ReasonCode string `json:"reason_code"`
Actor actorRequest `json:"actor"`
}
type blockUserResponse struct {
Outcome string `json:"outcome"`
SubjectKind string `json:"subject_kind"`
SubjectValue string `json:"subject_value"`
AffectedSessionCount int64 `json:"affected_session_count"`
AffectedDeviceSessionIDs []string `json:"affected_device_session_ids"`
}
var configureGinModeOnce sync.Once
func newHandlerWithConfig(cfg Config, deps Dependencies) (http.Handler, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
normalizedDeps, err := normalizeDependencies(deps)
if err != nil {
return nil, err
}
configureGinModeOnce.Do(func() {
gin.SetMode(gin.ReleaseMode)
})
engine := gin.New()
engine.Use(newOTelMiddleware(normalizedDeps.Telemetry))
engine.Use(withInternalObservability(normalizedDeps.Logger, normalizedDeps.Telemetry))
engine.GET("/api/v1/internal/sessions/:device_session_id", handleGetSession(normalizedDeps.GetSession, cfg.RequestTimeout))
engine.GET("/api/v1/internal/users/:user_id/sessions", handleListUserSessions(normalizedDeps.ListUserSessions, cfg.RequestTimeout))
engine.POST("/api/v1/internal/sessions/:device_session_id/revoke", handleRevokeDeviceSession(normalizedDeps.RevokeDeviceSession, cfg.RequestTimeout))
engine.POST("/api/v1/internal/users/:user_id/sessions/revoke-all", handleRevokeAllUserSessions(normalizedDeps.RevokeAllUserSessions, cfg.RequestTimeout))
engine.POST("/api/v1/internal/user-blocks", handleBlockUser(normalizedDeps.BlockUser, cfg.RequestTimeout))
return engine, nil
}
func newOTelMiddleware(runtime *telemetry.Runtime) gin.HandlerFunc {
options := []otelgin.Option{}
if runtime != nil {
options = append(
options,
otelgin.WithTracerProvider(runtime.TracerProvider()),
otelgin.WithMeterProvider(runtime.MeterProvider()),
)
}
return otelgin.Middleware(internalHTTPServiceName, options...)
}
func handleGetSession(useCase GetSessionUseCase, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := useCase.Execute(callCtx, getsession.Input{
DeviceSessionID: c.Param("device_session_id"),
})
if err != nil {
abortWithProjection(c, projectInternalError(err))
return
}
if err := validateGetSessionResult(&result); err != nil {
abortWithProjection(c, internalErrorProjection(fmt.Errorf("get session response: %w", err)))
return
}
c.JSON(http.StatusOK, getSessionResponse{Session: toSessionResponseDTO(result.Session)})
}
}
func handleListUserSessions(useCase ListUserSessionsUseCase, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := useCase.Execute(callCtx, listusersessions.Input{
UserID: c.Param("user_id"),
})
if err != nil {
abortWithProjection(c, projectInternalError(err))
return
}
if err := validateListUserSessionsResult(&result); err != nil {
abortWithProjection(c, internalErrorProjection(fmt.Errorf("list user sessions response: %w", err)))
return
}
c.JSON(http.StatusOK, listUserSessionsResponse{Sessions: toSessionResponseDTOs(result.Sessions)})
}
}
func handleRevokeDeviceSession(useCase RevokeDeviceSessionUseCase, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var request revokeDeviceSessionRequest
if err := decodeJSONRequest(c.Request, &request); err != nil {
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
return
}
if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil {
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := useCase.Execute(callCtx, revokedevicesession.Input{
DeviceSessionID: c.Param("device_session_id"),
ReasonCode: request.ReasonCode,
ActorType: request.Actor.Type,
ActorID: request.Actor.ID,
})
if err != nil {
abortWithProjection(c, projectInternalError(err))
return
}
if err := validateRevokeDeviceSessionResult(&result); err != nil {
abortWithProjection(c, internalErrorProjection(fmt.Errorf("revoke device session response: %w", err)))
return
}
c.JSON(http.StatusOK, revokeDeviceSessionResponse{
Outcome: result.Outcome,
DeviceSessionID: result.DeviceSessionID,
AffectedSessionCount: result.AffectedSessionCount,
})
}
}
func handleRevokeAllUserSessions(useCase RevokeAllUserSessionsUseCase, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var request revokeAllUserSessionsRequest
if err := decodeJSONRequest(c.Request, &request); err != nil {
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
return
}
if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil {
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := useCase.Execute(callCtx, revokeallusersessions.Input{
UserID: c.Param("user_id"),
ReasonCode: request.ReasonCode,
ActorType: request.Actor.Type,
ActorID: request.Actor.ID,
})
if err != nil {
abortWithProjection(c, projectInternalError(err))
return
}
if err := validateRevokeAllUserSessionsResult(&result); err != nil {
abortWithProjection(c, internalErrorProjection(fmt.Errorf("revoke all user sessions response: %w", err)))
return
}
c.JSON(http.StatusOK, revokeAllUserSessionsResponse{
Outcome: result.Outcome,
UserID: result.UserID,
AffectedSessionCount: result.AffectedSessionCount,
AffectedDeviceSessionIDs: cloneStrings(result.AffectedDeviceSessionIDs),
})
}
}
func handleBlockUser(useCase BlockUserUseCase, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var request blockUserRequest
if err := decodeJSONRequest(c.Request, &request); err != nil {
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
return
}
if err := validateBlockUserRequest(&request); err != nil {
abortWithProjection(c, projectInternalError(shared.InvalidRequest(err.Error())))
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := useCase.Execute(callCtx, blockuser.Input{
UserID: request.UserID,
Email: request.Email,
ReasonCode: request.ReasonCode,
ActorType: request.Actor.Type,
ActorID: request.Actor.ID,
})
if err != nil {
abortWithProjection(c, projectInternalError(err))
return
}
if err := validateBlockUserResult(&result); err != nil {
abortWithProjection(c, internalErrorProjection(fmt.Errorf("block user response: %w", err)))
return
}
c.JSON(http.StatusOK, blockUserResponse{
Outcome: result.Outcome,
SubjectKind: result.SubjectKind,
SubjectValue: result.SubjectValue,
AffectedSessionCount: result.AffectedSessionCount,
AffectedDeviceSessionIDs: cloneStrings(result.AffectedDeviceSessionIDs),
})
}
}
func toSessionResponseDTO(session shared.Session) sessionResponseDTO {
return sessionResponseDTO{
DeviceSessionID: session.DeviceSessionID,
UserID: session.UserID,
ClientPublicKey: session.ClientPublicKey,
Status: session.Status,
CreatedAt: session.CreatedAt,
RevokedAt: cloneStringPointer(session.RevokedAt),
RevokeReasonCode: cloneStringPointer(session.RevokeReasonCode),
RevokeActorType: cloneStringPointer(session.RevokeActorType),
RevokeActorID: cloneStringPointer(session.RevokeActorID),
}
}
func toSessionResponseDTOs(sessions []shared.Session) []sessionResponseDTO {
result := make([]sessionResponseDTO, 0, len(sessions))
for _, session := range sessions {
result = append(result, toSessionResponseDTO(session))
}
return result
}
func cloneStrings(values []string) []string {
result := make([]string, 0, len(values))
return append(result, values...)
}
func cloneStringPointer(value *string) *string {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
func validateAuditRequest(reasonCode string, actor actorRequest) error {
if strings.TrimSpace(reasonCode) == "" {
return errors.New("reason_code must not be empty")
}
if strings.TrimSpace(actor.Type) == "" {
return errors.New("actor.type must not be empty")
}
return nil
}
func validateBlockUserRequest(request *blockUserRequest) error {
if err := validateAuditRequest(request.ReasonCode, request.Actor); err != nil {
return err
}
hasUserID := strings.TrimSpace(request.UserID) != ""
hasEmail := strings.TrimSpace(request.Email) != ""
switch {
case hasUserID && hasEmail:
return errors.New("exactly one of user_id or email must be provided")
case !hasUserID && !hasEmail:
return errors.New("exactly one of user_id or email must be provided")
default:
return nil
}
}
func validateSessionDTO(session *shared.Session) error {
switch {
case strings.TrimSpace(session.DeviceSessionID) == "":
return errors.New("session.device_session_id must not be empty")
case strings.TrimSpace(session.UserID) == "":
return errors.New("session.user_id must not be empty")
case strings.TrimSpace(session.ClientPublicKey) == "":
return errors.New("session.client_public_key must not be empty")
case strings.TrimSpace(session.CreatedAt) == "":
return errors.New("session.created_at must not be empty")
}
if _, err := time.Parse(time.RFC3339, session.CreatedAt); err != nil {
return fmt.Errorf("session.created_at: %w", err)
}
switch session.Status {
case "active":
if session.RevokedAt != nil || session.RevokeReasonCode != nil || session.RevokeActorType != nil || session.RevokeActorID != nil {
return errors.New("active session must not contain revoke metadata")
}
case "revoked":
switch {
case session.RevokedAt == nil || strings.TrimSpace(*session.RevokedAt) == "":
return errors.New("revoked session must contain revoked_at")
case session.RevokeReasonCode == nil || strings.TrimSpace(*session.RevokeReasonCode) == "":
return errors.New("revoked session must contain revoke_reason_code")
case session.RevokeActorType == nil || strings.TrimSpace(*session.RevokeActorType) == "":
return errors.New("revoked session must contain revoke_actor_type")
}
if _, err := time.Parse(time.RFC3339, *session.RevokedAt); err != nil {
return fmt.Errorf("session.revoked_at: %w", err)
}
default:
return fmt.Errorf("session.status %q is unsupported", session.Status)
}
return nil
}
func validateGetSessionResult(result *getsession.Result) error {
return validateSessionDTO(&result.Session)
}
func validateListUserSessionsResult(result *listusersessions.Result) error {
if result.Sessions == nil {
return errors.New("sessions must not be null")
}
for index := range result.Sessions {
if err := validateSessionDTO(&result.Sessions[index]); err != nil {
return fmt.Errorf("sessions[%d]: %w", index, err)
}
}
return nil
}
func validateRevokeDeviceSessionResult(result *revokedevicesession.Result) error {
switch result.Outcome {
case "revoked":
if result.AffectedSessionCount != 1 {
return errors.New("revoked outcome must affect exactly one session")
}
case "already_revoked":
if result.AffectedSessionCount != 0 {
return errors.New("already_revoked outcome must affect zero sessions")
}
default:
return fmt.Errorf("revoke device session outcome %q is unsupported", result.Outcome)
}
if strings.TrimSpace(result.DeviceSessionID) == "" {
return errors.New("device_session_id must not be empty")
}
return nil
}
func validateRevokeAllUserSessionsResult(result *revokeallusersessions.Result) error {
switch result.Outcome {
case "revoked", "no_active_sessions":
default:
return fmt.Errorf("revoke all user sessions outcome %q is unsupported", result.Outcome)
}
if strings.TrimSpace(result.UserID) == "" {
return errors.New("user_id must not be empty")
}
if result.AffectedSessionCount < 0 {
return errors.New("affected_session_count must not be negative")
}
if result.AffectedDeviceSessionIDs == nil {
return errors.New("affected_device_session_ids must not be null")
}
if int64(len(result.AffectedDeviceSessionIDs)) != result.AffectedSessionCount {
return errors.New("affected_device_session_ids length must match affected_session_count")
}
for index, deviceSessionID := range result.AffectedDeviceSessionIDs {
if strings.TrimSpace(deviceSessionID) == "" {
return fmt.Errorf("affected_device_session_ids[%d] must not be empty", index)
}
}
return nil
}
func validateBlockUserResult(result *blockuser.Result) error {
switch result.Outcome {
case "blocked", "already_blocked":
default:
return fmt.Errorf("block user outcome %q is unsupported", result.Outcome)
}
switch result.SubjectKind {
case blockuser.SubjectKindUserID, blockuser.SubjectKindEmail:
default:
return fmt.Errorf("subject_kind %q is unsupported", result.SubjectKind)
}
if strings.TrimSpace(result.SubjectValue) == "" {
return errors.New("subject_value must not be empty")
}
if result.AffectedSessionCount < 0 {
return errors.New("affected_session_count must not be negative")
}
if result.AffectedDeviceSessionIDs == nil {
return errors.New("affected_device_session_ids must not be null")
}
if int64(len(result.AffectedDeviceSessionIDs)) != result.AffectedSessionCount {
return errors.New("affected_device_session_ids length must match affected_session_count")
}
for index, deviceSessionID := range result.AffectedDeviceSessionIDs {
if strings.TrimSpace(deviceSessionID) == "" {
return fmt.Errorf("affected_device_session_ids[%d] must not be empty", index)
}
}
return nil
}
func projectInternalError(err error) shared.InternalErrorProjection {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return shared.ProjectInternalError(shared.ServiceUnavailable(err))
}
return shared.ProjectInternalError(err)
}
func internalErrorProjection(err error) shared.InternalErrorProjection {
return shared.ProjectInternalError(shared.InternalError(err))
}
@@ -0,0 +1,784 @@
package internalhttp
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"galaxy/authsession/internal/service/blockuser"
"galaxy/authsession/internal/service/getsession"
"galaxy/authsession/internal/service/listusersessions"
"galaxy/authsession/internal/service/revokeallusersessions"
"galaxy/authsession/internal/service/revokedevicesession"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestGetSessionHandlerSuccess(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(func(_ context.Context, input getsession.Input) (getsession.Result, error) {
assert.Equal(t, getsession.Input{DeviceSessionID: "device-session-123"}, input)
return getsession.Result{
Session: validSessionDTO(),
}, nil
}),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil)
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"session":{"device_session_id":"device-session-123","user_id":"user-123","client_public_key":"public-key-material","status":"active","created_at":"2026-04-05T12:00:00Z"}}`, recorder.Body.String())
}
func TestListUserSessionsHandlerSuccess(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(func(_ context.Context, input listusersessions.Input) (listusersessions.Result, error) {
assert.Equal(t, listusersessions.Input{UserID: "user-123"}, input)
first := validSessionDTO()
second := validRevokedSessionDTO()
second.DeviceSessionID = "device-session-122"
return listusersessions.Result{Sessions: []shared.Session{first, second}}, nil
}),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/users/user-123/sessions", nil)
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Contains(t, recorder.Body.String(), `"sessions":[`)
assert.Contains(t, recorder.Body.String(), `"device_session_id":"device-session-123"`)
assert.Contains(t, recorder.Body.String(), `"device_session_id":"device-session-122"`)
}
func TestListUserSessionsHandlerUnknownUserReturnsEmptyArray(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(func(_ context.Context, input listusersessions.Input) (listusersessions.Result, error) {
assert.Equal(t, listusersessions.Input{UserID: "unknown-user"}, input)
return listusersessions.Result{Sessions: []shared.Session{}}, nil
}),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/users/unknown-user/sessions", nil)
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"sessions":[]}`, recorder.Body.String())
}
func TestRevokeDeviceSessionHandlerAlreadyRevoked(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(func(_ context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) {
assert.Equal(t, revokedevicesession.Input{
DeviceSessionID: "device-session-123",
ReasonCode: "admin_revoke",
ActorType: "system",
}, input)
return revokedevicesession.Result{
Outcome: "already_revoked",
DeviceSessionID: "device-session-123",
AffectedSessionCount: 0,
}, nil
}),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/internal/sessions/device-session-123/revoke",
bytes.NewBufferString(`{"reason_code":"admin_revoke","actor":{"type":"system"}}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"outcome":"already_revoked","device_session_id":"device-session-123","affected_session_count":0}`, recorder.Body.String())
}
func TestRevokeAllUserSessionsHandlerNoActiveSessions(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(_ context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) {
assert.Equal(t, revokeallusersessions.Input{
UserID: "user-123",
ReasonCode: "logout_all",
ActorType: "system",
}, input)
return revokeallusersessions.Result{
Outcome: "no_active_sessions",
UserID: "user-123",
AffectedSessionCount: 0,
AffectedDeviceSessionIDs: []string{},
}, nil
}),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/internal/users/user-123/sessions/revoke-all",
bytes.NewBufferString(`{"reason_code":"logout_all","actor":{"type":"system"}}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"outcome":"no_active_sessions","user_id":"user-123","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String())
}
func TestBlockUserHandlerSuccessByEmail(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(func(_ context.Context, input blockuser.Input) (blockuser.Result, error) {
assert.Equal(t, blockuser.Input{
Email: "pilot@example.com",
ReasonCode: "policy_blocked",
ActorType: "admin",
}, input)
return blockuser.Result{
Outcome: "blocked",
SubjectKind: blockuser.SubjectKindEmail,
SubjectValue: "pilot@example.com",
AffectedSessionCount: 0,
AffectedDeviceSessionIDs: []string{},
}, nil
}),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/internal/user-blocks",
bytes.NewBufferString(`{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String())
}
func TestBlockUserHandlerSuccessByUserID(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(func(_ context.Context, input blockuser.Input) (blockuser.Result, error) {
assert.Equal(t, blockuser.Input{
UserID: "user-123",
ReasonCode: "policy_blocked",
ActorType: "admin",
}, input)
return blockuser.Result{
Outcome: "already_blocked",
SubjectKind: blockuser.SubjectKindUserID,
SubjectValue: "user-123",
AffectedSessionCount: 0,
AffectedDeviceSessionIDs: []string{},
}, nil
}),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/internal/user-blocks",
bytes.NewBufferString(`{"user_id":"user-123","reason_code":"policy_blocked","actor":{"type":"admin"}}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"outcome":"already_blocked","subject_kind":"user_id","subject_value":"user-123","affected_session_count":0,"affected_device_session_ids":[]}`, recorder.Body.String())
}
func TestInternalHandlersRejectInvalidPathParams(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
target string
body string
wantStatus int
wantBody string
}{
{
name: "get session empty device session id",
method: http.MethodGet,
target: "/api/v1/internal/sessions/%20",
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"device session id must not be empty"}}`,
},
{
name: "list sessions empty user id",
method: http.MethodGet,
target: "/api/v1/internal/users/%20/sessions",
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"user id must not be empty"}}`,
},
{
name: "revoke all empty user id",
method: http.MethodPost,
target: "/api/v1/internal/users/%20/sessions/revoke-all",
body: `{"reason_code":"logout_all","actor":{"type":"system"}}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"user id must not be empty"}}`,
},
}
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{}, shared.InvalidRequest("device session id must not be empty")
}),
ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) {
return listusersessions.Result{}, shared.InvalidRequest("user id must not be empty")
}),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
return revokeallusersessions.Result{}, shared.InvalidRequest("user id must not be empty")
}),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
if tt.body != "" {
request.Header.Set("Content-Type", "application/json")
}
handler.ServeHTTP(recorder, request)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestInternalMutationHandlersRejectInvalidRequests(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
target string
body string
wantStatus int
wantBody string
}{
{
name: "revoke device session empty body",
method: http.MethodPost,
target: "/api/v1/internal/sessions/device-session-123/revoke",
body: ``,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body must not be empty"}}`,
},
{
name: "revoke device session malformed json",
method: http.MethodPost,
target: "/api/v1/internal/sessions/device-session-123/revoke",
body: `{"reason_code":`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
},
{
name: "revoke device session multiple objects",
method: http.MethodPost,
target: "/api/v1/internal/sessions/device-session-123/revoke",
body: `{"reason_code":"admin_revoke","actor":{"type":"system"}}{"reason_code":"admin_revoke","actor":{"type":"system"}}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body must contain a single JSON object"}}`,
},
{
name: "revoke device session unknown field",
method: http.MethodPost,
target: "/api/v1/internal/sessions/device-session-123/revoke",
body: `{"reason_code":"admin_revoke","actor":{"type":"system"},"extra":true}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains unknown field \"extra\""}}`,
},
{
name: "revoke device session invalid json type",
method: http.MethodPost,
target: "/api/v1/internal/sessions/device-session-123/revoke",
body: `{"reason_code":123,"actor":{"type":"system"}}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains an invalid value for \"reason_code\""}}`,
},
{
name: "revoke all missing reason code",
method: http.MethodPost,
target: "/api/v1/internal/users/user-123/sessions/revoke-all",
body: `{"reason_code":" ","actor":{"type":"system"}}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"reason_code must not be empty"}}`,
},
{
name: "block user missing actor type",
method: http.MethodPost,
target: "/api/v1/internal/user-blocks",
body: `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":" "}}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"actor.type must not be empty"}}`,
},
{
name: "block user missing subject",
method: http.MethodPost,
target: "/api/v1/internal/user-blocks",
body: `{"reason_code":"policy_blocked","actor":{"type":"system"}}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"exactly one of user_id or email must be provided"}}`,
},
{
name: "block user conflicting subjects",
method: http.MethodPost,
target: "/api/v1/internal/user-blocks",
body: `{"user_id":"user-123","email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"system"}}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"exactly one of user_id or email must be provided"}}`,
},
}
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
if tt.body != "" {
request.Header.Set("Content-Type", "application/json")
}
handler.ServeHTTP(recorder, request)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestInternalHandlersMapServiceErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
target string
body string
deps Dependencies
wantStatus int
wantBody string
}{
{
name: "get session not found",
method: http.MethodGet,
target: "/api/v1/internal/sessions/missing",
deps: Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{}, shared.SessionNotFound()
}),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
},
wantStatus: http.StatusNotFound,
wantBody: `{"error":{"code":"session_not_found","message":"session not found"}}`,
},
{
name: "revoke all subject not found",
method: http.MethodPost,
target: "/api/v1/internal/users/missing/sessions/revoke-all",
body: `{"reason_code":"logout_all","actor":{"type":"system"}}`,
deps: Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
return revokeallusersessions.Result{}, shared.SubjectNotFound()
}),
BlockUser: blockUserFunc(unexpectedBlockUser),
},
wantStatus: http.StatusNotFound,
wantBody: `{"error":{"code":"subject_not_found","message":"subject not found"}}`,
},
{
name: "service unavailable",
method: http.MethodGet,
target: "/api/v1/internal/sessions/device-session-123",
deps: Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{}, shared.ServiceUnavailable(errors.New("redis timeout"))
}),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
},
wantStatus: http.StatusServiceUnavailable,
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
},
{
name: "internal error",
method: http.MethodGet,
target: "/api/v1/internal/sessions/device-session-123",
deps: Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{}, shared.InternalError(errors.New("broken invariant"))
}),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
},
wantStatus: http.StatusInternalServerError,
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
},
{
name: "unexpected error hidden",
method: http.MethodGet,
target: "/api/v1/internal/sessions/device-session-123",
deps: Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{}, errors.New("boom")
}),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
},
wantStatus: http.StatusInternalServerError,
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
if tt.body != "" {
request.Header.Set("Content-Type", "application/json")
}
handler.ServeHTTP(recorder, request)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestInternalHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.RequestTimeout = 5 * time.Millisecond
handler := mustNewHandler(t, cfg, Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{}, context.DeadlineExceeded
}),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil)
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.JSONEq(t, `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, recorder.Body.String())
}
func TestInternalHandlersRejectInvalidSuccessPayloads(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
target string
body string
deps Dependencies
}{
{
name: "get session malformed response",
method: http.MethodGet,
target: "/api/v1/internal/sessions/device-session-123",
deps: Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
dto := validSessionDTO()
dto.DeviceSessionID = ""
return getsession.Result{Session: dto}, nil
}),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(unexpectedBlockUser),
},
},
{
name: "revoke all malformed response",
method: http.MethodPost,
target: "/api/v1/internal/users/user-123/sessions/revoke-all",
body: `{"reason_code":"logout_all","actor":{"type":"system"}}`,
deps: Dependencies{
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
return revokeallusersessions.Result{
Outcome: "revoked",
UserID: "user-123",
AffectedSessionCount: 2,
AffectedDeviceSessionIDs: []string{"device-session-1"},
}, nil
}),
BlockUser: blockUserFunc(unexpectedBlockUser),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(tt.method, tt.target, bytes.NewBufferString(tt.body))
if tt.body != "" {
request.Header.Set("Content-Type", "application/json")
}
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.JSONEq(t, `{"error":{"code":"internal_error","message":"internal server error"}}`, recorder.Body.String())
})
}
}
func TestInternalHandlerLogsDoNotContainSensitiveFields(t *testing.T) {
t.Parallel()
logger, buffer := newObservedLogger()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
Logger: logger,
GetSession: getSessionFunc(unexpectedGetSession),
ListUserSessions: listUserSessionsFunc(unexpectedListUserSessions),
RevokeDeviceSession: revokeDeviceSessionFunc(unexpectedRevokeDeviceSession),
RevokeAllUserSessions: revokeAllUserSessionsFunc(unexpectedRevokeAllUserSessions),
BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) {
return blockuser.Result{
Outcome: "blocked",
SubjectKind: blockuser.SubjectKindEmail,
SubjectValue: "pilot@example.com",
AffectedSessionCount: 0,
AffectedDeviceSessionIDs: []string{},
}, nil
}),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/internal/user-blocks",
bytes.NewBufferString(`{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin","id":"admin-1"}}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
require.Equal(t, http.StatusOK, recorder.Code)
logOutput := buffer.String()
assert.NotContains(t, logOutput, "pilot@example.com")
assert.NotContains(t, logOutput, "admin-1")
assert.NotContains(t, logOutput, "reason_code")
}
func mustNewHandler(t *testing.T, cfg Config, deps Dependencies) http.Handler {
t.Helper()
handler, err := newHandlerWithConfig(cfg, deps)
require.NoError(t, err)
return handler
}
type getSessionFunc func(ctx context.Context, input getsession.Input) (getsession.Result, error)
func (f getSessionFunc) Execute(ctx context.Context, input getsession.Input) (getsession.Result, error) {
return f(ctx, input)
}
type listUserSessionsFunc func(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error)
func (f listUserSessionsFunc) Execute(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error) {
return f(ctx, input)
}
type revokeDeviceSessionFunc func(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error)
func (f revokeDeviceSessionFunc) Execute(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error) {
return f(ctx, input)
}
type revokeAllUserSessionsFunc func(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error)
func (f revokeAllUserSessionsFunc) Execute(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error) {
return f(ctx, input)
}
type blockUserFunc func(ctx context.Context, input blockuser.Input) (blockuser.Result, error)
func (f blockUserFunc) Execute(ctx context.Context, input blockuser.Input) (blockuser.Result, error) {
return f(ctx, input)
}
func validSessionDTO() shared.Session {
return shared.Session{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-material",
Status: "active",
CreatedAt: "2026-04-05T12:00:00Z",
}
}
func validRevokedSessionDTO() shared.Session {
dto := validSessionDTO()
dto.Status = "revoked"
revokedAt := "2026-04-05T12:01:00Z"
reasonCode := "admin_revoke"
actorType := "admin"
actorID := "admin-1"
dto.RevokedAt = &revokedAt
dto.RevokeReasonCode = &reasonCode
dto.RevokeActorType = &actorType
dto.RevokeActorID = &actorID
return dto
}
func newObservedLogger() (*zap.Logger, *bytes.Buffer) {
buffer := &bytes.Buffer{}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = ""
core := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
zapcore.AddSync(buffer),
zap.DebugLevel,
)
return zap.New(core), buffer
}
func unexpectedGetSession(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{}, errors.New("unexpected call")
}
func unexpectedListUserSessions(context.Context, listusersessions.Input) (listusersessions.Result, error) {
return listusersessions.Result{}, errors.New("unexpected call")
}
func unexpectedRevokeDeviceSession(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) {
return revokedevicesession.Result{}, errors.New("unexpected call")
}
func unexpectedRevokeAllUserSessions(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
return revokeallusersessions.Result{}, errors.New("unexpected call")
}
func unexpectedBlockUser(context.Context, blockuser.Input) (blockuser.Result, error) {
return blockuser.Result{}, errors.New("unexpected call")
}
@@ -0,0 +1,93 @@
package internalhttp
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"galaxy/authsession/internal/service/shared"
"github.com/gin-gonic/gin"
)
const internalErrorCodeContextKey = "internal_error_code"
type malformedJSONRequestError struct {
message string
}
func (e *malformedJSONRequestError) Error() string {
if e == nil {
return ""
}
return e.message
}
func decodeJSONRequest(request *http.Request, target any) error {
if request == nil || request.Body == nil {
return &malformedJSONRequestError{message: "request body must not be empty"}
}
return decodeJSONReader(request.Body, target)
}
func decodeJSONReader(reader io.Reader, target any) error {
decoder := json.NewDecoder(reader)
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return describeJSONDecodeError(err)
}
if err := decoder.Decode(&struct{}{}); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
}
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
}
func describeJSONDecodeError(err error) error {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
switch {
case errors.Is(err, io.EOF):
return &malformedJSONRequestError{message: "request body must not be empty"}
case errors.As(err, &syntaxErr):
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
case errors.Is(err, io.ErrUnexpectedEOF):
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
case errors.As(err, &typeErr):
if strings.TrimSpace(typeErr.Field) != "" {
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
}
}
return &malformedJSONRequestError{message: "request body contains an invalid JSON value"}
case strings.HasPrefix(err.Error(), "json: unknown field "):
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
}
default:
return &malformedJSONRequestError{message: "request body contains invalid JSON"}
}
}
func abortWithProjection(c *gin.Context, projection shared.InternalErrorProjection) {
c.Set(internalErrorCodeContextKey, projection.Code)
c.AbortWithStatusJSON(projection.StatusCode, errorResponse{
Error: errorBody{
Code: projection.Code,
Message: projection.Message,
},
})
}
@@ -0,0 +1,86 @@
package internalhttp
import (
"time"
authlogging "galaxy/authsession/internal/logging"
"galaxy/authsession/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
type edgeOutcome string
const (
edgeOutcomeSuccess edgeOutcome = "success"
edgeOutcomeRejected edgeOutcome = "rejected"
edgeOutcomeFailed edgeOutcome = "failed"
)
func withInternalObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
if logger == nil {
logger = zap.NewNop()
}
return func(c *gin.Context) {
start := time.Now()
c.Next()
statusCode := c.Writer.Status()
route := c.FullPath()
if route == "" {
route = "unmatched"
}
errorCode, _ := c.Get(internalErrorCodeContextKey)
errorCodeValue, _ := errorCode.(string)
outcome := outcomeFromStatusCode(statusCode)
duration := time.Since(start)
fields := []zap.Field{
zap.String("component", "internal_http"),
zap.String("transport", "http"),
zap.String("route", route),
zap.String("method", c.Request.Method),
zap.Int("status_code", statusCode),
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
zap.String("edge_outcome", string(outcome)),
}
if errorCodeValue != "" {
fields = append(fields, zap.String("error_code", errorCodeValue))
}
fields = append(fields, authlogging.TraceFieldsFromContext(c.Request.Context())...)
metricAttrs := []attribute.KeyValue{
attribute.String("route", route),
attribute.String("method", c.Request.Method),
attribute.String("edge_outcome", string(outcome)),
}
if errorCodeValue != "" {
metricAttrs = append(metricAttrs, attribute.String("error_code", errorCodeValue))
}
metrics.RecordInternalHTTPRequest(c.Request.Context(), metricAttrs, duration)
switch outcome {
case edgeOutcomeSuccess:
logger.Info("internal request completed", fields...)
case edgeOutcomeFailed:
logger.Error("internal request failed", fields...)
default:
logger.Warn("internal request rejected", fields...)
}
}
}
func outcomeFromStatusCode(statusCode int) edgeOutcome {
switch {
case statusCode >= 500:
return edgeOutcomeFailed
case statusCode >= 400:
return edgeOutcomeRejected
default:
return edgeOutcomeSuccess
}
}
@@ -0,0 +1,121 @@
package internalhttp
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"galaxy/authsession/internal/service/blockuser"
"galaxy/authsession/internal/service/getsession"
"galaxy/authsession/internal/service/listusersessions"
"galaxy/authsession/internal/service/revokeallusersessions"
"galaxy/authsession/internal/service/revokedevicesession"
"galaxy/authsession/internal/service/shared"
authtelemetry "galaxy/authsession/internal/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
)
func TestInternalHandlerEmitsTraceFieldsAndMetrics(t *testing.T) {
t.Parallel()
logger, buffer := newObservedLogger()
telemetryRuntime, reader, recorder := newObservedInternalTelemetryRuntime(t)
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
Logger: logger,
Telemetry: telemetryRuntime,
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{Session: validSessionDTO()}, nil
}),
ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) {
return listusersessions.Result{Sessions: []shared.Session{}}, nil
}),
RevokeDeviceSession: revokeDeviceSessionFunc(func(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) {
return revokedevicesession.Result{}, nil
}),
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
return revokeallusersessions.Result{}, nil
}),
BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) {
return blockuser.Result{}, nil
}),
})
recorderHTTP := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/v1/internal/sessions/device-session-123", nil)
handler.ServeHTTP(recorderHTTP, request)
require.Equal(t, http.StatusOK, recorderHTTP.Code)
require.NotEmpty(t, recorder.Ended())
assert.Contains(t, buffer.String(), "otel_trace_id")
assert.Contains(t, buffer.String(), "otel_span_id")
assertMetricCount(t, reader, "authsession.internal_http.requests", map[string]string{
"route": "/api/v1/internal/sessions/:device_session_id",
"method": http.MethodGet,
"edge_outcome": "success",
}, 1)
}
func newObservedInternalTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader, *tracetest.SpanRecorder) {
t.Helper()
reader := sdkmetric.NewManualReader()
meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
recorder := tracetest.NewSpanRecorder()
tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
runtime, err := authtelemetry.NewWithProviders(meterProvider, tracerProvider)
require.NoError(t, err)
return runtime, reader, recorder
}
func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) {
t.Helper()
var resourceMetrics metricdata.ResourceMetrics
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
if metric.Name != metricName {
continue
}
sum, ok := metric.Data.(metricdata.Sum[int64])
require.True(t, ok)
for _, point := range sum.DataPoints {
if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
assert.Equal(t, wantValue, point.Value)
return
}
}
}
}
require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs)
}
func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
if len(values) != len(want) {
return false
}
for _, value := range values {
if want[string(value.Key)] != value.Value.AsString() {
return false
}
}
return true
}
@@ -0,0 +1,271 @@
package internalhttp
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
"galaxy/authsession/internal/service/blockuser"
"galaxy/authsession/internal/service/getsession"
"galaxy/authsession/internal/service/listusersessions"
"galaxy/authsession/internal/service/revokeallusersessions"
"galaxy/authsession/internal/service/revokedevicesession"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
const (
defaultAddr = ":8081"
defaultReadHeaderTimeout = 2 * time.Second
defaultReadTimeout = 10 * time.Second
defaultIdleTimeout = time.Minute
defaultRequestTimeout = 3 * time.Second
)
// GetSessionUseCase describes the trusted internal get-session service
// consumed by the HTTP transport layer.
type GetSessionUseCase interface {
// Execute loads one device session for trusted internal callers.
Execute(ctx context.Context, input getsession.Input) (getsession.Result, error)
}
// ListUserSessionsUseCase describes the trusted internal list-user-sessions
// service consumed by the HTTP transport layer.
type ListUserSessionsUseCase interface {
// Execute lists all sessions of one user for trusted internal callers.
Execute(ctx context.Context, input listusersessions.Input) (listusersessions.Result, error)
}
// RevokeDeviceSessionUseCase describes the trusted internal single-session
// revoke service consumed by the HTTP transport layer.
type RevokeDeviceSessionUseCase interface {
// Execute revokes one device session and returns the frozen
// acknowledgement.
Execute(ctx context.Context, input revokedevicesession.Input) (revokedevicesession.Result, error)
}
// RevokeAllUserSessionsUseCase describes the trusted internal bulk-revoke
// service consumed by the HTTP transport layer.
type RevokeAllUserSessionsUseCase interface {
// Execute revokes all active sessions of one user and returns the frozen
// acknowledgement.
Execute(ctx context.Context, input revokeallusersessions.Input) (revokeallusersessions.Result, error)
}
// BlockUserUseCase describes the trusted internal block-user service consumed
// by the HTTP transport layer.
type BlockUserUseCase interface {
// Execute applies a block state to one subject and returns the frozen
// acknowledgement.
Execute(ctx context.Context, input blockuser.Input) (blockuser.Result, error)
}
// Config describes the trusted internal HTTP listener owned by authsession.
type Config struct {
// Addr is the TCP listen address used by the trusted internal HTTP server.
Addr string
// ReadHeaderTimeout bounds how long the listener may spend reading request
// headers before the server rejects the connection.
ReadHeaderTimeout time.Duration
// ReadTimeout bounds how long the listener may spend reading one trusted
// internal request.
ReadTimeout time.Duration
// IdleTimeout bounds how long the listener keeps an idle keep-alive
// connection open.
IdleTimeout time.Duration
// RequestTimeout bounds one application-layer internal use-case call.
RequestTimeout time.Duration
}
// Validate reports whether cfg contains a usable internal HTTP listener
// configuration.
func (cfg Config) Validate() error {
switch {
case cfg.Addr == "":
return errors.New("internal HTTP addr must not be empty")
case cfg.ReadHeaderTimeout <= 0:
return errors.New("internal HTTP read header timeout must be positive")
case cfg.ReadTimeout <= 0:
return errors.New("internal HTTP read timeout must be positive")
case cfg.IdleTimeout <= 0:
return errors.New("internal HTTP idle timeout must be positive")
case cfg.RequestTimeout <= 0:
return errors.New("internal HTTP request timeout must be positive")
default:
return nil
}
}
// DefaultConfig returns the default trusted internal HTTP listener settings.
func DefaultConfig() Config {
return Config{
Addr: defaultAddr,
ReadHeaderTimeout: defaultReadHeaderTimeout,
ReadTimeout: defaultReadTimeout,
IdleTimeout: defaultIdleTimeout,
RequestTimeout: defaultRequestTimeout,
}
}
// Dependencies describes the collaborators used by the trusted internal HTTP
// transport layer.
type Dependencies struct {
// GetSession executes the trusted internal get-session use case.
GetSession GetSessionUseCase
// ListUserSessions executes the trusted internal list-user-sessions use
// case.
ListUserSessions ListUserSessionsUseCase
// RevokeDeviceSession executes the trusted internal single-session revoke
// use case.
RevokeDeviceSession RevokeDeviceSessionUseCase
// RevokeAllUserSessions executes the trusted internal bulk-revoke use case.
RevokeAllUserSessions RevokeAllUserSessionsUseCase
// BlockUser executes the trusted internal block-user use case.
BlockUser BlockUserUseCase
// Logger writes structured transport logs. When nil, a no-op logger is
// used.
Logger *zap.Logger
// Telemetry records OpenTelemetry spans and low-cardinality HTTP metrics.
// When nil, the transport still serves requests with no-op providers.
Telemetry *telemetry.Runtime
}
// Server owns the trusted internal HTTP listener exposed by authsession.
type Server struct {
cfg Config
handler http.Handler
logger *zap.Logger
stateMu sync.RWMutex
server *http.Server
listener net.Listener
}
// NewServer constructs one trusted internal HTTP server for cfg and deps.
func NewServer(cfg Config, deps Dependencies) (*Server, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("new internal HTTP server: %w", err)
}
handler, err := newHandlerWithConfig(cfg, deps)
if err != nil {
return nil, fmt.Errorf("new internal HTTP server: %w", err)
}
logger := deps.Logger
if logger == nil {
logger = zap.NewNop()
}
logger = logger.Named("internal_http")
return &Server{
cfg: cfg,
handler: handler,
logger: logger,
}, nil
}
// Run binds the configured listener and serves the trusted internal HTTP
// surface until Shutdown closes the server.
func (s *Server) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run internal HTTP server: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
listener, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("run internal HTTP server: listen on %q: %w", s.cfg.Addr, err)
}
server := &http.Server{
Handler: s.handler,
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
ReadTimeout: s.cfg.ReadTimeout,
IdleTimeout: s.cfg.IdleTimeout,
}
s.stateMu.Lock()
s.server = server
s.listener = listener
s.stateMu.Unlock()
s.logger.Info("internal HTTP server started", zap.String("addr", listener.Addr().String()))
defer func() {
s.stateMu.Lock()
s.server = nil
s.listener = nil
s.stateMu.Unlock()
}()
err = server.Serve(listener)
switch {
case err == nil:
return nil
case errors.Is(err, http.ErrServerClosed):
s.logger.Info("internal HTTP server stopped")
return nil
default:
return fmt.Errorf("run internal HTTP server: serve on %q: %w", s.cfg.Addr, err)
}
}
// Shutdown gracefully stops the trusted internal HTTP server within ctx.
func (s *Server) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown internal HTTP server: nil context")
}
s.stateMu.RLock()
server := s.server
s.stateMu.RUnlock()
if server == nil {
return nil
}
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("shutdown internal HTTP server: %w", err)
}
return nil
}
func normalizeDependencies(deps Dependencies) (Dependencies, error) {
switch {
case deps.GetSession == nil:
return Dependencies{}, errors.New("get session use case must not be nil")
case deps.ListUserSessions == nil:
return Dependencies{}, errors.New("list user sessions use case must not be nil")
case deps.RevokeDeviceSession == nil:
return Dependencies{}, errors.New("revoke device session use case must not be nil")
case deps.RevokeAllUserSessions == nil:
return Dependencies{}, errors.New("revoke all user sessions use case must not be nil")
case deps.BlockUser == nil:
return Dependencies{}, errors.New("block user use case must not be nil")
case deps.Logger == nil:
deps.Logger = zap.NewNop()
}
deps.Logger = deps.Logger.Named("internal_http")
return deps, nil
}
@@ -0,0 +1,106 @@
package internalhttp
import (
"bytes"
"context"
"net/http"
"testing"
"time"
"galaxy/authsession/internal/service/blockuser"
"galaxy/authsession/internal/service/getsession"
"galaxy/authsession/internal/service/listusersessions"
"galaxy/authsession/internal/service/revokeallusersessions"
"galaxy/authsession/internal/service/revokedevicesession"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewServerRejectsInvalidConfiguration(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Addr = ""
_, err := NewServer(cfg, validDependencies())
require.Error(t, err)
assert.Contains(t, err.Error(), "addr")
}
func TestServerRunAndShutdown(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Addr = "127.0.0.1:0"
server, err := NewServer(cfg, validDependencies())
require.NoError(t, err)
runErr := make(chan error, 1)
go func() {
runErr <- server.Run(context.Background())
}()
require.Eventually(t, func() bool {
server.stateMu.RLock()
defer server.stateMu.RUnlock()
return server.listener != nil
}, time.Second, 10*time.Millisecond)
server.stateMu.RLock()
addr := server.listener.Addr().String()
server.stateMu.RUnlock()
response, err := http.Post(
"http://"+addr+"/api/v1/internal/sessions/device-session-123/revoke",
"application/json",
bytes.NewBufferString(`{"reason_code":"admin_revoke","actor":{"type":"system"}}`),
)
require.NoError(t, err)
defer response.Body.Close()
assert.Equal(t, http.StatusOK, response.StatusCode)
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.NoError(t, server.Shutdown(shutdownCtx))
require.NoError(t, <-runErr)
}
func validDependencies() Dependencies {
return Dependencies{
GetSession: getSessionFunc(func(context.Context, getsession.Input) (getsession.Result, error) {
return getsession.Result{Session: validSessionDTO()}, nil
}),
ListUserSessions: listUserSessionsFunc(func(context.Context, listusersessions.Input) (listusersessions.Result, error) {
return listusersessions.Result{Sessions: []shared.Session{validSessionDTO()}}, nil
}),
RevokeDeviceSession: revokeDeviceSessionFunc(func(context.Context, revokedevicesession.Input) (revokedevicesession.Result, error) {
return revokedevicesession.Result{
Outcome: "revoked",
DeviceSessionID: "device-session-123",
AffectedSessionCount: 1,
}, nil
}),
RevokeAllUserSessions: revokeAllUserSessionsFunc(func(context.Context, revokeallusersessions.Input) (revokeallusersessions.Result, error) {
return revokeallusersessions.Result{
Outcome: "revoked",
UserID: "user-123",
AffectedSessionCount: 1,
AffectedDeviceSessionIDs: []string{"device-session-123"},
}, nil
}),
BlockUser: blockUserFunc(func(context.Context, blockuser.Input) (blockuser.Result, error) {
return blockuser.Result{
Outcome: "blocked",
SubjectKind: blockuser.SubjectKindEmail,
SubjectValue: "pilot@example.com",
AffectedSessionCount: 0,
AffectedDeviceSessionIDs: []string{},
}, nil
}),
}
}
@@ -0,0 +1,3 @@
// Package publichttp exposes the public HTTP transport expected by the
// gateway-facing authentication flow.
package publichttp
@@ -0,0 +1,391 @@
package publichttp
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"galaxy/authsession/internal/adapters/mail"
"galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPublicHTTPEndToEndSendThenConfirm(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{})
server := httptest.NewServer(app.handler)
defer server.Close()
sendResponse := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, sendResponse.Body)
attempts := app.mailSender.RecordedAttempts()
require.Len(t, attempts, 1)
confirmBody := map[string]string{
"challenge_id": "challenge-1",
"code": attempts[0].Input.Code,
"client_public_key": validClientPublicKey,
}
confirmResponse := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", confirmBody)
assert.Equal(t, http.StatusOK, confirmResponse.StatusCode)
assert.JSONEq(t, `{"device_session_id":"device-session-1"}`, confirmResponse.Body)
}
func TestPublicHTTPEndToEndBlockedSendReturnsChallengeID(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{
SeedBlockedEmail: true,
})
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, response.Body)
assert.Empty(t, app.mailSender.RecordedAttempts())
}
func TestPublicHTTPEndToEndThrottledSendStillReturnsChallengeID(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{
AbuseProtector: &testkit.InMemorySendEmailCodeAbuseProtector{},
})
server := httptest.NewServer(app.handler)
defer server.Close()
first := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
assert.Equal(t, http.StatusOK, first.StatusCode)
assert.JSONEq(t, `{"challenge_id":"challenge-1"}`, first.Body)
second := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
assert.Equal(t, http.StatusOK, second.StatusCode)
assert.JSONEq(t, `{"challenge_id":"challenge-2"}`, second.Body)
assert.Len(t, app.mailSender.RecordedAttempts(), 1)
}
func TestPublicHTTPEndToEndInvalidClientPublicKey(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{
SeedChallenge: seedChallengeOptions{
ID: "challenge-123",
Code: "123456",
Status: challenge.StatusSent,
},
})
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSON(
t,
server.URL+"/api/v1/public/auth/confirm-email-code",
`{"challenge_id":"challenge-123","code":"123456","client_public_key":"invalid"}`,
)
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`, response.Body)
}
func TestPublicHTTPEndToEndChallengeNotFound(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{})
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
"challenge_id": "missing",
"code": "123456",
"client_public_key": validClientPublicKey,
})
assert.Equal(t, http.StatusNotFound, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"challenge_not_found","message":"challenge not found"}}`, response.Body)
}
func TestPublicHTTPEndToEndChallengeExpired(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{
SeedChallenge: seedChallengeOptions{
ID: "challenge-123",
Code: "123456",
Status: challenge.StatusSent,
ExpiresAt: time.Date(2026, 4, 5, 11, 59, 0, 0, time.UTC),
},
})
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
"challenge_id": "challenge-123",
"code": "123456",
"client_public_key": validClientPublicKey,
})
assert.Equal(t, http.StatusGone, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"challenge_expired","message":"challenge expired"}}`, response.Body)
}
func TestPublicHTTPEndToEndInvalidCode(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{
SeedChallenge: seedChallengeOptions{
ID: "challenge-123",
Code: "123456",
Status: challenge.StatusSent,
},
})
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
"challenge_id": "challenge-123",
"code": "654321",
"client_public_key": validClientPublicKey,
})
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"invalid_code","message":"confirmation code is invalid"}}`, response.Body)
}
func TestPublicHTTPEndToEndThrottledChallengeConfirmReturnsInvalidCode(t *testing.T) {
t.Parallel()
app := newEndToEndApp(t, endToEndOptions{
SeedChallenge: seedChallengeOptions{
ID: "challenge-123",
Code: "123456",
Status: challenge.StatusDeliveryThrottled,
},
})
server := httptest.NewServer(app.handler)
defer server.Close()
response := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
"challenge_id": "challenge-123",
"code": "123456",
"client_public_key": validClientPublicKey,
})
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
assert.JSONEq(t, `{"error":{"code":"invalid_code","message":"confirmation code is invalid"}}`, response.Body)
}
func TestPublicHTTPEndToEndSessionLimitExceeded(t *testing.T) {
t.Parallel()
limit := 1
app := newEndToEndApp(t, endToEndOptions{
Config: ports.SessionLimitConfig{ActiveSessionLimit: &limit},
SeedExistingUser: true,
SeedActiveSession: &devicesession.Session{
ID: common.DeviceSessionID("device-session-existing"),
UserID: common.UserID("user-1"),
ClientPublicKey: mustClientPublicKey(t, secondValidClientPublicKey),
Status: devicesession.StatusActive,
CreatedAt: time.Date(2026, 4, 5, 11, 58, 0, 0, time.UTC),
},
})
server := httptest.NewServer(app.handler)
defer server.Close()
sendResponse := postJSON(t, server.URL+"/api/v1/public/auth/send-email-code", `{"email":"pilot@example.com"}`)
assert.Equal(t, http.StatusOK, sendResponse.StatusCode)
attempts := app.mailSender.RecordedAttempts()
require.Len(t, attempts, 1)
confirmResponse := postJSONValue(t, server.URL+"/api/v1/public/auth/confirm-email-code", map[string]string{
"challenge_id": "challenge-1",
"code": attempts[0].Input.Code,
"client_public_key": validClientPublicKey,
})
assert.Equal(t, http.StatusConflict, confirmResponse.StatusCode)
assert.JSONEq(t, `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`, confirmResponse.Body)
}
type endToEndOptions struct {
Config ports.SessionLimitConfig
AbuseProtector ports.SendEmailCodeAbuseProtector
SeedBlockedEmail bool
SeedExistingUser bool
SeedChallenge seedChallengeOptions
SeedActiveSession *devicesession.Session
}
type seedChallengeOptions struct {
ID string
Code string
Status challenge.Status
ExpiresAt time.Time
}
type endToEndApp struct {
handler http.Handler
mailSender *mail.StubSender
}
func newEndToEndApp(t *testing.T, options endToEndOptions) endToEndApp {
t.Helper()
now := time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC)
challengeStore := &testkit.InMemoryChallengeStore{}
sessionStore := &testkit.InMemorySessionStore{}
userDirectory := &userservice.StubDirectory{}
mailSender := &mail.StubSender{}
idGenerator := &testkit.SequenceIDGenerator{}
codeGenerator := testkit.FixedCodeGenerator{Code: "123456"}
codeHasher := testkit.DeterministicCodeHasher{}
clock := testkit.FixedClock{Time: now}
publisher := &testkit.RecordingProjectionPublisher{}
if options.SeedBlockedEmail {
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_blocked")))
}
if options.SeedExistingUser {
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
}
if options.SeedActiveSession != nil {
require.NoError(t, sessionStore.Create(context.Background(), *options.SeedActiveSession))
}
if options.SeedChallenge.ID != "" {
expiresAt := options.SeedChallenge.ExpiresAt
if expiresAt.IsZero() {
expiresAt = now.Add(challenge.InitialTTL)
}
record := challenge.Challenge{
ID: common.ChallengeID(options.SeedChallenge.ID),
Email: common.Email("pilot@example.com"),
CodeHash: mustHashCode(t, options.SeedChallenge.Code),
Status: options.SeedChallenge.Status,
DeliveryState: deliveryStateForSeedChallenge(options.SeedChallenge.Status),
CreatedAt: now.Add(-time.Minute),
ExpiresAt: expiresAt,
}
require.NoError(t, challengeStore.Create(context.Background(), record))
}
sendService, err := sendemailcode.NewWithRuntime(
challengeStore,
userDirectory,
idGenerator,
codeGenerator,
codeHasher,
mailSender,
options.AbuseProtector,
clock,
nil,
)
require.NoError(t, err)
confirmService, err := confirmemailcode.New(
challengeStore,
sessionStore,
userDirectory,
testkit.StaticConfigProvider{Config: options.Config},
publisher,
idGenerator,
codeHasher,
clock,
)
require.NoError(t, err)
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
SendEmailCode: sendService,
ConfirmEmailCode: confirmService,
})
return endToEndApp{
handler: handler,
mailSender: mailSender,
}
}
func deliveryStateForSeedChallenge(status challenge.Status) challenge.DeliveryState {
switch status {
case challenge.StatusDeliverySuppressed:
return challenge.DeliverySuppressed
case challenge.StatusDeliveryThrottled:
return challenge.DeliveryThrottled
default:
return challenge.DeliverySent
}
}
type httpResponse struct {
StatusCode int
Body string
}
func postJSON(t *testing.T, url string, body string) httpResponse {
t.Helper()
response, err := http.Post(url, "application/json", bytes.NewBufferString(body))
require.NoError(t, err)
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
require.NoError(t, err)
return httpResponse{StatusCode: response.StatusCode, Body: string(payload)}
}
func postJSONValue(t *testing.T, url string, value any) httpResponse {
t.Helper()
body, err := json.Marshal(value)
require.NoError(t, err)
return postJSON(t, url, string(body))
}
func mustHashCode(t *testing.T, code string) []byte {
t.Helper()
sum := sha256.Sum256([]byte(code))
return sum[:]
}
func mustClientPublicKey(t *testing.T, encoded string) common.ClientPublicKey {
t.Helper()
decoded, err := base64.StdEncoding.DecodeString(encoded)
require.NoError(t, err)
key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded))
require.NoError(t, err)
return key
}
const (
validClientPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
secondValidClientPublicKey = "ICEiIyQlJicoKSorLC0uLzAxMjM0NTY3ODk6Ozw9Pj8="
)
@@ -0,0 +1,242 @@
package publichttp
import (
"context"
"errors"
"fmt"
"net/http"
"net/mail"
"strings"
"sync"
"time"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
)
const jsonContentType = "application/json; charset=utf-8"
const publicHTTPServiceName = "galaxy-authsession-public"
type sendEmailCodeRequest struct {
Email string `json:"email"`
}
type sendEmailCodeResponse struct {
ChallengeID string `json:"challenge_id"`
}
type confirmEmailCodeRequest struct {
ChallengeID string `json:"challenge_id"`
Code string `json:"code"`
ClientPublicKey string `json:"client_public_key"`
}
type confirmEmailCodeResponse struct {
DeviceSessionID string `json:"device_session_id"`
}
type errorResponse struct {
Error errorBody `json:"error"`
}
type errorBody struct {
Code string `json:"code"`
Message string `json:"message"`
}
var configureGinModeOnce sync.Once
func newHandlerWithConfig(cfg Config, deps Dependencies) (http.Handler, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
normalizedDeps, err := normalizeDependencies(deps)
if err != nil {
return nil, err
}
configureGinModeOnce.Do(func() {
gin.SetMode(gin.ReleaseMode)
})
engine := gin.New()
engine.Use(newOTelMiddleware(normalizedDeps.Telemetry))
engine.Use(withPublicObservability(normalizedDeps.Logger, normalizedDeps.Telemetry))
engine.POST(
"/api/v1/public/auth/send-email-code",
handleSendEmailCode(normalizedDeps.SendEmailCode, cfg.RequestTimeout),
)
engine.POST(
"/api/v1/public/auth/confirm-email-code",
handleConfirmEmailCode(normalizedDeps.ConfirmEmailCode, cfg.RequestTimeout),
)
return engine, nil
}
func newOTelMiddleware(runtime *telemetry.Runtime) gin.HandlerFunc {
options := []otelgin.Option{}
if runtime != nil {
options = append(
options,
otelgin.WithTracerProvider(runtime.TracerProvider()),
otelgin.WithMeterProvider(runtime.MeterProvider()),
)
}
return otelgin.Middleware(publicHTTPServiceName, options...)
}
func handleSendEmailCode(useCase SendEmailCodeUseCase, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var request sendEmailCodeRequest
if err := decodeJSONRequest(c.Request, &request); err != nil {
abortWithProjection(c, projectSendEmailCodeError(shared.InvalidRequest(err.Error())))
return
}
if err := validateSendEmailCodeRequest(&request); err != nil {
abortWithProjection(c, projectSendEmailCodeError(shared.InvalidRequest(err.Error())))
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := useCase.Execute(callCtx, sendemailcode.Input{Email: request.Email})
if err != nil {
abortWithProjection(c, projectSendEmailCodeError(err))
return
}
if err := validateSendEmailCodeResult(&result); err != nil {
abortWithProjection(c, unavailableProjection(fmt.Errorf("send email code response: %w", err)))
return
}
c.JSON(http.StatusOK, sendEmailCodeResponse{ChallengeID: result.ChallengeID})
}
}
func handleConfirmEmailCode(useCase ConfirmEmailCodeUseCase, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var request confirmEmailCodeRequest
if err := decodeJSONRequest(c.Request, &request); err != nil {
abortWithProjection(c, projectConfirmEmailCodeError(shared.InvalidRequest(err.Error())))
return
}
if err := validateConfirmEmailCodeRequest(&request); err != nil {
abortWithProjection(c, projectConfirmEmailCodeError(shared.InvalidRequest(err.Error())))
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := useCase.Execute(callCtx, confirmemailcode.Input{
ChallengeID: request.ChallengeID,
Code: request.Code,
ClientPublicKey: request.ClientPublicKey,
})
if err != nil {
abortWithProjection(c, projectConfirmEmailCodeError(err))
return
}
if err := validateConfirmEmailCodeResult(&result); err != nil {
abortWithProjection(c, unavailableProjection(fmt.Errorf("confirm email code response: %w", err)))
return
}
c.JSON(http.StatusOK, confirmEmailCodeResponse{DeviceSessionID: result.DeviceSessionID})
}
}
func validateSendEmailCodeRequest(request *sendEmailCodeRequest) error {
request.Email = strings.TrimSpace(request.Email)
if request.Email == "" {
return errors.New("email must not be empty")
}
parsedAddress, err := mail.ParseAddress(request.Email)
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != request.Email {
return errors.New("email must be a single valid email address")
}
return nil
}
func validateSendEmailCodeResult(result *sendemailcode.Result) error {
result.ChallengeID = strings.TrimSpace(result.ChallengeID)
if result.ChallengeID == "" {
return errors.New("challenge_id must not be empty")
}
return nil
}
func validateConfirmEmailCodeRequest(request *confirmEmailCodeRequest) error {
request.ChallengeID = strings.TrimSpace(request.ChallengeID)
if request.ChallengeID == "" {
return errors.New("challenge_id must not be empty")
}
request.Code = strings.TrimSpace(request.Code)
if request.Code == "" {
return errors.New("code must not be empty")
}
request.ClientPublicKey = strings.TrimSpace(request.ClientPublicKey)
if request.ClientPublicKey == "" {
return errors.New("client_public_key must not be empty")
}
return nil
}
func validateConfirmEmailCodeResult(result *confirmemailcode.Result) error {
result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID)
if result.DeviceSessionID == "" {
return errors.New("device_session_id must not be empty")
}
return nil
}
func projectSendEmailCodeError(err error) shared.PublicErrorProjection {
if isTimeoutOrCanceled(err) {
return unavailableProjection(err)
}
projection := shared.ProjectPublicError(err)
if !shared.IsSendEmailCodePublicErrorCode(projection.Code) {
return unavailableProjection(err)
}
return projection
}
func projectConfirmEmailCodeError(err error) shared.PublicErrorProjection {
if isTimeoutOrCanceled(err) {
return unavailableProjection(err)
}
projection := shared.ProjectPublicError(err)
if !shared.IsConfirmEmailCodePublicErrorCode(projection.Code) {
return unavailableProjection(err)
}
return projection
}
func unavailableProjection(err error) shared.PublicErrorProjection {
return shared.ProjectPublicError(shared.ServiceUnavailable(err))
}
func isTimeoutOrCanceled(err error) bool {
return errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)
}
@@ -0,0 +1,463 @@
package publichttp
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestSendEmailCodeHandlerSuccess(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{ChallengeID: "challenge-123"}, nil
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, errors.New("unexpected call")
}),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/send-email-code",
bytes.NewBufferString(`{"email":" pilot@example.com "}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String())
}
func TestConfirmEmailCodeHandlerSuccess(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(_ context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) {
assert.Equal(t, confirmemailcode.Input{
ChallengeID: "challenge-123",
Code: "123456",
ClientPublicKey: "public-key-material",
}, input)
return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil
}),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/confirm-email-code",
bytes.NewBufferString(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String())
}
func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) {
t.Parallel()
tests := []struct {
name string
target string
body string
wantStatus int
wantBody string
}{
{
name: "empty body",
target: "/api/v1/public/auth/send-email-code",
body: ``,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body must not be empty"}}`,
},
{
name: "malformed json",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
},
{
name: "multiple objects",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}{"email":"next@example.com"}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body must contain a single JSON object"}}`,
},
{
name: "unknown field",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com","extra":true}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains unknown field \"extra\""}}`,
},
{
name: "invalid json type",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":123}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains an invalid value for \"email\""}}`,
},
{
name: "invalid email",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"not-an-email"}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`,
},
{
name: "empty code",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`,
},
}
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, errors.New("unexpected call")
}),
})
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body))
if tt.body != "" {
request.Header.Set("Content-Type", "application/json")
}
handler.ServeHTTP(recorder, request)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestPublicAuthHandlersMapServiceErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
target string
body string
deps Dependencies
wantStatus int
wantBody string
}{
{
name: "send route hides blocked by policy",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, shared.BlockedByPolicy()
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, errors.New("unexpected call")
}),
},
wantStatus: http.StatusServiceUnavailable,
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
},
{
name: "confirm invalid client public key",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, shared.InvalidClientPublicKey()
}),
},
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_client_public_key","message":"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key"}}`,
},
{
name: "confirm challenge not found",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, shared.ChallengeNotFound()
}),
},
wantStatus: http.StatusNotFound,
wantBody: `{"error":{"code":"challenge_not_found","message":"challenge not found"}}`,
},
{
name: "confirm challenge expired",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, shared.ChallengeExpired()
}),
},
wantStatus: http.StatusGone,
wantBody: `{"error":{"code":"challenge_expired","message":"challenge expired"}}`,
},
{
name: "confirm blocked by policy",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, shared.BlockedByPolicy()
}),
},
wantStatus: http.StatusForbidden,
wantBody: `{"error":{"code":"blocked_by_policy","message":"authentication is blocked by policy"}}`,
},
{
name: "confirm session limit exceeded",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, shared.SessionLimitExceeded()
}),
},
wantStatus: http.StatusConflict,
wantBody: `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`,
},
{
name: "confirm hides internal error",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, shared.InternalError(errors.New("broken invariant"))
}),
},
wantStatus: http.StatusServiceUnavailable,
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body))
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.RequestTimeout = 5 * time.Millisecond
handler := mustNewHandler(t, cfg, Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, context.DeadlineExceeded
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, errors.New("unexpected call")
}),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/send-email-code",
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.JSONEq(t, `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`, recorder.Body.String())
}
func TestPublicAuthHandlersRejectInvalidSuccessPayloads(t *testing.T) {
t.Parallel()
tests := []struct {
name string
target string
body string
deps Dependencies
wantBody string
}{
{
name: "send email blank challenge id",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{ChallengeID: " "}, nil
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, errors.New("unexpected call")
}),
},
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
},
{
name: "confirm blank device session id",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
deps: Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{DeviceSessionID: " "}, nil
}),
},
wantBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := mustNewHandler(t, DefaultConfig(), tt.deps)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, tt.target, bytes.NewBufferString(tt.body))
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.JSONEq(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) {
t.Parallel()
logger, buffer := newObservedLogger()
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
Logger: logger,
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, errors.New("unexpected call")
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil
}),
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/confirm-email-code",
bytes.NewBufferString(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorder, request)
require.Equal(t, http.StatusOK, recorder.Code)
logOutput := buffer.String()
assert.NotContains(t, logOutput, "challenge-123")
assert.NotContains(t, logOutput, "123456")
assert.NotContains(t, logOutput, "public-key-material")
assert.NotContains(t, logOutput, "pilot@example.com")
assert.NotContains(t, logOutput, "device-session-123")
}
func mustNewHandler(t *testing.T, cfg Config, deps Dependencies) http.Handler {
t.Helper()
handler, err := newHandlerWithConfig(cfg, deps)
require.NoError(t, err)
return handler
}
type sendEmailCodeFunc func(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error)
func (f sendEmailCodeFunc) Execute(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error) {
return f(ctx, input)
}
type confirmEmailCodeFunc func(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error)
func (f confirmEmailCodeFunc) Execute(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error) {
return f(ctx, input)
}
func newObservedLogger() (*zap.Logger, *bytes.Buffer) {
buffer := &bytes.Buffer{}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = ""
core := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
zapcore.AddSync(buffer),
zap.DebugLevel,
)
return zap.New(core), buffer
}
@@ -0,0 +1,93 @@
package publichttp
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"galaxy/authsession/internal/service/shared"
"github.com/gin-gonic/gin"
)
const publicErrorCodeContextKey = "public_error_code"
type malformedJSONRequestError struct {
message string
}
func (e *malformedJSONRequestError) Error() string {
if e == nil {
return ""
}
return e.message
}
func decodeJSONRequest(request *http.Request, target any) error {
if request == nil || request.Body == nil {
return &malformedJSONRequestError{message: "request body must not be empty"}
}
return decodeJSONReader(request.Body, target)
}
func decodeJSONReader(reader io.Reader, target any) error {
decoder := json.NewDecoder(reader)
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return describeJSONDecodeError(err)
}
if err := decoder.Decode(&struct{}{}); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
}
return &malformedJSONRequestError{message: "request body must contain a single JSON object"}
}
func describeJSONDecodeError(err error) error {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
switch {
case errors.Is(err, io.EOF):
return &malformedJSONRequestError{message: "request body must not be empty"}
case errors.As(err, &syntaxErr):
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
case errors.Is(err, io.ErrUnexpectedEOF):
return &malformedJSONRequestError{message: "request body contains malformed JSON"}
case errors.As(err, &typeErr):
if strings.TrimSpace(typeErr.Field) != "" {
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
}
}
return &malformedJSONRequestError{message: "request body contains an invalid JSON value"}
case strings.HasPrefix(err.Error(), "json: unknown field "):
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
}
default:
return &malformedJSONRequestError{message: "request body contains invalid JSON"}
}
}
func abortWithProjection(c *gin.Context, projection shared.PublicErrorProjection) {
c.Set(publicErrorCodeContextKey, projection.Code)
c.AbortWithStatusJSON(projection.StatusCode, errorResponse{
Error: errorBody{
Code: projection.Code,
Message: projection.Message,
},
})
}
@@ -0,0 +1,86 @@
package publichttp
import (
"time"
authlogging "galaxy/authsession/internal/logging"
"galaxy/authsession/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
type edgeOutcome string
const (
edgeOutcomeSuccess edgeOutcome = "success"
edgeOutcomeRejected edgeOutcome = "rejected"
edgeOutcomeFailed edgeOutcome = "failed"
)
func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
if logger == nil {
logger = zap.NewNop()
}
return func(c *gin.Context) {
start := time.Now()
c.Next()
statusCode := c.Writer.Status()
route := c.FullPath()
if route == "" {
route = "unmatched"
}
errorCode, _ := c.Get(publicErrorCodeContextKey)
errorCodeValue, _ := errorCode.(string)
outcome := outcomeFromStatusCode(statusCode)
duration := time.Since(start)
fields := []zap.Field{
zap.String("component", "public_http"),
zap.String("transport", "http"),
zap.String("route", route),
zap.String("method", c.Request.Method),
zap.Int("status_code", statusCode),
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
zap.String("edge_outcome", string(outcome)),
}
if errorCodeValue != "" {
fields = append(fields, zap.String("error_code", errorCodeValue))
}
fields = append(fields, authlogging.TraceFieldsFromContext(c.Request.Context())...)
metricAttrs := []attribute.KeyValue{
attribute.String("route", route),
attribute.String("method", c.Request.Method),
attribute.String("edge_outcome", string(outcome)),
}
if errorCodeValue != "" {
metricAttrs = append(metricAttrs, attribute.String("error_code", errorCodeValue))
}
metrics.RecordPublicHTTPRequest(c.Request.Context(), metricAttrs, duration)
switch outcome {
case edgeOutcomeSuccess:
logger.Info("public request completed", fields...)
case edgeOutcomeFailed:
logger.Error("public request failed", fields...)
default:
logger.Warn("public request rejected", fields...)
}
}
}
func outcomeFromStatusCode(statusCode int) edgeOutcome {
switch {
case statusCode >= 500:
return edgeOutcomeFailed
case statusCode >= 400:
return edgeOutcomeRejected
default:
return edgeOutcomeSuccess
}
}
@@ -0,0 +1,114 @@
package publichttp
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"testing"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
authtelemetry "galaxy/authsession/internal/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
)
func TestPublicHandlerEmitsTraceFieldsAndMetrics(t *testing.T) {
t.Parallel()
logger, buffer := newObservedLogger()
telemetryRuntime, reader, recorder := newObservedPublicTelemetryRuntime(t)
handler := mustNewHandler(t, DefaultConfig(), Dependencies{
Logger: logger,
Telemetry: telemetryRuntime,
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{ChallengeID: "challenge-123"}, nil
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, nil
}),
})
recorderHTTP := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/send-email-code",
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
)
request.Header.Set("Content-Type", "application/json")
handler.ServeHTTP(recorderHTTP, request)
require.Equal(t, http.StatusOK, recorderHTTP.Code)
require.NotEmpty(t, recorder.Ended())
assert.Contains(t, buffer.String(), "otel_trace_id")
assert.Contains(t, buffer.String(), "otel_span_id")
assertMetricCount(t, reader, "authsession.public_http.requests", map[string]string{
"route": "/api/v1/public/auth/send-email-code",
"method": http.MethodPost,
"edge_outcome": "success",
}, 1)
}
func newObservedPublicTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader, *tracetest.SpanRecorder) {
t.Helper()
reader := sdkmetric.NewManualReader()
meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
recorder := tracetest.NewSpanRecorder()
tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
runtime, err := authtelemetry.NewWithProviders(meterProvider, tracerProvider)
require.NoError(t, err)
return runtime, reader, recorder
}
func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) {
t.Helper()
var resourceMetrics metricdata.ResourceMetrics
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
if metric.Name != metricName {
continue
}
sum, ok := metric.Data.(metricdata.Sum[int64])
require.True(t, ok)
for _, point := range sum.DataPoints {
if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
assert.Equal(t, wantValue, point.Value)
return
}
}
}
}
require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs)
}
func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
if len(values) != len(want) {
return false
}
for _, value := range values {
if want[string(value.Key)] != value.Value.AsString() {
return false
}
}
return true
}
@@ -0,0 +1,228 @@
package publichttp
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
const (
defaultAddr = ":8080"
defaultReadHeaderTimeout = 2 * time.Second
defaultReadTimeout = 10 * time.Second
defaultIdleTimeout = time.Minute
defaultRequestTimeout = 3 * time.Second
)
// SendEmailCodeUseCase describes the public send-email-code application
// service consumed by the HTTP transport layer.
type SendEmailCodeUseCase interface {
// Execute validates input and creates a new login challenge.
Execute(ctx context.Context, input sendemailcode.Input) (sendemailcode.Result, error)
}
// ConfirmEmailCodeUseCase describes the public confirm-email-code application
// service consumed by the HTTP transport layer.
type ConfirmEmailCodeUseCase interface {
// Execute validates input and completes an existing login challenge.
Execute(ctx context.Context, input confirmemailcode.Input) (confirmemailcode.Result, error)
}
// Config describes the public HTTP listener owned by authsession.
type Config struct {
// Addr is the TCP listen address used by the public HTTP server.
Addr string
// ReadHeaderTimeout bounds how long the listener may spend reading request
// headers before the server rejects the connection.
ReadHeaderTimeout time.Duration
// ReadTimeout bounds how long the listener may spend reading one public
// request.
ReadTimeout time.Duration
// IdleTimeout bounds how long the listener keeps an idle keep-alive
// connection open.
IdleTimeout time.Duration
// RequestTimeout bounds one application-layer public-auth use-case call.
RequestTimeout time.Duration
}
// Validate reports whether cfg contains a usable public HTTP listener
// configuration.
func (cfg Config) Validate() error {
switch {
case cfg.Addr == "":
return errors.New("public HTTP addr must not be empty")
case cfg.ReadHeaderTimeout <= 0:
return errors.New("public HTTP read header timeout must be positive")
case cfg.ReadTimeout <= 0:
return errors.New("public HTTP read timeout must be positive")
case cfg.IdleTimeout <= 0:
return errors.New("public HTTP idle timeout must be positive")
case cfg.RequestTimeout <= 0:
return errors.New("public HTTP request timeout must be positive")
default:
return nil
}
}
// DefaultConfig returns the default public HTTP listener settings aligned with
// the gateway public-auth transport timeouts.
func DefaultConfig() Config {
return Config{
Addr: defaultAddr,
ReadHeaderTimeout: defaultReadHeaderTimeout,
ReadTimeout: defaultReadTimeout,
IdleTimeout: defaultIdleTimeout,
RequestTimeout: defaultRequestTimeout,
}
}
// Dependencies describes the collaborators used by the public HTTP transport
// layer.
type Dependencies struct {
// SendEmailCode executes the public send-email-code use case.
SendEmailCode SendEmailCodeUseCase
// ConfirmEmailCode executes the public confirm-email-code use case.
ConfirmEmailCode ConfirmEmailCodeUseCase
// Logger writes structured transport logs. When nil, a no-op logger is
// used.
Logger *zap.Logger
// Telemetry records OpenTelemetry spans and low-cardinality HTTP metrics.
// When nil, the transport still serves requests with no-op providers.
Telemetry *telemetry.Runtime
}
// Server owns the public auth HTTP listener exposed by authsession.
type Server struct {
cfg Config
handler http.Handler
logger *zap.Logger
stateMu sync.RWMutex
server *http.Server
listener net.Listener
}
// NewServer constructs one public auth HTTP server for cfg and deps.
func NewServer(cfg Config, deps Dependencies) (*Server, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("new public HTTP server: %w", err)
}
handler, err := newHandlerWithConfig(cfg, deps)
if err != nil {
return nil, fmt.Errorf("new public HTTP server: %w", err)
}
logger := deps.Logger
if logger == nil {
logger = zap.NewNop()
}
logger = logger.Named("public_http")
return &Server{
cfg: cfg,
handler: handler,
logger: logger,
}, nil
}
// Run binds the configured listener and serves the public auth HTTP surface
// until Shutdown closes the server.
func (s *Server) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run public HTTP server: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
listener, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("run public HTTP server: listen on %q: %w", s.cfg.Addr, err)
}
server := &http.Server{
Handler: s.handler,
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
ReadTimeout: s.cfg.ReadTimeout,
IdleTimeout: s.cfg.IdleTimeout,
}
s.stateMu.Lock()
s.server = server
s.listener = listener
s.stateMu.Unlock()
s.logger.Info("public HTTP server started", zap.String("addr", listener.Addr().String()))
defer func() {
s.stateMu.Lock()
s.server = nil
s.listener = nil
s.stateMu.Unlock()
}()
err = server.Serve(listener)
switch {
case err == nil:
return nil
case errors.Is(err, http.ErrServerClosed):
s.logger.Info("public HTTP server stopped")
return nil
default:
return fmt.Errorf("run public HTTP server: serve on %q: %w", s.cfg.Addr, err)
}
}
// Shutdown gracefully stops the public HTTP server within ctx.
func (s *Server) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown public HTTP server: nil context")
}
s.stateMu.RLock()
server := s.server
s.stateMu.RUnlock()
if server == nil {
return nil
}
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("shutdown public HTTP server: %w", err)
}
return nil
}
func normalizeDependencies(deps Dependencies) (Dependencies, error) {
switch {
case deps.SendEmailCode == nil:
return Dependencies{}, errors.New("send email code use case must not be nil")
case deps.ConfirmEmailCode == nil:
return Dependencies{}, errors.New("confirm email code use case must not be nil")
case deps.Logger == nil:
deps.Logger = zap.NewNop()
}
deps.Logger = deps.Logger.Named("public_http")
return deps, nil
}
@@ -0,0 +1,81 @@
package publichttp
import (
"bytes"
"context"
"net/http"
"testing"
"time"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewServerRejectsInvalidConfiguration(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Addr = ""
_, err := NewServer(cfg, Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{}, nil
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{}, nil
}),
})
require.Error(t, err)
assert.Contains(t, err.Error(), "addr")
}
func TestServerRunAndShutdown(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Addr = "127.0.0.1:0"
server, err := NewServer(cfg, Dependencies{
SendEmailCode: sendEmailCodeFunc(func(context.Context, sendemailcode.Input) (sendemailcode.Result, error) {
return sendemailcode.Result{ChallengeID: "challenge-123"}, nil
}),
ConfirmEmailCode: confirmEmailCodeFunc(func(context.Context, confirmemailcode.Input) (confirmemailcode.Result, error) {
return confirmemailcode.Result{DeviceSessionID: "device-session-123"}, nil
}),
})
require.NoError(t, err)
runErr := make(chan error, 1)
go func() {
runErr <- server.Run(context.Background())
}()
require.Eventually(t, func() bool {
server.stateMu.RLock()
defer server.stateMu.RUnlock()
return server.listener != nil
}, time.Second, 10*time.Millisecond)
server.stateMu.RLock()
addr := server.listener.Addr().String()
server.stateMu.RUnlock()
response, err := http.Post(
"http://"+addr+"/api/v1/public/auth/send-email-code",
"application/json",
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
)
require.NoError(t, err)
defer response.Body.Close()
assert.Equal(t, http.StatusOK, response.StatusCode)
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.NoError(t, server.Shutdown(shutdownCtx))
require.NoError(t, <-runErr)
}
+168
View File
@@ -0,0 +1,168 @@
// Package app wires the authsession process lifecycle and coordinates
// component startup and graceful shutdown.
package app
import (
"context"
"errors"
"fmt"
"sync"
"galaxy/authsession/internal/config"
)
// Component is a long-lived authsession subsystem that participates in
// coordinated startup and graceful shutdown.
type Component interface {
// Run starts the component and blocks until it stops.
Run(context.Context) error
// Shutdown stops the component within the provided timeout-bounded context.
Shutdown(context.Context) error
}
// App owns the process-level lifecycle of authsession and its registered
// components.
type App struct {
cfg config.Config
components []Component
}
// New constructs an App with a defensive copy of the supplied components.
func New(cfg config.Config, components ...Component) *App {
clonedComponents := append([]Component(nil), components...)
return &App{
cfg: cfg,
components: clonedComponents,
}
}
// Run starts all configured components, waits for cancellation or the first
// component failure, and then executes best-effort graceful shutdown.
func (a *App) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run authsession app: nil context")
}
if err := a.validate(); err != nil {
return err
}
if len(a.components) == 0 {
<-ctx.Done()
return nil
}
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
results := make(chan componentResult, len(a.components))
var runWG sync.WaitGroup
for idx, component := range a.components {
runWG.Add(1)
go func(index int, component Component) {
defer runWG.Done()
results <- componentResult{
index: index,
err: component.Run(runCtx),
}
}(idx, component)
}
var runErr error
select {
case <-ctx.Done():
case result := <-results:
runErr = classifyComponentResult(ctx, result)
}
cancel()
shutdownErr := a.shutdownComponents()
waitErr := a.waitForComponents(&runWG)
return errors.Join(runErr, shutdownErr, waitErr)
}
type componentResult struct {
index int
err error
}
func (a *App) validate() error {
if a.cfg.ShutdownTimeout <= 0 {
return fmt.Errorf("run authsession app: shutdown timeout must be positive, got %s", a.cfg.ShutdownTimeout)
}
for idx, component := range a.components {
if component == nil {
return fmt.Errorf("run authsession app: component %d is nil", idx)
}
}
return nil
}
func classifyComponentResult(parentCtx context.Context, result componentResult) error {
switch {
case result.err == nil:
if parentCtx.Err() != nil {
return nil
}
return fmt.Errorf("run authsession app: component %d exited without error before shutdown", result.index)
case errors.Is(result.err, context.Canceled) && parentCtx.Err() != nil:
return nil
default:
return fmt.Errorf("run authsession app: component %d: %w", result.index, result.err)
}
}
func (a *App) shutdownComponents() error {
var shutdownWG sync.WaitGroup
errs := make(chan error, len(a.components))
for idx, component := range a.components {
shutdownWG.Add(1)
go func(index int, component Component) {
defer shutdownWG.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
defer cancel()
if err := component.Shutdown(shutdownCtx); err != nil {
errs <- fmt.Errorf("shutdown authsession component %d: %w", index, err)
}
}(idx, component)
}
shutdownWG.Wait()
close(errs)
var joined error
for err := range errs {
joined = errors.Join(joined, err)
}
return joined
}
func (a *App) waitForComponents(runWG *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
runWG.Wait()
close(done)
}()
waitCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
defer cancel()
select {
case <-done:
return nil
case <-waitCtx.Done():
return fmt.Errorf("wait for authsession components: %w", waitCtx.Err())
}
}
+284
View File
@@ -0,0 +1,284 @@
package app
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/adapters/local"
"galaxy/authsession/internal/adapters/mail"
"galaxy/authsession/internal/adapters/redis/challengestore"
"galaxy/authsession/internal/adapters/redis/configprovider"
"galaxy/authsession/internal/adapters/redis/projectionpublisher"
"galaxy/authsession/internal/adapters/redis/sendemailcodeabuse"
"galaxy/authsession/internal/adapters/redis/sessionstore"
"galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/api/internalhttp"
"galaxy/authsession/internal/api/publichttp"
"galaxy/authsession/internal/config"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/blockuser"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/getsession"
"galaxy/authsession/internal/service/listusersessions"
"galaxy/authsession/internal/service/revokeallusersessions"
"galaxy/authsession/internal/service/revokedevicesession"
"galaxy/authsession/internal/service/sendemailcode"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
type pinger interface {
Ping(context.Context) error
}
type closer interface {
Close() error
}
// Runtime owns the runnable authsession application plus the adapter cleanup
// functions that must run after the process stops.
type Runtime struct {
// App coordinates the long-lived HTTP listeners.
App *App
cleanupFns []func() error
}
// NewRuntime constructs the runnable authsession process from cfg using the
// Stage 18 Redis adapters, local runtime helpers, and the selectable mail and
// user-service runtime adapters from Stages 20 and 21.
func NewRuntime(ctx context.Context, cfg config.Config, logger *zap.Logger, telemetryRuntime *telemetry.Runtime) (*Runtime, error) {
if ctx == nil {
return nil, errors.New("new authsession runtime: nil context")
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("new authsession runtime: %w", err)
}
if logger == nil {
logger = zap.NewNop()
}
runtime := &Runtime{}
cleanupOnError := func(err error) (*Runtime, error) {
return nil, errors.Join(err, runtime.Close())
}
challengeStore, err := challengestore.New(challengestore.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
KeyPrefix: cfg.Redis.ChallengeKeyPrefix,
OperationTimeout: cfg.Redis.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: challenge store: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, challengeStore.Close)
sessionStore, err := sessionstore.New(sessionstore.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
SessionKeyPrefix: cfg.Redis.SessionKeyPrefix,
UserSessionsKeyPrefix: cfg.Redis.UserSessionsKeyPrefix,
UserActiveSessionsKeyPrefix: cfg.Redis.UserActiveSessionsKeyPrefix,
OperationTimeout: cfg.Redis.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: session store: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, sessionStore.Close)
configStore, err := configprovider.New(configprovider.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
SessionLimitKey: cfg.Redis.SessionLimitKey,
OperationTimeout: cfg.Redis.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: config provider: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, configStore.Close)
publisher, err := projectionpublisher.New(projectionpublisher.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
SessionCacheKeyPrefix: cfg.Redis.GatewaySessionCacheKeyPrefix,
SessionEventsStream: cfg.Redis.GatewaySessionEventsStream,
StreamMaxLen: cfg.Redis.GatewaySessionEventsStreamMaxLen,
OperationTimeout: cfg.Redis.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: projection publisher: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, publisher.Close)
abuseProtector, err := sendemailcodeabuse.New(sendemailcodeabuse.Config{
Addr: cfg.Redis.Addr,
Username: cfg.Redis.Username,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
TLSEnabled: cfg.Redis.TLSEnabled,
KeyPrefix: cfg.Redis.SendEmailCodeThrottleKeyPrefix,
OperationTimeout: cfg.Redis.OperationTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: send email code abuse protector: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, abuseProtector.Close)
for name, dependency := range map[string]pinger{
"challenge store": challengeStore,
"session store": sessionStore,
"config provider": configStore,
"projection publisher": publisher,
"send email code abuse protector": abuseProtector,
} {
if err := dependency.Ping(ctx); err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: ping %s: %w", name, err))
}
}
clock := local.Clock{}
idGenerator := local.IDGenerator{}
codeGenerator := local.CodeGenerator{}
codeHasher := local.CodeHasher{}
var mailSender ports.MailSender
switch cfg.MailService.Mode {
case "stub":
mailSender = &mail.StubSender{}
case "rest":
restClient, err := mail.NewRESTClient(mail.Config{
BaseURL: cfg.MailService.BaseURL,
RequestTimeout: cfg.MailService.RequestTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: mail service REST client: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, restClient.Close)
mailSender = restClient
default:
return cleanupOnError(fmt.Errorf("new authsession runtime: unsupported mail service mode %q", cfg.MailService.Mode))
}
var userDirectory ports.UserDirectory
switch cfg.UserService.Mode {
case "stub":
userDirectory = &userservice.StubDirectory{}
case "rest":
restClient, err := userservice.NewRESTClient(userservice.Config{
BaseURL: cfg.UserService.BaseURL,
RequestTimeout: cfg.UserService.RequestTimeout,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: user service REST client: %w", err))
}
runtime.cleanupFns = append(runtime.cleanupFns, restClient.Close)
userDirectory = restClient
default:
return cleanupOnError(fmt.Errorf("new authsession runtime: unsupported user service mode %q", cfg.UserService.Mode))
}
sendEmailCodeService, err := sendemailcode.NewWithObservability(
challengeStore,
userDirectory,
idGenerator,
codeGenerator,
codeHasher,
mailSender,
abuseProtector,
clock,
logger,
telemetryRuntime,
)
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: send email code service: %w", err))
}
confirmEmailCodeService, err := confirmemailcode.NewWithObservability(
challengeStore,
sessionStore,
userDirectory,
configStore,
publisher,
idGenerator,
codeHasher,
clock,
logger,
telemetryRuntime,
)
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: confirm email code service: %w", err))
}
getSessionService, err := getsession.New(sessionStore)
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: get session service: %w", err))
}
listUserSessionsService, err := listusersessions.New(sessionStore)
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: list user sessions service: %w", err))
}
revokeDeviceSessionService, err := revokedevicesession.NewWithObservability(sessionStore, publisher, clock, logger, telemetryRuntime)
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: revoke device session service: %w", err))
}
revokeAllUserSessionsService, err := revokeallusersessions.NewWithObservability(sessionStore, userDirectory, publisher, clock, logger, telemetryRuntime)
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: revoke all user sessions service: %w", err))
}
blockUserService, err := blockuser.NewWithObservability(userDirectory, sessionStore, publisher, clock, logger, telemetryRuntime)
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: block user service: %w", err))
}
publicServer, err := publichttp.NewServer(cfg.PublicHTTP, publichttp.Dependencies{
SendEmailCode: sendEmailCodeService,
ConfirmEmailCode: confirmEmailCodeService,
Logger: logger,
Telemetry: telemetryRuntime,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: public HTTP server: %w", err))
}
internalServer, err := internalhttp.NewServer(cfg.InternalHTTP, internalhttp.Dependencies{
GetSession: getSessionService,
ListUserSessions: listUserSessionsService,
RevokeDeviceSession: revokeDeviceSessionService,
RevokeAllUserSessions: revokeAllUserSessionsService,
BlockUser: blockUserService,
Logger: logger,
Telemetry: telemetryRuntime,
})
if err != nil {
return cleanupOnError(fmt.Errorf("new authsession runtime: internal HTTP server: %w", err))
}
runtime.App = New(cfg, publicServer, internalServer)
return runtime, nil
}
// Close releases the runtime-managed adapter resources. Close is idempotent in
// practice because every underlying adapter Close method is idempotent.
func (r *Runtime) Close() error {
if r == nil {
return nil
}
var joined error
for index := len(r.cleanupFns) - 1; index >= 0; index-- {
joined = errors.Join(joined, r.cleanupFns[index]())
}
return joined
}
+212
View File
@@ -0,0 +1,212 @@
package app
import (
"bytes"
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"galaxy/authsession/internal/config"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestNewRuntimeStartsAndStopsHTTPServers(t *testing.T) {
t.Parallel()
redisServer := miniredis.RunT(t)
cfg := config.DefaultConfig()
cfg.Redis.Addr = redisServer.Addr()
cfg.PublicHTTP.Addr = mustFreeAddr(t)
cfg.InternalHTTP.Addr = mustFreeAddr(t)
runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, runtime.Close())
}()
runCtx, cancel := context.WithCancel(context.Background())
defer cancel()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- runtime.App.Run(runCtx)
}()
require.Eventually(t, func() bool {
response, err := http.Post(
"http://"+cfg.PublicHTTP.Addr+"/api/v1/public/auth/send-email-code",
"application/json",
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
)
if err != nil {
return false
}
defer response.Body.Close()
_, _ = io.ReadAll(response.Body)
return response.StatusCode == http.StatusOK
}, 5*time.Second, 25*time.Millisecond)
require.Eventually(t, func() bool {
response, err := http.Get("http://" + cfg.InternalHTTP.Addr + "/api/v1/internal/sessions/missing")
if err != nil {
return false
}
defer response.Body.Close()
_, _ = io.ReadAll(response.Body)
return response.StatusCode == http.StatusNotFound
}, 5*time.Second, 25*time.Millisecond)
cancel()
require.NoError(t, <-runErrCh)
}
func TestNewRuntimeUsesRESTUserDirectoryWhenConfigured(t *testing.T) {
t.Parallel()
redisServer := miniredis.RunT(t)
userService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && r.URL.Path == "/api/v1/internal/users/user-1/exists" {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"exists":true}`)
return
}
http.NotFound(w, r)
}))
defer userService.Close()
cfg := config.DefaultConfig()
cfg.Redis.Addr = redisServer.Addr()
cfg.PublicHTTP.Addr = mustFreeAddr(t)
cfg.InternalHTTP.Addr = mustFreeAddr(t)
cfg.UserService.Mode = "rest"
cfg.UserService.BaseURL = userService.URL
cfg.UserService.RequestTimeout = 250 * time.Millisecond
runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, runtime.Close())
}()
runCtx, cancel := context.WithCancel(context.Background())
defer cancel()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- runtime.App.Run(runCtx)
}()
require.Eventually(t, func() bool {
response, err := http.Post(
"http://"+cfg.InternalHTTP.Addr+"/api/v1/internal/users/user-1/sessions/revoke-all",
"application/json",
bytes.NewBufferString(`{"reason_code":"logout_all","actor":{"type":"system"}}`),
)
if err != nil {
return false
}
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
if err != nil {
return false
}
return response.StatusCode == http.StatusOK &&
bytes.Contains(payload, []byte(`"outcome":"no_active_sessions"`)) &&
bytes.Contains(payload, []byte(`"user_id":"user-1"`))
}, 5*time.Second, 25*time.Millisecond)
cancel()
require.NoError(t, <-runErrCh)
}
func TestNewRuntimeUsesRESTMailSenderWhenConfigured(t *testing.T) {
t.Parallel()
redisServer := miniredis.RunT(t)
var calls atomic.Int64
mailService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && r.URL.Path == "/api/v1/internal/login-code-deliveries" {
calls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"outcome":"suppressed"}`)
return
}
http.NotFound(w, r)
}))
defer mailService.Close()
cfg := config.DefaultConfig()
cfg.Redis.Addr = redisServer.Addr()
cfg.PublicHTTP.Addr = mustFreeAddr(t)
cfg.InternalHTTP.Addr = mustFreeAddr(t)
cfg.MailService.Mode = "rest"
cfg.MailService.BaseURL = mailService.URL
cfg.MailService.RequestTimeout = 250 * time.Millisecond
runtime, err := NewRuntime(context.Background(), cfg, zap.NewNop(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, runtime.Close())
}()
runCtx, cancel := context.WithCancel(context.Background())
defer cancel()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- runtime.App.Run(runCtx)
}()
require.Eventually(t, func() bool {
response, err := http.Post(
"http://"+cfg.PublicHTTP.Addr+"/api/v1/public/auth/send-email-code",
"application/json",
bytes.NewBufferString(`{"email":"pilot@example.com"}`),
)
if err != nil {
return false
}
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
if err != nil {
return false
}
return response.StatusCode == http.StatusOK &&
bytes.Contains(payload, []byte(`"challenge_id":"`)) &&
calls.Load() == 1
}, 5*time.Second, 25*time.Millisecond)
cancel()
require.NoError(t, <-runErrCh)
}
func mustFreeAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
assert.NoError(t, listener.Close())
}()
return listener.Addr().String()
}
+610
View File
@@ -0,0 +1,610 @@
// Package config loads the authsession process configuration from environment
// variables.
package config
import (
"fmt"
"os"
"strconv"
"strings"
"time"
"galaxy/authsession/internal/api/internalhttp"
"galaxy/authsession/internal/api/publichttp"
"go.uber.org/zap/zapcore"
)
const (
shutdownTimeoutEnvVar = "AUTHSESSION_SHUTDOWN_TIMEOUT"
logLevelEnvVar = "AUTHSESSION_LOG_LEVEL"
publicHTTPAddrEnvVar = "AUTHSESSION_PUBLIC_HTTP_ADDR"
publicHTTPReadHeaderTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_READ_HEADER_TIMEOUT"
publicHTTPReadTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_READ_TIMEOUT"
publicHTTPIdleTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_IDLE_TIMEOUT"
publicHTTPRequestTimeoutEnvVar = "AUTHSESSION_PUBLIC_HTTP_REQUEST_TIMEOUT"
internalHTTPAddrEnvVar = "AUTHSESSION_INTERNAL_HTTP_ADDR"
internalHTTPReadHeaderTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_READ_HEADER_TIMEOUT"
internalHTTPReadTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_READ_TIMEOUT"
internalHTTPIdleTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_IDLE_TIMEOUT"
internalHTTPRequestTimeoutEnvVar = "AUTHSESSION_INTERNAL_HTTP_REQUEST_TIMEOUT"
redisAddrEnvVar = "AUTHSESSION_REDIS_ADDR"
redisUsernameEnvVar = "AUTHSESSION_REDIS_USERNAME"
redisPasswordEnvVar = "AUTHSESSION_REDIS_PASSWORD"
redisDBEnvVar = "AUTHSESSION_REDIS_DB"
redisTLSEnabledEnvVar = "AUTHSESSION_REDIS_TLS_ENABLED"
redisOperationTimeoutEnvVar = "AUTHSESSION_REDIS_OPERATION_TIMEOUT"
redisChallengeKeyPrefixEnvVar = "AUTHSESSION_REDIS_CHALLENGE_KEY_PREFIX"
redisSessionKeyPrefixEnvVar = "AUTHSESSION_REDIS_SESSION_KEY_PREFIX"
redisUserSessionsKeyPrefixEnvVar = "AUTHSESSION_REDIS_USER_SESSIONS_KEY_PREFIX"
redisUserActiveSessionsKeyPrefixEnvVar = "AUTHSESSION_REDIS_USER_ACTIVE_SESSIONS_KEY_PREFIX"
redisSessionLimitKeyEnvVar = "AUTHSESSION_REDIS_SESSION_LIMIT_KEY"
redisGatewaySessionCacheKeyPrefixEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_CACHE_KEY_PREFIX"
redisGatewaySessionEventsStreamEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM"
redisGatewaySessionEventsStreamMaxLenEnvVar = "AUTHSESSION_REDIS_GATEWAY_SESSION_EVENTS_STREAM_MAX_LEN"
redisSendEmailCodeThrottleKeyPrefixEnvVar = "AUTHSESSION_REDIS_SEND_EMAIL_CODE_THROTTLE_KEY_PREFIX"
userServiceModeEnvVar = "AUTHSESSION_USER_SERVICE_MODE"
userServiceBaseURLEnvVar = "AUTHSESSION_USER_SERVICE_BASE_URL"
userServiceRequestTimeoutEnvVar = "AUTHSESSION_USER_SERVICE_REQUEST_TIMEOUT"
mailServiceModeEnvVar = "AUTHSESSION_MAIL_SERVICE_MODE"
mailServiceBaseURLEnvVar = "AUTHSESSION_MAIL_SERVICE_BASE_URL"
mailServiceRequestTimeoutEnvVar = "AUTHSESSION_MAIL_SERVICE_REQUEST_TIMEOUT"
otelServiceNameEnvVar = "OTEL_SERVICE_NAME"
otelTracesExporterEnvVar = "OTEL_TRACES_EXPORTER"
otelMetricsExporterEnvVar = "OTEL_METRICS_EXPORTER"
otelExporterOTLPProtocolEnvVar = "OTEL_EXPORTER_OTLP_PROTOCOL"
otelExporterOTLPTracesProtocolEnvVar = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"
otelExporterOTLPMetricsProtocolEnvVar = "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL"
otelStdoutTracesEnabledEnvVar = "AUTHSESSION_OTEL_STDOUT_TRACES_ENABLED"
otelStdoutMetricsEnabledEnvVar = "AUTHSESSION_OTEL_STDOUT_METRICS_ENABLED"
defaultShutdownTimeout = 5 * time.Second
defaultLogLevel = "info"
defaultRedisDB = 0
defaultRedisOperationTimeout = 250 * time.Millisecond
defaultChallengeKeyPrefix = "authsession:challenge:"
defaultSessionKeyPrefix = "authsession:session:"
defaultUserSessionsKeyPrefix = "authsession:user-sessions:"
defaultUserActiveSessionsKeyPrefix = "authsession:user-active-sessions:"
defaultSessionLimitKey = "authsession:config:active-session-limit"
defaultGatewaySessionCacheKeyPrefix = "gateway:session:"
defaultGatewaySessionEventsStream = "gateway:session_events"
defaultGatewaySessionEventsStreamMaxLen = 1024
defaultSendEmailCodeThrottleKeyPrefix = "authsession:send-email-code-throttle:"
defaultUserServiceMode = userServiceModeStub
defaultUserServiceRequestTimeout = time.Second
defaultMailServiceMode = mailServiceModeStub
defaultMailServiceRequestTimeout = time.Second
defaultOTelServiceName = "galaxy-authsession"
otelExporterNone = "none"
otelExporterOTLP = "otlp"
otelProtocolHTTPProtobuf = "http/protobuf"
otelProtocolGRPC = "grpc"
userServiceModeStub = "stub"
userServiceModeREST = "rest"
mailServiceModeStub = "stub"
mailServiceModeREST = "rest"
)
// Config stores the full process-level authsession configuration.
type Config struct {
// ShutdownTimeout bounds graceful shutdown of every long-lived component.
ShutdownTimeout time.Duration
// Logging configures the process-wide structured logger.
Logging LoggingConfig
// PublicHTTP configures the public HTTP listener.
PublicHTTP publichttp.Config
// InternalHTTP configures the trusted internal HTTP listener.
InternalHTTP internalhttp.Config
// Redis configures the Redis-backed adapters.
Redis RedisConfig
// UserService configures the selectable runtime user-directory adapter.
UserService UserServiceConfig
// MailService configures the selectable runtime mail-delivery adapter.
MailService MailServiceConfig
// Telemetry configures the process-wide OpenTelemetry runtime.
Telemetry TelemetryConfig
}
// LoggingConfig configures the process-wide structured logger.
type LoggingConfig struct {
// Level stores the zap-compatible log level string.
Level string
}
// RedisConfig configures the Redis-backed authsession adapters.
type RedisConfig struct {
// Addr is the shared Redis address used by the authsession adapters.
Addr string
// Username is the optional Redis ACL username.
Username string
// Password is the optional Redis ACL password.
Password string
// DB is the Redis logical database index.
DB int
// TLSEnabled configures whether Redis connections use TLS.
TLSEnabled bool
// OperationTimeout bounds each adapter Redis round trip.
OperationTimeout time.Duration
// ChallengeKeyPrefix namespaces the challenge source-of-truth records.
ChallengeKeyPrefix string
// SessionKeyPrefix namespaces the primary session records.
SessionKeyPrefix string
// UserSessionsKeyPrefix namespaces the all-session user index.
UserSessionsKeyPrefix string
// UserActiveSessionsKeyPrefix namespaces the active-session user index.
UserActiveSessionsKeyPrefix string
// SessionLimitKey stores the exact session-limit Redis key.
SessionLimitKey string
// GatewaySessionCacheKeyPrefix namespaces the projected gateway session
// cache keys.
GatewaySessionCacheKeyPrefix string
// GatewaySessionEventsStream stores the projected gateway session-events
// Redis Stream key.
GatewaySessionEventsStream string
// GatewaySessionEventsStreamMaxLen bounds the projected gateway session
// event stream with approximate trimming.
GatewaySessionEventsStreamMaxLen int64
// SendEmailCodeThrottleKeyPrefix namespaces the resend-throttle TTL keys.
SendEmailCodeThrottleKeyPrefix string
}
// UserServiceConfig configures the runtime user-directory integration mode.
type UserServiceConfig struct {
// Mode selects the runtime adapter implementation. Supported values are
// `stub` and `rest`.
Mode string
// BaseURL is the absolute base URL of the REST-backed user-service when
// Mode is `rest`.
BaseURL string
// RequestTimeout bounds each outbound user-service request when Mode is
// `rest`.
RequestTimeout time.Duration
}
// MailServiceConfig configures the runtime mail-delivery integration mode.
type MailServiceConfig struct {
// Mode selects the runtime adapter implementation. Supported values are
// `stub` and `rest`.
Mode string
// BaseURL is the absolute base URL of the REST-backed mail service when
// Mode is `rest`.
BaseURL string
// RequestTimeout bounds each outbound mail-service request when Mode is
// `rest`.
RequestTimeout time.Duration
}
// TelemetryConfig configures the authsession OpenTelemetry runtime.
type TelemetryConfig struct {
// ServiceName overrides the default OpenTelemetry service name.
ServiceName string
// TracesExporter selects the external traces exporter. Supported values are
// `none` and `otlp`.
TracesExporter string
// MetricsExporter selects the external metrics exporter. Supported values
// are `none` and `otlp`.
MetricsExporter string
// TracesProtocol selects the OTLP traces protocol when TracesExporter is
// `otlp`.
TracesProtocol string
// MetricsProtocol selects the OTLP metrics protocol when MetricsExporter is
// `otlp`.
MetricsProtocol string
// StdoutTracesEnabled enables the additional stdout trace exporter used for
// local development and debugging.
StdoutTracesEnabled bool
// StdoutMetricsEnabled enables the additional stdout metric exporter used
// for local development and debugging.
StdoutMetricsEnabled bool
}
// DefaultConfig returns the default authsession process configuration with all
// optional values filled.
func DefaultConfig() Config {
return Config{
ShutdownTimeout: defaultShutdownTimeout,
Logging: LoggingConfig{
Level: defaultLogLevel,
},
PublicHTTP: publichttp.DefaultConfig(),
InternalHTTP: internalhttp.DefaultConfig(),
Redis: RedisConfig{
DB: defaultRedisDB,
OperationTimeout: defaultRedisOperationTimeout,
ChallengeKeyPrefix: defaultChallengeKeyPrefix,
SessionKeyPrefix: defaultSessionKeyPrefix,
UserSessionsKeyPrefix: defaultUserSessionsKeyPrefix,
UserActiveSessionsKeyPrefix: defaultUserActiveSessionsKeyPrefix,
SessionLimitKey: defaultSessionLimitKey,
GatewaySessionCacheKeyPrefix: defaultGatewaySessionCacheKeyPrefix,
GatewaySessionEventsStream: defaultGatewaySessionEventsStream,
GatewaySessionEventsStreamMaxLen: defaultGatewaySessionEventsStreamMaxLen,
SendEmailCodeThrottleKeyPrefix: defaultSendEmailCodeThrottleKeyPrefix,
},
UserService: UserServiceConfig{
Mode: defaultUserServiceMode,
RequestTimeout: defaultUserServiceRequestTimeout,
},
MailService: MailServiceConfig{
Mode: defaultMailServiceMode,
RequestTimeout: defaultMailServiceRequestTimeout,
},
Telemetry: TelemetryConfig{
ServiceName: defaultOTelServiceName,
TracesExporter: otelExporterNone,
MetricsExporter: otelExporterNone,
},
}
}
// LoadFromEnv loads the authsession process configuration from environment
// variables, applying documented defaults where appropriate.
func LoadFromEnv() (Config, error) {
cfg := DefaultConfig()
var err error
cfg.ShutdownTimeout, err = loadDurationEnvWithDefault(shutdownTimeoutEnvVar, cfg.ShutdownTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Logging.Level = loadStringEnvWithDefault(logLevelEnvVar, cfg.Logging.Level)
if err := validateLogLevel(cfg.Logging.Level); err != nil {
return Config{}, fmt.Errorf("load authsession config: %s: %w", logLevelEnvVar, err)
}
cfg.PublicHTTP.Addr = loadStringEnvWithDefault(publicHTTPAddrEnvVar, cfg.PublicHTTP.Addr)
cfg.PublicHTTP.ReadHeaderTimeout, err = loadDurationEnvWithDefault(publicHTTPReadHeaderTimeoutEnvVar, cfg.PublicHTTP.ReadHeaderTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.PublicHTTP.ReadTimeout, err = loadDurationEnvWithDefault(publicHTTPReadTimeoutEnvVar, cfg.PublicHTTP.ReadTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.PublicHTTP.IdleTimeout, err = loadDurationEnvWithDefault(publicHTTPIdleTimeoutEnvVar, cfg.PublicHTTP.IdleTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.PublicHTTP.RequestTimeout, err = loadDurationEnvWithDefault(publicHTTPRequestTimeoutEnvVar, cfg.PublicHTTP.RequestTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.InternalHTTP.Addr = loadStringEnvWithDefault(internalHTTPAddrEnvVar, cfg.InternalHTTP.Addr)
cfg.InternalHTTP.ReadHeaderTimeout, err = loadDurationEnvWithDefault(internalHTTPReadHeaderTimeoutEnvVar, cfg.InternalHTTP.ReadHeaderTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.InternalHTTP.ReadTimeout, err = loadDurationEnvWithDefault(internalHTTPReadTimeoutEnvVar, cfg.InternalHTTP.ReadTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.InternalHTTP.IdleTimeout, err = loadDurationEnvWithDefault(internalHTTPIdleTimeoutEnvVar, cfg.InternalHTTP.IdleTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.InternalHTTP.RequestTimeout, err = loadDurationEnvWithDefault(internalHTTPRequestTimeoutEnvVar, cfg.InternalHTTP.RequestTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Redis.Addr = loadStringEnvWithDefault(redisAddrEnvVar, cfg.Redis.Addr)
cfg.Redis.Username = os.Getenv(redisUsernameEnvVar)
cfg.Redis.Password = os.Getenv(redisPasswordEnvVar)
cfg.Redis.DB, err = loadIntEnvWithDefault(redisDBEnvVar, cfg.Redis.DB)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Redis.TLSEnabled, err = loadBoolEnvWithDefault(redisTLSEnabledEnvVar, cfg.Redis.TLSEnabled)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Redis.OperationTimeout, err = loadDurationEnvWithDefault(redisOperationTimeoutEnvVar, cfg.Redis.OperationTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Redis.ChallengeKeyPrefix = loadStringEnvWithDefault(redisChallengeKeyPrefixEnvVar, cfg.Redis.ChallengeKeyPrefix)
cfg.Redis.SessionKeyPrefix = loadStringEnvWithDefault(redisSessionKeyPrefixEnvVar, cfg.Redis.SessionKeyPrefix)
cfg.Redis.UserSessionsKeyPrefix = loadStringEnvWithDefault(redisUserSessionsKeyPrefixEnvVar, cfg.Redis.UserSessionsKeyPrefix)
cfg.Redis.UserActiveSessionsKeyPrefix = loadStringEnvWithDefault(redisUserActiveSessionsKeyPrefixEnvVar, cfg.Redis.UserActiveSessionsKeyPrefix)
cfg.Redis.SessionLimitKey = loadStringEnvWithDefault(redisSessionLimitKeyEnvVar, cfg.Redis.SessionLimitKey)
cfg.Redis.GatewaySessionCacheKeyPrefix = loadStringEnvWithDefault(redisGatewaySessionCacheKeyPrefixEnvVar, cfg.Redis.GatewaySessionCacheKeyPrefix)
cfg.Redis.GatewaySessionEventsStream = loadStringEnvWithDefault(redisGatewaySessionEventsStreamEnvVar, cfg.Redis.GatewaySessionEventsStream)
streamMaxLen, err := loadInt64EnvWithDefault(redisGatewaySessionEventsStreamMaxLenEnvVar, cfg.Redis.GatewaySessionEventsStreamMaxLen)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Redis.GatewaySessionEventsStreamMaxLen = streamMaxLen
cfg.Redis.SendEmailCodeThrottleKeyPrefix = loadStringEnvWithDefault(redisSendEmailCodeThrottleKeyPrefixEnvVar, cfg.Redis.SendEmailCodeThrottleKeyPrefix)
cfg.UserService.Mode = strings.TrimSpace(loadStringEnvWithDefault(userServiceModeEnvVar, cfg.UserService.Mode))
cfg.UserService.BaseURL = loadStringEnvWithDefault(userServiceBaseURLEnvVar, cfg.UserService.BaseURL)
cfg.UserService.RequestTimeout, err = loadDurationEnvWithDefault(userServiceRequestTimeoutEnvVar, cfg.UserService.RequestTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.MailService.Mode = strings.TrimSpace(loadStringEnvWithDefault(mailServiceModeEnvVar, cfg.MailService.Mode))
cfg.MailService.BaseURL = loadStringEnvWithDefault(mailServiceBaseURLEnvVar, cfg.MailService.BaseURL)
cfg.MailService.RequestTimeout, err = loadDurationEnvWithDefault(mailServiceRequestTimeoutEnvVar, cfg.MailService.RequestTimeout)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Telemetry.ServiceName = loadStringEnvWithDefault(otelServiceNameEnvVar, cfg.Telemetry.ServiceName)
cfg.Telemetry.TracesExporter = normalizeExporterValue(loadStringEnvWithDefault(otelTracesExporterEnvVar, cfg.Telemetry.TracesExporter))
cfg.Telemetry.MetricsExporter = normalizeExporterValue(loadStringEnvWithDefault(otelMetricsExporterEnvVar, cfg.Telemetry.MetricsExporter))
cfg.Telemetry.TracesProtocol = loadOTLPProtocol(
os.Getenv(otelExporterOTLPTracesProtocolEnvVar),
os.Getenv(otelExporterOTLPProtocolEnvVar),
cfg.Telemetry.TracesExporter,
)
cfg.Telemetry.MetricsProtocol = loadOTLPProtocol(
os.Getenv(otelExporterOTLPMetricsProtocolEnvVar),
os.Getenv(otelExporterOTLPProtocolEnvVar),
cfg.Telemetry.MetricsExporter,
)
cfg.Telemetry.StdoutTracesEnabled, err = loadBoolEnvWithDefault(otelStdoutTracesEnabledEnvVar, cfg.Telemetry.StdoutTracesEnabled)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
cfg.Telemetry.StdoutMetricsEnabled, err = loadBoolEnvWithDefault(otelStdoutMetricsEnabledEnvVar, cfg.Telemetry.StdoutMetricsEnabled)
if err != nil {
return Config{}, fmt.Errorf("load authsession config: %w", err)
}
if err := cfg.Validate(); err != nil {
return Config{}, err
}
return cfg, nil
}
// Validate reports whether cfg contains a consistent authsession process
// configuration.
func (cfg Config) Validate() error {
switch {
case cfg.ShutdownTimeout <= 0:
return fmt.Errorf("load authsession config: %s must be positive", shutdownTimeoutEnvVar)
case strings.TrimSpace(cfg.Redis.Addr) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisAddrEnvVar)
case cfg.Redis.DB < 0:
return fmt.Errorf("load authsession config: %s must not be negative", redisDBEnvVar)
case cfg.Redis.OperationTimeout <= 0:
return fmt.Errorf("load authsession config: %s must be positive", redisOperationTimeoutEnvVar)
case strings.TrimSpace(cfg.Redis.ChallengeKeyPrefix) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisChallengeKeyPrefixEnvVar)
case strings.TrimSpace(cfg.Redis.SessionKeyPrefix) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisSessionKeyPrefixEnvVar)
case strings.TrimSpace(cfg.Redis.UserSessionsKeyPrefix) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisUserSessionsKeyPrefixEnvVar)
case strings.TrimSpace(cfg.Redis.UserActiveSessionsKeyPrefix) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisUserActiveSessionsKeyPrefixEnvVar)
case strings.TrimSpace(cfg.Redis.SessionLimitKey) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisSessionLimitKeyEnvVar)
case strings.TrimSpace(cfg.Redis.GatewaySessionCacheKeyPrefix) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisGatewaySessionCacheKeyPrefixEnvVar)
case strings.TrimSpace(cfg.Redis.GatewaySessionEventsStream) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisGatewaySessionEventsStreamEnvVar)
case cfg.Redis.GatewaySessionEventsStreamMaxLen <= 0:
return fmt.Errorf("load authsession config: %s must be positive", redisGatewaySessionEventsStreamMaxLenEnvVar)
case strings.TrimSpace(cfg.Redis.SendEmailCodeThrottleKeyPrefix) == "":
return fmt.Errorf("load authsession config: %s must not be empty", redisSendEmailCodeThrottleKeyPrefixEnvVar)
}
if err := cfg.PublicHTTP.Validate(); err != nil {
return fmt.Errorf("load authsession config: public HTTP: %w", err)
}
if err := cfg.InternalHTTP.Validate(); err != nil {
return fmt.Errorf("load authsession config: internal HTTP: %w", err)
}
if err := cfg.UserService.Validate(); err != nil {
return fmt.Errorf("load authsession config: %w", err)
}
if err := cfg.MailService.Validate(); err != nil {
return fmt.Errorf("load authsession config: %w", err)
}
if err := cfg.Telemetry.Validate(); err != nil {
return fmt.Errorf("load authsession config: %w", err)
}
return nil
}
// Validate reports whether cfg contains a supported user-service runtime
// configuration.
func (cfg UserServiceConfig) Validate() error {
switch cfg.Mode {
case userServiceModeStub:
return nil
case userServiceModeREST:
if strings.TrimSpace(cfg.BaseURL) == "" {
return fmt.Errorf("%s must not be empty in rest mode", userServiceBaseURLEnvVar)
}
if cfg.RequestTimeout <= 0 {
return fmt.Errorf("%s must be positive in rest mode", userServiceRequestTimeoutEnvVar)
}
return nil
default:
return fmt.Errorf("%s %q is unsupported", userServiceModeEnvVar, cfg.Mode)
}
}
// Validate reports whether cfg contains a supported mail-service runtime
// configuration.
func (cfg MailServiceConfig) Validate() error {
switch cfg.Mode {
case mailServiceModeStub:
return nil
case mailServiceModeREST:
if strings.TrimSpace(cfg.BaseURL) == "" {
return fmt.Errorf("%s must not be empty in rest mode", mailServiceBaseURLEnvVar)
}
if cfg.RequestTimeout <= 0 {
return fmt.Errorf("%s must be positive in rest mode", mailServiceRequestTimeoutEnvVar)
}
return nil
default:
return fmt.Errorf("%s %q is unsupported", mailServiceModeEnvVar, cfg.Mode)
}
}
// Validate reports whether cfg contains a supported OpenTelemetry exporter
// configuration.
func (cfg TelemetryConfig) Validate() error {
switch cfg.TracesExporter {
case otelExporterNone, otelExporterOTLP:
default:
return fmt.Errorf("%s %q is unsupported", otelTracesExporterEnvVar, cfg.TracesExporter)
}
switch cfg.MetricsExporter {
case otelExporterNone, otelExporterOTLP:
default:
return fmt.Errorf("%s %q is unsupported", otelMetricsExporterEnvVar, cfg.MetricsExporter)
}
if cfg.TracesProtocol != "" && cfg.TracesProtocol != otelProtocolHTTPProtobuf && cfg.TracesProtocol != otelProtocolGRPC {
return fmt.Errorf("%s %q is unsupported", otelExporterOTLPTracesProtocolEnvVar, cfg.TracesProtocol)
}
if cfg.MetricsProtocol != "" && cfg.MetricsProtocol != otelProtocolHTTPProtobuf && cfg.MetricsProtocol != otelProtocolGRPC {
return fmt.Errorf("%s %q is unsupported", otelExporterOTLPMetricsProtocolEnvVar, cfg.MetricsProtocol)
}
return nil
}
func loadStringEnvWithDefault(name string, value string) string {
if raw, ok := os.LookupEnv(name); ok {
return strings.TrimSpace(raw)
}
return value
}
func loadDurationEnvWithDefault(name string, value time.Duration) (time.Duration, error) {
raw, ok := os.LookupEnv(name)
if !ok {
return value, nil
}
parsed, err := time.ParseDuration(strings.TrimSpace(raw))
if err != nil {
return 0, fmt.Errorf("%s: %w", name, err)
}
return parsed, nil
}
func loadIntEnvWithDefault(name string, value int) (int, error) {
raw, ok := os.LookupEnv(name)
if !ok {
return value, nil
}
parsed, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil {
return 0, fmt.Errorf("%s: %w", name, err)
}
return parsed, nil
}
func loadInt64EnvWithDefault(name string, value int64) (int64, error) {
raw, ok := os.LookupEnv(name)
if !ok {
return value, nil
}
parsed, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
if err != nil {
return 0, fmt.Errorf("%s: %w", name, err)
}
return parsed, nil
}
func loadBoolEnvWithDefault(name string, value bool) (bool, error) {
raw, ok := os.LookupEnv(name)
if !ok {
return value, nil
}
parsed, err := strconv.ParseBool(strings.TrimSpace(raw))
if err != nil {
return false, fmt.Errorf("%s: %w", name, err)
}
return parsed, nil
}
func validateLogLevel(value string) error {
var level zapcore.Level
if err := level.UnmarshalText([]byte(strings.TrimSpace(value))); err != nil {
return err
}
return nil
}
func normalizeExporterValue(value string) string {
switch strings.TrimSpace(value) {
case "", otelExporterNone:
return otelExporterNone
default:
return strings.TrimSpace(value)
}
}
func loadOTLPProtocol(primary string, fallback string, exporter string) string {
protocol := strings.TrimSpace(primary)
if protocol == "" {
protocol = strings.TrimSpace(fallback)
}
if protocol == "" && exporter == otelExporterOTLP {
return otelProtocolHTTPProtobuf
}
return protocol
}
+161
View File
@@ -0,0 +1,161 @@
package config
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoadFromEnvUsesDefaults(t *testing.T) {
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
cfg, err := LoadFromEnv()
require.NoError(t, err)
defaults := DefaultConfig()
assert.Equal(t, defaults.ShutdownTimeout, cfg.ShutdownTimeout)
assert.Equal(t, defaults.Logging.Level, cfg.Logging.Level)
assert.Equal(t, defaults.PublicHTTP, cfg.PublicHTTP)
assert.Equal(t, defaults.InternalHTTP, cfg.InternalHTTP)
assert.Equal(t, "127.0.0.1:6379", cfg.Redis.Addr)
assert.Equal(t, defaults.Redis.DB, cfg.Redis.DB)
assert.Equal(t, defaults.Redis.OperationTimeout, cfg.Redis.OperationTimeout)
assert.Equal(t, defaults.UserService, cfg.UserService)
assert.Equal(t, defaults.MailService, cfg.MailService)
assert.Equal(t, defaults.Telemetry.ServiceName, cfg.Telemetry.ServiceName)
assert.Equal(t, defaults.Telemetry.TracesExporter, cfg.Telemetry.TracesExporter)
assert.Equal(t, defaults.Telemetry.MetricsExporter, cfg.Telemetry.MetricsExporter)
assert.False(t, cfg.Telemetry.StdoutTracesEnabled)
assert.False(t, cfg.Telemetry.StdoutMetricsEnabled)
}
func TestLoadFromEnvAppliesOverrides(t *testing.T) {
t.Setenv(shutdownTimeoutEnvVar, "9s")
t.Setenv(logLevelEnvVar, "debug")
t.Setenv(publicHTTPAddrEnvVar, "127.0.0.1:18080")
t.Setenv(internalHTTPAddrEnvVar, "127.0.0.1:18081")
t.Setenv(redisAddrEnvVar, "127.0.0.1:6380")
t.Setenv(redisUsernameEnvVar, "alice")
t.Setenv(redisPasswordEnvVar, "secret")
t.Setenv(redisDBEnvVar, "3")
t.Setenv(redisTLSEnabledEnvVar, "true")
t.Setenv(redisOperationTimeoutEnvVar, "750ms")
t.Setenv(userServiceModeEnvVar, "rest")
t.Setenv(userServiceBaseURLEnvVar, "http://127.0.0.1:19090")
t.Setenv(userServiceRequestTimeoutEnvVar, "900ms")
t.Setenv(mailServiceModeEnvVar, "rest")
t.Setenv(mailServiceBaseURLEnvVar, "http://127.0.0.1:19091")
t.Setenv(mailServiceRequestTimeoutEnvVar, "950ms")
t.Setenv(otelServiceNameEnvVar, "custom-authsession")
t.Setenv(otelTracesExporterEnvVar, "otlp")
t.Setenv(otelMetricsExporterEnvVar, "otlp")
t.Setenv(otelExporterOTLPProtocolEnvVar, "grpc")
t.Setenv(otelStdoutTracesEnabledEnvVar, "true")
t.Setenv(otelStdoutMetricsEnabledEnvVar, "true")
cfg, err := LoadFromEnv()
require.NoError(t, err)
assert.Equal(t, 9*time.Second, cfg.ShutdownTimeout)
assert.Equal(t, "debug", cfg.Logging.Level)
assert.Equal(t, "127.0.0.1:18080", cfg.PublicHTTP.Addr)
assert.Equal(t, "127.0.0.1:18081", cfg.InternalHTTP.Addr)
assert.Equal(t, "127.0.0.1:6380", cfg.Redis.Addr)
assert.Equal(t, "alice", cfg.Redis.Username)
assert.Equal(t, "secret", cfg.Redis.Password)
assert.Equal(t, 3, cfg.Redis.DB)
assert.True(t, cfg.Redis.TLSEnabled)
assert.Equal(t, 750*time.Millisecond, cfg.Redis.OperationTimeout)
assert.Equal(t, UserServiceConfig{
Mode: "rest",
BaseURL: "http://127.0.0.1:19090",
RequestTimeout: 900 * time.Millisecond,
}, cfg.UserService)
assert.Equal(t, MailServiceConfig{
Mode: "rest",
BaseURL: "http://127.0.0.1:19091",
RequestTimeout: 950 * time.Millisecond,
}, cfg.MailService)
assert.Equal(t, "custom-authsession", cfg.Telemetry.ServiceName)
assert.Equal(t, "otlp", cfg.Telemetry.TracesExporter)
assert.Equal(t, "otlp", cfg.Telemetry.MetricsExporter)
assert.Equal(t, "grpc", cfg.Telemetry.TracesProtocol)
assert.Equal(t, "grpc", cfg.Telemetry.MetricsProtocol)
assert.True(t, cfg.Telemetry.StdoutTracesEnabled)
assert.True(t, cfg.Telemetry.StdoutMetricsEnabled)
}
func TestLoadFromEnvRejectsInvalidValues(t *testing.T) {
tests := []struct {
name string
envName string
envVal string
}{
{name: "invalid duration", envName: shutdownTimeoutEnvVar, envVal: "later"},
{name: "invalid bool", envName: otelStdoutTracesEnabledEnvVar, envVal: "sometimes"},
{name: "invalid log level", envName: logLevelEnvVar, envVal: "verbose"},
{name: "invalid traces protocol", envName: otelExporterOTLPTracesProtocolEnvVar, envVal: "udp"},
{name: "invalid user service mode", envName: userServiceModeEnvVar, envVal: "grpc"},
{name: "invalid user service timeout", envName: userServiceRequestTimeoutEnvVar, envVal: "never"},
{name: "invalid mail service mode", envName: mailServiceModeEnvVar, envVal: "grpc"},
{name: "invalid mail service timeout", envName: mailServiceRequestTimeoutEnvVar, envVal: "never"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
t.Setenv(tt.envName, tt.envVal)
if tt.envName == otelExporterOTLPTracesProtocolEnvVar {
t.Setenv(otelTracesExporterEnvVar, "otlp")
}
_, err := LoadFromEnv()
require.Error(t, err)
assert.Contains(t, err.Error(), tt.envName)
})
}
}
func TestLoadFromEnvRejectsInvalidRESTUserServiceConfiguration(t *testing.T) {
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
t.Setenv(userServiceModeEnvVar, "rest")
t.Run("missing base url", func(t *testing.T) {
_, err := LoadFromEnv()
require.Error(t, err)
assert.Contains(t, err.Error(), userServiceBaseURLEnvVar)
})
t.Run("non positive timeout", func(t *testing.T) {
t.Setenv(userServiceBaseURLEnvVar, "http://127.0.0.1:19090")
t.Setenv(userServiceRequestTimeoutEnvVar, "0s")
_, err := LoadFromEnv()
require.Error(t, err)
assert.Contains(t, err.Error(), userServiceRequestTimeoutEnvVar)
})
}
func TestLoadFromEnvRejectsInvalidRESTMailServiceConfiguration(t *testing.T) {
t.Setenv(redisAddrEnvVar, "127.0.0.1:6379")
t.Setenv(mailServiceModeEnvVar, "rest")
t.Run("missing base url", func(t *testing.T) {
_, err := LoadFromEnv()
require.Error(t, err)
assert.Contains(t, err.Error(), mailServiceBaseURLEnvVar)
})
t.Run("non positive timeout", func(t *testing.T) {
t.Setenv(mailServiceBaseURLEnvVar, "http://127.0.0.1:19091")
t.Setenv(mailServiceRequestTimeoutEnvVar, "0s")
_, err := LoadFromEnv()
require.Error(t, err)
assert.Contains(t, err.Error(), mailServiceRequestTimeoutEnvVar)
})
}
@@ -0,0 +1,342 @@
// Package challenge defines the source-of-truth domain model for one e-mail
// confirmation challenge.
package challenge
import (
"errors"
"fmt"
"time"
"galaxy/authsession/internal/domain/common"
)
// Status identifies the coarse lifecycle state of one challenge.
type Status string
const (
// StatusPendingSend reports that the challenge has been created but its
// delivery outcome has not been recorded yet.
StatusPendingSend Status = "pending_send"
// StatusSent reports that the confirmation code was delivered successfully.
StatusSent Status = "sent"
// StatusDeliverySuppressed reports that outward send succeeded but actual
// delivery was intentionally suppressed by policy.
StatusDeliverySuppressed Status = "delivery_suppressed"
// StatusDeliveryThrottled reports that a fresh challenge was created but
// delivery was skipped because the auth-side resend cooldown is still
// active.
StatusDeliveryThrottled Status = "delivery_throttled"
// StatusConfirmedPendingExpire reports that the challenge was confirmed
// successfully and is temporarily retained for idempotent retry handling.
StatusConfirmedPendingExpire Status = "confirmed_pending_expire"
// StatusExpired reports that the challenge can no longer be confirmed.
StatusExpired Status = "expired"
// StatusFailed reports that the challenge reached a terminal failure state.
StatusFailed Status = "failed"
// StatusCancelled reports that the challenge was cancelled explicitly.
StatusCancelled Status = "cancelled"
)
// IsKnown reports whether Status is one of the challenge states supported by
// the current domain model.
func (s Status) IsKnown() bool {
switch s {
case StatusPendingSend,
StatusSent,
StatusDeliverySuppressed,
StatusDeliveryThrottled,
StatusConfirmedPendingExpire,
StatusExpired,
StatusFailed,
StatusCancelled:
return true
default:
return false
}
}
// IsTerminal reports whether Status can no longer accept any lifecycle
// transition in the v1 challenge state machine.
func (s Status) IsTerminal() bool {
switch s {
case StatusExpired, StatusFailed, StatusCancelled:
return true
default:
return false
}
}
// AcceptsFreshConfirm reports whether Status may still consume a first
// successful confirmation attempt.
func (s Status) AcceptsFreshConfirm() bool {
switch s {
case StatusSent, StatusDeliverySuppressed:
return true
default:
return false
}
}
// IsConfirmedRetryState reports whether Status should use the idempotent retry
// path for a previously successful confirmation.
func (s Status) IsConfirmedRetryState() bool {
return s == StatusConfirmedPendingExpire
}
// CanTransitionTo reports whether the current challenge Status may move to
// next under the coarse lifecycle rules fixed by Stage 2.
func (s Status) CanTransitionTo(next Status) bool {
switch s {
case StatusPendingSend:
switch next {
case StatusSent, StatusDeliverySuppressed, StatusDeliveryThrottled, StatusFailed, StatusCancelled, StatusExpired:
return true
}
case StatusSent, StatusDeliverySuppressed:
switch next {
case StatusConfirmedPendingExpire, StatusFailed, StatusCancelled, StatusExpired:
return true
}
case StatusConfirmedPendingExpire:
return next == StatusExpired
}
return false
}
// DeliveryState identifies the recorded delivery result of one challenge.
type DeliveryState string
const (
// DeliveryPending reports that no delivery outcome has been recorded yet.
DeliveryPending DeliveryState = "pending"
// DeliverySent reports that the challenge code was sent successfully.
DeliverySent DeliveryState = "sent"
// DeliverySuppressed reports that the outward flow stays success-shaped
// while actual delivery is intentionally skipped.
DeliverySuppressed DeliveryState = "suppressed"
// DeliveryThrottled reports that the outward flow stays success-shaped
// while actual delivery is skipped because the resend cooldown is active.
DeliveryThrottled DeliveryState = "throttled"
// DeliveryFailed reports that delivery was attempted and failed explicitly.
DeliveryFailed DeliveryState = "failed"
)
// IsKnown reports whether DeliveryState is one of the delivery states
// supported by the current domain model.
func (s DeliveryState) IsKnown() bool {
switch s {
case DeliveryPending, DeliverySent, DeliverySuppressed, DeliveryThrottled, DeliveryFailed:
return true
default:
return false
}
}
// CanTransitionTo reports whether the current DeliveryState may move to next
// under the coarse delivery rules fixed by Stage 2.
func (s DeliveryState) CanTransitionTo(next DeliveryState) bool {
if s != DeliveryPending {
return false
}
switch next {
case DeliverySent, DeliverySuppressed, DeliveryThrottled, DeliveryFailed:
return true
default:
return false
}
}
// AttemptCounters groups the mutable send and confirm counters tracked by one
// challenge aggregate.
type AttemptCounters struct {
// Send counts delivery attempts initiated for the challenge.
Send int
// Confirm counts confirmation attempts evaluated against the challenge.
Confirm int
}
// Validate reports whether AttemptCounters contains only non-negative values.
func (c AttemptCounters) Validate() error {
if c.Send < 0 {
return errors.New("challenge send attempt count must not be negative")
}
if c.Confirm < 0 {
return errors.New("challenge confirm attempt count must not be negative")
}
return nil
}
// AbuseMetadata stores minimal abuse-related timestamps without fixing later
// anti-abuse policy details too early.
type AbuseMetadata struct {
// LastAttemptAt optionally records the last send or confirm attempt time
// associated with the challenge.
LastAttemptAt *time.Time
}
// Validate reports whether AbuseMetadata contains structurally valid values.
func (m AbuseMetadata) Validate() error {
if m.LastAttemptAt != nil && m.LastAttemptAt.IsZero() {
return errors.New("challenge abuse metadata last attempt time must not be zero")
}
return nil
}
// Confirmation stores the idempotency metadata recorded after a successful
// challenge confirmation.
type Confirmation struct {
// SessionID is the created device session returned by the successful
// confirmation.
SessionID common.DeviceSessionID
// ClientPublicKey is the validated client key bound to SessionID.
ClientPublicKey common.ClientPublicKey
// ConfirmedAt records when the successful confirmation happened.
ConfirmedAt time.Time
}
// Validate reports whether Confirmation contains all metadata required for a
// confirmed challenge.
func (c Confirmation) Validate() error {
if err := c.SessionID.Validate(); err != nil {
return fmt.Errorf("challenge confirmation session id: %w", err)
}
if err := c.ClientPublicKey.Validate(); err != nil {
return fmt.Errorf("challenge confirmation client public key: %w", err)
}
if c.ConfirmedAt.IsZero() {
return errors.New("challenge confirmation time must not be zero")
}
return nil
}
// Challenge is the minimal source-of-truth aggregate shape fixed by Stage 2.
type Challenge struct {
// ID identifies the challenge.
ID common.ChallengeID
// Email stores the normalized target e-mail address.
Email common.Email
// CodeHash stores only the hashed confirmation code.
CodeHash []byte
// Status reports the coarse challenge lifecycle state.
Status Status
// DeliveryState reports the recorded delivery outcome.
DeliveryState DeliveryState
// CreatedAt reports when the challenge was created.
CreatedAt time.Time
// ExpiresAt reports when the challenge becomes unusable.
ExpiresAt time.Time
// Attempts groups the send and confirm counters.
Attempts AttemptCounters
// Abuse stores minimal abuse-related timestamps.
Abuse AbuseMetadata
// Confirmation is present only after a successful confirm transition.
Confirmation *Confirmation
}
// IsExpiredAt reports whether the challenge is unusable at now either because
// it is already marked expired or because its expiration timestamp has passed.
func (c Challenge) IsExpiredAt(now time.Time) bool {
return c.Status == StatusExpired || !c.ExpiresAt.After(now)
}
// Validate reports whether Challenge satisfies the Stage-2 structural and
// lifecycle invariants.
func (c Challenge) Validate() error {
if err := c.ID.Validate(); err != nil {
return fmt.Errorf("challenge id: %w", err)
}
if err := c.Email.Validate(); err != nil {
return fmt.Errorf("challenge email: %w", err)
}
if len(c.CodeHash) == 0 {
return errors.New("challenge code hash must not be empty")
}
if !c.Status.IsKnown() {
return fmt.Errorf("challenge status %q is unsupported", c.Status)
}
if !c.DeliveryState.IsKnown() {
return fmt.Errorf("challenge delivery state %q is unsupported", c.DeliveryState)
}
if c.CreatedAt.IsZero() {
return errors.New("challenge creation time must not be zero")
}
if c.ExpiresAt.IsZero() {
return errors.New("challenge expiration time must not be zero")
}
if c.ExpiresAt.Before(c.CreatedAt) {
return errors.New("challenge expiration time must not be before creation time")
}
if err := c.Attempts.Validate(); err != nil {
return err
}
if err := c.Abuse.Validate(); err != nil {
return err
}
switch c.Status {
case StatusPendingSend:
if c.DeliveryState != DeliveryPending {
return errors.New("pending_send challenge must keep pending delivery state")
}
case StatusSent:
if c.DeliveryState != DeliverySent {
return errors.New("sent challenge must keep sent delivery state")
}
case StatusDeliverySuppressed:
if c.DeliveryState != DeliverySuppressed {
return errors.New("delivery_suppressed challenge must keep suppressed delivery state")
}
case StatusDeliveryThrottled:
if c.DeliveryState != DeliveryThrottled {
return errors.New("delivery_throttled challenge must keep throttled delivery state")
}
case StatusConfirmedPendingExpire:
if c.DeliveryState != DeliverySent && c.DeliveryState != DeliverySuppressed {
return errors.New("confirmed_pending_expire challenge must come from sent or suppressed delivery state")
}
}
if c.Status == StatusConfirmedPendingExpire {
if c.Confirmation == nil {
return errors.New("confirmed_pending_expire challenge must contain confirmation metadata")
}
if err := c.Confirmation.Validate(); err != nil {
return fmt.Errorf("challenge confirmation: %w", err)
}
return nil
}
if c.Confirmation != nil {
return errors.New("only confirmed_pending_expire challenge may contain confirmation metadata")
}
return nil
}
@@ -0,0 +1,439 @@
package challenge
import (
"crypto/ed25519"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
)
func TestPolicyConstants(t *testing.T) {
t.Parallel()
if InitialTTL != 5*time.Minute {
require.Failf(t, "test failed", "InitialTTL = %s, want %s", InitialTTL, 5*time.Minute)
}
if ResendThrottleCooldown != time.Minute {
require.Failf(t, "test failed", "ResendThrottleCooldown = %s, want %s", ResendThrottleCooldown, time.Minute)
}
if ConfirmedRetention != 5*time.Minute {
require.Failf(t, "test failed", "ConfirmedRetention = %s, want %s", ConfirmedRetention, 5*time.Minute)
}
if MaxInvalidConfirmAttempts != 5 {
require.Failf(t, "test failed", "MaxInvalidConfirmAttempts = %d, want %d", MaxInvalidConfirmAttempts, 5)
}
}
func TestStatusIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Status
want bool
}{
{name: "pending send", value: StatusPendingSend, want: true},
{name: "sent", value: StatusSent, want: true},
{name: "suppressed", value: StatusDeliverySuppressed, want: true},
{name: "throttled", value: StatusDeliveryThrottled, want: true},
{name: "confirmed", value: StatusConfirmedPendingExpire, want: true},
{name: "expired", value: StatusExpired, want: true},
{name: "failed", value: StatusFailed, want: true},
{name: "cancelled", value: StatusCancelled, want: true},
{name: "unknown", value: Status("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestStatusIsTerminal(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Status
want bool
}{
{name: "pending send", value: StatusPendingSend, want: false},
{name: "sent", value: StatusSent, want: false},
{name: "delivery suppressed", value: StatusDeliverySuppressed, want: false},
{name: "delivery throttled", value: StatusDeliveryThrottled, want: false},
{name: "confirmed pending expire", value: StatusConfirmedPendingExpire, want: false},
{name: "expired", value: StatusExpired, want: true},
{name: "failed", value: StatusFailed, want: true},
{name: "cancelled", value: StatusCancelled, want: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsTerminal(); got != tt.want {
require.Failf(t, "test failed", "IsTerminal() = %v, want %v", got, tt.want)
}
})
}
}
func TestStatusAcceptsFreshConfirm(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Status
want bool
}{
{name: "pending send", value: StatusPendingSend, want: false},
{name: "sent", value: StatusSent, want: true},
{name: "delivery suppressed", value: StatusDeliverySuppressed, want: true},
{name: "delivery throttled", value: StatusDeliveryThrottled, want: false},
{name: "confirmed", value: StatusConfirmedPendingExpire, want: false},
{name: "expired", value: StatusExpired, want: false},
{name: "failed", value: StatusFailed, want: false},
{name: "cancelled", value: StatusCancelled, want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.AcceptsFreshConfirm(); got != tt.want {
require.Failf(t, "test failed", "AcceptsFreshConfirm() = %v, want %v", got, tt.want)
}
})
}
}
func TestStatusIsConfirmedRetryState(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Status
want bool
}{
{name: "sent", value: StatusSent, want: false},
{name: "delivery suppressed", value: StatusDeliverySuppressed, want: false},
{name: "delivery throttled", value: StatusDeliveryThrottled, want: false},
{name: "confirmed", value: StatusConfirmedPendingExpire, want: true},
{name: "expired", value: StatusExpired, want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsConfirmedRetryState(); got != tt.want {
require.Failf(t, "test failed", "IsConfirmedRetryState() = %v, want %v", got, tt.want)
}
})
}
}
func TestStatusCanTransitionTo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
from Status
to Status
want bool
}{
{name: "pending to sent", from: StatusPendingSend, to: StatusSent, want: true},
{name: "pending to suppressed", from: StatusPendingSend, to: StatusDeliverySuppressed, want: true},
{name: "pending to throttled", from: StatusPendingSend, to: StatusDeliveryThrottled, want: true},
{name: "pending to failed", from: StatusPendingSend, to: StatusFailed, want: true},
{name: "pending to cancelled", from: StatusPendingSend, to: StatusCancelled, want: true},
{name: "pending to expired", from: StatusPendingSend, to: StatusExpired, want: true},
{name: "pending to confirmed", from: StatusPendingSend, to: StatusConfirmedPendingExpire, want: false},
{name: "sent to confirmed", from: StatusSent, to: StatusConfirmedPendingExpire, want: true},
{name: "sent to failed", from: StatusSent, to: StatusFailed, want: true},
{name: "suppressed to confirmed", from: StatusDeliverySuppressed, to: StatusConfirmedPendingExpire, want: true},
{name: "throttled to confirmed", from: StatusDeliveryThrottled, to: StatusConfirmedPendingExpire, want: false},
{name: "confirmed to expired", from: StatusConfirmedPendingExpire, to: StatusExpired, want: true},
{name: "confirmed to failed", from: StatusConfirmedPendingExpire, to: StatusFailed, want: false},
{name: "expired terminal", from: StatusExpired, to: StatusCancelled, want: false},
{name: "failed terminal", from: StatusFailed, to: StatusExpired, want: false},
{name: "cancelled terminal", from: StatusCancelled, to: StatusExpired, want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.from.CanTransitionTo(tt.to); got != tt.want {
require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want)
}
})
}
}
func TestDeliveryStateIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value DeliveryState
want bool
}{
{name: "pending", value: DeliveryPending, want: true},
{name: "sent", value: DeliverySent, want: true},
{name: "suppressed", value: DeliverySuppressed, want: true},
{name: "throttled", value: DeliveryThrottled, want: true},
{name: "failed", value: DeliveryFailed, want: true},
{name: "unknown", value: DeliveryState("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestDeliveryStateCanTransitionTo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
from DeliveryState
to DeliveryState
want bool
}{
{name: "pending to sent", from: DeliveryPending, to: DeliverySent, want: true},
{name: "pending to suppressed", from: DeliveryPending, to: DeliverySuppressed, want: true},
{name: "pending to throttled", from: DeliveryPending, to: DeliveryThrottled, want: true},
{name: "pending to failed", from: DeliveryPending, to: DeliveryFailed, want: true},
{name: "sent terminal", from: DeliverySent, to: DeliveryFailed, want: false},
{name: "suppressed terminal", from: DeliverySuppressed, to: DeliverySent, want: false},
{name: "failed terminal", from: DeliveryFailed, to: DeliverySent, want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.from.CanTransitionTo(tt.to); got != tt.want {
require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want)
}
})
}
}
func TestChallengeIsExpiredAt(t *testing.T) {
t.Parallel()
now := time.Unix(1_775_121_700, 0).UTC()
tests := []struct {
name string
mutate func(*Challenge)
want bool
}{
{name: "active before expiration", want: false},
{
name: "expired status",
mutate: func(c *Challenge) {
c.Status = StatusExpired
},
want: true,
},
{
name: "expiration timestamp passed",
mutate: func(c *Challenge) {
c.ExpiresAt = now
},
want: true,
},
{
name: "confirmed retained before expiration",
mutate: func(c *Challenge) {
c.Status = StatusConfirmedPendingExpire
c.DeliveryState = DeliverySent
c.Confirmation = validConfirmation(t)
c.ExpiresAt = now.Add(time.Second)
},
want: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
challenge := validChallenge(t)
challenge.CreatedAt = now.Add(-time.Minute)
challenge.ExpiresAt = now.Add(time.Minute)
if tt.mutate != nil {
tt.mutate(&challenge)
}
if err := challenge.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
if got := challenge.IsExpiredAt(now); got != tt.want {
require.Failf(t, "test failed", "IsExpiredAt() = %v, want %v", got, tt.want)
}
})
}
}
func TestChallengeValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*Challenge)
wantErr bool
}{
{name: "valid pending"},
{
name: "valid confirmed",
mutate: func(c *Challenge) {
c.Status = StatusConfirmedPendingExpire
c.DeliveryState = DeliverySent
c.Confirmation = validConfirmation(t)
},
},
{
name: "confirmed requires metadata",
mutate: func(c *Challenge) {
c.Status = StatusConfirmedPendingExpire
c.DeliveryState = DeliverySent
},
wantErr: true,
},
{
name: "unconfirmed rejects metadata",
mutate: func(c *Challenge) {
c.Confirmation = validConfirmation(t)
},
wantErr: true,
},
{
name: "pending requires pending delivery",
mutate: func(c *Challenge) {
c.DeliveryState = DeliverySent
},
wantErr: true,
},
{
name: "sent requires sent delivery",
mutate: func(c *Challenge) {
c.Status = StatusSent
c.DeliveryState = DeliverySuppressed
},
wantErr: true,
},
{
name: "throttled requires throttled delivery",
mutate: func(c *Challenge) {
c.Status = StatusDeliveryThrottled
c.DeliveryState = DeliverySent
},
wantErr: true,
},
{
name: "expiration before creation",
mutate: func(c *Challenge) {
c.ExpiresAt = c.CreatedAt.Add(-time.Second)
},
wantErr: true,
},
{
name: "negative confirm attempts",
mutate: func(c *Challenge) {
c.Attempts.Confirm = -1
},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
challenge := validChallenge(t)
if tt.mutate != nil {
tt.mutate(&challenge)
}
err := challenge.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
func validChallenge(t *testing.T) Challenge {
t.Helper()
return Challenge{
ID: common.ChallengeID("challenge-123"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hash-123"),
Status: StatusPendingSend,
DeliveryState: DeliveryPending,
CreatedAt: time.Unix(1_775_121_600, 0).UTC(),
ExpiresAt: time.Unix(1_775_121_900, 0).UTC(),
Attempts: AttemptCounters{
Send: 0,
Confirm: 0,
},
}
}
func validConfirmation(t *testing.T) *Confirmation {
t.Helper()
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
for index := range raw {
raw[index] = byte(index + 1)
}
key, err := common.NewClientPublicKey(raw)
if err != nil {
require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err)
}
return &Confirmation{
SessionID: common.DeviceSessionID("device-session-123"),
ClientPublicKey: key,
ConfirmedAt: time.Unix(1_775_121_700, 0).UTC(),
}
}
@@ -0,0 +1,26 @@
package challenge
import "time"
const (
// InitialTTL is the v1 lifetime of a newly created challenge before it
// becomes expired.
InitialTTL = 5 * time.Minute
// ResendThrottleCooldown is the fixed Stage-17 cooldown applied to repeated
// public send-email-code requests for the same normalized e-mail address.
ResendThrottleCooldown = time.Minute
// ConfirmedRetention is the v1 idempotency window kept after a successful
// challenge confirmation.
ConfirmedRetention = 5 * time.Minute
// MaxInvalidConfirmAttempts is the v1 threshold after which repeated invalid
// confirmation codes move a challenge into the failed state.
MaxInvalidConfirmAttempts = 5
)
// V1 resend policy keeps every public send-email-code request independent:
// each call creates a fresh challenge, existing challenges are not reused or
// deduplicated, and Stage 17 adds a fixed auth-side resend cooldown that may
// record the fresh challenge as delivery_throttled.
+201
View File
@@ -0,0 +1,201 @@
// Package common defines small shared domain primitives used by auth/session
// aggregates and integration models.
package common
import (
"bytes"
"crypto/ed25519"
"encoding/base64"
"errors"
"fmt"
"net/mail"
"strings"
)
// ChallengeID identifies one auth confirmation challenge owned by the service.
type ChallengeID string
// String returns ChallengeID as a plain string identifier.
func (id ChallengeID) String() string {
return string(id)
}
// IsZero reports whether ChallengeID does not contain a usable identifier.
func (id ChallengeID) IsZero() bool {
return strings.TrimSpace(string(id)) == ""
}
// Validate reports whether ChallengeID is non-empty and already normalized for
// domain use.
func (id ChallengeID) Validate() error {
return validateToken("challenge id", string(id))
}
// DeviceSessionID identifies one persisted device session.
type DeviceSessionID string
// String returns DeviceSessionID as a plain string identifier.
func (id DeviceSessionID) String() string {
return string(id)
}
// IsZero reports whether DeviceSessionID does not contain a usable identifier.
func (id DeviceSessionID) IsZero() bool {
return strings.TrimSpace(string(id)) == ""
}
// Validate reports whether DeviceSessionID is non-empty and already
// normalized for domain use.
func (id DeviceSessionID) Validate() error {
return validateToken("device session id", string(id))
}
// UserID identifies one user resolved through the user-service boundary.
type UserID string
// String returns UserID as a plain string identifier.
func (id UserID) String() string {
return string(id)
}
// IsZero reports whether UserID does not contain a usable identifier.
func (id UserID) IsZero() bool {
return strings.TrimSpace(string(id)) == ""
}
// Validate reports whether UserID is non-empty and already normalized for
// domain use.
func (id UserID) Validate() error {
return validateToken("user id", string(id))
}
// Email stores one already-normalized e-mail address used by the auth domain.
type Email string
// String returns Email as the stored canonical e-mail string.
func (e Email) String() string {
return string(e)
}
// IsZero reports whether Email does not contain a usable e-mail value.
func (e Email) IsZero() bool {
return strings.TrimSpace(string(e)) == ""
}
// Validate reports whether Email is non-empty, does not contain surrounding
// whitespace, and matches the same single-address syntax expected by the
// public gateway contract.
func (e Email) Validate() error {
raw := string(e)
if err := validateToken("email", raw); err != nil {
return err
}
parsedAddress, err := mail.ParseAddress(raw)
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != raw {
return fmt.Errorf("email %q must be a single valid email address", raw)
}
return nil
}
// RevokeReasonCode stores one machine-readable revoke reason code.
type RevokeReasonCode string
// String returns RevokeReasonCode as its stored code value.
func (code RevokeReasonCode) String() string {
return string(code)
}
// IsZero reports whether RevokeReasonCode is empty.
func (code RevokeReasonCode) IsZero() bool {
return strings.TrimSpace(string(code)) == ""
}
// Validate reports whether RevokeReasonCode is non-empty and normalized for
// domain use.
func (code RevokeReasonCode) Validate() error {
return validateToken("revoke reason code", string(code))
}
// RevokeActorType stores one machine-readable actor type for revoke audit.
type RevokeActorType string
// String returns RevokeActorType as its stored type value.
func (actorType RevokeActorType) String() string {
return string(actorType)
}
// IsZero reports whether RevokeActorType is empty.
func (actorType RevokeActorType) IsZero() bool {
return strings.TrimSpace(string(actorType)) == ""
}
// Validate reports whether RevokeActorType is non-empty and normalized for
// domain use.
func (actorType RevokeActorType) Validate() error {
return validateToken("revoke actor type", string(actorType))
}
// ClientPublicKey stores one validated Ed25519 public key in parsed binary
// form inside the domain model.
type ClientPublicKey struct {
value ed25519.PublicKey
}
// NewClientPublicKey validates value and returns a defensive copy suitable for
// storing inside domain aggregates.
func NewClientPublicKey(value ed25519.PublicKey) (ClientPublicKey, error) {
key := ClientPublicKey{
value: bytes.Clone(value),
}
if err := key.Validate(); err != nil {
return ClientPublicKey{}, err
}
return key, nil
}
// String returns ClientPublicKey as the standard base64-encoded raw 32-byte
// Ed25519 public key string.
func (key ClientPublicKey) String() string {
if key.IsZero() {
return ""
}
return base64.StdEncoding.EncodeToString(key.value)
}
// IsZero reports whether ClientPublicKey does not contain key material.
func (key ClientPublicKey) IsZero() bool {
return len(key.value) == 0
}
// Validate reports whether ClientPublicKey contains exactly one Ed25519 public
// key.
func (key ClientPublicKey) Validate() error {
switch len(key.value) {
case 0:
return errors.New("client public key must not be empty")
case ed25519.PublicKeySize:
return nil
default:
return fmt.Errorf("client public key must contain exactly %d bytes", ed25519.PublicKeySize)
}
}
// PublicKey returns a defensive copy of the parsed Ed25519 public key.
func (key ClientPublicKey) PublicKey() ed25519.PublicKey {
return bytes.Clone(key.value)
}
func validateToken(name string, value string) error {
switch {
case strings.TrimSpace(value) == "":
return fmt.Errorf("%s must not be empty", name)
case strings.TrimSpace(value) != value:
return fmt.Errorf("%s must not contain surrounding whitespace", name)
default:
return nil
}
}
@@ -0,0 +1,133 @@
package common
import (
"crypto/ed25519"
"github.com/stretchr/testify/require"
"testing"
)
func TestChallengeIDValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value ChallengeID
wantErr bool
}{
{name: "valid", value: ChallengeID("challenge-123")},
{name: "empty", value: ChallengeID(""), wantErr: true},
{name: "whitespace", value: ChallengeID(" challenge-123 "), wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.value.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
func TestEmailValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Email
wantErr bool
}{
{name: "valid", value: Email("pilot@example.com")},
{name: "invalid", value: Email("pilot"), wantErr: true},
{name: "surrounding whitespace", value: Email(" pilot@example.com "), wantErr: true},
{name: "display name", value: Email("Pilot <pilot@example.com>"), wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.value.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
func TestNewClientPublicKey(t *testing.T) {
t.Parallel()
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
for i := range raw {
raw[i] = byte(i)
}
key, err := NewClientPublicKey(raw)
if err != nil {
require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err)
}
if key.IsZero() {
require.FailNow(t, "IsZero() = true, want false")
}
cloned := key.PublicKey()
if len(cloned) != ed25519.PublicKeySize {
require.Failf(t, "test failed", "PublicKey() length = %d, want %d", len(cloned), ed25519.PublicKeySize)
}
raw[0] = 99
if key.PublicKey()[0] == 99 {
require.FailNow(t, "PublicKey() was mutated through constructor input")
}
}
func TestClientPublicKeyValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value ClientPublicKey
wantErr bool
}{
{name: "empty", value: ClientPublicKey{}, wantErr: true},
{
name: "short",
value: ClientPublicKey{value: make(ed25519.PublicKey, ed25519.PublicKeySize-1)},
wantErr: true,
},
{
name: "valid",
value: ClientPublicKey{value: make(ed25519.PublicKey, ed25519.PublicKeySize)},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.value.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
@@ -0,0 +1,162 @@
// Package devicesession defines the source-of-truth domain model for one
// authenticated device session.
package devicesession
import (
"errors"
"fmt"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
)
// Status identifies the coarse lifecycle state of one device session.
type Status string
const (
// StatusActive reports that the session may be used for authenticated
// request verification.
StatusActive Status = "active"
// StatusRevoked reports that the session has been revoked and must no
// longer authenticate requests.
StatusRevoked Status = "revoked"
)
// RevokeReasonDeviceLogout reports that one device logged itself out.
const RevokeReasonDeviceLogout common.RevokeReasonCode = "device_logout"
// RevokeReasonLogoutAll reports that the session was revoked by a
// user-scoped logout-all action.
const RevokeReasonLogoutAll common.RevokeReasonCode = "logout_all"
// RevokeReasonAdminRevoke reports that the session was revoked
// administratively.
const RevokeReasonAdminRevoke common.RevokeReasonCode = "admin_revoke"
// RevokeReasonUserBlocked reports that the session was revoked because future
// auth flow for the user or e-mail was blocked.
const RevokeReasonUserBlocked common.RevokeReasonCode = "user_blocked"
// IsKnown reports whether Status is one of the device-session states
// supported by the current domain model.
func (s Status) IsKnown() bool {
switch s {
case StatusActive, StatusRevoked:
return true
default:
return false
}
}
// CanTransitionTo reports whether the current device-session Status may move
// to next under the Stage-2 lifecycle rules.
func (s Status) CanTransitionTo(next Status) bool {
return s == StatusActive && next == StatusRevoked
}
// IsKnownRevokeReasonCode reports whether code is one of the built-in revoke
// reasons fixed by the Stage-2 domain model.
func IsKnownRevokeReasonCode(code common.RevokeReasonCode) bool {
switch code {
case RevokeReasonDeviceLogout,
RevokeReasonLogoutAll,
RevokeReasonAdminRevoke,
RevokeReasonUserBlocked:
return true
default:
return false
}
}
// Revocation stores the audit metadata recorded when a session is revoked.
type Revocation struct {
// At reports when the revoke took effect.
At time.Time
// ReasonCode stores one machine-readable revoke reason code.
ReasonCode common.RevokeReasonCode
// ActorType stores one machine-readable initiator type.
ActorType common.RevokeActorType
// ActorID optionally stores a stable initiator identifier.
ActorID string
}
// Validate reports whether Revocation contains all metadata required for a
// revoked session.
func (r Revocation) Validate() error {
if r.At.IsZero() {
return errors.New("session revocation time must not be zero")
}
if err := r.ReasonCode.Validate(); err != nil {
return fmt.Errorf("session revocation reason code: %w", err)
}
if err := r.ActorType.Validate(); err != nil {
return fmt.Errorf("session revocation actor type: %w", err)
}
if strings.TrimSpace(r.ActorID) != r.ActorID {
return errors.New("session revocation actor id must not contain surrounding whitespace")
}
return nil
}
// Session is the minimal source-of-truth aggregate shape fixed by Stage 2.
type Session struct {
// ID identifies the device session.
ID common.DeviceSessionID
// UserID identifies the durable user linkage for the session.
UserID common.UserID
// ClientPublicKey stores the validated device public key in parsed form.
ClientPublicKey common.ClientPublicKey
// Status reports the coarse lifecycle state of the session.
Status Status
// CreatedAt reports when the session was created.
CreatedAt time.Time
// Revocation is present only when Status is StatusRevoked.
Revocation *Revocation
}
// Validate reports whether Session satisfies the Stage-2 structural and
// lifecycle invariants.
func (s Session) Validate() error {
if err := s.ID.Validate(); err != nil {
return fmt.Errorf("session id: %w", err)
}
if err := s.UserID.Validate(); err != nil {
return fmt.Errorf("session user id: %w", err)
}
if err := s.ClientPublicKey.Validate(); err != nil {
return fmt.Errorf("session client public key: %w", err)
}
if !s.Status.IsKnown() {
return fmt.Errorf("session status %q is unsupported", s.Status)
}
if s.CreatedAt.IsZero() {
return errors.New("session creation time must not be zero")
}
switch s.Status {
case StatusActive:
if s.Revocation != nil {
return errors.New("active session must not contain revocation metadata")
}
case StatusRevoked:
if s.Revocation == nil {
return errors.New("revoked session must contain revocation metadata")
}
if err := s.Revocation.Validate(); err != nil {
return fmt.Errorf("session revocation: %w", err)
}
}
return nil
}
@@ -0,0 +1,186 @@
package devicesession
import (
"crypto/ed25519"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
)
func TestStatusIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Status
want bool
}{
{name: "active", value: StatusActive, want: true},
{name: "revoked", value: StatusRevoked, want: true},
{name: "unknown", value: Status("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestStatusCanTransitionTo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
from Status
to Status
want bool
}{
{name: "active to revoked", from: StatusActive, to: StatusRevoked, want: true},
{name: "active to active", from: StatusActive, to: StatusActive, want: false},
{name: "revoked terminal", from: StatusRevoked, to: StatusActive, want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.from.CanTransitionTo(tt.to); got != tt.want {
require.Failf(t, "test failed", "CanTransitionTo() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsKnownRevokeReasonCode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value common.RevokeReasonCode
want bool
}{
{name: "device logout", value: RevokeReasonDeviceLogout, want: true},
{name: "logout all", value: RevokeReasonLogoutAll, want: true},
{name: "admin revoke", value: RevokeReasonAdminRevoke, want: true},
{name: "user blocked", value: RevokeReasonUserBlocked, want: true},
{name: "custom code", value: common.RevokeReasonCode("custom_policy"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := IsKnownRevokeReasonCode(tt.value); got != tt.want {
require.Failf(t, "test failed", "IsKnownRevokeReasonCode() = %v, want %v", got, tt.want)
}
})
}
}
func TestSessionValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*Session)
wantErr bool
}{
{name: "active valid"},
{
name: "revoked valid",
mutate: func(s *Session) {
s.Status = StatusRevoked
s.Revocation = validRevocation()
},
},
{
name: "active rejects revocation",
mutate: func(s *Session) {
s.Revocation = validRevocation()
},
wantErr: true,
},
{
name: "revoked requires revocation",
mutate: func(s *Session) {
s.Status = StatusRevoked
},
wantErr: true,
},
{
name: "revoked requires complete metadata",
mutate: func(s *Session) {
s.Status = StatusRevoked
revocation := validRevocation()
revocation.ReasonCode = ""
s.Revocation = revocation
},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
session := validSession(t)
if tt.mutate != nil {
tt.mutate(&session)
}
err := session.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
func validSession(t *testing.T) Session {
t.Helper()
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
for index := range raw {
raw[index] = byte(index + 7)
}
key, err := common.NewClientPublicKey(raw)
if err != nil {
require.Failf(t, "test failed", "NewClientPublicKey() returned error: %v", err)
}
return Session{
ID: common.DeviceSessionID("device-session-123"),
UserID: common.UserID("user-123"),
ClientPublicKey: key,
Status: StatusActive,
CreatedAt: time.Unix(1_775_121_600, 0).UTC(),
}
}
func validRevocation() *Revocation {
return &Revocation{
At: time.Unix(1_775_121_800, 0).UTC(),
ReasonCode: RevokeReasonAdminRevoke,
ActorType: common.RevokeActorType("admin"),
ActorID: "admin-123",
}
}
@@ -0,0 +1,141 @@
// Package gatewayprojection defines the gateway-facing integration snapshot
// model that stays separate from source-of-truth session entities.
package gatewayprojection
import (
"crypto/ed25519"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
)
// Status identifies the coarse lifecycle state projected to the gateway.
type Status string
const (
// StatusActive reports that the projected session may authenticate
// requests on the gateway hot path.
StatusActive Status = "active"
// StatusRevoked reports that the projected session must be rejected on the
// gateway hot path.
StatusRevoked Status = "revoked"
)
// IsKnown reports whether Status is one of the projection states supported by
// the current integration model.
func (s Status) IsKnown() bool {
switch s {
case StatusActive, StatusRevoked:
return true
default:
return false
}
}
// Snapshot stores the gateway-facing session projection without exposing any
// Redis-specific field naming or storage encoding.
type Snapshot struct {
// DeviceSessionID identifies the projected device session.
DeviceSessionID common.DeviceSessionID
// UserID identifies the projected user.
UserID common.UserID
// ClientPublicKey stores the standard base64-encoded raw 32-byte Ed25519
// public key string expected by the gateway.
ClientPublicKey string
// Status reports whether the projected session is active or revoked.
Status Status
// RevokedAt optionally reports when the revoke took effect.
RevokedAt *time.Time
// RevokeReasonCode optionally stores the machine-readable revoke reason.
RevokeReasonCode common.RevokeReasonCode
// RevokeActorType optionally stores the machine-readable revoke actor type.
RevokeActorType common.RevokeActorType
// RevokeActorID optionally stores a stable revoke actor identifier.
RevokeActorID string
}
// Validate reports whether Snapshot satisfies the Stage-2 structural
// invariants.
func (s Snapshot) Validate() error {
if err := s.DeviceSessionID.Validate(); err != nil {
return fmt.Errorf("gateway projection device session id: %w", err)
}
if err := s.UserID.Validate(); err != nil {
return fmt.Errorf("gateway projection user id: %w", err)
}
if err := validateClientPublicKey(s.ClientPublicKey); err != nil {
return fmt.Errorf("gateway projection client public key: %w", err)
}
if !s.Status.IsKnown() {
return fmt.Errorf("gateway projection status %q is unsupported", s.Status)
}
if s.Status == StatusActive {
if s.RevokedAt != nil {
return errors.New("active gateway projection must not contain revoked time")
}
if !s.RevokeReasonCode.IsZero() {
return errors.New("active gateway projection must not contain revoke reason code")
}
if !s.RevokeActorType.IsZero() {
return errors.New("active gateway projection must not contain revoke actor type")
}
if s.RevokeActorID != "" {
return errors.New("active gateway projection must not contain revoke actor id")
}
return nil
}
if s.RevokedAt != nil && s.RevokedAt.IsZero() {
return errors.New("gateway projection revoked time must not be zero")
}
if !s.RevokeReasonCode.IsZero() {
if err := s.RevokeReasonCode.Validate(); err != nil {
return fmt.Errorf("gateway projection revoke reason code: %w", err)
}
}
if !s.RevokeActorType.IsZero() {
if err := s.RevokeActorType.Validate(); err != nil {
return fmt.Errorf("gateway projection revoke actor type: %w", err)
}
}
if s.RevokeActorType.IsZero() && s.RevokeActorID != "" {
return errors.New("gateway projection revoke actor id requires revoke actor type")
}
if strings.TrimSpace(s.RevokeActorID) != s.RevokeActorID {
return errors.New("gateway projection revoke actor id must not contain surrounding whitespace")
}
return nil
}
func validateClientPublicKey(value string) error {
switch {
case strings.TrimSpace(value) == "":
return errors.New("client public key must not be empty")
case strings.TrimSpace(value) != value:
return errors.New("client public key must not contain surrounding whitespace")
}
decoded, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return fmt.Errorf("client public key must be valid base64: %w", err)
}
if len(decoded) != ed25519.PublicKeySize {
return fmt.Errorf("client public key must contain exactly %d bytes", ed25519.PublicKeySize)
}
return nil
}
@@ -0,0 +1,146 @@
package gatewayprojection
import (
"crypto/ed25519"
"encoding/base64"
"github.com/stretchr/testify/require"
"reflect"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
)
func TestStatusIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Status
want bool
}{
{name: "active", value: StatusActive, want: true},
{name: "revoked", value: StatusRevoked, want: true},
{name: "unknown", value: Status("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestSnapshotValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*Snapshot)
wantErr bool
}{
{name: "active valid"},
{
name: "revoked valid",
mutate: func(snapshot *Snapshot) {
snapshot.Status = StatusRevoked
revokedAt := time.Unix(1_775_121_900, 0).UTC()
snapshot.RevokedAt = &revokedAt
snapshot.RevokeReasonCode = common.RevokeReasonCode("admin_revoke")
snapshot.RevokeActorType = common.RevokeActorType("admin")
snapshot.RevokeActorID = "admin-123"
},
},
{
name: "active rejects revoke metadata",
mutate: func(snapshot *Snapshot) {
snapshot.RevokeReasonCode = common.RevokeReasonCode("admin_revoke")
},
wantErr: true,
},
{
name: "invalid key encoding",
mutate: func(snapshot *Snapshot) {
snapshot.ClientPublicKey = "not-base64"
},
wantErr: true,
},
{
name: "actor id requires actor type",
mutate: func(snapshot *Snapshot) {
snapshot.Status = StatusRevoked
snapshot.RevokeActorID = "admin-123"
},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
snapshot := validSnapshot()
if tt.mutate != nil {
tt.mutate(&snapshot)
}
err := snapshot.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
func TestSnapshotStaysSeparateFromSessionDomainShape(t *testing.T) {
t.Parallel()
snapshotType := reflect.TypeOf(Snapshot{})
sessionType := reflect.TypeOf(devicesession.Session{})
clientPublicKeyField, ok := snapshotType.FieldByName("ClientPublicKey")
if !ok {
require.FailNow(t, "Snapshot is missing ClientPublicKey field")
}
if clientPublicKeyField.Type.Kind() != reflect.String {
require.Failf(t, "test failed", "Snapshot.ClientPublicKey kind = %s, want string", clientPublicKeyField.Type.Kind())
}
sessionClientPublicKeyField, ok := sessionType.FieldByName("ClientPublicKey")
if !ok {
require.FailNow(t, "devicesession.Session is missing ClientPublicKey field")
}
if clientPublicKeyField.Type == sessionClientPublicKeyField.Type {
require.FailNow(t, "Snapshot.ClientPublicKey must stay separate from devicesession.Session.ClientPublicKey type")
}
if _, ok := snapshotType.FieldByName("RevokedAtMS"); ok {
require.FailNow(t, "Snapshot must not expose Redis-specific RevokedAtMS field")
}
}
func validSnapshot() Snapshot {
raw := make(ed25519.PublicKey, ed25519.PublicKeySize)
for index := range raw {
raw[index] = byte(index + 17)
}
return Snapshot{
DeviceSessionID: common.DeviceSessionID("device-session-123"),
UserID: common.UserID("user-123"),
ClientPublicKey: base64.StdEncoding.EncodeToString(raw),
Status: StatusActive,
}
}
@@ -0,0 +1,89 @@
// Package sessionlimit defines the domain decision shape used for active
// device-session limit evaluation.
package sessionlimit
import (
"errors"
"fmt"
)
// Kind identifies the coarse outcome of evaluating the active-session limit.
type Kind string
const (
// KindDisabled reports that no configured limit is currently active.
KindDisabled Kind = "disabled"
// KindAllowed reports that creating the next session is allowed.
KindAllowed Kind = "allowed"
// KindExceeded reports that creating the next session would exceed the
// configured limit.
KindExceeded Kind = "exceeded"
)
// IsKnown reports whether Kind is one of the session-limit outcomes supported
// by the current domain model.
func (k Kind) IsKnown() bool {
switch k {
case KindDisabled, KindAllowed, KindExceeded:
return true
default:
return false
}
}
// Decision stores the result of evaluating one possible next session creation.
type Decision struct {
// Kind reports the coarse decision outcome.
Kind Kind
// ConfiguredLimit stores the active configured limit when one exists.
ConfiguredLimit *int
// ActiveSessionCount stores the current active-session count before create.
ActiveSessionCount int
// NextSessionCount stores the count that would exist after creating the next
// session.
NextSessionCount int
}
// Validate reports whether Decision satisfies the Stage-2 structural
// invariants.
func (d Decision) Validate() error {
if !d.Kind.IsKnown() {
return fmt.Errorf("session-limit decision kind %q is unsupported", d.Kind)
}
if d.ActiveSessionCount < 0 {
return errors.New("session-limit active session count must not be negative")
}
if d.NextSessionCount < 0 {
return errors.New("session-limit next session count must not be negative")
}
if d.NextSessionCount != d.ActiveSessionCount+1 {
return errors.New("session-limit next session count must equal active session count plus one")
}
switch d.Kind {
case KindDisabled:
if d.ConfiguredLimit != nil {
return errors.New("disabled session-limit decision must not contain configured limit")
}
case KindAllowed, KindExceeded:
if d.ConfiguredLimit == nil {
return errors.New("limited session-limit decision must contain configured limit")
}
if *d.ConfiguredLimit <= 0 {
return errors.New("session-limit configured limit must be positive")
}
if d.Kind == KindAllowed && d.NextSessionCount > *d.ConfiguredLimit {
return errors.New("allowed session-limit decision must not exceed configured limit")
}
if d.Kind == KindExceeded && d.NextSessionCount <= *d.ConfiguredLimit {
return errors.New("exceeded session-limit decision must be above configured limit")
}
}
return nil
}
@@ -0,0 +1,128 @@
package sessionlimit
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestKindIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Kind
want bool
}{
{name: "disabled", value: KindDisabled, want: true},
{name: "allowed", value: KindAllowed, want: true},
{name: "exceeded", value: KindExceeded, want: true},
{name: "unknown", value: Kind("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestDecisionValidate(t *testing.T) {
t.Parallel()
limitTwo := 2
limitThree := 3
tests := []struct {
name string
value Decision
wantErr bool
}{
{
name: "disabled valid",
value: Decision{
Kind: KindDisabled,
ActiveSessionCount: 0,
NextSessionCount: 1,
},
},
{
name: "allowed valid",
value: Decision{
Kind: KindAllowed,
ConfiguredLimit: &limitThree,
ActiveSessionCount: 1,
NextSessionCount: 2,
},
},
{
name: "exceeded valid",
value: Decision{
Kind: KindExceeded,
ConfiguredLimit: &limitTwo,
ActiveSessionCount: 2,
NextSessionCount: 3,
},
},
{
name: "disabled rejects limit",
value: Decision{
Kind: KindDisabled,
ConfiguredLimit: &limitTwo,
ActiveSessionCount: 0,
NextSessionCount: 1,
},
wantErr: true,
},
{
name: "allowed requires limit",
value: Decision{
Kind: KindAllowed,
ActiveSessionCount: 0,
NextSessionCount: 1,
},
wantErr: true,
},
{
name: "allowed rejects overflow",
value: Decision{
Kind: KindAllowed,
ConfiguredLimit: &limitTwo,
ActiveSessionCount: 2,
NextSessionCount: 3,
},
wantErr: true,
},
{
name: "next count must be active plus one",
value: Decision{
Kind: KindDisabled,
ActiveSessionCount: 2,
NextSessionCount: 2,
},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.value.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
@@ -0,0 +1,110 @@
// Package userresolution defines the domain result returned by the user
// resolution boundary before session creation.
package userresolution
import (
"errors"
"fmt"
"strings"
"galaxy/authsession/internal/domain/common"
)
// Kind identifies the coarse user-resolution result for one normalized e-mail.
type Kind string
const (
// KindExisting reports that the e-mail belongs to an existing user.
KindExisting Kind = "existing"
// KindCreatable reports that the e-mail is free and user creation is
// allowed.
KindCreatable Kind = "creatable"
// KindBlocked reports that the e-mail or subject is blocked from login or
// registration.
KindBlocked Kind = "blocked"
)
// IsKnown reports whether Kind is one of the user-resolution kinds supported
// by the current domain model.
func (k Kind) IsKnown() bool {
switch k {
case KindExisting, KindCreatable, KindBlocked:
return true
default:
return false
}
}
// BlockReasonCode stores one machine-readable user-block reason.
type BlockReasonCode string
// String returns BlockReasonCode as its stored code value.
func (code BlockReasonCode) String() string {
return string(code)
}
// IsZero reports whether BlockReasonCode is empty.
func (code BlockReasonCode) IsZero() bool {
return strings.TrimSpace(string(code)) == ""
}
// Validate reports whether BlockReasonCode is non-empty and normalized for
// domain use.
func (code BlockReasonCode) Validate() error {
switch {
case code.IsZero():
return errors.New("block reason code must not be empty")
case strings.TrimSpace(string(code)) != string(code):
return errors.New("block reason code must not contain surrounding whitespace")
default:
return nil
}
}
// Result stores the coarse user-resolution outcome consumed by later auth
// workflow stages.
type Result struct {
// Kind reports the coarse resolution outcome.
Kind Kind
// UserID is set only when Kind is KindExisting.
UserID common.UserID
// BlockReasonCode is set only when Kind is KindBlocked.
BlockReasonCode BlockReasonCode
}
// Validate reports whether Result satisfies the Stage-2 structural invariants.
func (r Result) Validate() error {
if !r.Kind.IsKnown() {
return fmt.Errorf("user resolution kind %q is unsupported", r.Kind)
}
switch r.Kind {
case KindExisting:
if err := r.UserID.Validate(); err != nil {
return fmt.Errorf("user resolution user id: %w", err)
}
if !r.BlockReasonCode.IsZero() {
return errors.New("existing user resolution must not contain block reason code")
}
case KindCreatable:
if !r.UserID.IsZero() {
return errors.New("creatable user resolution must not contain user id")
}
if !r.BlockReasonCode.IsZero() {
return errors.New("creatable user resolution must not contain block reason code")
}
case KindBlocked:
if !r.UserID.IsZero() {
return errors.New("blocked user resolution must not contain user id")
}
if err := r.BlockReasonCode.Validate(); err != nil {
return fmt.Errorf("user resolution block reason code: %w", err)
}
}
return nil
}
@@ -0,0 +1,113 @@
package userresolution
import (
"github.com/stretchr/testify/require"
"testing"
"galaxy/authsession/internal/domain/common"
)
func TestKindIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Kind
want bool
}{
{name: "existing", value: KindExisting, want: true},
{name: "creatable", value: KindCreatable, want: true},
{name: "blocked", value: KindBlocked, want: true},
{name: "unknown", value: Kind("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestResultValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value Result
wantErr bool
}{
{
name: "existing valid",
value: Result{
Kind: KindExisting,
UserID: common.UserID("user-123"),
},
},
{
name: "creatable valid",
value: Result{
Kind: KindCreatable,
},
},
{
name: "blocked valid",
value: Result{
Kind: KindBlocked,
BlockReasonCode: BlockReasonCode("policy_blocked"),
},
},
{
name: "existing requires user id",
value: Result{
Kind: KindExisting,
},
wantErr: true,
},
{
name: "creatable rejects user id",
value: Result{
Kind: KindCreatable,
UserID: common.UserID("user-123"),
},
wantErr: true,
},
{
name: "blocked requires reason",
value: Result{
Kind: KindBlocked,
},
wantErr: true,
},
{
name: "blocked rejects user id",
value: Result{
Kind: KindBlocked,
UserID: common.UserID("user-123"),
BlockReasonCode: BlockReasonCode("policy_blocked"),
},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.value.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
+82
View File
@@ -0,0 +1,82 @@
// Package logging configures the authsession structured logger and provides
// context-aware helpers for attaching OpenTelemetry trace identifiers.
package logging
import (
"context"
"strings"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// New constructs the process-wide JSON logger from level.
func New(level string) (*zap.Logger, error) {
atomicLevel := zap.NewAtomicLevel()
if err := atomicLevel.UnmarshalText([]byte(strings.TrimSpace(level))); err != nil {
return nil, err
}
zapCfg := zap.NewProductionConfig()
zapCfg.Level = atomicLevel
zapCfg.Sampling = nil
zapCfg.Encoding = "json"
zapCfg.EncoderConfig.TimeKey = "timestamp"
zapCfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
zapCfg.OutputPaths = []string{"stdout"}
zapCfg.ErrorOutputPaths = []string{"stderr"}
return zapCfg.Build()
}
// TraceFieldsFromContext returns zap fields for the active OpenTelemetry span
// when ctx carries a valid span context.
func TraceFieldsFromContext(ctx context.Context) []zap.Field {
if ctx == nil {
return nil
}
spanContext := trace.SpanContextFromContext(ctx)
if !spanContext.IsValid() {
return nil
}
return []zap.Field{
zap.String("otel_trace_id", spanContext.TraceID().String()),
zap.String("otel_span_id", spanContext.SpanID().String()),
}
}
// Sync flushes logger and ignores the benign stdout or stderr sync errors
// commonly returned by containerized or redirected process outputs.
func Sync(logger *zap.Logger) error {
if logger == nil {
return nil
}
err := logger.Sync()
if err == nil || isIgnorableSyncError(err) {
return nil
}
return err
}
func isIgnorableSyncError(err error) bool {
if err == nil {
return false
}
message := strings.ToLower(err.Error())
switch {
case strings.Contains(message, "invalid argument"):
return true
case strings.Contains(message, "bad file descriptor"):
return true
case strings.Contains(message, "inappropriate ioctl for device"):
return true
default:
return false
}
}
@@ -0,0 +1,37 @@
package logging
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
)
func TestNewRejectsInvalidLogLevel(t *testing.T) {
t.Parallel()
_, err := New("verbose")
require.Error(t, err)
}
func TestTraceFieldsFromContextReturnsTraceAndSpanIDs(t *testing.T) {
t.Parallel()
recorder := tracetest.NewSpanRecorder()
provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
ctx, span := provider.Tracer("test").Start(context.Background(), "operation")
defer span.End()
fields := TraceFieldsFromContext(ctx)
require.Len(t, fields, 2)
assert.Equal(t, "otel_trace_id", fields[0].Key)
assert.Equal(t, "otel_span_id", fields[1].Key)
assert.NotEmpty(t, fields[0].String)
assert.NotEmpty(t, fields[1].String)
}
@@ -0,0 +1,43 @@
package ports
import (
"context"
"fmt"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
)
// ChallengeStore provides source-of-truth persistence for auth confirmation
// challenges without exposing storage-specific primitives.
type ChallengeStore interface {
// Get returns the stored challenge for challengeID. Implementations must
// wrap ErrNotFound when challengeID does not exist.
Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error)
// Create persists record as a new challenge. Implementations must wrap
// ErrConflict when record.ID already exists.
Create(ctx context.Context, record challenge.Challenge) error
// CompareAndSwap replaces previous with next when the currently stored
// challenge matches previous exactly. Implementations must wrap ErrConflict
// when the stored challenge differs from previous and wrap ErrNotFound when
// previous.ID does not exist.
CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error
}
// ValidateComparableChallenges reports whether previous and next are suitable
// for one ChallengeStore compare-and-swap call.
func ValidateComparableChallenges(previous challenge.Challenge, next challenge.Challenge) error {
if err := previous.Validate(); err != nil {
return fmt.Errorf("previous challenge: %w", err)
}
if err := next.Validate(); err != nil {
return fmt.Errorf("next challenge: %w", err)
}
if previous.ID != next.ID {
return fmt.Errorf("challenge compare-and-swap ids must match: %q != %q", previous.ID, next.ID)
}
return nil
}
+9
View File
@@ -0,0 +1,9 @@
package ports
import "time"
// Clock returns current UTC time for the auth/session application layer.
type Clock interface {
// Now returns the current service time.
Now() time.Time
}
@@ -0,0 +1,8 @@
package ports
// CodeGenerator generates cleartext confirmation codes for new auth
// challenges.
type CodeGenerator interface {
// Generate returns one fresh cleartext confirmation code.
Generate() (string, error)
}
+11
View File
@@ -0,0 +1,11 @@
package ports
// CodeHasher hashes cleartext confirmation codes and compares later user input
// against stored hashes.
type CodeHasher interface {
// Hash returns the stored representation for code.
Hash(code string) ([]byte, error)
// Compare reports whether hash matches code.
Compare(hash []byte, code string) (bool, error)
}
@@ -0,0 +1,42 @@
package ports
import (
"context"
"errors"
"fmt"
)
// ConfigProvider returns dynamic auth/session configuration required by later
// service workflows.
type ConfigProvider interface {
// LoadSessionLimit returns the current active-session-limit configuration.
// A nil ActiveSessionLimit means that the limit is disabled.
LoadSessionLimit(ctx context.Context) (SessionLimitConfig, error)
}
// SessionLimitConfig stores the active-session-limit configuration in a form
// that preserves “limit absent” as a first-class state.
type SessionLimitConfig struct {
// ActiveSessionLimit stores the configured limit when one is present. Nil
// means that no active-session limit is configured.
ActiveSessionLimit *int
}
// Validate reports whether SessionLimitConfig contains a valid limit value
// when one is configured.
func (c SessionLimitConfig) Validate() error {
if c.ActiveSessionLimit != nil && *c.ActiveSessionLimit <= 0 {
return errors.New("session limit config active session limit must be positive when configured")
}
return nil
}
// String returns a debug-friendly representation of SessionLimitConfig.
func (c SessionLimitConfig) String() string {
if c.ActiveSessionLimit == nil {
return "session_limit=disabled"
}
return fmt.Sprintf("session_limit=%d", *c.ActiveSessionLimit)
}
+16
View File
@@ -0,0 +1,16 @@
// Package ports defines the storage-agnostic and transport-agnostic service
// boundaries used by the auth/session application layer.
package ports
import "errors"
var (
// ErrNotFound reports that a requested source-of-truth record or remote
// subject does not exist in the dependency behind the port.
ErrNotFound = errors.New("ports: record not found")
// ErrConflict reports that a create or compare-and-swap style mutation
// cannot be applied because the current dependency state no longer matches
// the caller expectation.
ErrConflict = errors.New("ports: conflict")
)
@@ -0,0 +1,13 @@
package ports
import "galaxy/authsession/internal/domain/common"
// IDGenerator generates stable domain identifiers for new challenges and
// device sessions.
type IDGenerator interface {
// NewChallengeID returns a fresh challenge identifier.
NewChallengeID() (common.ChallengeID, error)
// NewDeviceSessionID returns a fresh device-session identifier.
NewDeviceSessionID() (common.DeviceSessionID, error)
}
+86
View File
@@ -0,0 +1,86 @@
package ports
import (
"context"
"errors"
"fmt"
"strings"
"galaxy/authsession/internal/domain/common"
)
// MailSender delivers the public login code or intentionally suppresses
// outward delivery while keeping the auth flow success-shaped.
type MailSender interface {
// SendLoginCode attempts delivery for one generated login code. Explicit
// delivery failure is reported through error, while sent vs suppressed is
// returned in the result.
SendLoginCode(ctx context.Context, input SendLoginCodeInput) (SendLoginCodeResult, error)
}
// SendLoginCodeInput describes one mail-delivery request generated by the auth
// flow.
type SendLoginCodeInput struct {
// Email identifies the normalized target e-mail address.
Email common.Email
// Code stores the cleartext login code that should be delivered to Email.
Code string
}
// Validate reports whether SendLoginCodeInput contains a complete delivery
// request.
func (i SendLoginCodeInput) Validate() error {
if err := i.Email.Validate(); err != nil {
return fmt.Errorf("send login code input email: %w", err)
}
switch {
case strings.TrimSpace(i.Code) == "":
return errors.New("send login code input code must not be empty")
case strings.TrimSpace(i.Code) != i.Code:
return errors.New("send login code input code must not contain surrounding whitespace")
default:
return nil
}
}
// SendLoginCodeOutcome identifies the coarse mail-delivery outcome reported
// back to the auth flow.
type SendLoginCodeOutcome string
const (
// SendLoginCodeOutcomeSent reports that delivery was attempted and accepted.
SendLoginCodeOutcomeSent SendLoginCodeOutcome = "sent"
// SendLoginCodeOutcomeSuppressed reports that outward behavior remains
// success-shaped while actual delivery is intentionally skipped.
SendLoginCodeOutcomeSuppressed SendLoginCodeOutcome = "suppressed"
)
// IsKnown reports whether SendLoginCodeOutcome is supported by the current
// mail-sender contract.
func (o SendLoginCodeOutcome) IsKnown() bool {
switch o {
case SendLoginCodeOutcomeSent, SendLoginCodeOutcomeSuppressed:
return true
default:
return false
}
}
// SendLoginCodeResult describes the stable outcome returned by MailSender for
// one delivery request.
type SendLoginCodeResult struct {
// Outcome reports whether delivery was sent or intentionally suppressed.
Outcome SendLoginCodeOutcome
}
// Validate reports whether SendLoginCodeResult satisfies the mail-sender
// contract invariants.
func (r SendLoginCodeResult) Validate() error {
if !r.Outcome.IsKnown() {
return fmt.Errorf("send login code result outcome %q is unsupported", r.Outcome)
}
return nil
}
+371
View File
@@ -0,0 +1,371 @@
package ports
import (
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
)
func TestRevokeSessionOutcomeIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value RevokeSessionOutcome
want bool
}{
{name: "revoked", value: RevokeSessionOutcomeRevoked, want: true},
{name: "already revoked", value: RevokeSessionOutcomeAlreadyRevoked, want: true},
{name: "unknown", value: RevokeSessionOutcome("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestRevokeUserSessionsOutcomeIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value RevokeUserSessionsOutcome
want bool
}{
{name: "revoked", value: RevokeUserSessionsOutcomeRevoked, want: true},
{name: "no active sessions", value: RevokeUserSessionsOutcomeNoActiveSessions, want: true},
{name: "unknown", value: RevokeUserSessionsOutcome("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestEnsureUserOutcomeIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value EnsureUserOutcome
want bool
}{
{name: "existing", value: EnsureUserOutcomeExisting, want: true},
{name: "created", value: EnsureUserOutcomeCreated, want: true},
{name: "blocked", value: EnsureUserOutcomeBlocked, want: true},
{name: "unknown", value: EnsureUserOutcome("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestBlockUserOutcomeIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value BlockUserOutcome
want bool
}{
{name: "blocked", value: BlockUserOutcomeBlocked, want: true},
{name: "already blocked", value: BlockUserOutcomeAlreadyBlocked, want: true},
{name: "unknown", value: BlockUserOutcome("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestSendLoginCodeOutcomeIsKnown(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value SendLoginCodeOutcome
want bool
}{
{name: "sent", value: SendLoginCodeOutcomeSent, want: true},
{name: "suppressed", value: SendLoginCodeOutcomeSuppressed, want: true},
{name: "unknown", value: SendLoginCodeOutcome("unknown"), want: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := tt.value.IsKnown(); got != tt.want {
require.Failf(t, "test failed", "IsKnown() = %v, want %v", got, tt.want)
}
})
}
}
func TestSessionLimitConfigValidate(t *testing.T) {
t.Parallel()
positive := 3
zero := 0
tests := []struct {
name string
value SessionLimitConfig
wantErr bool
}{
{name: "absent", value: SessionLimitConfig{}},
{name: "positive", value: SessionLimitConfig{ActiveSessionLimit: &positive}},
{name: "zero", value: SessionLimitConfig{ActiveSessionLimit: &zero}, wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.value.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
func TestRevokeSessionInputValidate(t *testing.T) {
t.Parallel()
input := RevokeSessionInput{
DeviceSessionID: common.DeviceSessionID("device-session-1"),
Revocation: devicesession.Revocation{
At: time.Unix(10, 0).UTC(),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
},
}
if err := input.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
}
func TestRevokeSessionResultValidate(t *testing.T) {
t.Parallel()
result := RevokeSessionResult{
Outcome: RevokeSessionOutcomeRevoked,
Session: revokedSessionFixture(),
}
if err := result.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
}
func TestRevokeUserSessionsResultValidate(t *testing.T) {
t.Parallel()
result := RevokeUserSessionsResult{
Outcome: RevokeUserSessionsOutcomeRevoked,
UserID: common.UserID("user-1"),
Sessions: []devicesession.Session{
revokedSessionFixture(),
},
}
if err := result.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
}
func TestEnsureUserResultValidate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value EnsureUserResult
wantErr bool
}{
{
name: "existing",
value: EnsureUserResult{
Outcome: EnsureUserOutcomeExisting,
UserID: common.UserID("user-1"),
},
},
{
name: "created",
value: EnsureUserResult{
Outcome: EnsureUserOutcomeCreated,
UserID: common.UserID("user-2"),
},
},
{
name: "blocked",
value: EnsureUserResult{
Outcome: EnsureUserOutcomeBlocked,
BlockReasonCode: userresolution.BlockReasonCode("policy_block"),
},
},
{
name: "blocked with user id",
value: EnsureUserResult{
Outcome: EnsureUserOutcomeBlocked,
UserID: common.UserID("user-1"),
BlockReasonCode: userresolution.BlockReasonCode("policy_block"),
},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.value.Validate()
if tt.wantErr && err == nil {
require.FailNow(t, "Validate() returned nil error")
}
if !tt.wantErr && err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
})
}
}
func TestBlockUserInputsAndResultValidate(t *testing.T) {
t.Parallel()
byID := BlockUserByIDInput{
UserID: common.UserID("user-1"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
}
if err := byID.Validate(); err != nil {
require.Failf(t, "test failed", "BlockUserByIDInput.Validate() returned error: %v", err)
}
byEmail := BlockUserByEmailInput{
Email: common.Email("pilot@example.com"),
ReasonCode: userresolution.BlockReasonCode("policy_block"),
}
if err := byEmail.Validate(); err != nil {
require.Failf(t, "test failed", "BlockUserByEmailInput.Validate() returned error: %v", err)
}
result := BlockUserResult{
Outcome: BlockUserOutcomeBlocked,
UserID: common.UserID("user-1"),
}
if err := result.Validate(); err != nil {
require.Failf(t, "test failed", "BlockUserResult.Validate() returned error: %v", err)
}
}
func TestSendLoginCodeInputAndResultValidate(t *testing.T) {
t.Parallel()
input := SendLoginCodeInput{
Email: common.Email("pilot@example.com"),
Code: "654321",
}
if err := input.Validate(); err != nil {
require.Failf(t, "test failed", "SendLoginCodeInput.Validate() returned error: %v", err)
}
result := SendLoginCodeResult{Outcome: SendLoginCodeOutcomeSent}
if err := result.Validate(); err != nil {
require.Failf(t, "test failed", "SendLoginCodeResult.Validate() returned error: %v", err)
}
}
func TestValidateComparableChallenges(t *testing.T) {
t.Parallel()
previous := challengeFixture()
next := challengeFixture()
next.Status = challenge.StatusSent
next.DeliveryState = challenge.DeliverySent
if err := ValidateComparableChallenges(previous, next); err != nil {
require.Failf(t, "test failed", "ValidateComparableChallenges() returned error: %v", err)
}
}
func challengeFixture() challenge.Challenge {
timestamp := time.Unix(10, 0).UTC()
return challenge.Challenge{
ID: common.ChallengeID("challenge-1"),
Email: common.Email("pilot@example.com"),
CodeHash: []byte("hash"),
Status: challenge.StatusPendingSend,
DeliveryState: challenge.DeliveryPending,
CreatedAt: timestamp,
ExpiresAt: timestamp.Add(5 * time.Minute),
}
}
func revokedSessionFixture() devicesession.Session {
timestamp := time.Unix(10, 0).UTC()
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID("device-session-1"),
UserID: common.UserID("user-1"),
ClientPublicKey: key,
Status: devicesession.StatusRevoked,
CreatedAt: timestamp.Add(-time.Minute),
Revocation: &devicesession.Revocation{
At: timestamp,
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
},
}
}
@@ -0,0 +1,15 @@
package ports
import (
"context"
"galaxy/authsession/internal/domain/gatewayprojection"
)
// GatewaySessionProjectionPublisher publishes gateway-facing session snapshots
// after source-of-truth session changes.
type GatewaySessionProjectionPublisher interface {
// PublishSession writes or propagates snapshot in the gateway-facing
// projection model.
PublishSession(ctx context.Context, snapshot gatewayprojection.Snapshot) error
}
@@ -0,0 +1,100 @@
package ports
import (
"context"
"fmt"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
)
// SendEmailCodeAbuseProtector decides whether one public send-email-code
// attempt may proceed immediately or must be throttled by the auth-side resend
// cooldown.
type SendEmailCodeAbuseProtector interface {
// CheckAndReserve validates input, checks the current resend cooldown
// decision for input.Email, and reserves a new cooldown window immediately
// when the outcome is allowed.
CheckAndReserve(ctx context.Context, input SendEmailCodeAbuseInput) (SendEmailCodeAbuseResult, error)
}
// SendEmailCodeAbuseInput describes one resend-throttle decision request for
// a normalized public send-email-code attempt.
type SendEmailCodeAbuseInput struct {
// Email identifies the normalized e-mail address addressed by the public
// request.
Email common.Email
// Now records when the send attempt is being evaluated.
Now time.Time
}
// Validate reports whether SendEmailCodeAbuseInput contains a complete resend
// cooldown decision request.
func (i SendEmailCodeAbuseInput) Validate() error {
if err := i.Email.Validate(); err != nil {
return fmt.Errorf("send email code abuse input email: %w", err)
}
if i.Now.IsZero() {
return fmt.Errorf("send email code abuse input now must not be zero")
}
return nil
}
// SendEmailCodeAbuseOutcome identifies the coarse resend-throttle decision for
// one public send-email-code attempt.
type SendEmailCodeAbuseOutcome string
const (
// SendEmailCodeAbuseOutcomeAllowed reports that the attempt may proceed and
// that the cooldown window has been reserved immediately.
SendEmailCodeAbuseOutcomeAllowed SendEmailCodeAbuseOutcome = "allowed"
// SendEmailCodeAbuseOutcomeThrottled reports that the cooldown window is
// still active and that the caller must not extend it.
SendEmailCodeAbuseOutcomeThrottled SendEmailCodeAbuseOutcome = "throttled"
)
// IsKnown reports whether SendEmailCodeAbuseOutcome belongs to the stable
// Stage-17 resend-throttle contract.
func (o SendEmailCodeAbuseOutcome) IsKnown() bool {
switch o {
case SendEmailCodeAbuseOutcomeAllowed, SendEmailCodeAbuseOutcomeThrottled:
return true
default:
return false
}
}
// SendEmailCodeAbuseResult describes one resend-throttle decision returned by
// SendEmailCodeAbuseProtector.
type SendEmailCodeAbuseResult struct {
// Outcome reports whether the current send attempt may proceed or must be
// throttled.
Outcome SendEmailCodeAbuseOutcome
}
// Validate reports whether SendEmailCodeAbuseResult satisfies the resend
// cooldown contract.
func (r SendEmailCodeAbuseResult) Validate() error {
if !r.Outcome.IsKnown() {
return fmt.Errorf("send email code abuse result outcome %q is unsupported", r.Outcome)
}
return nil
}
// SendEmailCodeThrottleStatusToChallengeStatus maps one resend-throttle
// outcome to the challenge lifecycle state used by sendemailcode.
func SendEmailCodeThrottleStatusToChallengeStatus(outcome SendEmailCodeAbuseOutcome) (challenge.Status, challenge.DeliveryState, error) {
switch outcome {
case SendEmailCodeAbuseOutcomeAllowed:
return challenge.StatusPendingSend, challenge.DeliveryPending, nil
case SendEmailCodeAbuseOutcomeThrottled:
return challenge.StatusDeliveryThrottled, challenge.DeliveryThrottled, nil
default:
return "", "", fmt.Errorf("map send email code abuse outcome %q: unsupported outcome", outcome)
}
}
@@ -0,0 +1,47 @@
package ports
import (
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSendEmailCodeAbuseOutcomeIsKnown(t *testing.T) {
t.Parallel()
assert.True(t, SendEmailCodeAbuseOutcomeAllowed.IsKnown())
assert.True(t, SendEmailCodeAbuseOutcomeThrottled.IsKnown())
assert.False(t, SendEmailCodeAbuseOutcome("unknown").IsKnown())
}
func TestSendEmailCodeAbuseInputAndResultValidate(t *testing.T) {
t.Parallel()
input := SendEmailCodeAbuseInput{
Email: common.Email("pilot@example.com"),
Now: time.Unix(10, 0).UTC(),
}
require.NoError(t, input.Validate())
result := SendEmailCodeAbuseResult{Outcome: SendEmailCodeAbuseOutcomeThrottled}
require.NoError(t, result.Validate())
}
func TestSendEmailCodeThrottleStatusToChallengeStatus(t *testing.T) {
t.Parallel()
status, deliveryState, err := SendEmailCodeThrottleStatusToChallengeStatus(SendEmailCodeAbuseOutcomeAllowed)
require.NoError(t, err)
assert.Equal(t, challenge.StatusPendingSend, status)
assert.Equal(t, challenge.DeliveryPending, deliveryState)
status, deliveryState, err = SendEmailCodeThrottleStatusToChallengeStatus(SendEmailCodeAbuseOutcomeThrottled)
require.NoError(t, err)
assert.Equal(t, challenge.StatusDeliveryThrottled, status)
assert.Equal(t, challenge.DeliveryThrottled, deliveryState)
}
+214
View File
@@ -0,0 +1,214 @@
package ports
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
)
// SessionStore provides source-of-truth persistence for device sessions
// without exposing storage-specific encoding or transaction primitives.
type SessionStore interface {
// Get returns the stored session for deviceSessionID. Implementations must
// wrap ErrNotFound when deviceSessionID does not exist.
Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error)
// ListByUserID returns every stored session for userID in newest-first
// order. Implementations must return an empty slice, not ErrNotFound, when
// userID has no stored sessions.
ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error)
// CountActiveByUserID returns the number of active sessions currently stored
// for userID.
CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error)
// Create persists record as a new device session. Implementations must wrap
// ErrConflict when record.ID already exists.
Create(ctx context.Context, record devicesession.Session) error
// Revoke stores a revoked view of one target session. Implementations must
// wrap ErrNotFound when input.DeviceSessionID does not exist.
Revoke(ctx context.Context, input RevokeSessionInput) (RevokeSessionResult, error)
// RevokeAllByUserID stores revoked views for all currently active sessions
// owned by input.UserID.
RevokeAllByUserID(ctx context.Context, input RevokeUserSessionsInput) (RevokeUserSessionsResult, error)
}
// RevokeSessionInput describes one single-session revoke mutation requested
// from SessionStore.
type RevokeSessionInput struct {
// DeviceSessionID identifies the session that should be revoked.
DeviceSessionID common.DeviceSessionID
// Revocation stores the audit metadata that must be attached to the revoked
// session.
Revocation devicesession.Revocation
}
// Validate reports whether RevokeSessionInput contains a complete revoke
// request.
func (i RevokeSessionInput) Validate() error {
if err := i.DeviceSessionID.Validate(); err != nil {
return fmt.Errorf("revoke session input device session id: %w", err)
}
if err := i.Revocation.Validate(); err != nil {
return fmt.Errorf("revoke session input revocation: %w", err)
}
return nil
}
// RevokeSessionOutcome identifies the coarse outcome of revoking one device
// session.
type RevokeSessionOutcome string
const (
// RevokeSessionOutcomeRevoked reports that an active session was moved to
// the revoked state by the current mutation.
RevokeSessionOutcomeRevoked RevokeSessionOutcome = "revoked"
// RevokeSessionOutcomeAlreadyRevoked reports that the requested session had
// already been revoked before the current mutation.
RevokeSessionOutcomeAlreadyRevoked RevokeSessionOutcome = "already_revoked"
)
// IsKnown reports whether RevokeSessionOutcome is supported by the current
// session-store contract.
func (o RevokeSessionOutcome) IsKnown() bool {
switch o {
case RevokeSessionOutcomeRevoked, RevokeSessionOutcomeAlreadyRevoked:
return true
default:
return false
}
}
// RevokeSessionResult describes the stable outcome returned by SessionStore
// after a single-session revoke attempt.
type RevokeSessionResult struct {
// Outcome reports whether the session was revoked just now or had already
// been revoked.
Outcome RevokeSessionOutcome
// Session stores the current source-of-truth session state after the revoke
// attempt.
Session devicesession.Session
}
// Validate reports whether RevokeSessionResult satisfies the session-store
// contract invariants.
func (r RevokeSessionResult) Validate() error {
if !r.Outcome.IsKnown() {
return fmt.Errorf("revoke session result outcome %q is unsupported", r.Outcome)
}
if err := r.Session.Validate(); err != nil {
return fmt.Errorf("revoke session result session: %w", err)
}
if r.Session.Status != devicesession.StatusRevoked {
return errors.New("revoke session result session must be revoked")
}
return nil
}
// RevokeUserSessionsInput describes one bulk user-session revoke mutation
// requested from SessionStore.
type RevokeUserSessionsInput struct {
// UserID identifies the owner whose active sessions should be revoked.
UserID common.UserID
// Revocation stores the audit metadata that must be attached to every
// revoked session.
Revocation devicesession.Revocation
}
// Validate reports whether RevokeUserSessionsInput contains a complete bulk
// revoke request.
func (i RevokeUserSessionsInput) Validate() error {
if err := i.UserID.Validate(); err != nil {
return fmt.Errorf("revoke user sessions input user id: %w", err)
}
if err := i.Revocation.Validate(); err != nil {
return fmt.Errorf("revoke user sessions input revocation: %w", err)
}
return nil
}
// RevokeUserSessionsOutcome identifies the coarse outcome of revoking all
// active sessions of one user.
type RevokeUserSessionsOutcome string
const (
// RevokeUserSessionsOutcomeRevoked reports that one or more active sessions
// were revoked by the current mutation.
RevokeUserSessionsOutcomeRevoked RevokeUserSessionsOutcome = "revoked"
// RevokeUserSessionsOutcomeNoActiveSessions reports that the target user did
// not currently own any active sessions.
RevokeUserSessionsOutcomeNoActiveSessions RevokeUserSessionsOutcome = "no_active_sessions"
)
// IsKnown reports whether RevokeUserSessionsOutcome is supported by the
// current session-store contract.
func (o RevokeUserSessionsOutcome) IsKnown() bool {
switch o {
case RevokeUserSessionsOutcomeRevoked, RevokeUserSessionsOutcomeNoActiveSessions:
return true
default:
return false
}
}
// RevokeUserSessionsResult describes the stable outcome returned by
// SessionStore after one bulk revoke attempt.
type RevokeUserSessionsResult struct {
// Outcome reports whether at least one active session was revoked.
Outcome RevokeUserSessionsOutcome
// UserID identifies the owner whose sessions were evaluated.
UserID common.UserID
// Sessions stores the current source-of-truth session states for every
// session affected by the bulk revoke operation.
Sessions []devicesession.Session
}
// Validate reports whether RevokeUserSessionsResult satisfies the bulk
// session-store contract invariants.
func (r RevokeUserSessionsResult) Validate() error {
if !r.Outcome.IsKnown() {
return fmt.Errorf("revoke user sessions result outcome %q is unsupported", r.Outcome)
}
if err := r.UserID.Validate(); err != nil {
return fmt.Errorf("revoke user sessions result user id: %w", err)
}
for index, session := range r.Sessions {
if err := session.Validate(); err != nil {
return fmt.Errorf("revoke user sessions result session %d: %w", index, err)
}
if session.Status != devicesession.StatusRevoked {
return fmt.Errorf("revoke user sessions result session %d must be revoked", index)
}
if session.UserID != r.UserID {
return fmt.Errorf("revoke user sessions result session %d belongs to %q, want %q", index, session.UserID, r.UserID)
}
}
switch r.Outcome {
case RevokeUserSessionsOutcomeRevoked:
if len(r.Sessions) == 0 {
return errors.New("revoke user sessions result must include sessions when outcome is revoked")
}
case RevokeUserSessionsOutcomeNoActiveSessions:
if len(r.Sessions) != 0 {
return errors.New("revoke user sessions result must not include sessions when outcome is no_active_sessions")
}
}
return nil
}
@@ -0,0 +1,203 @@
package ports
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
)
// UserDirectory provides the auth/session boundary to user ownership,
// registration, and block-policy decisions.
type UserDirectory interface {
// ResolveByEmail returns the current resolution state for email without
// creating any new user record.
ResolveByEmail(ctx context.Context, email common.Email) (userresolution.Result, error)
// ExistsByUserID reports whether userID currently identifies a stored user
// record.
ExistsByUserID(ctx context.Context, userID common.UserID) (bool, error)
// EnsureUserByEmail returns an existing user for email, creates a new user
// when registration is allowed, or reports a blocked outcome when the
// address may not continue through confirm flow.
EnsureUserByEmail(ctx context.Context, email common.Email) (EnsureUserResult, error)
// BlockByUserID applies a block state to the user identified by
// input.UserID. Implementations must wrap ErrNotFound when input.UserID does
// not exist.
BlockByUserID(ctx context.Context, input BlockUserByIDInput) (BlockUserResult, error)
// BlockByEmail applies a block state to input.Email, even when no user
// record currently exists for that e-mail address.
BlockByEmail(ctx context.Context, input BlockUserByEmailInput) (BlockUserResult, error)
}
// EnsureUserOutcome identifies the coarse outcome of ensuring a user record
// for one normalized e-mail address.
type EnsureUserOutcome string
const (
// EnsureUserOutcomeExisting reports that the e-mail already belonged to a
// stored user.
EnsureUserOutcomeExisting EnsureUserOutcome = "existing"
// EnsureUserOutcomeCreated reports that a new user was created for the
// e-mail address.
EnsureUserOutcomeCreated EnsureUserOutcome = "created"
// EnsureUserOutcomeBlocked reports that the e-mail cannot be used for login
// or registration.
EnsureUserOutcomeBlocked EnsureUserOutcome = "blocked"
)
// IsKnown reports whether EnsureUserOutcome is supported by the current
// user-directory contract.
func (o EnsureUserOutcome) IsKnown() bool {
switch o {
case EnsureUserOutcomeExisting, EnsureUserOutcomeCreated, EnsureUserOutcomeBlocked:
return true
default:
return false
}
}
// EnsureUserResult describes the stable outcome returned by UserDirectory
// after one ensure-user attempt.
type EnsureUserResult struct {
// Outcome reports whether the user already existed, was created, or is
// blocked by policy.
Outcome EnsureUserOutcome
// UserID is present when Outcome is EnsureUserOutcomeExisting or
// EnsureUserOutcomeCreated.
UserID common.UserID
// BlockReasonCode is present only when Outcome is EnsureUserOutcomeBlocked.
BlockReasonCode userresolution.BlockReasonCode
}
// Validate reports whether EnsureUserResult satisfies the user-directory
// contract invariants.
func (r EnsureUserResult) Validate() error {
if !r.Outcome.IsKnown() {
return fmt.Errorf("ensure user result outcome %q is unsupported", r.Outcome)
}
switch r.Outcome {
case EnsureUserOutcomeExisting, EnsureUserOutcomeCreated:
if err := r.UserID.Validate(); err != nil {
return fmt.Errorf("ensure user result user id: %w", err)
}
if !r.BlockReasonCode.IsZero() {
return errors.New("ensure user result must not contain block reason code for existing or created outcomes")
}
case EnsureUserOutcomeBlocked:
if !r.UserID.IsZero() {
return errors.New("ensure user result must not contain user id for blocked outcome")
}
if err := r.BlockReasonCode.Validate(); err != nil {
return fmt.Errorf("ensure user result block reason code: %w", err)
}
}
return nil
}
// BlockUserByIDInput describes one block mutation targeted by stable user id.
type BlockUserByIDInput struct {
// UserID identifies the user that should be blocked.
UserID common.UserID
// ReasonCode stores the machine-readable block reason to apply.
ReasonCode userresolution.BlockReasonCode
}
// Validate reports whether BlockUserByIDInput contains a complete block
// request.
func (i BlockUserByIDInput) Validate() error {
if err := i.UserID.Validate(); err != nil {
return fmt.Errorf("block user by id input user id: %w", err)
}
if err := i.ReasonCode.Validate(); err != nil {
return fmt.Errorf("block user by id input reason code: %w", err)
}
return nil
}
// BlockUserByEmailInput describes one block mutation targeted by normalized
// e-mail address.
type BlockUserByEmailInput struct {
// Email identifies the e-mail address that should be blocked.
Email common.Email
// ReasonCode stores the machine-readable block reason to apply.
ReasonCode userresolution.BlockReasonCode
}
// Validate reports whether BlockUserByEmailInput contains a complete block
// request.
func (i BlockUserByEmailInput) Validate() error {
if err := i.Email.Validate(); err != nil {
return fmt.Errorf("block user by email input email: %w", err)
}
if err := i.ReasonCode.Validate(); err != nil {
return fmt.Errorf("block user by email input reason code: %w", err)
}
return nil
}
// BlockUserOutcome identifies the coarse outcome of blocking one user or
// e-mail subject.
type BlockUserOutcome string
const (
// BlockUserOutcomeBlocked reports that the current mutation applied a new
// block state.
BlockUserOutcomeBlocked BlockUserOutcome = "blocked"
// BlockUserOutcomeAlreadyBlocked reports that the target subject had already
// been blocked before the current mutation.
BlockUserOutcomeAlreadyBlocked BlockUserOutcome = "already_blocked"
)
// IsKnown reports whether BlockUserOutcome is supported by the current
// user-directory contract.
func (o BlockUserOutcome) IsKnown() bool {
switch o {
case BlockUserOutcomeBlocked, BlockUserOutcomeAlreadyBlocked:
return true
default:
return false
}
}
// BlockUserResult describes the stable outcome returned by UserDirectory after
// one block attempt.
type BlockUserResult struct {
// Outcome reports whether the current mutation applied a new block state.
Outcome BlockUserOutcome
// UserID optionally stores the stable user identifier resolved for the
// blocked subject when one exists.
UserID common.UserID
}
// Validate reports whether BlockUserResult satisfies the user-directory
// contract invariants.
func (r BlockUserResult) Validate() error {
if !r.Outcome.IsKnown() {
return fmt.Errorf("block user result outcome %q is unsupported", r.Outcome)
}
if !r.UserID.IsZero() {
if err := r.UserID.Validate(); err != nil {
return fmt.Errorf("block user result user id: %w", err)
}
}
return nil
}
@@ -0,0 +1,88 @@
package blockuser
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRetriesProjectionPublishesForBlockFlow(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{errors.New("publish failed"), nil},
}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
require.Len(t, publisher.PublishedSnapshots(), 2)
}
func TestExecuteRepairsProjectionOnRepeatedAlreadyBlockedRequest(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
sessionRecord, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, sessionRecord.Revocation)
assert.Equal(t, devicesession.StatusRevoked, sessionRecord.Status)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, sessionRecord.Revocation.ReasonCode)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "already_blocked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
}
@@ -0,0 +1,91 @@
package blockuser
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const blockFlowPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
func TestBlockUserAffectsLaterSendAndConfirmFlows(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
sessionStore := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
idGenerator := &testkit.SequenceIDGenerator{
ChallengeIDs: []common.ChallengeID{"challenge-1"},
DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"},
}
hasher := testkit.DeterministicCodeHasher{}
mailSender := &testkit.RecordingMailSender{}
now := time.Unix(20, 0).UTC()
clock := testkit.FixedClock{Time: now}
blockService, err := New(userDirectory, sessionStore, publisher, clock)
require.NoError(t, err)
_, err = blockService.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
sendService, err := sendemailcode.New(
challengeStore,
userDirectory,
idGenerator,
testkit.FixedCodeGenerator{Code: "654321"},
hasher,
mailSender,
clock,
)
require.NoError(t, err)
sendResult, err := sendService.Execute(context.Background(), sendemailcode.Input{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Equal(t, "challenge-1", sendResult.ChallengeID)
assert.Empty(t, mailSender.RecordedInputs())
challengeRecord, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, err)
assert.Equal(t, challenge.StatusDeliverySuppressed, challengeRecord.Status)
assert.Equal(t, challenge.DeliverySuppressed, challengeRecord.DeliveryState)
confirmService, err := confirmemailcode.New(
challengeStore,
sessionStore,
userDirectory,
testkit.StaticConfigProvider{},
publisher,
idGenerator,
hasher,
clock,
)
require.NoError(t, err)
_, err = confirmService.Execute(context.Background(), confirmemailcode.Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: blockFlowPublicKey,
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err))
updatedChallenge, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusFailed, updatedChallenge.Status)
}
@@ -0,0 +1,64 @@
package blockuser
import (
"bytes"
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestExecuteLogsSafeOutcomeFields(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
sessionStore := &testkit.InMemorySessionStore{}
require.NoError(t, sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
logger, buffer := newObservedServiceLogger()
service, err := NewWithObservability(
userDirectory,
sessionStore,
&testkit.RecordingProjectionPublisher{},
testkit.FixedClock{Time: time.Unix(20, 0).UTC()},
logger,
nil,
)
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
logOutput := buffer.String()
assert.Contains(t, logOutput, "block_user")
assert.Contains(t, logOutput, "\"user_id\":\"user-1\"")
assert.Contains(t, logOutput, "\"reason_code\":\"policy_block\"")
assert.NotContains(t, logOutput, "pilot@example.com")
}
func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) {
buffer := &bytes.Buffer{}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = ""
core := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
zapcore.AddSync(buffer),
zap.DebugLevel,
)
return zap.New(core), buffer
}
@@ -0,0 +1,294 @@
// Package blockuser implements the trusted internal block-user use case.
package blockuser
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
const (
// SubjectKindUserID identifies a block request addressed by stable user id.
SubjectKindUserID = "user_id"
// SubjectKindEmail identifies a block request addressed by normalized e-mail
// address.
SubjectKindEmail = "email"
)
// Input describes one trusted internal block-user request.
type Input struct {
// UserID identifies the subject to block when the request is user-id based.
UserID string
// Email identifies the subject to block when the request is e-mail based.
Email string
// ReasonCode stores the machine-readable block reason code applied to the
// user directory.
ReasonCode string
// ActorType stores the machine-readable actor type for any derived session
// revocation.
ActorType string
// ActorID stores the optional stable actor identifier for any derived
// session revocation.
ActorID string
}
// Result describes the frozen internal block-user acknowledgement.
type Result struct {
// Outcome reports whether the block state was newly applied or already
// existed.
Outcome string
// SubjectKind reports whether the request targeted `user_id` or `email`.
SubjectKind string
// SubjectValue stores the normalized subject value addressed by the
// operation.
SubjectValue string
// AffectedSessionCount reports how many sessions changed state during the
// current call.
AffectedSessionCount int64
// AffectedDeviceSessionIDs lists every session identifier affected during
// the current call.
AffectedDeviceSessionIDs []string
}
// Service executes the trusted internal block-user use case.
type Service struct {
userDirectory ports.UserDirectory
sessionStore ports.SessionStore
publisher ports.GatewaySessionProjectionPublisher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a block-user service wired to the required ports.
func New(userDirectory ports.UserDirectory, sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
return NewWithObservability(userDirectory, sessionStore, publisher, clock, nil, nil)
}
// NewWithObservability returns a block-user service wired to the required
// ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
userDirectory ports.UserDirectory,
sessionStore ports.SessionStore,
publisher ports.GatewaySessionProjectionPublisher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case userDirectory == nil:
return nil, fmt.Errorf("blockuser: user directory must not be nil")
case sessionStore == nil:
return nil, fmt.Errorf("blockuser: session store must not be nil")
case publisher == nil:
return nil, fmt.Errorf("blockuser: projection publisher must not be nil")
case clock == nil:
return nil, fmt.Errorf("blockuser: clock must not be nil")
default:
return &Service{
userDirectory: userDirectory,
sessionStore: sessionStore,
publisher: publisher,
clock: clock,
logger: namedLogger(logger, "block_user"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute applies the requested block state and revokes any active sessions of
// the resolved user when one exists.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "block_user"),
}
defer func() {
if result.Outcome != "" {
logFields = append(logFields, zap.String("outcome", result.Outcome))
}
if result.SubjectKind != "" {
logFields = append(logFields, zap.String("subject_kind", result.SubjectKind))
}
if result.AffectedSessionCount > 0 {
logFields = append(logFields, zap.Int64("affected_session_count", result.AffectedSessionCount))
}
shared.LogServiceOutcome(s.logger, ctx, "block user completed", err, logFields...)
}()
subjectKind, subjectValue, storeResult, err := s.blockSubject(ctx, input)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("reason_code", shared.NormalizeString(input.ReasonCode)))
if !storeResult.UserID.IsZero() {
logFields = append(logFields, zap.String("user_id", storeResult.UserID.String()))
}
affectedDeviceSessionIDs := []string{}
affectedSessionCount := int64(0)
if !storeResult.UserID.IsZero() {
revocation, err := shared.BuildRevocation(
devicesession.RevokeReasonUserBlocked.String(),
input.ActorType,
input.ActorID,
s.clock.Now(),
)
if err != nil {
return Result{}, err
}
revokeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{
UserID: storeResult.UserID,
Revocation: revocation,
})
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := revokeResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
for _, record := range revokeResult.Sessions {
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user"); err != nil {
return Result{}, err
}
affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String())
}
if revokeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions {
if err := s.republishCurrentRevokedSessions(ctx, storeResult.UserID); err != nil {
return Result{}, err
}
}
affectedSessionCount = int64(len(revokeResult.Sessions))
if affectedSessionCount > 0 {
s.telemetry.RecordSessionRevocations(ctx, "block_user", devicesession.RevokeReasonUserBlocked.String(), affectedSessionCount)
}
}
result = Result{
Outcome: string(storeResult.Outcome),
SubjectKind: subjectKind,
SubjectValue: subjectValue,
AffectedSessionCount: affectedSessionCount,
AffectedDeviceSessionIDs: affectedDeviceSessionIDs,
}
return result, nil
}
func (s *Service) blockSubject(ctx context.Context, input Input) (string, string, ports.BlockUserResult, error) {
userID := shared.NormalizeString(input.UserID)
email := shared.NormalizeString(input.Email)
switch {
case userID == "" && email == "":
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
case userID != "" && email != "":
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
case userID != "":
parsedUserID, err := shared.ParseUserID(userID)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
result, err := s.userDirectory.BlockByUserID(ctx, ports.BlockUserByIDInput{
UserID: parsedUserID,
ReasonCode: reasonCode,
})
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return "", "", ports.BlockUserResult{}, shared.SubjectNotFound()
default:
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
}
}
if err := result.Validate(); err != nil {
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_user_id", string(result.Outcome))
return SubjectKindUserID, parsedUserID.String(), result, nil
default:
parsedEmail, err := shared.ParseEmail(email)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
result, err := s.userDirectory.BlockByEmail(ctx, ports.BlockUserByEmailInput{
Email: parsedEmail,
ReasonCode: reasonCode,
})
if err != nil {
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
}
if err := result.Validate(); err != nil {
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_email", string(result.Outcome))
return SubjectKindEmail, parsedEmail.String(), result, nil
}
}
func parseBlockReasonCode(value string) (userresolution.BlockReasonCode, error) {
reasonCode := userresolution.BlockReasonCode(shared.NormalizeString(value))
if err := reasonCode.Validate(); err != nil {
return "", shared.InvalidRequest(err.Error())
}
return reasonCode, nil
}
func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error {
records, err := s.sessionStore.ListByUserID(ctx, userID)
if err != nil {
return shared.ServiceUnavailable(err)
}
for _, record := range records {
if record.Status != devicesession.StatusRevoked {
continue
}
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user_repair"); err != nil {
return err
}
}
return nil
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,237 @@
package blockuser
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteBlocksByUserIDAndRevokesSessions(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
assert.Equal(t, "user-1", result.SubjectValue)
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), published[0].RevokeActorType)
}
func TestExecuteBlocksByEmailWithoutExistingUser(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
assert.Equal(t, "pilot@example.com", result.SubjectValue)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
assert.Empty(t, publisher.PublishedSnapshots())
}
func TestExecuteBlocksByEmailWithExistingUserAndRevokesSessions(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
}
func TestExecuteReturnsSubjectNotFoundForUnknownUserID(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemoryUserDirectory{}, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{
UserID: "missing",
ReasonCode: "policy_block",
ActorType: "admin",
})
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
}
func TestExecuteAlreadyBlockedStillRevokesLingeringSessions(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedBlockedUser(common.Email("pilot@example.com"), common.UserID("user-1"), userresolution.BlockReasonCode("policy_block")); err != nil {
require.Failf(t, "test failed", "SeedBlockedUser() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "already_blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
}
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,60 @@
package blockuser
import (
"context"
"testing"
"time"
stubuserservice "galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
t.Parallel()
t.Run("blocks by email through runtime stub", func(t *testing.T) {
t.Parallel()
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
store := &testkit.InMemorySessionStore{}
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(userDirectory, store, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
})
t.Run("blocks by user id through runtime stub", func(t *testing.T) {
t.Parallel()
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
assert.Equal(t, "blocked", result.Outcome)
})
}
@@ -0,0 +1,39 @@
package confirmemailcode
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteReturnsInvalidCodeForThrottledChallengeWithoutConsumingAttempts(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = challenge.StatusDeliveryThrottled
record.DeliveryState = challenge.DeliveryThrottled
require.NoError(t, record.Validate())
require.NoError(t, deps.challengeStore.Create(context.Background(), record))
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeInvalidCode, shared.CodeOf(err))
updated, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, 0, updated.Attempts.Confirm)
assert.Equal(t, challenge.StatusDeliveryThrottled, updated.Status)
}
@@ -0,0 +1,106 @@
package confirmemailcode
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteConfirmsChallengeAfterTransientProjectionPublishFailures(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, deps.challengeStore.Create(
context.Background(),
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
))
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
}
func TestExecuteConfirmedRetryRepublishesAfterTransientProjectionPublishFailures(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
key := mustClientPublicKey(t, publicKeyString())
require.NoError(t, deps.challengeStore.Create(
context.Background(),
confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
))
require.NoError(t, deps.sessionStore.Create(
context.Background(),
activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute)),
))
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
}
func TestExecuteRepairsProjectionOnIdenticalRetryAfterExhaustedPublishRetries(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Err = errors.New("publish failed")
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, deps.challengeStore.Create(
context.Background(),
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
))
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
sessionRecord, getErr := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
assert.Equal(t, devicesession.StatusActive, sessionRecord.Status)
challengeRecord, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusConfirmedPendingExpire, challengeRecord.Status)
require.NotNil(t, challengeRecord.Confirmation)
deps.publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
}
@@ -0,0 +1,588 @@
// Package confirmemailcode implements the public confirm-email-code use case.
package confirmemailcode
import (
"context"
"errors"
"fmt"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/sessionlimit"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
const (
revokeReasonConfirmRace common.RevokeReasonCode = "confirm_race_repair"
revokeActorTypeService common.RevokeActorType = "service"
revokeActorIDService = "confirmemailcode"
)
// Input describes one public confirm-email-code request.
type Input struct {
// ChallengeID identifies the challenge that should be confirmed.
ChallengeID string
// Code is the cleartext confirmation code submitted by the caller.
Code string
// ClientPublicKey is the base64-encoded raw 32-byte Ed25519 public key that
// should be registered for the created device session.
ClientPublicKey string
}
// Result describes one public confirm-email-code response.
type Result struct {
// DeviceSessionID is the stable identifier of the created or idempotently
// recovered device session.
DeviceSessionID string
}
// Service executes the public confirm-email-code use case.
type Service struct {
challengeStore ports.ChallengeStore
sessionStore ports.SessionStore
userDirectory ports.UserDirectory
configProvider ports.ConfigProvider
publisher ports.GatewaySessionProjectionPublisher
idGenerator ports.IDGenerator
codeHasher ports.CodeHasher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a confirm-email-code service wired to the required ports.
func New(
challengeStore ports.ChallengeStore,
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
configProvider ports.ConfigProvider,
publisher ports.GatewaySessionProjectionPublisher,
idGenerator ports.IDGenerator,
codeHasher ports.CodeHasher,
clock ports.Clock,
) (*Service, error) {
return NewWithTelemetry(
challengeStore,
sessionStore,
userDirectory,
configProvider,
publisher,
idGenerator,
codeHasher,
clock,
nil,
)
}
// NewWithTelemetry returns a confirm-email-code service wired to the required
// ports plus the optional Stage-17 telemetry runtime.
func NewWithTelemetry(
challengeStore ports.ChallengeStore,
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
configProvider ports.ConfigProvider,
publisher ports.GatewaySessionProjectionPublisher,
idGenerator ports.IDGenerator,
codeHasher ports.CodeHasher,
clock ports.Clock,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
return NewWithObservability(
challengeStore,
sessionStore,
userDirectory,
configProvider,
publisher,
idGenerator,
codeHasher,
clock,
nil,
telemetryRuntime,
)
}
// NewWithObservability returns a confirm-email-code service wired to the
// required ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
challengeStore ports.ChallengeStore,
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
configProvider ports.ConfigProvider,
publisher ports.GatewaySessionProjectionPublisher,
idGenerator ports.IDGenerator,
codeHasher ports.CodeHasher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case challengeStore == nil:
return nil, fmt.Errorf("confirmemailcode: challenge store must not be nil")
case sessionStore == nil:
return nil, fmt.Errorf("confirmemailcode: session store must not be nil")
case userDirectory == nil:
return nil, fmt.Errorf("confirmemailcode: user directory must not be nil")
case configProvider == nil:
return nil, fmt.Errorf("confirmemailcode: config provider must not be nil")
case publisher == nil:
return nil, fmt.Errorf("confirmemailcode: projection publisher must not be nil")
case idGenerator == nil:
return nil, fmt.Errorf("confirmemailcode: id generator must not be nil")
case codeHasher == nil:
return nil, fmt.Errorf("confirmemailcode: code hasher must not be nil")
case clock == nil:
return nil, fmt.Errorf("confirmemailcode: clock must not be nil")
default:
return &Service{
challengeStore: challengeStore,
sessionStore: sessionStore,
userDirectory: userDirectory,
configProvider: configProvider,
publisher: publisher,
idGenerator: idGenerator,
codeHasher: codeHasher,
clock: clock,
logger: namedLogger(logger, "confirm_email_code"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute validates one challenge confirmation attempt, creates a device
// session when policy allows it, and handles short-window idempotent retries.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
}
defer func() {
outcome := string(telemetry.ConfirmEmailCodeOutcomeSuccess)
if err != nil {
outcome = shared.CodeOf(err)
if outcome == "" {
outcome = shared.ErrorCodeServiceUnavailable
}
}
s.telemetry.RecordConfirmEmailCode(ctx, outcome)
logFields = append(logFields, zap.String("outcome", outcome))
if result.DeviceSessionID != "" {
logFields = append(logFields, zap.String("device_session_id", result.DeviceSessionID))
}
shared.LogServiceOutcome(s.logger, ctx, "confirm email code completed", err, logFields...)
}()
challengeID, err := shared.ParseChallengeID(input.ChallengeID)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("challenge_id", challengeID.String()))
code, err := shared.ParseRequiredCode(input.Code)
if err != nil {
return Result{}, err
}
clientPublicKey, err := shared.ParseClientPublicKey(input.ClientPublicKey)
if err != nil {
return Result{}, err
}
for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ {
current, err := s.challengeStore.Get(ctx, challengeID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.ChallengeNotFound()
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
now := s.clock.Now().UTC()
if expired, err := s.ensureChallengeNotExpired(ctx, current, now); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return Result{}, err
} else if expired {
return Result{}, shared.ChallengeExpired()
}
switch {
case current.Status.IsConfirmedRetryState():
return s.handleConfirmedRetry(ctx, current, code, clientPublicKey)
case !current.Status.AcceptsFreshConfirm():
return Result{}, shared.InvalidCode()
}
match, err := s.codeHasher.Compare(current.CodeHash, code)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if !match {
if err := s.recordInvalidConfirmAttempt(ctx, current, now); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return Result{}, err
}
return Result{}, shared.InvalidCode()
}
ensureUserResult, err := s.userDirectory.EnsureUserByEmail(ctx, current.Email)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := ensureUserResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "ensure_user_by_email", string(ensureUserResult.Outcome))
if !ensureUserResult.UserID.IsZero() {
logFields = append(logFields, zap.String("user_id", ensureUserResult.UserID.String()))
}
if ensureUserResult.Outcome == ports.EnsureUserOutcomeBlocked {
if err := s.markChallengeFailed(ctx, current, now); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return Result{}, err
}
return Result{}, shared.BlockedByPolicy()
}
limitConfig, err := s.configProvider.LoadSessionLimit(ctx)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
decision, err := s.evaluateSessionLimit(ctx, ensureUserResult.UserID, limitConfig)
if err != nil {
return Result{}, err
}
if decision.Kind == sessionlimit.KindExceeded {
s.telemetry.RecordSessionLimitRejection(ctx)
return Result{}, shared.SessionLimitExceeded()
}
sessionRecord, err := s.createSession(ctx, ensureUserResult.UserID, clientPublicKey, now)
if err != nil {
return Result{}, err
}
next := current
next.Status = challenge.StatusConfirmedPendingExpire
next.ExpiresAt = now.Add(challenge.ConfirmedRetention)
next.Abuse.LastAttemptAt = &now
next.Confirmation = &challenge.Confirmation{
SessionID: sessionRecord.ID,
ClientPublicKey: clientPublicKey,
ConfirmedAt: now,
}
if err := next.Validate(); err != nil {
s.bestEffortRevokeSupersededSession(ctx, sessionRecord)
return Result{}, shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
if errors.Is(err, ports.ErrConflict) {
return s.handleCreateSessionCASConflict(ctx, challengeID, code, clientPublicKey, sessionRecord)
}
s.bestEffortRevokeSupersededSession(ctx, sessionRecord)
return Result{}, shared.ServiceUnavailable(err)
}
// Publish the currently stored session view so a concurrent revoke/block
// cannot overwrite source of truth with a stale active projection.
currentSession, err := s.sessionStore.Get(ctx, sessionRecord.ID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: newly created session %q was not found", sessionRecord.ID))
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := s.publishSession(ctx, currentSession, "confirm_email_code"); err != nil {
return Result{}, err
}
return Result{DeviceSessionID: currentSession.ID.String()}, nil
}
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: compare-and-swap retry limit exceeded"))
}
func (s *Service) ensureChallengeNotExpired(ctx context.Context, current challenge.Challenge, now time.Time) (bool, error) {
if current.IsExpiredAt(now) {
if current.Status != challenge.StatusExpired && current.Status.CanTransitionTo(challenge.StatusExpired) {
next := current
next.Status = challenge.StatusExpired
next.Abuse.LastAttemptAt = &now
next.Confirmation = nil
if err := next.Validate(); err != nil {
return true, shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
if !errors.Is(err, ports.ErrConflict) {
return true, shared.ServiceUnavailable(err)
}
return false, err
}
}
return true, nil
}
return false, nil
}
func (s *Service) handleConfirmedRetry(ctx context.Context, current challenge.Challenge, code string, clientPublicKey common.ClientPublicKey) (Result, error) {
match, err := s.codeHasher.Compare(current.CodeHash, code)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if !match {
return Result{}, shared.InvalidCode()
}
if current.Confirmation == nil {
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed challenge is missing confirmation metadata"))
}
if current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() {
return Result{}, shared.InvalidCode()
}
record, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed session %q was not found", current.Confirmation.SessionID))
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := s.publishSession(ctx, record, "confirm_email_code_retry"); err != nil {
return Result{}, err
}
return Result{DeviceSessionID: record.ID.String()}, nil
}
func (s *Service) recordInvalidConfirmAttempt(ctx context.Context, current challenge.Challenge, now time.Time) error {
next := current
next.Attempts.Confirm++
next.Abuse.LastAttemptAt = &now
if next.Attempts.Confirm >= challenge.MaxInvalidConfirmAttempts {
next.Status = challenge.StatusFailed
}
if err := next.Validate(); err != nil {
return shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
switch {
case errors.Is(err, ports.ErrConflict):
return err
default:
return shared.ServiceUnavailable(err)
}
}
return nil
}
func (s *Service) markChallengeFailed(ctx context.Context, current challenge.Challenge, now time.Time) error {
next := current
next.Status = challenge.StatusFailed
next.Abuse.LastAttemptAt = &now
if err := next.Validate(); err != nil {
return shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
switch {
case errors.Is(err, ports.ErrConflict):
return err
default:
return shared.ServiceUnavailable(err)
}
}
return nil
}
func (s *Service) evaluateSessionLimit(ctx context.Context, userID common.UserID, config ports.SessionLimitConfig) (sessionlimit.Decision, error) {
activeSessionCount, err := s.sessionStore.CountActiveByUserID(ctx, userID)
if err != nil {
return sessionlimit.Decision{}, shared.ServiceUnavailable(err)
}
decision, err := shared.EvaluateSessionLimit(config, activeSessionCount)
if err != nil {
return sessionlimit.Decision{}, err
}
return decision, nil
}
func (s *Service) createSession(ctx context.Context, userID common.UserID, clientPublicKey common.ClientPublicKey, now time.Time) (devicesession.Session, error) {
for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ {
deviceSessionID, err := s.idGenerator.NewDeviceSessionID()
if err != nil {
return devicesession.Session{}, shared.ServiceUnavailable(err)
}
record := devicesession.Session{
ID: deviceSessionID,
UserID: userID,
ClientPublicKey: clientPublicKey,
Status: devicesession.StatusActive,
CreatedAt: now,
}
if err := record.Validate(); err != nil {
return devicesession.Session{}, shared.InternalError(err)
}
if err := s.sessionStore.Create(ctx, record); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return devicesession.Session{}, shared.ServiceUnavailable(err)
}
s.telemetry.RecordSessionCreated(ctx)
return record, nil
}
return devicesession.Session{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: session id conflict retry limit exceeded"))
}
func (s *Service) handleCreateSessionCASConflict(
ctx context.Context,
challengeID common.ChallengeID,
code string,
clientPublicKey common.ClientPublicKey,
createdSession devicesession.Session,
) (Result, error) {
defer s.bestEffortRevokeSupersededSession(ctx, createdSession)
current, err := s.challengeStore.Get(ctx, challengeID)
if err != nil {
if errors.Is(err, ports.ErrNotFound) {
return Result{}, shared.ServiceUnavailable(err)
}
return Result{}, shared.ServiceUnavailable(err)
}
if current.Status != challenge.StatusConfirmedPendingExpire || current.Confirmation == nil {
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q changed to unexpected status %q after create", challengeID, current.Status))
}
match, err := s.codeHasher.Compare(current.CodeHash, code)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if !match || current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() {
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q was confirmed by a different payload", challengeID))
}
winningSession, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: winning session %q was not found", current.Confirmation.SessionID))
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := s.publishSession(ctx, winningSession, "confirm_email_code_race_winner"); err != nil {
return Result{}, err
}
return Result{DeviceSessionID: winningSession.ID.String()}, nil
}
func (s *Service) bestEffortRevokeSupersededSession(ctx context.Context, record devicesession.Session) {
revocation := devicesession.Revocation{
At: s.clock.Now().UTC(),
ReasonCode: revokeReasonConfirmRace,
ActorType: revokeActorTypeService,
ActorID: revokeActorIDService,
}
if err := revocation.Validate(); err != nil {
return
}
revokeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: revocation,
})
if err != nil {
s.logger.Warn(
"best-effort superseded session revoke failed",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", record.ID.String()),
zap.String("reason_code", revocation.ReasonCode.String()),
zap.Error(err),
)
return
}
if err := revokeResult.Validate(); err != nil {
s.logger.Warn(
"best-effort superseded session revoke produced invalid result",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", record.ID.String()),
zap.Error(err),
)
return
}
if revokeResult.Outcome == ports.RevokeSessionOutcomeRevoked {
s.telemetry.RecordSessionRevocations(ctx, "confirm_email_code_race_cleanup", revocation.ReasonCode.String(), 1)
}
snapshot, err := shared.ToGatewayProjectionSnapshot(revokeResult.Session)
if err != nil {
s.logger.Warn(
"best-effort superseded session snapshot mapping failed",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", revokeResult.Session.ID.String()),
zap.Error(err),
)
return
}
if err := shared.PublishProjectionSnapshotWithTelemetry(ctx, s.publisher, snapshot, s.telemetry, "confirm_email_code_race_cleanup"); err != nil {
s.logger.Warn(
"best-effort superseded session publish failed",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", revokeResult.Session.ID.String()),
zap.Error(err),
)
}
}
func (s *Service) publishSession(ctx context.Context, record devicesession.Session, operation string) error {
return shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, operation)
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,682 @@
package confirmemailcode
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
)
func TestExecuteConfirmsChallengeForExistingUser(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != devicesession.StatusActive {
require.Failf(t, "test failed", "session status = %q, want %q", record.Status, devicesession.StatusActive)
}
challengeRecord, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if challengeRecord.Status != challenge.StatusConfirmedPendingExpire || challengeRecord.Confirmation == nil {
require.Failf(t, "test failed", "challenge status = %q, confirmation = %+v", challengeRecord.Status, challengeRecord.Confirmation)
}
if len(deps.publisher.PublishedSnapshots()) != 1 {
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots()))
}
}
func TestExecuteConfirmsChallengeByCreatingUser(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.QueueCreatedUserIDs(common.UserID("user-created")); err != nil {
require.Failf(t, "test failed", "QueueCreatedUserIDs() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "new@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.UserID != common.UserID("user-created") {
require.Failf(t, "test failed", "session user id = %q, want %q", record.UserID, common.UserID("user-created"))
}
}
func TestExecuteConfirmsSuppressedChallenge(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = challenge.StatusDeliverySuppressed
record.DeliveryState = challenge.DeliverySuppressed
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
}
func TestExecuteReturnsChallengeNotFound(t *testing.T) {
t.Parallel()
service := mustNewConfirmService(t, newConfirmDeps(t))
_, err := service.Execute(context.Background(), Input{
ChallengeID: "missing",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeChallengeNotFound {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeNotFound)
}
}
func TestExecuteReturnsChallengeExpiredAndMarksExpired(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-2*time.Minute), deps.now.Add(-time.Second))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusExpired {
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusExpired)
}
}
func TestExecuteReturnsChallengeExpiredForConfirmedChallengeAfterRetentionWindow(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
key, err := shared.ParseClientPublicKey(publicKeyString())
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
record := confirmedChallengeFixture(
t,
deps.hasher,
"challenge-1",
"pilot@example.com",
"654321",
"device-session-1",
key,
deps.now.Add(-2*challenge.ConfirmedRetention),
deps.now.Add(-time.Second),
)
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Status != challenge.StatusExpired {
require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusExpired)
}
if updated.Confirmation != nil {
require.Failf(t, "test failed", "Confirmation = %+v, want nil after expiration", updated.Confirmation)
}
}
func TestExecuteReturnsInvalidClientPublicKey(t *testing.T) {
t.Parallel()
service := mustNewConfirmService(t, newConfirmDeps(t))
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: "invalid",
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidClientPublicKey {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidClientPublicKey)
}
}
func TestExecuteInvalidCodeIncrementsAttempts(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "000000",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Attempts.Confirm != 1 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 1", record.Attempts.Confirm)
}
}
func TestExecuteFifthInvalidAttemptMarksChallengeFailed(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Attempts.Confirm = 4
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "000000",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Status != challenge.StatusFailed {
require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusFailed)
}
}
func TestExecuteDoesNotCreateSessionAfterTooManyAttempts(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Attempts.Confirm = challenge.MaxInvalidConfirmAttempts
record.Status = challenge.StatusFailed
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
if got, err := deps.sessionStore.CountActiveByUserID(context.Background(), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "CountActiveByUserID() returned error: %v", err)
} else if got != 0 {
require.Failf(t, "test failed", "CountActiveByUserID() = %d, want 0", got)
}
}
func TestExecuteReturnsSameSessionIDForIdempotentRetryAndRepublishes(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
key, err := shared.ParseClientPublicKey(publicKeyString())
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
if len(deps.publisher.PublishedSnapshots()) != 1 {
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots()))
}
}
func TestExecuteReturnsInvalidCodeForDifferentKeyDuringIdempotentRetry(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
key, err := shared.ParseClientPublicKey(publicKeyString())
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: alternatePublicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Attempts.Confirm != 0 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm)
}
if updated.Confirmation == nil {
require.FailNow(t, "Confirmation = nil, want metadata to stay intact")
}
if updated.Confirmation.SessionID != common.DeviceSessionID("device-session-1") {
require.Failf(t, "test failed", "Confirmation.SessionID = %q, want %q", updated.Confirmation.SessionID, common.DeviceSessionID("device-session-1"))
}
}
func TestExecuteReturnsInvalidCodeForNonConfirmableStates(t *testing.T) {
t.Parallel()
tests := []struct {
name string
status challenge.Status
deliveryState challenge.DeliveryState
}{
{name: "pending send", status: challenge.StatusPendingSend, deliveryState: challenge.DeliveryPending},
{name: "failed", status: challenge.StatusFailed, deliveryState: challenge.DeliveryFailed},
{name: "cancelled", status: challenge.StatusCancelled, deliveryState: challenge.DeliverySent},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = tt.status
record.DeliveryState = tt.deliveryState
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Attempts.Confirm != 0 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm)
}
})
}
}
func TestExecuteMarksChallengeFailedAndReturnsBlockedByPolicy(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeBlockedByPolicy {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeBlockedByPolicy)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusFailed {
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusFailed)
}
}
func TestExecuteReturnsSessionLimitExceededWithoutConsumingChallenge(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-existing", "user-1", mustClientPublicKey(t, publicKeyString()), deps.now.Add(-2*time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
limit := 1
deps.configProvider.Config.ActiveSessionLimit = &limit
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeSessionLimitExceeded {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionLimitExceeded)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusSent {
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusSent)
}
if record.Attempts.Confirm != 0 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", record.Attempts.Confirm)
}
}
func TestExecuteReturnsServiceUnavailableThenSucceedsIdempotentlyAfterPublishFailure(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Err = errors.New("publish failed")
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeServiceUnavailable {
require.Failf(t, "test failed", "first Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeServiceUnavailable)
}
deps.publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "second Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "second Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
}
type confirmDeps struct {
challengeStore *testkit.InMemoryChallengeStore
sessionStore *testkit.InMemorySessionStore
userDirectory *testkit.InMemoryUserDirectory
configProvider testkit.StaticConfigProvider
publisher *testkit.RecordingProjectionPublisher
idGenerator *testkit.SequenceIDGenerator
hasher testkit.DeterministicCodeHasher
now time.Time
}
func newConfirmDeps(t *testing.T) confirmDeps {
t.Helper()
return confirmDeps{
challengeStore: &testkit.InMemoryChallengeStore{},
sessionStore: &testkit.InMemorySessionStore{},
userDirectory: &testkit.InMemoryUserDirectory{},
configProvider: testkit.StaticConfigProvider{},
publisher: &testkit.RecordingProjectionPublisher{},
idGenerator: &testkit.SequenceIDGenerator{
DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"},
},
hasher: testkit.DeterministicCodeHasher{},
now: time.Unix(20, 0).UTC(),
}
}
func mustNewConfirmService(t *testing.T, deps confirmDeps) *Service {
t.Helper()
service, err := New(
deps.challengeStore,
deps.sessionStore,
deps.userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
testkit.FixedClock{Time: deps.now},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
return service
}
func sentChallengeFixture(
t *testing.T,
hasher testkit.DeterministicCodeHasher,
challengeID string,
email string,
code string,
createdAt time.Time,
expiresAt time.Time,
) challenge.Challenge {
t.Helper()
codeHash, err := hasher.Hash(code)
if err != nil {
require.Failf(t, "test failed", "Hash() returned error: %v", err)
}
record := challenge.Challenge{
ID: common.ChallengeID(challengeID),
Email: common.Email(email),
CodeHash: codeHash,
Status: challenge.StatusSent,
DeliveryState: challenge.DeliverySent,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
return record
}
func confirmedChallengeFixture(
t *testing.T,
hasher testkit.DeterministicCodeHasher,
challengeID string,
email string,
code string,
deviceSessionID string,
clientPublicKey common.ClientPublicKey,
createdAt time.Time,
expiresAt time.Time,
) challenge.Challenge {
t.Helper()
record := sentChallengeFixture(t, hasher, challengeID, email, code, createdAt, expiresAt)
record.Status = challenge.StatusConfirmedPendingExpire
record.Confirmation = &challenge.Confirmation{
SessionID: common.DeviceSessionID(deviceSessionID),
ClientPublicKey: clientPublicKey,
ConfirmedAt: createdAt.Add(time.Minute),
}
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
return record
}
func activeSessionFixture(deviceSessionID string, userID string, clientPublicKey common.ClientPublicKey, createdAt time.Time) devicesession.Session {
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: clientPublicKey,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func mustClientPublicKey(t *testing.T, value string) common.ClientPublicKey {
t.Helper()
key, err := shared.ParseClientPublicKey(value)
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
return key
}
func publicKeyString() string {
return "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
}
func alternatePublicKeyString() string {
return "AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE="
}
@@ -0,0 +1,109 @@
package confirmemailcode
import (
"context"
"testing"
"time"
stubuserservice "galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
t.Parallel()
t.Run("creates user through EnsureUserByEmail", func(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.QueueCreatedUserIDs(common.UserID("user-created")))
deps.userDirectory = nil
require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture(
t,
deps.hasher,
"challenge-1",
"pilot@example.com",
"654321",
deps.now.Add(-time.Minute),
deps.now.Add(time.Minute),
)))
service, err := New(
deps.challengeStore,
deps.sessionStore,
userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
fixedClock(deps.now),
)
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
sessionRecord, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, err)
assert.Equal(t, common.UserID("user-created"), sessionRecord.UserID)
})
t.Run("blocked email returns blocked by policy", func(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")))
require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture(
t,
deps.hasher,
"challenge-1",
"pilot@example.com",
"654321",
deps.now.Add(-time.Minute),
deps.now.Add(time.Minute),
)))
service, err := New(
deps.challengeStore,
deps.sessionStore,
userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
fixedClock(deps.now),
)
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err))
record, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusFailed, record.Status)
})
}
type fixedClock time.Time
func (c fixedClock) Now() time.Time {
return time.Time(c)
}
@@ -0,0 +1,104 @@
package confirmemailcode
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
authtelemetry "galaxy/authsession/internal/telemetry"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
)
func TestExecuteRecordsInvalidCodeMetricForThrottledChallenge(t *testing.T) {
t.Parallel()
runtime, reader := newObservedConfirmTelemetryRuntime(t)
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = challenge.StatusDeliveryThrottled
record.DeliveryState = challenge.DeliveryThrottled
require.NoError(t, record.Validate())
require.NoError(t, deps.challengeStore.Create(context.Background(), record))
service, err := NewWithTelemetry(
deps.challengeStore,
deps.sessionStore,
deps.userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
testkit.FixedClock{Time: deps.now},
runtime,
)
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assertConfirmMetricCount(t, reader, map[string]string{"outcome": "invalid_code"}, 1)
}
func newObservedConfirmTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader) {
t.Helper()
reader := sdkmetric.NewManualReader()
provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
runtime, err := authtelemetry.New(provider)
require.NoError(t, err)
return runtime, reader
}
func assertConfirmMetricCount(t *testing.T, reader *sdkmetric.ManualReader, wantAttrs map[string]string, wantValue int64) {
t.Helper()
var resourceMetrics metricdata.ResourceMetrics
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
if metric.Name != "authsession.confirm_email_code.attempts" {
continue
}
sum, ok := metric.Data.(metricdata.Sum[int64])
require.True(t, ok)
for _, point := range sum.DataPoints {
if hasConfirmMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
assert.Equal(t, wantValue, point.Value)
return
}
}
}
}
require.Failf(t, "test failed", "confirm metric with attrs %v not found", wantAttrs)
}
func hasConfirmMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
if len(values) != len(want) {
return false
}
for _, value := range values {
if want[string(value.Key)] != value.Value.AsString() {
return false
}
}
return true
}
@@ -0,0 +1,65 @@
// Package getsession implements the trusted internal read use case for one
// device session.
package getsession
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
)
// Input describes one trusted internal get-session request.
type Input struct {
// DeviceSessionID identifies the session that should be read.
DeviceSessionID string
}
// Result describes one trusted internal get-session response.
type Result struct {
// Session stores the frozen internal read-model DTO.
Session shared.Session
}
// Service executes the trusted internal get-session use case against the
// configured ports.
type Service struct {
sessionStore ports.SessionStore
}
// New returns a get-session service wired to sessionStore.
func New(sessionStore ports.SessionStore) (*Service, error) {
if sessionStore == nil {
return nil, fmt.Errorf("getsession: session store must not be nil")
}
return &Service{sessionStore: sessionStore}, nil
}
// Execute loads one source-of-truth session and projects it into the frozen
// internal read DTO shape.
func (s *Service) Execute(ctx context.Context, input Input) (Result, error) {
deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID)
if err != nil {
return Result{}, err
}
record, err := s.sessionStore.Get(ctx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.SessionNotFound()
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
session, err := shared.ToSession(record)
if err != nil {
return Result{}, shared.InternalError(err)
}
return Result{Session: session}, nil
}
@@ -0,0 +1,68 @@
package getsession
import (
"context"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
)
func TestExecuteReturnsMappedSession(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{DeviceSessionID: " device-session-1 "})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.Session.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().Session.DeviceSessionID = %q, want %q", result.Session.DeviceSessionID, "device-session-1")
}
if result.Session.CreatedAt != time.Unix(10, 0).UTC().Format(time.RFC3339) {
require.Failf(t, "test failed", "Execute().Session.CreatedAt = %q", result.Session.CreatedAt)
}
}
func TestExecuteReturnsSessionNotFound(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{DeviceSessionID: "missing"})
if shared.CodeOf(err) != shared.ErrorCodeSessionNotFound {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionNotFound)
}
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,58 @@
// Package listusersessions implements the trusted internal read use case for
// listing all sessions of one user.
package listusersessions
import (
"context"
"fmt"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
)
// Input describes one trusted internal list-user-sessions request.
type Input struct {
// UserID identifies the owner whose sessions should be listed.
UserID string
}
// Result describes one trusted internal list-user-sessions response.
type Result struct {
// Sessions stores the frozen internal read-model DTO slice.
Sessions []shared.Session
}
// Service executes the trusted internal list-user-sessions use case.
type Service struct {
sessionStore ports.SessionStore
}
// New returns a list-user-sessions service wired to sessionStore.
func New(sessionStore ports.SessionStore) (*Service, error) {
if sessionStore == nil {
return nil, fmt.Errorf("listusersessions: session store must not be nil")
}
return &Service{sessionStore: sessionStore}, nil
}
// Execute loads all source-of-truth sessions for one user and projects them
// into the frozen internal read DTO shape.
func (s *Service) Execute(ctx context.Context, input Input) (Result, error) {
userID, err := shared.ParseUserID(input.UserID)
if err != nil {
return Result{}, err
}
records, err := s.sessionStore.ListByUserID(ctx, userID)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
sessions, err := shared.ToSessions(records)
if err != nil {
return Result{}, shared.InternalError(err)
}
return Result{Sessions: sessions}, nil
}
@@ -0,0 +1,73 @@
package listusersessions
import (
"context"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/testkit"
)
func TestExecutePreservesNewestFirstOrder(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
older := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())
for _, record := range []devicesession.Session{older, newer} {
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
}
service, err := New(store)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{UserID: "user-1"})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if len(result.Sessions) != 2 {
require.Failf(t, "test failed", "Execute().Sessions length = %d, want 2", len(result.Sessions))
}
if result.Sessions[0].DeviceSessionID != "device-session-2" || result.Sessions[1].DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().Sessions order = [%q %q]", result.Sessions[0].DeviceSessionID, result.Sessions[1].DeviceSessionID)
}
}
func TestExecuteReturnsEmptyForUnknownUser(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{UserID: "missing"})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if len(result.Sessions) != 0 {
require.Failf(t, "test failed", "Execute().Sessions length = %d, want 0", len(result.Sessions))
}
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,106 @@
package revokeallusersessions
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRetriesProjectionPublishesForBulkRevoke(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{
errors.New("publish failed"),
nil,
errors.New("publish failed"),
nil,
},
}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())))
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
assert.EqualValues(t, 2, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs)
require.Len(t, publisher.PublishedSnapshots(), 4)
}
func TestExecuteRepublishesCurrentRevokedSessionsOnNoActiveSessionsRetry(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{
nil,
errors.New("publish failed"),
errors.New("publish failed"),
errors.New("publish failed"),
},
}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())))
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, publisher.PublishedSnapshots(), 4)
for _, deviceSessionID := range []common.DeviceSessionID{"device-session-1", "device-session-2"} {
record, getErr := store.Get(context.Background(), deviceSessionID)
require.NoError(t, getErr)
require.NotNil(t, record.Revocation)
assert.Equal(t, devicesession.StatusRevoked, record.Status)
}
publisher.Errors = nil
publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "no_active_sessions", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
published := publisher.PublishedSnapshots()
require.Len(t, published, 6)
assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{
published[4].DeviceSessionID,
published[5].DeviceSessionID,
})
}
@@ -0,0 +1,200 @@
// Package revokeallusersessions implements the trusted internal bulk revoke
// use case for all sessions of one user.
package revokeallusersessions
import (
"context"
"fmt"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
// Input describes one trusted internal revoke-all-user-sessions request.
type Input struct {
// UserID identifies the owner whose sessions should be revoked.
UserID string
// ReasonCode stores the machine-readable revoke reason code.
ReasonCode string
// ActorType stores the machine-readable revoke actor type.
ActorType string
// ActorID stores the optional stable revoke actor identifier.
ActorID string
}
// Result describes the frozen internal bulk revoke acknowledgement.
type Result struct {
// Outcome reports whether active sessions were revoked during the current
// call.
Outcome string
// UserID identifies the user addressed by the operation.
UserID string
// AffectedSessionCount reports how many sessions changed state during the
// current call.
AffectedSessionCount int64
// AffectedDeviceSessionIDs lists every session identifier affected during
// the current call.
AffectedDeviceSessionIDs []string
}
// Service executes the trusted internal revoke-all-user-sessions use case.
type Service struct {
sessionStore ports.SessionStore
userDirectory ports.UserDirectory
publisher ports.GatewaySessionProjectionPublisher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a revoke-all-user-sessions service wired to the required ports.
func New(sessionStore ports.SessionStore, userDirectory ports.UserDirectory, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
return NewWithObservability(sessionStore, userDirectory, publisher, clock, nil, nil)
}
// NewWithObservability returns a revoke-all-user-sessions service wired to the
// required ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
publisher ports.GatewaySessionProjectionPublisher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case sessionStore == nil:
return nil, fmt.Errorf("revokeallusersessions: session store must not be nil")
case userDirectory == nil:
return nil, fmt.Errorf("revokeallusersessions: user directory must not be nil")
case publisher == nil:
return nil, fmt.Errorf("revokeallusersessions: projection publisher must not be nil")
case clock == nil:
return nil, fmt.Errorf("revokeallusersessions: clock must not be nil")
default:
return &Service{
sessionStore: sessionStore,
userDirectory: userDirectory,
publisher: publisher,
clock: clock,
logger: namedLogger(logger, "revoke_all_user_sessions"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute revokes all active sessions of one user and republishes revoked
// gateway projections for every affected session.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "revoke_all_user_sessions"),
}
defer func() {
shared.LogServiceOutcome(s.logger, ctx, "revoke all user sessions completed", err, logFields...)
}()
userID, err := shared.ParseUserID(input.UserID)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("user_id", userID.String()))
revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now())
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String()))
exists, err := s.userDirectory.ExistsByUserID(ctx, userID)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "exists_by_user_id", boolOutcome(exists))
if !exists {
return Result{}, shared.SubjectNotFound()
}
storeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{
UserID: userID,
Revocation: revocation,
})
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := storeResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome)))
affectedDeviceSessionIDs := make([]string, 0, len(storeResult.Sessions))
for _, record := range storeResult.Sessions {
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions"); err != nil {
return Result{}, err
}
affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String())
}
if storeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions {
if err := s.republishCurrentRevokedSessions(ctx, userID); err != nil {
return Result{}, err
}
}
affectedSessionCount := int64(len(storeResult.Sessions))
if affectedSessionCount > 0 {
s.telemetry.RecordSessionRevocations(ctx, "revoke_all_user_sessions", revocation.ReasonCode.String(), affectedSessionCount)
}
logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount))
return Result{
Outcome: string(storeResult.Outcome),
UserID: storeResult.UserID.String(),
AffectedSessionCount: affectedSessionCount,
AffectedDeviceSessionIDs: affectedDeviceSessionIDs,
}, nil
}
func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error {
records, err := s.sessionStore.ListByUserID(ctx, userID)
if err != nil {
return shared.ServiceUnavailable(err)
}
for _, record := range records {
if record.Status != devicesession.StatusRevoked {
continue
}
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions_repair"); err != nil {
return err
}
}
return nil
}
func boolOutcome(value bool) string {
if value {
return "exists"
}
return "missing"
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,162 @@
package revokeallusersessions
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRevokesExistingUserSessionsAndPublishes(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
for _, record := range []devicesession.Session{
activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()),
activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()),
} {
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
}
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
assert.EqualValues(t, 2, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs)
for _, deviceSessionID := range result.AffectedDeviceSessionIDs {
stored, getErr := store.Get(context.Background(), common.DeviceSessionID(deviceSessionID))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
assert.Empty(t, stored.Revocation.ActorID)
assert.Equal(t, time.Unix(30, 0).UTC(), stored.Revocation.At)
}
published := publisher.PublishedSnapshots()
require.Len(t, published, 2)
assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{
published[0].DeviceSessionID,
published[1].DeviceSessionID,
})
for _, snapshot := range published {
assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, snapshot.RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("system"), snapshot.RevokeActorType)
require.NotNil(t, snapshot.RevokedAt)
assert.Equal(t, time.Unix(30, 0).UTC(), *snapshot.RevokedAt)
}
}
func TestExecuteReturnsNoActiveSessionsForExistingUserWithoutActiveSessions(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "no_active_sessions", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, publisher.PublishedSnapshots())
}
func TestExecuteReturnsSubjectNotFoundForUnknownUser(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{}, &testkit.InMemoryUserDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{
UserID: "missing",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
}
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,53 @@
package revokeallusersessions
import (
"context"
"testing"
"time"
stubuserservice "galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
t.Parallel()
t.Run("existing user uses ExistsByUserID and returns no active sessions", func(t *testing.T) {
t.Parallel()
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
service, err := New(&testkit.InMemorySessionStore{}, userDirectory, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "no_active_sessions", result.Outcome)
assert.Zero(t, result.AffectedSessionCount)
})
t.Run("unknown user returns subject not found", func(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{}, &stubuserservice.StubDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "missing",
ReasonCode: "logout_all",
ActorType: "system",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
})
}
@@ -0,0 +1,75 @@
package revokedevicesession
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRetriesProjectionPublishUntilSuccess(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{errors.New("publish failed"), nil},
}
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
require.Len(t, publisher.PublishedSnapshots(), 2)
}
func TestExecuteRepairsProjectionOnRepeatedAlreadyRevokedRequest(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "already_revoked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
}
@@ -0,0 +1,151 @@
// Package revokedevicesession implements the trusted internal single-session
// revoke use case.
package revokedevicesession
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
// Input describes one trusted internal revoke-device-session request.
type Input struct {
// DeviceSessionID identifies the session that should be revoked.
DeviceSessionID string
// ReasonCode stores the machine-readable revoke reason code.
ReasonCode string
// ActorType stores the machine-readable revoke actor type.
ActorType string
// ActorID stores the optional stable revoke actor identifier.
ActorID string
}
// Result describes the frozen internal revoke-device-session acknowledgement.
type Result struct {
// Outcome reports whether the current call revoked the session or found it
// already revoked.
Outcome string
// DeviceSessionID identifies the session addressed by the operation.
DeviceSessionID string
// AffectedSessionCount reports how many sessions changed state during the
// current call.
AffectedSessionCount int64
}
// Service executes the trusted internal revoke-device-session use case.
type Service struct {
sessionStore ports.SessionStore
publisher ports.GatewaySessionProjectionPublisher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a revoke-device-session service wired to the required ports.
func New(sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
return NewWithObservability(sessionStore, publisher, clock, nil, nil)
}
// NewWithObservability returns a revoke-device-session service wired to the
// required ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
sessionStore ports.SessionStore,
publisher ports.GatewaySessionProjectionPublisher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case sessionStore == nil:
return nil, fmt.Errorf("revokedevicesession: session store must not be nil")
case publisher == nil:
return nil, fmt.Errorf("revokedevicesession: projection publisher must not be nil")
case clock == nil:
return nil, fmt.Errorf("revokedevicesession: clock must not be nil")
default:
return &Service{
sessionStore: sessionStore,
publisher: publisher,
clock: clock,
logger: namedLogger(logger, "revoke_device_session"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute revokes one device session and republishes the current gateway
// projection for the resulting source-of-truth session state.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "revoke_device_session"),
}
defer func() {
shared.LogServiceOutcome(s.logger, ctx, "revoke device session completed", err, logFields...)
}()
deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("device_session_id", deviceSessionID.String()))
revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now())
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String()))
storeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{
DeviceSessionID: deviceSessionID,
Revocation: revocation,
})
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.SessionNotFound()
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := storeResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome)))
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, storeResult.Session, s.telemetry, "revoke_device_session"); err != nil {
return Result{}, err
}
affectedSessionCount := int64(0)
if storeResult.Outcome == ports.RevokeSessionOutcomeRevoked {
affectedSessionCount = 1
s.telemetry.RecordSessionRevocations(ctx, "revoke_device_session", revocation.ReasonCode.String(), affectedSessionCount)
}
logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount))
return Result{
Outcome: string(storeResult.Outcome),
DeviceSessionID: storeResult.Session.ID.String(),
AffectedSessionCount: affectedSessionCount,
}, nil
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,166 @@
package revokedevicesession
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRevokesActiveSessionAndPublishes(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, err)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
assert.Empty(t, stored.Revocation.ActorID)
assert.Equal(t, time.Unix(20, 0).UTC(), stored.Revocation.At)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
assert.Equal(t, common.DeviceSessionID("device-session-1"), published[0].DeviceSessionID)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType)
require.NotNil(t, published[0].RevokedAt)
assert.Equal(t, time.Unix(20, 0).UTC(), published[0].RevokedAt.UTC())
}
func TestExecuteAlreadyRevokedReturnsZeroAffectedAndRepublishes(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "already_revoked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, err)
require.NotNil(t, stored.Revocation)
assert.Equal(t, *record.Revocation, *stored.Revocation)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType)
require.NotNil(t, published[0].RevokedAt)
assert.Equal(t, record.Revocation.At, *published[0].RevokedAt)
}
func TestExecuteReturnsSessionNotFound(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{
DeviceSessionID: "missing",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeSessionNotFound, shared.CodeOf(err))
}
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
record := activeSessionFixture(deviceSessionID, userID, createdAt)
record.Status = devicesession.StatusRevoked
record.Revocation = &devicesession.Revocation{
At: createdAt.Add(time.Minute),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
}
return record
}

Some files were not shown because too many files have changed in this diff Show More